1use std::sync::Arc;
10
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13
14use mnemo_core::query::MnemoEngine;
15
16use crate::PgWireConfig;
17use crate::parser::{self, ParsedStatement};
18
19pub async fn handle_connection(
21 mut stream: TcpStream,
22 engine: Arc<MnemoEngine>,
23 config: &PgWireConfig,
24) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
25 let startup_len_raw = stream.read_i32().await?;
27 let startup_len = usize::try_from(startup_len_raw)
28 .map_err(|_| format!("negative startup message length: {startup_len_raw}"))?;
29 if !(8..=10240).contains(&startup_len) {
30 return Err("invalid startup message length".into());
31 }
32
33 let mut startup_buf = vec![0u8; startup_len - 4];
34 stream.read_exact(&mut startup_buf).await?;
35
36 let protocol_version = i32::from_be_bytes([
37 startup_buf[0],
38 startup_buf[1],
39 startup_buf[2],
40 startup_buf[3],
41 ]);
42
43 if protocol_version == 80877103 {
45 stream.write_all(b"N").await?;
46 let startup_len_raw = stream.read_i32().await?;
48 let startup_len = usize::try_from(startup_len_raw)
49 .map_err(|_| format!("negative startup message length: {startup_len_raw}"))?;
50 if !(8..=10240).contains(&startup_len) {
51 return Err("invalid startup message length after SSL".into());
52 }
53 startup_buf = vec![0u8; startup_len - 4];
54 stream.read_exact(&mut startup_buf).await?;
55 }
56
57 if let Some(ref expected_password) = config.password {
59 stream.write_all(&[b'R', 0, 0, 0, 8, 0, 0, 0, 3]).await?;
61 stream.flush().await?;
62
63 let pw_type = stream.read_u8().await?;
65 if pw_type != b'p' {
66 send_error(&mut stream, "expected password message").await?;
67 return Err("expected password message".into());
68 }
69 let pw_len_raw = stream.read_i32().await?;
70 let pw_len = usize::try_from(pw_len_raw)
71 .map_err(|_| format!("negative password message length: {pw_len_raw}"))?;
72 if !(5..=10240).contains(&pw_len) {
73 return Err("invalid password message length".into());
74 }
75 let mut pw_buf = vec![0u8; pw_len - 4];
76 stream.read_exact(&mut pw_buf).await?;
77 let client_password = String::from_utf8_lossy(&pw_buf)
78 .trim_end_matches('\0')
79 .to_string();
80
81 if client_password != *expected_password {
82 send_error(&mut stream, "password authentication failed").await?;
83 return Err("authentication failed".into());
84 }
85 }
86
87 stream.write_all(&[b'R', 0, 0, 0, 8, 0, 0, 0, 0]).await?;
89
90 send_parameter_status(&mut stream, "server_version", "16.0").await?;
92 send_parameter_status(&mut stream, "server_encoding", "UTF8").await?;
93 send_parameter_status(&mut stream, "client_encoding", "UTF8").await?;
94 send_parameter_status(&mut stream, "application_name", "mnemo-pgwire").await?;
95
96 send_ready_for_query(&mut stream).await?;
98
99 while let Ok(msg_type) = stream.read_u8().await {
101 let msg_len_raw = stream.read_i32().await?;
102 let msg_len = usize::try_from(msg_len_raw)
103 .map_err(|_| format!("negative message length: {msg_len_raw}"))?;
104 if !(4..=1_048_576).contains(&msg_len) {
105 break;
106 }
107
108 let mut msg_buf = vec![0u8; msg_len - 4];
109 if !msg_buf.is_empty() {
110 stream.read_exact(&mut msg_buf).await?;
111 }
112
113 match msg_type {
114 b'Q' => {
115 let sql = String::from_utf8_lossy(&msg_buf)
117 .trim_end_matches('\0')
118 .to_string();
119
120 tracing::debug!("pgwire query: {sql}");
121
122 match handle_query(&sql, &engine, config).await {
123 Ok(response) => {
124 send_query_response(&mut stream, &response).await?;
125 }
126 Err(e) => {
127 send_error(&mut stream, &e.to_string()).await?;
128 }
129 }
130
131 send_ready_for_query(&mut stream).await?;
132 }
133 b'X' => {
134 tracing::debug!("pgwire client terminated");
136 break;
137 }
138 _ => {
139 send_error(
141 &mut stream,
142 &format!("unsupported message type: {}", msg_type as char),
143 )
144 .await?;
145 send_ready_for_query(&mut stream).await?;
146 }
147 }
148 }
149
150 Ok(())
151}
152
153struct QueryResponse {
155 columns: Vec<String>,
156 rows: Vec<Vec<String>>,
157 command_tag: String,
158}
159
160async fn handle_query(
161 sql: &str,
162 engine: &MnemoEngine,
163 config: &PgWireConfig,
164) -> Result<QueryResponse, Box<dyn std::error::Error + Send + Sync>> {
165 let stmt = parser::parse_sql(sql);
166
167 match stmt {
168 ParsedStatement::Select(q) => {
169 let agent_id = q
170 .agent_id
171 .unwrap_or_else(|| config.default_agent_id.clone());
172
173 let request = mnemo_core::query::recall::RecallRequest {
174 agent_id: Some(agent_id),
175 query: q.query_text.unwrap_or_default(),
176 limit: Some(q.limit),
177 memory_type: None,
178 memory_types: None,
179 scope: None,
180 strategy: Some("exact".to_string()),
181 min_importance: None,
182 tags: None,
183 org_id: None,
184 temporal_range: None,
185 recency_half_life_hours: None,
186 hybrid_weights: None,
187 rrf_k: None,
188 as_of: None,
189 explain: None,
190 with_provenance: None,
191 mode: None,
192 };
193
194 let response = engine.recall(request).await?;
195
196 let columns = vec![
197 "id".to_string(),
198 "agent_id".to_string(),
199 "content".to_string(),
200 "memory_type".to_string(),
201 "importance".to_string(),
202 "created_at".to_string(),
203 ];
204
205 let rows: Vec<Vec<String>> = response
206 .memories
207 .iter()
208 .skip(q.offset)
209 .map(|m| {
210 vec![
211 m.id.to_string(),
212 m.agent_id.clone(),
213 m.content.clone(),
214 m.memory_type.to_string(),
215 m.importance.to_string(),
216 m.created_at.clone(),
217 ]
218 })
219 .collect();
220
221 let count = rows.len();
222 Ok(QueryResponse {
223 columns,
224 rows,
225 command_tag: format!("SELECT {count}"),
226 })
227 }
228
229 ParsedStatement::Insert(q) => {
230 let agent_id = q
231 .agent_id
232 .unwrap_or_else(|| config.default_agent_id.clone());
233
234 let request = mnemo_core::query::remember::RememberRequest {
235 content: q.content,
236 agent_id: Some(agent_id),
237 memory_type: q.memory_type.as_deref().and_then(parse_memory_type),
238 scope: None,
239 importance: q.importance,
240 tags: if q.tags.is_empty() {
241 None
242 } else {
243 Some(q.tags)
244 },
245 metadata: None,
246 source_type: None,
247 source_id: None,
248 org_id: None,
249 thread_id: None,
250 ttl_seconds: None,
251 related_to: None,
252 decay_rate: None,
253 created_by: None,
254 };
255
256 let response = engine.remember(request).await?;
257
258 Ok(QueryResponse {
259 columns: vec!["id".to_string(), "content_hash".to_string()],
260 rows: vec![vec![response.id.to_string(), response.content_hash.clone()]],
261 command_tag: "INSERT 0 1".to_string(),
262 })
263 }
264
265 ParsedStatement::Delete(q) => {
266 if let Some(memory_id_str) = q.memory_id {
267 let memory_id: uuid::Uuid = memory_id_str
268 .parse()
269 .map_err(|e| format!("invalid UUID in DELETE WHERE id = '...': {e}"))?;
270
271 let agent_id = q
272 .agent_id
273 .unwrap_or_else(|| config.default_agent_id.clone());
274
275 let request = mnemo_core::query::forget::ForgetRequest {
276 memory_ids: vec![memory_id],
277 agent_id: Some(agent_id),
278 strategy: Some(mnemo_core::query::forget::ForgetStrategy::SoftDelete),
279 criteria: None,
280 };
281
282 let response = engine.forget(request).await?;
283 let count = response.forgotten.len();
284
285 Ok(QueryResponse {
286 columns: vec![],
287 rows: vec![],
288 command_tag: format!("DELETE {count}"),
289 })
290 } else {
291 Err("DELETE requires WHERE id = '...' clause".into())
292 }
293 }
294
295 ParsedStatement::Unsupported(s) => Err(format!("unsupported SQL: {s}").into()),
296 }
297}
298
299fn parse_memory_type(s: &str) -> Option<mnemo_core::model::memory::MemoryType> {
300 match s.to_lowercase().as_str() {
301 "episodic" => Some(mnemo_core::model::memory::MemoryType::Episodic),
302 "semantic" => Some(mnemo_core::model::memory::MemoryType::Semantic),
303 "procedural" => Some(mnemo_core::model::memory::MemoryType::Procedural),
304 "working" => Some(mnemo_core::model::memory::MemoryType::Working),
305 _ => None,
306 }
307}
308
309async fn send_parameter_status(
310 stream: &mut TcpStream,
311 name: &str,
312 value: &str,
313) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
314 let mut buf = Vec::new();
315 buf.push(b'S'); let name_bytes = name.as_bytes();
318 let value_bytes = value.as_bytes();
319 let len = 4 + name_bytes.len() + 1 + value_bytes.len() + 1;
320 buf.extend_from_slice(&(len as i32).to_be_bytes());
321 buf.extend_from_slice(name_bytes);
322 buf.push(0);
323 buf.extend_from_slice(value_bytes);
324 buf.push(0);
325
326 stream.write_all(&buf).await?;
327 Ok(())
328}
329
330async fn send_ready_for_query(
331 stream: &mut TcpStream,
332) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
333 stream.write_all(&[b'Z', 0, 0, 0, 5, b'I']).await?;
335 Ok(())
336}
337
338async fn send_error(
339 stream: &mut TcpStream,
340 message: &str,
341) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
342 let mut buf = Vec::new();
343 buf.push(b'E'); let mut fields = Vec::new();
346 fields.push(b'S');
348 fields.extend_from_slice(b"ERROR\0");
349 fields.push(b'C');
351 fields.extend_from_slice(b"42000\0");
352 fields.push(b'M');
354 fields.extend_from_slice(message.as_bytes());
355 fields.push(0);
356 fields.push(0);
358
359 let len = 4 + fields.len();
360 buf.extend_from_slice(&(len as i32).to_be_bytes());
361 buf.extend_from_slice(&fields);
362
363 stream.write_all(&buf).await?;
364 Ok(())
365}
366
367async fn send_query_response(
368 stream: &mut TcpStream,
369 response: &QueryResponse,
370) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
371 if !response.columns.is_empty() {
372 let mut desc_buf = Vec::new();
374 desc_buf.extend_from_slice(&(response.columns.len() as i16).to_be_bytes());
375
376 for col in &response.columns {
377 desc_buf.extend_from_slice(col.as_bytes());
378 desc_buf.push(0); desc_buf.extend_from_slice(&0i32.to_be_bytes()); desc_buf.extend_from_slice(&0i16.to_be_bytes()); desc_buf.extend_from_slice(&25i32.to_be_bytes()); desc_buf.extend_from_slice(&(-1i16).to_be_bytes()); desc_buf.extend_from_slice(&(-1i32).to_be_bytes()); desc_buf.extend_from_slice(&0i16.to_be_bytes()); }
386
387 let mut msg = Vec::new();
388 msg.push(b'T'); let len = 4 + desc_buf.len();
390 msg.extend_from_slice(&(len as i32).to_be_bytes());
391 msg.extend_from_slice(&desc_buf);
392 stream.write_all(&msg).await?;
393
394 for row in &response.rows {
396 let mut row_buf = Vec::new();
397 row_buf.extend_from_slice(&(row.len() as i16).to_be_bytes());
398
399 for val in row {
400 let bytes = val.as_bytes();
401 row_buf.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
402 row_buf.extend_from_slice(bytes);
403 }
404
405 let mut msg = Vec::new();
406 msg.push(b'D'); let len = 4 + row_buf.len();
408 msg.extend_from_slice(&(len as i32).to_be_bytes());
409 msg.extend_from_slice(&row_buf);
410 stream.write_all(&msg).await?;
411 }
412 }
413
414 let tag = response.command_tag.as_bytes();
416 let mut msg = Vec::new();
417 msg.push(b'C'); let len = 4 + tag.len() + 1;
419 msg.extend_from_slice(&(len as i32).to_be_bytes());
420 msg.extend_from_slice(tag);
421 msg.push(0);
422 stream.write_all(&msg).await?;
423
424 Ok(())
425}