Skip to main content

mnemo_pgwire/
server.rs

1//! PostgreSQL wire protocol connection handler.
2//!
3//! Implements the subset of the PostgreSQL wire protocol needed for
4//! simple query execution. Handles startup, authentication (trust mode),
5//! and the simple query flow.
6//!
7//! Reference: <https://www.postgresql.org/docs/current/protocol.html>
8
9use std::sync::Arc;
10
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13
14use mnemo_core::query::MnemoEngine;
15
16use crate::PgWireConfig;
17use crate::parser::{self, ParsedStatement};
18
19/// Handle a single PostgreSQL wire protocol connection.
20pub async fn handle_connection(
21    mut stream: TcpStream,
22    engine: Arc<MnemoEngine>,
23    config: &PgWireConfig,
24) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
25    // Phase 1: Startup message
26    let startup_len_raw = stream.read_i32().await?;
27    let startup_len = usize::try_from(startup_len_raw)
28        .map_err(|_| format!("negative startup message length: {startup_len_raw}"))?;
29    if !(8..=10240).contains(&startup_len) {
30        return Err("invalid startup message length".into());
31    }
32
33    let mut startup_buf = vec![0u8; startup_len - 4];
34    stream.read_exact(&mut startup_buf).await?;
35
36    let protocol_version = i32::from_be_bytes([
37        startup_buf[0],
38        startup_buf[1],
39        startup_buf[2],
40        startup_buf[3],
41    ]);
42
43    // SSL request (80877103) — respond with 'N' (no SSL)
44    if protocol_version == 80877103 {
45        stream.write_all(b"N").await?;
46        // Client will retry with normal startup
47        let startup_len_raw = stream.read_i32().await?;
48        let startup_len = usize::try_from(startup_len_raw)
49            .map_err(|_| format!("negative startup message length: {startup_len_raw}"))?;
50        if !(8..=10240).contains(&startup_len) {
51            return Err("invalid startup message length after SSL".into());
52        }
53        startup_buf = vec![0u8; startup_len - 4];
54        stream.read_exact(&mut startup_buf).await?;
55    }
56
57    // Phase 2: Authentication
58    if let Some(ref expected_password) = config.password {
59        // Send AuthenticationCleartextPassword (type 3)
60        stream.write_all(&[b'R', 0, 0, 0, 8, 0, 0, 0, 3]).await?;
61        stream.flush().await?;
62
63        // Read password message (type 'p')
64        let pw_type = stream.read_u8().await?;
65        if pw_type != b'p' {
66            send_error(&mut stream, "expected password message").await?;
67            return Err("expected password message".into());
68        }
69        let pw_len_raw = stream.read_i32().await?;
70        let pw_len = usize::try_from(pw_len_raw)
71            .map_err(|_| format!("negative password message length: {pw_len_raw}"))?;
72        if !(5..=10240).contains(&pw_len) {
73            return Err("invalid password message length".into());
74        }
75        let mut pw_buf = vec![0u8; pw_len - 4];
76        stream.read_exact(&mut pw_buf).await?;
77        let client_password = String::from_utf8_lossy(&pw_buf)
78            .trim_end_matches('\0')
79            .to_string();
80
81        if client_password != *expected_password {
82            send_error(&mut stream, "password authentication failed").await?;
83            return Err("authentication failed".into());
84        }
85    }
86
87    // Send AuthenticationOk
88    stream.write_all(&[b'R', 0, 0, 0, 8, 0, 0, 0, 0]).await?;
89
90    // Send ParameterStatus messages
91    send_parameter_status(&mut stream, "server_version", "16.0").await?;
92    send_parameter_status(&mut stream, "server_encoding", "UTF8").await?;
93    send_parameter_status(&mut stream, "client_encoding", "UTF8").await?;
94    send_parameter_status(&mut stream, "application_name", "mnemo-pgwire").await?;
95
96    // Send ReadyForQuery
97    send_ready_for_query(&mut stream).await?;
98
99    // Phase 3: Query loop
100    while let Ok(msg_type) = stream.read_u8().await {
101        let msg_len_raw = stream.read_i32().await?;
102        let msg_len = usize::try_from(msg_len_raw)
103            .map_err(|_| format!("negative message length: {msg_len_raw}"))?;
104        if !(4..=1_048_576).contains(&msg_len) {
105            break;
106        }
107
108        let mut msg_buf = vec![0u8; msg_len - 4];
109        if !msg_buf.is_empty() {
110            stream.read_exact(&mut msg_buf).await?;
111        }
112
113        match msg_type {
114            b'Q' => {
115                // Simple Query
116                let sql = String::from_utf8_lossy(&msg_buf)
117                    .trim_end_matches('\0')
118                    .to_string();
119
120                tracing::debug!("pgwire query: {sql}");
121
122                match handle_query(&sql, &engine, config).await {
123                    Ok(response) => {
124                        send_query_response(&mut stream, &response).await?;
125                    }
126                    Err(e) => {
127                        send_error(&mut stream, &e.to_string()).await?;
128                    }
129                }
130
131                send_ready_for_query(&mut stream).await?;
132            }
133            b'X' => {
134                // Terminate
135                tracing::debug!("pgwire client terminated");
136                break;
137            }
138            _ => {
139                // Unsupported message type — send error and continue
140                send_error(
141                    &mut stream,
142                    &format!("unsupported message type: {}", msg_type as char),
143                )
144                .await?;
145                send_ready_for_query(&mut stream).await?;
146            }
147        }
148    }
149
150    Ok(())
151}
152
153/// Query response rows.
154struct QueryResponse {
155    columns: Vec<String>,
156    rows: Vec<Vec<String>>,
157    command_tag: String,
158}
159
160async fn handle_query(
161    sql: &str,
162    engine: &MnemoEngine,
163    config: &PgWireConfig,
164) -> Result<QueryResponse, Box<dyn std::error::Error + Send + Sync>> {
165    let stmt = parser::parse_sql(sql);
166
167    match stmt {
168        ParsedStatement::Select(q) => {
169            let agent_id = q
170                .agent_id
171                .unwrap_or_else(|| config.default_agent_id.clone());
172
173            let request = mnemo_core::query::recall::RecallRequest {
174                agent_id: Some(agent_id),
175                query: q.query_text.unwrap_or_default(),
176                limit: Some(q.limit),
177                memory_type: None,
178                memory_types: None,
179                scope: None,
180                strategy: Some("exact".to_string()),
181                min_importance: None,
182                tags: None,
183                org_id: None,
184                temporal_range: None,
185                recency_half_life_hours: None,
186                hybrid_weights: None,
187                rrf_k: None,
188                as_of: None,
189                explain: None,
190                with_provenance: None,
191                mode: None,
192            };
193
194            let response = engine.recall(request).await?;
195
196            let columns = vec![
197                "id".to_string(),
198                "agent_id".to_string(),
199                "content".to_string(),
200                "memory_type".to_string(),
201                "importance".to_string(),
202                "created_at".to_string(),
203            ];
204
205            let rows: Vec<Vec<String>> = response
206                .memories
207                .iter()
208                .skip(q.offset)
209                .map(|m| {
210                    vec![
211                        m.id.to_string(),
212                        m.agent_id.clone(),
213                        m.content.clone(),
214                        m.memory_type.to_string(),
215                        m.importance.to_string(),
216                        m.created_at.clone(),
217                    ]
218                })
219                .collect();
220
221            let count = rows.len();
222            Ok(QueryResponse {
223                columns,
224                rows,
225                command_tag: format!("SELECT {count}"),
226            })
227        }
228
229        ParsedStatement::Insert(q) => {
230            let agent_id = q
231                .agent_id
232                .unwrap_or_else(|| config.default_agent_id.clone());
233
234            let request = mnemo_core::query::remember::RememberRequest {
235                content: q.content,
236                agent_id: Some(agent_id),
237                memory_type: q.memory_type.as_deref().and_then(parse_memory_type),
238                scope: None,
239                importance: q.importance,
240                tags: if q.tags.is_empty() {
241                    None
242                } else {
243                    Some(q.tags)
244                },
245                metadata: None,
246                source_type: None,
247                source_id: None,
248                org_id: None,
249                thread_id: None,
250                ttl_seconds: None,
251                related_to: None,
252                decay_rate: None,
253                created_by: None,
254            };
255
256            let response = engine.remember(request).await?;
257
258            Ok(QueryResponse {
259                columns: vec!["id".to_string(), "content_hash".to_string()],
260                rows: vec![vec![response.id.to_string(), response.content_hash.clone()]],
261                command_tag: "INSERT 0 1".to_string(),
262            })
263        }
264
265        ParsedStatement::Delete(q) => {
266            if let Some(memory_id_str) = q.memory_id {
267                let memory_id: uuid::Uuid = memory_id_str
268                    .parse()
269                    .map_err(|e| format!("invalid UUID in DELETE WHERE id = '...': {e}"))?;
270
271                let agent_id = q
272                    .agent_id
273                    .unwrap_or_else(|| config.default_agent_id.clone());
274
275                let request = mnemo_core::query::forget::ForgetRequest {
276                    memory_ids: vec![memory_id],
277                    agent_id: Some(agent_id),
278                    strategy: Some(mnemo_core::query::forget::ForgetStrategy::SoftDelete),
279                    criteria: None,
280                };
281
282                let response = engine.forget(request).await?;
283                let count = response.forgotten.len();
284
285                Ok(QueryResponse {
286                    columns: vec![],
287                    rows: vec![],
288                    command_tag: format!("DELETE {count}"),
289                })
290            } else {
291                Err("DELETE requires WHERE id = '...' clause".into())
292            }
293        }
294
295        ParsedStatement::Unsupported(s) => Err(format!("unsupported SQL: {s}").into()),
296    }
297}
298
299fn parse_memory_type(s: &str) -> Option<mnemo_core::model::memory::MemoryType> {
300    match s.to_lowercase().as_str() {
301        "episodic" => Some(mnemo_core::model::memory::MemoryType::Episodic),
302        "semantic" => Some(mnemo_core::model::memory::MemoryType::Semantic),
303        "procedural" => Some(mnemo_core::model::memory::MemoryType::Procedural),
304        "working" => Some(mnemo_core::model::memory::MemoryType::Working),
305        _ => None,
306    }
307}
308
309async fn send_parameter_status(
310    stream: &mut TcpStream,
311    name: &str,
312    value: &str,
313) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
314    let mut buf = Vec::new();
315    buf.push(b'S'); // ParameterStatus type
316
317    let name_bytes = name.as_bytes();
318    let value_bytes = value.as_bytes();
319    let len = 4 + name_bytes.len() + 1 + value_bytes.len() + 1;
320    buf.extend_from_slice(&(len as i32).to_be_bytes());
321    buf.extend_from_slice(name_bytes);
322    buf.push(0);
323    buf.extend_from_slice(value_bytes);
324    buf.push(0);
325
326    stream.write_all(&buf).await?;
327    Ok(())
328}
329
330async fn send_ready_for_query(
331    stream: &mut TcpStream,
332) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
333    // ReadyForQuery: type 'Z', length 5, transaction status 'I' (idle)
334    stream.write_all(&[b'Z', 0, 0, 0, 5, b'I']).await?;
335    Ok(())
336}
337
338async fn send_error(
339    stream: &mut TcpStream,
340    message: &str,
341) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
342    let mut buf = Vec::new();
343    buf.push(b'E'); // ErrorResponse type
344
345    let mut fields = Vec::new();
346    // Severity
347    fields.push(b'S');
348    fields.extend_from_slice(b"ERROR\0");
349    // SQLSTATE (42000 = syntax error)
350    fields.push(b'C');
351    fields.extend_from_slice(b"42000\0");
352    // Message
353    fields.push(b'M');
354    fields.extend_from_slice(message.as_bytes());
355    fields.push(0);
356    // Terminator
357    fields.push(0);
358
359    let len = 4 + fields.len();
360    buf.extend_from_slice(&(len as i32).to_be_bytes());
361    buf.extend_from_slice(&fields);
362
363    stream.write_all(&buf).await?;
364    Ok(())
365}
366
367async fn send_query_response(
368    stream: &mut TcpStream,
369    response: &QueryResponse,
370) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
371    if !response.columns.is_empty() {
372        // RowDescription
373        let mut desc_buf = Vec::new();
374        desc_buf.extend_from_slice(&(response.columns.len() as i16).to_be_bytes());
375
376        for col in &response.columns {
377            desc_buf.extend_from_slice(col.as_bytes());
378            desc_buf.push(0); // null terminator
379            desc_buf.extend_from_slice(&0i32.to_be_bytes()); // table OID
380            desc_buf.extend_from_slice(&0i16.to_be_bytes()); // column attr number
381            desc_buf.extend_from_slice(&25i32.to_be_bytes()); // type OID (text = 25)
382            desc_buf.extend_from_slice(&(-1i16).to_be_bytes()); // type size (-1 = variable)
383            desc_buf.extend_from_slice(&(-1i32).to_be_bytes()); // type modifier
384            desc_buf.extend_from_slice(&0i16.to_be_bytes()); // format code (text = 0)
385        }
386
387        let mut msg = Vec::new();
388        msg.push(b'T'); // RowDescription type
389        let len = 4 + desc_buf.len();
390        msg.extend_from_slice(&(len as i32).to_be_bytes());
391        msg.extend_from_slice(&desc_buf);
392        stream.write_all(&msg).await?;
393
394        // DataRow for each row
395        for row in &response.rows {
396            let mut row_buf = Vec::new();
397            row_buf.extend_from_slice(&(row.len() as i16).to_be_bytes());
398
399            for val in row {
400                let bytes = val.as_bytes();
401                row_buf.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
402                row_buf.extend_from_slice(bytes);
403            }
404
405            let mut msg = Vec::new();
406            msg.push(b'D'); // DataRow type
407            let len = 4 + row_buf.len();
408            msg.extend_from_slice(&(len as i32).to_be_bytes());
409            msg.extend_from_slice(&row_buf);
410            stream.write_all(&msg).await?;
411        }
412    }
413
414    // CommandComplete
415    let tag = response.command_tag.as_bytes();
416    let mut msg = Vec::new();
417    msg.push(b'C'); // CommandComplete type
418    let len = 4 + tag.len() + 1;
419    msg.extend_from_slice(&(len as i32).to_be_bytes());
420    msg.extend_from_slice(tag);
421    msg.push(0);
422    stream.write_all(&msg).await?;
423
424    Ok(())
425}