Skip to main content

mnemo_pgwire/
server.rs

1//! PostgreSQL wire protocol connection handler.
2//!
3//! Implements the subset of the PostgreSQL wire protocol needed for
4//! simple query execution. Handles startup, authentication (trust mode),
5//! and the simple query flow.
6//!
7//! Reference: <https://www.postgresql.org/docs/current/protocol.html>
8
9use std::sync::Arc;
10
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13
14use mnemo_core::query::MnemoEngine;
15
16use crate::PgWireConfig;
17use crate::parser::{self, ParsedStatement};
18
19/// Handle a single PostgreSQL wire protocol connection.
20pub async fn handle_connection(
21    mut stream: TcpStream,
22    engine: Arc<MnemoEngine>,
23    config: &PgWireConfig,
24) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
25    // Phase 1: Startup message
26    let startup_len_raw = stream.read_i32().await?;
27    let startup_len = usize::try_from(startup_len_raw)
28        .map_err(|_| format!("negative startup message length: {startup_len_raw}"))?;
29    if !(8..=10240).contains(&startup_len) {
30        return Err("invalid startup message length".into());
31    }
32
33    let mut startup_buf = vec![0u8; startup_len - 4];
34    stream.read_exact(&mut startup_buf).await?;
35
36    let protocol_version = i32::from_be_bytes([
37        startup_buf[0],
38        startup_buf[1],
39        startup_buf[2],
40        startup_buf[3],
41    ]);
42
43    // SSL request (80877103) — respond with 'N' (no SSL)
44    if protocol_version == 80877103 {
45        stream.write_all(b"N").await?;
46        // Client will retry with normal startup
47        let startup_len_raw = stream.read_i32().await?;
48        let startup_len = usize::try_from(startup_len_raw)
49            .map_err(|_| format!("negative startup message length: {startup_len_raw}"))?;
50        if !(8..=10240).contains(&startup_len) {
51            return Err("invalid startup message length after SSL".into());
52        }
53        startup_buf = vec![0u8; startup_len - 4];
54        stream.read_exact(&mut startup_buf).await?;
55    }
56
57    // Phase 2: Authentication
58    if let Some(ref expected_password) = config.password {
59        // Send AuthenticationCleartextPassword (type 3)
60        stream.write_all(&[b'R', 0, 0, 0, 8, 0, 0, 0, 3]).await?;
61        stream.flush().await?;
62
63        // Read password message (type 'p')
64        let pw_type = stream.read_u8().await?;
65        if pw_type != b'p' {
66            send_error(&mut stream, "expected password message").await?;
67            return Err("expected password message".into());
68        }
69        let pw_len_raw = stream.read_i32().await?;
70        let pw_len = usize::try_from(pw_len_raw)
71            .map_err(|_| format!("negative password message length: {pw_len_raw}"))?;
72        if !(5..=10240).contains(&pw_len) {
73            return Err("invalid password message length".into());
74        }
75        let mut pw_buf = vec![0u8; pw_len - 4];
76        stream.read_exact(&mut pw_buf).await?;
77        let client_password = String::from_utf8_lossy(&pw_buf)
78            .trim_end_matches('\0')
79            .to_string();
80
81        if client_password != *expected_password {
82            send_error(&mut stream, "password authentication failed").await?;
83            return Err("authentication failed".into());
84        }
85    }
86
87    // Send AuthenticationOk
88    stream.write_all(&[b'R', 0, 0, 0, 8, 0, 0, 0, 0]).await?;
89
90    // Send ParameterStatus messages
91    send_parameter_status(&mut stream, "server_version", "16.0").await?;
92    send_parameter_status(&mut stream, "server_encoding", "UTF8").await?;
93    send_parameter_status(&mut stream, "client_encoding", "UTF8").await?;
94    send_parameter_status(&mut stream, "application_name", "mnemo-pgwire").await?;
95
96    // Send ReadyForQuery
97    send_ready_for_query(&mut stream).await?;
98
99    // Phase 3: Query loop
100    while let Ok(msg_type) = stream.read_u8().await {
101        let msg_len_raw = stream.read_i32().await?;
102        let msg_len = usize::try_from(msg_len_raw)
103            .map_err(|_| format!("negative message length: {msg_len_raw}"))?;
104        if !(4..=1_048_576).contains(&msg_len) {
105            break;
106        }
107
108        let mut msg_buf = vec![0u8; msg_len - 4];
109        if !msg_buf.is_empty() {
110            stream.read_exact(&mut msg_buf).await?;
111        }
112
113        match msg_type {
114            b'Q' => {
115                // Simple Query
116                let sql = String::from_utf8_lossy(&msg_buf)
117                    .trim_end_matches('\0')
118                    .to_string();
119
120                tracing::debug!("pgwire query: {sql}");
121
122                match handle_query(&sql, &engine, config).await {
123                    Ok(response) => {
124                        send_query_response(&mut stream, &response).await?;
125                    }
126                    Err(e) => {
127                        send_error(&mut stream, &e.to_string()).await?;
128                    }
129                }
130
131                send_ready_for_query(&mut stream).await?;
132            }
133            b'X' => {
134                // Terminate
135                tracing::debug!("pgwire client terminated");
136                break;
137            }
138            _ => {
139                // Unsupported message type — send error and continue
140                send_error(
141                    &mut stream,
142                    &format!("unsupported message type: {}", msg_type as char),
143                )
144                .await?;
145                send_ready_for_query(&mut stream).await?;
146            }
147        }
148    }
149
150    Ok(())
151}
152
153/// Query response rows.
154struct QueryResponse {
155    columns: Vec<String>,
156    rows: Vec<Vec<String>>,
157    command_tag: String,
158}
159
160async fn handle_query(
161    sql: &str,
162    engine: &MnemoEngine,
163    config: &PgWireConfig,
164) -> Result<QueryResponse, Box<dyn std::error::Error + Send + Sync>> {
165    let stmt = parser::parse_sql(sql);
166
167    match stmt {
168        ParsedStatement::Select(q) => {
169            let agent_id = q
170                .agent_id
171                .unwrap_or_else(|| config.default_agent_id.clone());
172
173            let request = mnemo_core::query::recall::RecallRequest {
174                agent_id: Some(agent_id),
175                query: q.query_text.unwrap_or_default(),
176                limit: Some(q.limit),
177                memory_type: None,
178                memory_types: None,
179                scope: None,
180                strategy: Some("exact".to_string()),
181                min_importance: None,
182                tags: None,
183                org_id: None,
184                temporal_range: None,
185                recency_half_life_hours: None,
186                hybrid_weights: None,
187                rrf_k: None,
188                as_of: None,
189                explain: None,
190                with_provenance: None,
191            };
192
193            let response = engine.recall(request).await?;
194
195            let columns = vec![
196                "id".to_string(),
197                "agent_id".to_string(),
198                "content".to_string(),
199                "memory_type".to_string(),
200                "importance".to_string(),
201                "created_at".to_string(),
202            ];
203
204            let rows: Vec<Vec<String>> = response
205                .memories
206                .iter()
207                .skip(q.offset)
208                .map(|m| {
209                    vec![
210                        m.id.to_string(),
211                        m.agent_id.clone(),
212                        m.content.clone(),
213                        m.memory_type.to_string(),
214                        m.importance.to_string(),
215                        m.created_at.clone(),
216                    ]
217                })
218                .collect();
219
220            let count = rows.len();
221            Ok(QueryResponse {
222                columns,
223                rows,
224                command_tag: format!("SELECT {count}"),
225            })
226        }
227
228        ParsedStatement::Insert(q) => {
229            let agent_id = q
230                .agent_id
231                .unwrap_or_else(|| config.default_agent_id.clone());
232
233            let request = mnemo_core::query::remember::RememberRequest {
234                content: q.content,
235                agent_id: Some(agent_id),
236                memory_type: q.memory_type.as_deref().and_then(parse_memory_type),
237                scope: None,
238                importance: q.importance,
239                tags: if q.tags.is_empty() {
240                    None
241                } else {
242                    Some(q.tags)
243                },
244                metadata: None,
245                source_type: None,
246                source_id: None,
247                org_id: None,
248                thread_id: None,
249                ttl_seconds: None,
250                related_to: None,
251                decay_rate: None,
252                created_by: None,
253            };
254
255            let response = engine.remember(request).await?;
256
257            Ok(QueryResponse {
258                columns: vec!["id".to_string(), "content_hash".to_string()],
259                rows: vec![vec![response.id.to_string(), response.content_hash.clone()]],
260                command_tag: "INSERT 0 1".to_string(),
261            })
262        }
263
264        ParsedStatement::Delete(q) => {
265            if let Some(memory_id_str) = q.memory_id {
266                let memory_id: uuid::Uuid = memory_id_str
267                    .parse()
268                    .map_err(|e| format!("invalid UUID in DELETE WHERE id = '...': {e}"))?;
269
270                let agent_id = q
271                    .agent_id
272                    .unwrap_or_else(|| config.default_agent_id.clone());
273
274                let request = mnemo_core::query::forget::ForgetRequest {
275                    memory_ids: vec![memory_id],
276                    agent_id: Some(agent_id),
277                    strategy: Some(mnemo_core::query::forget::ForgetStrategy::SoftDelete),
278                    criteria: None,
279                };
280
281                let response = engine.forget(request).await?;
282                let count = response.forgotten.len();
283
284                Ok(QueryResponse {
285                    columns: vec![],
286                    rows: vec![],
287                    command_tag: format!("DELETE {count}"),
288                })
289            } else {
290                Err("DELETE requires WHERE id = '...' clause".into())
291            }
292        }
293
294        ParsedStatement::Unsupported(s) => Err(format!("unsupported SQL: {s}").into()),
295    }
296}
297
298fn parse_memory_type(s: &str) -> Option<mnemo_core::model::memory::MemoryType> {
299    match s.to_lowercase().as_str() {
300        "episodic" => Some(mnemo_core::model::memory::MemoryType::Episodic),
301        "semantic" => Some(mnemo_core::model::memory::MemoryType::Semantic),
302        "procedural" => Some(mnemo_core::model::memory::MemoryType::Procedural),
303        "working" => Some(mnemo_core::model::memory::MemoryType::Working),
304        _ => None,
305    }
306}
307
308async fn send_parameter_status(
309    stream: &mut TcpStream,
310    name: &str,
311    value: &str,
312) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
313    let mut buf = Vec::new();
314    buf.push(b'S'); // ParameterStatus type
315
316    let name_bytes = name.as_bytes();
317    let value_bytes = value.as_bytes();
318    let len = 4 + name_bytes.len() + 1 + value_bytes.len() + 1;
319    buf.extend_from_slice(&(len as i32).to_be_bytes());
320    buf.extend_from_slice(name_bytes);
321    buf.push(0);
322    buf.extend_from_slice(value_bytes);
323    buf.push(0);
324
325    stream.write_all(&buf).await?;
326    Ok(())
327}
328
329async fn send_ready_for_query(
330    stream: &mut TcpStream,
331) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
332    // ReadyForQuery: type 'Z', length 5, transaction status 'I' (idle)
333    stream.write_all(&[b'Z', 0, 0, 0, 5, b'I']).await?;
334    Ok(())
335}
336
337async fn send_error(
338    stream: &mut TcpStream,
339    message: &str,
340) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
341    let mut buf = Vec::new();
342    buf.push(b'E'); // ErrorResponse type
343
344    let mut fields = Vec::new();
345    // Severity
346    fields.push(b'S');
347    fields.extend_from_slice(b"ERROR\0");
348    // SQLSTATE (42000 = syntax error)
349    fields.push(b'C');
350    fields.extend_from_slice(b"42000\0");
351    // Message
352    fields.push(b'M');
353    fields.extend_from_slice(message.as_bytes());
354    fields.push(0);
355    // Terminator
356    fields.push(0);
357
358    let len = 4 + fields.len();
359    buf.extend_from_slice(&(len as i32).to_be_bytes());
360    buf.extend_from_slice(&fields);
361
362    stream.write_all(&buf).await?;
363    Ok(())
364}
365
366async fn send_query_response(
367    stream: &mut TcpStream,
368    response: &QueryResponse,
369) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
370    if !response.columns.is_empty() {
371        // RowDescription
372        let mut desc_buf = Vec::new();
373        desc_buf.extend_from_slice(&(response.columns.len() as i16).to_be_bytes());
374
375        for col in &response.columns {
376            desc_buf.extend_from_slice(col.as_bytes());
377            desc_buf.push(0); // null terminator
378            desc_buf.extend_from_slice(&0i32.to_be_bytes()); // table OID
379            desc_buf.extend_from_slice(&0i16.to_be_bytes()); // column attr number
380            desc_buf.extend_from_slice(&25i32.to_be_bytes()); // type OID (text = 25)
381            desc_buf.extend_from_slice(&(-1i16).to_be_bytes()); // type size (-1 = variable)
382            desc_buf.extend_from_slice(&(-1i32).to_be_bytes()); // type modifier
383            desc_buf.extend_from_slice(&0i16.to_be_bytes()); // format code (text = 0)
384        }
385
386        let mut msg = Vec::new();
387        msg.push(b'T'); // RowDescription type
388        let len = 4 + desc_buf.len();
389        msg.extend_from_slice(&(len as i32).to_be_bytes());
390        msg.extend_from_slice(&desc_buf);
391        stream.write_all(&msg).await?;
392
393        // DataRow for each row
394        for row in &response.rows {
395            let mut row_buf = Vec::new();
396            row_buf.extend_from_slice(&(row.len() as i16).to_be_bytes());
397
398            for val in row {
399                let bytes = val.as_bytes();
400                row_buf.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
401                row_buf.extend_from_slice(bytes);
402            }
403
404            let mut msg = Vec::new();
405            msg.push(b'D'); // DataRow type
406            let len = 4 + row_buf.len();
407            msg.extend_from_slice(&(len as i32).to_be_bytes());
408            msg.extend_from_slice(&row_buf);
409            stream.write_all(&msg).await?;
410        }
411    }
412
413    // CommandComplete
414    let tag = response.command_tag.as_bytes();
415    let mut msg = Vec::new();
416    msg.push(b'C'); // CommandComplete type
417    let len = 4 + tag.len() + 1;
418    msg.extend_from_slice(&(len as i32).to_be_bytes());
419    msg.extend_from_slice(tag);
420    msg.push(0);
421    stream.write_all(&msg).await?;
422
423    Ok(())
424}