1use std::sync::Arc;
10
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13
14use mnemo_core::query::MnemoEngine;
15
16use crate::PgWireConfig;
17use crate::parser::{self, ParsedStatement};
18
19pub async fn handle_connection(
21 mut stream: TcpStream,
22 engine: Arc<MnemoEngine>,
23 config: &PgWireConfig,
24) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
25 let startup_len_raw = stream.read_i32().await?;
27 let startup_len = usize::try_from(startup_len_raw)
28 .map_err(|_| format!("negative startup message length: {startup_len_raw}"))?;
29 if !(8..=10240).contains(&startup_len) {
30 return Err("invalid startup message length".into());
31 }
32
33 let mut startup_buf = vec![0u8; startup_len - 4];
34 stream.read_exact(&mut startup_buf).await?;
35
36 let protocol_version = i32::from_be_bytes([
37 startup_buf[0],
38 startup_buf[1],
39 startup_buf[2],
40 startup_buf[3],
41 ]);
42
43 if protocol_version == 80877103 {
45 stream.write_all(b"N").await?;
46 let startup_len_raw = stream.read_i32().await?;
48 let startup_len = usize::try_from(startup_len_raw)
49 .map_err(|_| format!("negative startup message length: {startup_len_raw}"))?;
50 if !(8..=10240).contains(&startup_len) {
51 return Err("invalid startup message length after SSL".into());
52 }
53 startup_buf = vec![0u8; startup_len - 4];
54 stream.read_exact(&mut startup_buf).await?;
55 }
56
57 if let Some(ref expected_password) = config.password {
59 stream.write_all(&[b'R', 0, 0, 0, 8, 0, 0, 0, 3]).await?;
61 stream.flush().await?;
62
63 let pw_type = stream.read_u8().await?;
65 if pw_type != b'p' {
66 send_error(&mut stream, "expected password message").await?;
67 return Err("expected password message".into());
68 }
69 let pw_len_raw = stream.read_i32().await?;
70 let pw_len = usize::try_from(pw_len_raw)
71 .map_err(|_| format!("negative password message length: {pw_len_raw}"))?;
72 if !(5..=10240).contains(&pw_len) {
73 return Err("invalid password message length".into());
74 }
75 let mut pw_buf = vec![0u8; pw_len - 4];
76 stream.read_exact(&mut pw_buf).await?;
77 let client_password = String::from_utf8_lossy(&pw_buf)
78 .trim_end_matches('\0')
79 .to_string();
80
81 if client_password != *expected_password {
82 send_error(&mut stream, "password authentication failed").await?;
83 return Err("authentication failed".into());
84 }
85 }
86
87 stream.write_all(&[b'R', 0, 0, 0, 8, 0, 0, 0, 0]).await?;
89
90 send_parameter_status(&mut stream, "server_version", "16.0").await?;
92 send_parameter_status(&mut stream, "server_encoding", "UTF8").await?;
93 send_parameter_status(&mut stream, "client_encoding", "UTF8").await?;
94 send_parameter_status(&mut stream, "application_name", "mnemo-pgwire").await?;
95
96 send_ready_for_query(&mut stream).await?;
98
99 while let Ok(msg_type) = stream.read_u8().await {
101 let msg_len_raw = stream.read_i32().await?;
102 let msg_len = usize::try_from(msg_len_raw)
103 .map_err(|_| format!("negative message length: {msg_len_raw}"))?;
104 if !(4..=1_048_576).contains(&msg_len) {
105 break;
106 }
107
108 let mut msg_buf = vec![0u8; msg_len - 4];
109 if !msg_buf.is_empty() {
110 stream.read_exact(&mut msg_buf).await?;
111 }
112
113 match msg_type {
114 b'Q' => {
115 let sql = String::from_utf8_lossy(&msg_buf)
117 .trim_end_matches('\0')
118 .to_string();
119
120 tracing::debug!("pgwire query: {sql}");
121
122 match handle_query(&sql, &engine, config).await {
123 Ok(response) => {
124 send_query_response(&mut stream, &response).await?;
125 }
126 Err(e) => {
127 send_error(&mut stream, &e.to_string()).await?;
128 }
129 }
130
131 send_ready_for_query(&mut stream).await?;
132 }
133 b'X' => {
134 tracing::debug!("pgwire client terminated");
136 break;
137 }
138 _ => {
139 send_error(
141 &mut stream,
142 &format!("unsupported message type: {}", msg_type as char),
143 )
144 .await?;
145 send_ready_for_query(&mut stream).await?;
146 }
147 }
148 }
149
150 Ok(())
151}
152
153struct QueryResponse {
155 columns: Vec<String>,
156 rows: Vec<Vec<String>>,
157 command_tag: String,
158}
159
160async fn handle_query(
161 sql: &str,
162 engine: &MnemoEngine,
163 config: &PgWireConfig,
164) -> Result<QueryResponse, Box<dyn std::error::Error + Send + Sync>> {
165 let stmt = parser::parse_sql(sql);
166
167 match stmt {
168 ParsedStatement::Select(q) => {
169 let agent_id = q
170 .agent_id
171 .unwrap_or_else(|| config.default_agent_id.clone());
172
173 let request = mnemo_core::query::recall::RecallRequest {
174 agent_id: Some(agent_id),
175 query: q.query_text.unwrap_or_default(),
176 limit: Some(q.limit),
177 memory_type: None,
178 memory_types: None,
179 scope: None,
180 strategy: Some("exact".to_string()),
181 min_importance: None,
182 tags: None,
183 org_id: None,
184 temporal_range: None,
185 recency_half_life_hours: None,
186 hybrid_weights: None,
187 rrf_k: None,
188 as_of: None,
189 explain: None,
190 };
191
192 let response = engine.recall(request).await?;
193
194 let columns = vec![
195 "id".to_string(),
196 "agent_id".to_string(),
197 "content".to_string(),
198 "memory_type".to_string(),
199 "importance".to_string(),
200 "created_at".to_string(),
201 ];
202
203 let rows: Vec<Vec<String>> = response
204 .memories
205 .iter()
206 .skip(q.offset)
207 .map(|m| {
208 vec![
209 m.id.to_string(),
210 m.agent_id.clone(),
211 m.content.clone(),
212 m.memory_type.to_string(),
213 m.importance.to_string(),
214 m.created_at.clone(),
215 ]
216 })
217 .collect();
218
219 let count = rows.len();
220 Ok(QueryResponse {
221 columns,
222 rows,
223 command_tag: format!("SELECT {count}"),
224 })
225 }
226
227 ParsedStatement::Insert(q) => {
228 let agent_id = q
229 .agent_id
230 .unwrap_or_else(|| config.default_agent_id.clone());
231
232 let request = mnemo_core::query::remember::RememberRequest {
233 content: q.content,
234 agent_id: Some(agent_id),
235 memory_type: q.memory_type.as_deref().and_then(parse_memory_type),
236 scope: None,
237 importance: q.importance,
238 tags: if q.tags.is_empty() {
239 None
240 } else {
241 Some(q.tags)
242 },
243 metadata: None,
244 source_type: None,
245 source_id: None,
246 org_id: None,
247 thread_id: None,
248 ttl_seconds: None,
249 related_to: None,
250 decay_rate: None,
251 created_by: None,
252 };
253
254 let response = engine.remember(request).await?;
255
256 Ok(QueryResponse {
257 columns: vec!["id".to_string(), "content_hash".to_string()],
258 rows: vec![vec![response.id.to_string(), response.content_hash.clone()]],
259 command_tag: "INSERT 0 1".to_string(),
260 })
261 }
262
263 ParsedStatement::Delete(q) => {
264 if let Some(memory_id_str) = q.memory_id {
265 let memory_id: uuid::Uuid = memory_id_str
266 .parse()
267 .map_err(|e| format!("invalid UUID in DELETE WHERE id = '...': {e}"))?;
268
269 let agent_id = q
270 .agent_id
271 .unwrap_or_else(|| config.default_agent_id.clone());
272
273 let request = mnemo_core::query::forget::ForgetRequest {
274 memory_ids: vec![memory_id],
275 agent_id: Some(agent_id),
276 strategy: Some(mnemo_core::query::forget::ForgetStrategy::SoftDelete),
277 criteria: None,
278 };
279
280 let response = engine.forget(request).await?;
281 let count = response.forgotten.len();
282
283 Ok(QueryResponse {
284 columns: vec![],
285 rows: vec![],
286 command_tag: format!("DELETE {count}"),
287 })
288 } else {
289 Err("DELETE requires WHERE id = '...' clause".into())
290 }
291 }
292
293 ParsedStatement::Unsupported(s) => Err(format!("unsupported SQL: {s}").into()),
294 }
295}
296
297fn parse_memory_type(s: &str) -> Option<mnemo_core::model::memory::MemoryType> {
298 match s.to_lowercase().as_str() {
299 "episodic" => Some(mnemo_core::model::memory::MemoryType::Episodic),
300 "semantic" => Some(mnemo_core::model::memory::MemoryType::Semantic),
301 "procedural" => Some(mnemo_core::model::memory::MemoryType::Procedural),
302 "working" => Some(mnemo_core::model::memory::MemoryType::Working),
303 _ => None,
304 }
305}
306
307async fn send_parameter_status(
308 stream: &mut TcpStream,
309 name: &str,
310 value: &str,
311) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
312 let mut buf = Vec::new();
313 buf.push(b'S'); let name_bytes = name.as_bytes();
316 let value_bytes = value.as_bytes();
317 let len = 4 + name_bytes.len() + 1 + value_bytes.len() + 1;
318 buf.extend_from_slice(&(len as i32).to_be_bytes());
319 buf.extend_from_slice(name_bytes);
320 buf.push(0);
321 buf.extend_from_slice(value_bytes);
322 buf.push(0);
323
324 stream.write_all(&buf).await?;
325 Ok(())
326}
327
328async fn send_ready_for_query(
329 stream: &mut TcpStream,
330) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
331 stream.write_all(&[b'Z', 0, 0, 0, 5, b'I']).await?;
333 Ok(())
334}
335
336async fn send_error(
337 stream: &mut TcpStream,
338 message: &str,
339) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
340 let mut buf = Vec::new();
341 buf.push(b'E'); let mut fields = Vec::new();
344 fields.push(b'S');
346 fields.extend_from_slice(b"ERROR\0");
347 fields.push(b'C');
349 fields.extend_from_slice(b"42000\0");
350 fields.push(b'M');
352 fields.extend_from_slice(message.as_bytes());
353 fields.push(0);
354 fields.push(0);
356
357 let len = 4 + fields.len();
358 buf.extend_from_slice(&(len as i32).to_be_bytes());
359 buf.extend_from_slice(&fields);
360
361 stream.write_all(&buf).await?;
362 Ok(())
363}
364
365async fn send_query_response(
366 stream: &mut TcpStream,
367 response: &QueryResponse,
368) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
369 if !response.columns.is_empty() {
370 let mut desc_buf = Vec::new();
372 desc_buf.extend_from_slice(&(response.columns.len() as i16).to_be_bytes());
373
374 for col in &response.columns {
375 desc_buf.extend_from_slice(col.as_bytes());
376 desc_buf.push(0); desc_buf.extend_from_slice(&0i32.to_be_bytes()); desc_buf.extend_from_slice(&0i16.to_be_bytes()); desc_buf.extend_from_slice(&25i32.to_be_bytes()); desc_buf.extend_from_slice(&(-1i16).to_be_bytes()); desc_buf.extend_from_slice(&(-1i32).to_be_bytes()); desc_buf.extend_from_slice(&0i16.to_be_bytes()); }
384
385 let mut msg = Vec::new();
386 msg.push(b'T'); let len = 4 + desc_buf.len();
388 msg.extend_from_slice(&(len as i32).to_be_bytes());
389 msg.extend_from_slice(&desc_buf);
390 stream.write_all(&msg).await?;
391
392 for row in &response.rows {
394 let mut row_buf = Vec::new();
395 row_buf.extend_from_slice(&(row.len() as i16).to_be_bytes());
396
397 for val in row {
398 let bytes = val.as_bytes();
399 row_buf.extend_from_slice(&(bytes.len() as i32).to_be_bytes());
400 row_buf.extend_from_slice(bytes);
401 }
402
403 let mut msg = Vec::new();
404 msg.push(b'D'); let len = 4 + row_buf.len();
406 msg.extend_from_slice(&(len as i32).to_be_bytes());
407 msg.extend_from_slice(&row_buf);
408 stream.write_all(&msg).await?;
409 }
410 }
411
412 let tag = response.command_tag.as_bytes();
414 let mut msg = Vec::new();
415 msg.push(b'C'); let len = 4 + tag.len() + 1;
417 msg.extend_from_slice(&(len as i32).to_be_bytes());
418 msg.extend_from_slice(tag);
419 msg.push(0);
420 stream.write_all(&msg).await?;
421
422 Ok(())
423}