1use std::sync::Arc;
10
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13
14use mnemo_core::query::MnemoEngine;
15
16use crate::PgWireConfig;
17use crate::parser::{self, ParsedStatement};
18
19pub async fn handle_connection(
21 mut stream: TcpStream,
22 engine: Arc<MnemoEngine>,
23 config: &PgWireConfig,
24) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
25 let startup_len_raw = stream.read_i32().await?;
27 let startup_len = usize::try_from(startup_len_raw)
28 .map_err(|_| format!("negative startup message length: {startup_len_raw}"))?;
29 if !(8..=10240).contains(&startup_len) {
30 return Err("invalid startup message length".into());
31 }
32
33 let mut startup_buf = vec![0u8; startup_len - 4];
34 stream.read_exact(&mut startup_buf).await?;
35
36 let protocol_version = i32::from_be_bytes([
37 startup_buf[0],
38 startup_buf[1],
39 startup_buf[2],
40 startup_buf[3],
41 ]);
42
43 if protocol_version == 80877103 {
45 stream.write_all(b"N").await?;
46 let startup_len_raw = stream.read_i32().await?;
48 let startup_len = usize::try_from(startup_len_raw)
49 .map_err(|_| format!("negative startup message length: {startup_len_raw}"))?;
50 if !(8..=10240).contains(&startup_len) {
51 return Err("invalid startup message length after SSL".into());
52 }
53 startup_buf = vec![0u8; startup_len - 4];
54 stream.read_exact(&mut startup_buf).await?;
55 }
56
57 if let Some(ref expected_password) = config.password {
59 stream.write_all(&[b'R', 0, 0, 0, 8, 0, 0, 0, 3]).await?;
61 stream.flush().await?;
62
63 let pw_type = stream.read_u8().await?;
65 if pw_type != b'p' {
66 send_error(&mut stream, "expected password message").await?;
67 return Err("expected password message".into());
68 }
69 let pw_len_raw = stream.read_i32().await?;
70 let pw_len = usize::try_from(pw_len_raw)
71 .map_err(|_| format!("negative password message length: {pw_len_raw}"))?;
72 if !(5..=10240).contains(&pw_len) {
73 return Err("invalid password message length".into());
74 }
75 let mut pw_buf = vec![0u8; pw_len - 4];
76 stream.read_exact(&mut pw_buf).await?;
77 let client_password = String::from_utf8_lossy(&pw_buf)
78 .trim_end_matches('\0')
79 .to_string();
80
81 if client_password != *expected_password {
82 send_error(&mut stream, "password authentication failed").await?;
83 return Err("authentication failed".into());
84 }
85 }
86
87 stream.write_all(&[b'R', 0, 0, 0, 8, 0, 0, 0, 0]).await?;
89
90 send_parameter_status(&mut stream, "server_version", "16.0").await?;
92 send_parameter_status(&mut stream, "server_encoding", "UTF8").await?;
93 send_parameter_status(&mut stream, "client_encoding", "UTF8").await?;
94 send_parameter_status(&mut stream, "application_name", "mnemo-pgwire").await?;
95
96 send_ready_for_query(&mut stream).await?;
98
99 while let Ok(msg_type) = stream.read_u8().await {
101 let msg_len_raw = stream.read_i32().await?;
102 let msg_len = usize::try_from(msg_len_raw)
103 .map_err(|_| format!("negative message length: {msg_len_raw}"))?;
104 if !(4..=1_048_576).contains(&msg_len) {
105 break;
106 }
107
108 let mut msg_buf = vec![0u8; msg_len - 4];
109 if !msg_buf.is_empty() {
110 stream.read_exact(&mut msg_buf).await?;
111 }
112
113 match msg_type {
114 b'Q' => {
115 let sql = String::from_utf8_lossy(&msg_buf)
117 .trim_end_matches('\0')
118 .to_string();
119
120 tracing::debug!("pgwire query: {sql}");
121
122 match handle_query(&sql, &engine, config).await {
123 Ok(response) => {
124 send_query_response(&mut stream, &response).await?;
125 }
126 Err(e) => {
127 send_error(&mut stream, &e.to_string()).await?;
128 }
129 }
130
131 send_ready_for_query(&mut stream).await?;
132 }
133 b'X' => {
134 tracing::debug!("pgwire client terminated");
136 break;
137 }
138 _ => {
139 send_error(
141 &mut stream,
142 &format!("unsupported message type: {}", msg_type as char),
143 )
144 .await?;
145 send_ready_for_query(&mut stream).await?;
146 }
147 }
148 }
149
150 Ok(())
151}
152
153struct QueryResponse {
155 columns: Vec<String>,
156 rows: Vec<Vec<String>>,
157 command_tag: String,
158}
159
160async fn handle_query(
161 sql: &str,
162 engine: &MnemoEngine,
163 config: &PgWireConfig,
164) -> Result<QueryResponse, Box<dyn std::error::Error + Send + Sync>> {
165 let stmt = parser::parse_sql(sql);
166
167 match stmt {
168 ParsedStatement::Select(q) => {
169 let agent_id = q
170 .agent_id
171 .unwrap_or_else(|| config.default_agent_id.clone());
172
173 let request = mnemo_core::query::recall::RecallRequest {
174 agent_id: Some(agent_id),
175 query: q.query_text.unwrap_or_default(),
176 limit: Some(q.limit),
177 memory_type: None,
178 memory_types: None,
179 scope: None,
180 strategy: Some("exact".to_string()),
181 min_importance: None,
182 tags: None,
183 org_id: None,
184 temporal_range: None,
185 recency_half_life_hours: None,
186 hybrid_weights: None,
187 rrf_k: None,
188 as_of: None,
189 explain: None,
190 with_provenance: None,
191 };
192
193 let response = engine.recall(request).await?;
194
195 let columns = vec![
196 "id".to_string(),
197 "agent_id".to_string(),
198 "content".to_string(),
199 "memory_type".to_string(),
200 "importance".to_string(),
201 "created_at".to_string(),
202 ];
203
204 let rows: Vec<Vec<String>> = response
205 .memories
206 .iter()
207 .skip(q.offset)
208 .map(|m| {
209 vec![
210 m.id.to_string(),
211 m.agent_id.clone(),
212 m.content.clone(),
213 m.memory_type.to_string(),
214 m.importance.to_string(),
215 m.created_at.clone(),
216 ]
217 })
218 .collect();
219
220 let count = rows.len();
221 Ok(QueryResponse {
222 columns,
223 rows,
224 command_tag: format!("SELECT {count}"),
225 })
226 }
227
228 ParsedStatement::Insert(q) => {
229 let agent_id = q
230 .agent_id
231 .unwrap_or_else(|| config.default_agent_id.clone());
232
233 let request = mnemo_core::query::remember::RememberRequest {
234 content: q.content,
235 agent_id: Some(agent_id),
236 memory_type: q.memory_type.as_deref().and_then(parse_memory_type),
237 scope: None,
238 importance: q.importance,
239 tags: if q.tags.is_empty() {
240 None
241 } else {
242 Some(q.tags)
243 },
244 metadata: None,
245 source_type: None,
246 source_id: None,
247 org_id: None,
248 thread_id: None,
249 ttl_seconds: None,
250 related_to: None,
251 decay_rate: None,
252 created_by: None,
253 };
254
255 let response = engine.remember(request).await?;
256
257 Ok(QueryResponse {
258 columns: vec!["id".to_string(), "content_hash".to_string()],
259 rows: vec![vec![response.id.to_string(), response.content_hash.clone()]],
260 command_tag: "INSERT 0 1".to_string(),
261 })
262 }
263
264 ParsedStatement::Delete(q) => {
265 if let Some(memory_id_str) = q.memory_id {
266 let memory_id: uuid::Uuid = memory_id_str
267 .parse()
268 .map_err(|e| format!("invalid UUID in DELETE WHERE id = '...': {e}"))?;
269
270 let agent_id = q
271 .agent_id
272 .unwrap_or_else(|| config.default_agent_id.clone());
273
274 let request = mnemo_core::query::forget::ForgetRequest {
275 memory_ids: vec![memory_id],
276 agent_id: Some(agent_id),
277 strategy: Some(mnemo_core::query::forget::ForgetStrategy::SoftDelete),
278 criteria: None,
279 };
280
281 let response = engine.forget(request).await?;
282 let count = response.forgotten.len();
283
284 Ok(QueryResponse {
285 columns: vec![],
286 rows: vec![],
287 command_tag: format!("DELETE {count}"),
288 })
289 } else {
290 Err("DELETE requires WHERE id = '...' clause".into())
291 }
292 }
293
294 ParsedStatement::Unsupported(s) => Err(format!("unsupported SQL: {s}").into()),
295 }
296}
297
298fn parse_memory_type(s: &str) -> Option<mnemo_core::model::memory::MemoryType> {
299 match s.to_lowercase().as_str() {
300 "episodic" => Some(mnemo_core::model::memory::MemoryType::Episodic),
301 "semantic" => Some(mnemo_core::model::memory::MemoryType::Semantic),
302 "procedural" => Some(mnemo_core::model::memory::MemoryType::Procedural),
303 "working" => Some(mnemo_core::model::memory::MemoryType::Working),
304 _ => None,
305 }
306}
307
308async fn send_parameter_status(
309 stream: &mut TcpStream,
310 name: &str,
311 value: &str,
312) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
313 let mut buf = Vec::new();
314 buf.push(b'S'); let name_bytes = name.as_bytes();
317 let value_bytes = value.as_bytes();
318 let len = 4 + name_bytes.len() + 1 + value_bytes.len() + 1;
319 buf.extend_from_slice(&(len as i32).to_be_bytes());
320 buf.extend_from_slice(name_bytes);
321 buf.push(0);
322 buf.extend_from_slice(value_bytes);
323 buf.push(0);
324
325 stream.write_all(&buf).await?;
326 Ok(())
327}
328
329async fn send_ready_for_query(
330 stream: &mut TcpStream,
331) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
332 stream.write_all(&[b'Z', 0, 0, 0, 5, b'I']).await?;
334 Ok(())
335}
336
337async fn send_error(
338 stream: &mut TcpStream,
339 message: &str,
340) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
341 let mut buf = Vec::new();
342 buf.push(b'E'); let mut fields = Vec::new();
345 fields.push(b'S');
347 fields.extend_from_slice(b"ERROR\0");
348 fields.push(b'C');
350 fields.extend_from_slice(b"42000\0");
351 fields.push(b'M');
353 fields.extend_from_slice(message.as_bytes());
354 fields.push(0);
355 fields.push(0);
357
358 let len = 4 + fields.len();
359 buf.extend_from_slice(&(len as i32).to_be_bytes());
360 buf.extend_from_slice(&fields);
361
362 stream.write_all(&buf).await?;
363 Ok(())
364}
365
366async fn send_query_response(
367 stream: &mut TcpStream,
368 response: &QueryResponse,
369) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
370 if !response.columns.is_empty() {
371 let mut desc_buf = Vec::new();
373 desc_buf.extend_from_slice(&(response.columns.len() as i16).to_be_bytes());
374
375 for col in &response.columns {
376 desc_buf.extend_from_slice(col.as_bytes());
377 desc_buf.push(0); desc_buf.extend_from_slice(&0i32.to_be_bytes()); desc_buf.extend_from_slice(&0i16.to_be_bytes()); desc_buf.extend_from_slice(&25i32.to_be_bytes()); desc_buf.extend_from_slice(&(-1i16).to_be_bytes()); desc_buf.extend_from_slice(&(-1i32).to_be_bytes()); desc_buf.extend_from_slice(&0i16.to_be_bytes()); }
385
386 let mut msg = Vec::new();
387 msg.push(b'T'); let len = 4 + desc_buf.len();
389 msg.extend_from_slice(&(len as i32).to_be_bytes());
390 msg.extend_from_slice(&desc_buf);
391 stream.write_all(&msg).await?;
392
393 for row in &response.rows {
395 let mut row_buf = Vec::new();
396 row_buf.extend_from_slice(&(row.len() as i16).to_be_bytes());
397
398 for val in row {
399 let bytes = val.as_bytes();
400 row_buf.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
401 row_buf.extend_from_slice(bytes);
402 }
403
404 let mut msg = Vec::new();
405 msg.push(b'D'); let len = 4 + row_buf.len();
407 msg.extend_from_slice(&(len as i32).to_be_bytes());
408 msg.extend_from_slice(&row_buf);
409 stream.write_all(&msg).await?;
410 }
411 }
412
413 let tag = response.command_tag.as_bytes();
415 let mut msg = Vec::new();
416 msg.push(b'C'); let len = 4 + tag.len() + 1;
418 msg.extend_from_slice(&(len as i32).to_be_bytes());
419 msg.extend_from_slice(tag);
420 msg.push(0);
421 stream.write_all(&msg).await?;
422
423 Ok(())
424}