Skip to main content

mnemo_pgwire/
server.rs

1//! PostgreSQL wire protocol connection handler.
2//!
3//! Implements the subset of the PostgreSQL wire protocol needed for
4//! simple query execution. Handles startup, authentication (trust mode),
5//! and the simple query flow.
6//!
7//! Reference: <https://www.postgresql.org/docs/current/protocol.html>
8
9use std::sync::Arc;
10
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13
14use mnemo_core::query::MnemoEngine;
15
16use crate::PgWireConfig;
17use crate::parser::{self, ParsedStatement};
18
19/// Handle a single PostgreSQL wire protocol connection.
20pub async fn handle_connection(
21    mut stream: TcpStream,
22    engine: Arc<MnemoEngine>,
23    config: &PgWireConfig,
24) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
25    // Phase 1: Startup message
26    let startup_len_raw = stream.read_i32().await?;
27    let startup_len = usize::try_from(startup_len_raw)
28        .map_err(|_| format!("negative startup message length: {startup_len_raw}"))?;
29    if !(8..=10240).contains(&startup_len) {
30        return Err("invalid startup message length".into());
31    }
32
33    let mut startup_buf = vec![0u8; startup_len - 4];
34    stream.read_exact(&mut startup_buf).await?;
35
36    let protocol_version = i32::from_be_bytes([
37        startup_buf[0],
38        startup_buf[1],
39        startup_buf[2],
40        startup_buf[3],
41    ]);
42
43    // SSL request (80877103) — respond with 'N' (no SSL)
44    if protocol_version == 80877103 {
45        stream.write_all(b"N").await?;
46        // Client will retry with normal startup
47        let startup_len_raw = stream.read_i32().await?;
48        let startup_len = usize::try_from(startup_len_raw)
49            .map_err(|_| format!("negative startup message length: {startup_len_raw}"))?;
50        if !(8..=10240).contains(&startup_len) {
51            return Err("invalid startup message length after SSL".into());
52        }
53        startup_buf = vec![0u8; startup_len - 4];
54        stream.read_exact(&mut startup_buf).await?;
55    }
56
57    // Phase 2: Authentication
58    if let Some(ref expected_password) = config.password {
59        // Send AuthenticationCleartextPassword (type 3)
60        stream.write_all(&[b'R', 0, 0, 0, 8, 0, 0, 0, 3]).await?;
61        stream.flush().await?;
62
63        // Read password message (type 'p')
64        let pw_type = stream.read_u8().await?;
65        if pw_type != b'p' {
66            send_error(&mut stream, "expected password message").await?;
67            return Err("expected password message".into());
68        }
69        let pw_len_raw = stream.read_i32().await?;
70        let pw_len = usize::try_from(pw_len_raw)
71            .map_err(|_| format!("negative password message length: {pw_len_raw}"))?;
72        if !(5..=10240).contains(&pw_len) {
73            return Err("invalid password message length".into());
74        }
75        let mut pw_buf = vec![0u8; pw_len - 4];
76        stream.read_exact(&mut pw_buf).await?;
77        let client_password = String::from_utf8_lossy(&pw_buf)
78            .trim_end_matches('\0')
79            .to_string();
80
81        if client_password != *expected_password {
82            send_error(&mut stream, "password authentication failed").await?;
83            return Err("authentication failed".into());
84        }
85    }
86
87    // Send AuthenticationOk
88    stream.write_all(&[b'R', 0, 0, 0, 8, 0, 0, 0, 0]).await?;
89
90    // Send ParameterStatus messages
91    send_parameter_status(&mut stream, "server_version", "16.0").await?;
92    send_parameter_status(&mut stream, "server_encoding", "UTF8").await?;
93    send_parameter_status(&mut stream, "client_encoding", "UTF8").await?;
94    send_parameter_status(&mut stream, "application_name", "mnemo-pgwire").await?;
95
96    // Send ReadyForQuery
97    send_ready_for_query(&mut stream).await?;
98
99    // Phase 3: Query loop
100    while let Ok(msg_type) = stream.read_u8().await {
101        let msg_len_raw = stream.read_i32().await?;
102        let msg_len = usize::try_from(msg_len_raw)
103            .map_err(|_| format!("negative message length: {msg_len_raw}"))?;
104        if !(4..=1_048_576).contains(&msg_len) {
105            break;
106        }
107
108        let mut msg_buf = vec![0u8; msg_len - 4];
109        if !msg_buf.is_empty() {
110            stream.read_exact(&mut msg_buf).await?;
111        }
112
113        match msg_type {
114            b'Q' => {
115                // Simple Query
116                let sql = String::from_utf8_lossy(&msg_buf)
117                    .trim_end_matches('\0')
118                    .to_string();
119
120                tracing::debug!("pgwire query: {sql}");
121
122                match handle_query(&sql, &engine, config).await {
123                    Ok(response) => {
124                        send_query_response(&mut stream, &response).await?;
125                    }
126                    Err(e) => {
127                        send_error(&mut stream, &e.to_string()).await?;
128                    }
129                }
130
131                send_ready_for_query(&mut stream).await?;
132            }
133            b'X' => {
134                // Terminate
135                tracing::debug!("pgwire client terminated");
136                break;
137            }
138            _ => {
139                // Unsupported message type — send error and continue
140                send_error(
141                    &mut stream,
142                    &format!("unsupported message type: {}", msg_type as char),
143                )
144                .await?;
145                send_ready_for_query(&mut stream).await?;
146            }
147        }
148    }
149
150    Ok(())
151}
152
153/// Query response rows.
154struct QueryResponse {
155    columns: Vec<String>,
156    rows: Vec<Vec<String>>,
157    command_tag: String,
158}
159
160async fn handle_query(
161    sql: &str,
162    engine: &MnemoEngine,
163    config: &PgWireConfig,
164) -> Result<QueryResponse, Box<dyn std::error::Error + Send + Sync>> {
165    let stmt = parser::parse_sql(sql);
166
167    match stmt {
168        ParsedStatement::Select(q) => {
169            let agent_id = q
170                .agent_id
171                .unwrap_or_else(|| config.default_agent_id.clone());
172
173            let request = mnemo_core::query::recall::RecallRequest {
174                agent_id: Some(agent_id),
175                query: q.query_text.unwrap_or_default(),
176                limit: Some(q.limit),
177                memory_type: None,
178                memory_types: None,
179                scope: None,
180                strategy: Some("exact".to_string()),
181                min_importance: None,
182                tags: None,
183                org_id: None,
184                temporal_range: None,
185                recency_half_life_hours: None,
186                hybrid_weights: None,
187                rrf_k: None,
188                as_of: None,
189                explain: None,
190            };
191
192            let response = engine.recall(request).await?;
193
194            let columns = vec![
195                "id".to_string(),
196                "agent_id".to_string(),
197                "content".to_string(),
198                "memory_type".to_string(),
199                "importance".to_string(),
200                "created_at".to_string(),
201            ];
202
203            let rows: Vec<Vec<String>> = response
204                .memories
205                .iter()
206                .skip(q.offset)
207                .map(|m| {
208                    vec![
209                        m.id.to_string(),
210                        m.agent_id.clone(),
211                        m.content.clone(),
212                        m.memory_type.to_string(),
213                        m.importance.to_string(),
214                        m.created_at.clone(),
215                    ]
216                })
217                .collect();
218
219            let count = rows.len();
220            Ok(QueryResponse {
221                columns,
222                rows,
223                command_tag: format!("SELECT {count}"),
224            })
225        }
226
227        ParsedStatement::Insert(q) => {
228            let agent_id = q
229                .agent_id
230                .unwrap_or_else(|| config.default_agent_id.clone());
231
232            let request = mnemo_core::query::remember::RememberRequest {
233                content: q.content,
234                agent_id: Some(agent_id),
235                memory_type: q.memory_type.as_deref().and_then(parse_memory_type),
236                scope: None,
237                importance: q.importance,
238                tags: if q.tags.is_empty() {
239                    None
240                } else {
241                    Some(q.tags)
242                },
243                metadata: None,
244                source_type: None,
245                source_id: None,
246                org_id: None,
247                thread_id: None,
248                ttl_seconds: None,
249                related_to: None,
250                decay_rate: None,
251                created_by: None,
252            };
253
254            let response = engine.remember(request).await?;
255
256            Ok(QueryResponse {
257                columns: vec!["id".to_string(), "content_hash".to_string()],
258                rows: vec![vec![response.id.to_string(), response.content_hash.clone()]],
259                command_tag: "INSERT 0 1".to_string(),
260            })
261        }
262
263        ParsedStatement::Delete(q) => {
264            if let Some(memory_id_str) = q.memory_id {
265                let memory_id: uuid::Uuid = memory_id_str
266                    .parse()
267                    .map_err(|e| format!("invalid UUID in DELETE WHERE id = '...': {e}"))?;
268
269                let agent_id = q
270                    .agent_id
271                    .unwrap_or_else(|| config.default_agent_id.clone());
272
273                let request = mnemo_core::query::forget::ForgetRequest {
274                    memory_ids: vec![memory_id],
275                    agent_id: Some(agent_id),
276                    strategy: Some(mnemo_core::query::forget::ForgetStrategy::SoftDelete),
277                    criteria: None,
278                };
279
280                let response = engine.forget(request).await?;
281                let count = response.forgotten.len();
282
283                Ok(QueryResponse {
284                    columns: vec![],
285                    rows: vec![],
286                    command_tag: format!("DELETE {count}"),
287                })
288            } else {
289                Err("DELETE requires WHERE id = '...' clause".into())
290            }
291        }
292
293        ParsedStatement::Unsupported(s) => Err(format!("unsupported SQL: {s}").into()),
294    }
295}
296
297fn parse_memory_type(s: &str) -> Option<mnemo_core::model::memory::MemoryType> {
298    match s.to_lowercase().as_str() {
299        "episodic" => Some(mnemo_core::model::memory::MemoryType::Episodic),
300        "semantic" => Some(mnemo_core::model::memory::MemoryType::Semantic),
301        "procedural" => Some(mnemo_core::model::memory::MemoryType::Procedural),
302        "working" => Some(mnemo_core::model::memory::MemoryType::Working),
303        _ => None,
304    }
305}
306
307async fn send_parameter_status(
308    stream: &mut TcpStream,
309    name: &str,
310    value: &str,
311) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
312    let mut buf = Vec::new();
313    buf.push(b'S'); // ParameterStatus type
314
315    let name_bytes = name.as_bytes();
316    let value_bytes = value.as_bytes();
317    let len = 4 + name_bytes.len() + 1 + value_bytes.len() + 1;
318    buf.extend_from_slice(&(len as i32).to_be_bytes());
319    buf.extend_from_slice(name_bytes);
320    buf.push(0);
321    buf.extend_from_slice(value_bytes);
322    buf.push(0);
323
324    stream.write_all(&buf).await?;
325    Ok(())
326}
327
328async fn send_ready_for_query(
329    stream: &mut TcpStream,
330) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
331    // ReadyForQuery: type 'Z', length 5, transaction status 'I' (idle)
332    stream.write_all(&[b'Z', 0, 0, 0, 5, b'I']).await?;
333    Ok(())
334}
335
336async fn send_error(
337    stream: &mut TcpStream,
338    message: &str,
339) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
340    let mut buf = Vec::new();
341    buf.push(b'E'); // ErrorResponse type
342
343    let mut fields = Vec::new();
344    // Severity
345    fields.push(b'S');
346    fields.extend_from_slice(b"ERROR\0");
347    // SQLSTATE (42000 = syntax error)
348    fields.push(b'C');
349    fields.extend_from_slice(b"42000\0");
350    // Message
351    fields.push(b'M');
352    fields.extend_from_slice(message.as_bytes());
353    fields.push(0);
354    // Terminator
355    fields.push(0);
356
357    let len = 4 + fields.len();
358    buf.extend_from_slice(&(len as i32).to_be_bytes());
359    buf.extend_from_slice(&fields);
360
361    stream.write_all(&buf).await?;
362    Ok(())
363}
364
365async fn send_query_response(
366    stream: &mut TcpStream,
367    response: &QueryResponse,
368) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
369    if !response.columns.is_empty() {
370        // RowDescription
371        let mut desc_buf = Vec::new();
372        desc_buf.extend_from_slice(&(response.columns.len() as i16).to_be_bytes());
373
374        for col in &response.columns {
375            desc_buf.extend_from_slice(col.as_bytes());
376            desc_buf.push(0); // null terminator
377            desc_buf.extend_from_slice(&0i32.to_be_bytes()); // table OID
378            desc_buf.extend_from_slice(&0i16.to_be_bytes()); // column attr number
379            desc_buf.extend_from_slice(&25i32.to_be_bytes()); // type OID (text = 25)
380            desc_buf.extend_from_slice(&(-1i16).to_be_bytes()); // type size (-1 = variable)
381            desc_buf.extend_from_slice(&(-1i32).to_be_bytes()); // type modifier
382            desc_buf.extend_from_slice(&0i16.to_be_bytes()); // format code (text = 0)
383        }
384
385        let mut msg = Vec::new();
386        msg.push(b'T'); // RowDescription type
387        let len = 4 + desc_buf.len();
388        msg.extend_from_slice(&(len as i32).to_be_bytes());
389        msg.extend_from_slice(&desc_buf);
390        stream.write_all(&msg).await?;
391
392        // DataRow for each row
393        for row in &response.rows {
394            let mut row_buf = Vec::new();
395            row_buf.extend_from_slice(&(row.len() as i16).to_be_bytes());
396
397            for val in row {
398                let bytes = val.as_bytes();
399                row_buf.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
400                row_buf.extend_from_slice(bytes);
401            }
402
403            let mut msg = Vec::new();
404            msg.push(b'D'); // DataRow type
405            let len = 4 + row_buf.len();
406            msg.extend_from_slice(&(len as i32).to_be_bytes());
407            msg.extend_from_slice(&row_buf);
408            stream.write_all(&msg).await?;
409        }
410    }
411
412    // CommandComplete
413    let tag = response.command_tag.as_bytes();
414    let mut msg = Vec::new();
415    msg.push(b'C'); // CommandComplete type
416    let len = 4 + tag.len() + 1;
417    msg.extend_from_slice(&(len as i32).to_be_bytes());
418    msg.extend_from_slice(tag);
419    msg.push(0);
420    stream.write_all(&msg).await?;
421
422    Ok(())
423}