Skip to main content

mnemo_pgwire/
parser.rs

1//! Minimal SQL parser for pgwire queries.
2//!
3//! Parses a limited SQL subset and maps to Mnemo operations.
4//! This is not a full SQL parser — it handles the common patterns
5//! that clients will use to interact with the memories table.
6
7/// Parsed SQL statement mapped to a Mnemo operation.
8#[derive(Debug, Clone, PartialEq)]
9pub enum ParsedStatement {
10    /// SELECT query on the memories table.
11    Select(SelectQuery),
12    /// INSERT into the memories table.
13    Insert(InsertQuery),
14    /// DELETE from the memories table.
15    Delete(DeleteQuery),
16    /// Unrecognized or unsupported statement.
17    Unsupported(String),
18}
19
20/// A parsed SELECT statement.
21#[derive(Debug, Clone, PartialEq)]
22pub struct SelectQuery {
23    /// WHERE agent_id = '...'
24    pub agent_id: Option<String>,
25    /// WHERE content LIKE '%...%' or free-text query
26    pub query_text: Option<String>,
27    /// LIMIT clause
28    pub limit: usize,
29    /// OFFSET clause
30    pub offset: usize,
31}
32
33/// A parsed INSERT statement.
34#[derive(Debug, Clone, PartialEq)]
35pub struct InsertQuery {
36    pub content: String,
37    pub agent_id: Option<String>,
38    pub importance: Option<f32>,
39    pub memory_type: Option<String>,
40    pub tags: Vec<String>,
41}
42
43/// A parsed DELETE statement.
44#[derive(Debug, Clone, PartialEq)]
45pub struct DeleteQuery {
46    /// WHERE id = '...'
47    pub memory_id: Option<String>,
48    /// WHERE agent_id = '...'
49    pub agent_id: Option<String>,
50}
51
52/// Parse a SQL string into a `ParsedStatement`.
53///
54/// Supports:
55/// - `SELECT * FROM memories [WHERE ...] [LIMIT n] [OFFSET n]`
56/// - `INSERT INTO memories (col, ...) VALUES (val, ...)`
57/// - `DELETE FROM memories WHERE id = '...'`
58pub fn parse_sql(sql: &str) -> ParsedStatement {
59    let trimmed = sql.trim().trim_end_matches(';');
60    let upper = trimmed.to_uppercase();
61
62    if upper.starts_with("SELECT") {
63        parse_select(trimmed)
64    } else if upper.starts_with("INSERT") {
65        parse_insert(trimmed)
66    } else if upper.starts_with("DELETE") {
67        parse_delete(trimmed)
68    } else {
69        ParsedStatement::Unsupported(trimmed.to_string())
70    }
71}
72
73fn parse_select(sql: &str) -> ParsedStatement {
74    let upper = sql.to_uppercase();
75    let mut query = SelectQuery {
76        agent_id: None,
77        query_text: None,
78        limit: 50,
79        offset: 0,
80    };
81
82    // Extract LIMIT
83    if let Some(pos) = upper.find("LIMIT") {
84        let after = &sql[pos + 5..].trim();
85        if let Some(num_str) = after.split_whitespace().next()
86            && let Ok(n) = num_str.parse::<usize>()
87        {
88            query.limit = n;
89        }
90    }
91
92    // Extract OFFSET
93    if let Some(pos) = upper.find("OFFSET") {
94        let after = &sql[pos + 6..].trim();
95        if let Some(num_str) = after.split_whitespace().next()
96            && let Ok(n) = num_str.parse::<usize>()
97        {
98            query.offset = n;
99        }
100    }
101
102    // Extract WHERE agent_id = '...'
103    if let Some(agent_id) = extract_string_condition(&upper, sql, "AGENT_ID") {
104        query.agent_id = Some(agent_id);
105    }
106
107    // Extract WHERE content LIKE '%...%'
108    if let Some(pos) = upper.find("CONTENT LIKE") {
109        let after = &sql[pos + 12..].trim();
110        if let Some(value) = extract_quoted_value(after) {
111            // Strip % wildcards
112            let clean = value.trim_matches('%').to_string();
113            if !clean.is_empty() {
114                query.query_text = Some(clean);
115            }
116        }
117    }
118
119    ParsedStatement::Select(query)
120}
121
122fn parse_insert(sql: &str) -> ParsedStatement {
123    // Extract column names and values from INSERT INTO memories (cols) VALUES (vals)
124    let upper = sql.to_uppercase();
125
126    let cols_start = match upper.find('(') {
127        Some(p) => p,
128        None => return ParsedStatement::Unsupported(sql.to_string()),
129    };
130    let cols_end = match upper[cols_start..].find(')') {
131        Some(p) => cols_start + p,
132        None => return ParsedStatement::Unsupported(sql.to_string()),
133    };
134
135    let values_marker = match upper[cols_end..].find("VALUES") {
136        Some(p) => cols_end + p,
137        None => return ParsedStatement::Unsupported(sql.to_string()),
138    };
139
140    let vals_start = match upper[values_marker..].find('(') {
141        Some(p) => values_marker + p,
142        None => return ParsedStatement::Unsupported(sql.to_string()),
143    };
144    let vals_end = match sql[vals_start..].rfind(')') {
145        Some(p) => vals_start + p,
146        None => return ParsedStatement::Unsupported(sql.to_string()),
147    };
148
149    let columns: Vec<String> = sql[cols_start + 1..cols_end]
150        .split(',')
151        .map(|c| c.trim().to_uppercase())
152        .collect();
153
154    let values: Vec<String> = split_sql_values(&sql[vals_start + 1..vals_end]);
155
156    let mut insert = InsertQuery {
157        content: String::new(),
158        agent_id: None,
159        importance: None,
160        memory_type: None,
161        tags: vec![],
162    };
163
164    for (i, col) in columns.iter().enumerate() {
165        if i >= values.len() {
166            break;
167        }
168        let val = unquote(&values[i]);
169        match col.as_str() {
170            "CONTENT" => insert.content = val,
171            "AGENT_ID" => insert.agent_id = Some(val),
172            "IMPORTANCE" => insert.importance = val.parse().ok(),
173            "MEMORY_TYPE" => insert.memory_type = Some(val),
174            _ => {}
175        }
176    }
177
178    if insert.content.is_empty() {
179        return ParsedStatement::Unsupported(sql.to_string());
180    }
181
182    ParsedStatement::Insert(insert)
183}
184
185fn parse_delete(sql: &str) -> ParsedStatement {
186    let upper = sql.to_uppercase();
187    let mut delete = DeleteQuery {
188        memory_id: None,
189        agent_id: None,
190    };
191
192    if let Some(id) = extract_string_condition(&upper, sql, "ID") {
193        delete.memory_id = Some(id);
194    }
195    if let Some(agent_id) = extract_string_condition(&upper, sql, "AGENT_ID") {
196        delete.agent_id = Some(agent_id);
197    }
198
199    ParsedStatement::Delete(delete)
200}
201
202/// Extract a string value from `WHERE column = 'value'` pattern.
203fn extract_string_condition(upper: &str, original: &str, column: &str) -> Option<String> {
204    let pattern = format!("{column} =");
205    if let Some(pos) = upper.find(&pattern) {
206        let after = &original[pos + pattern.len()..].trim_start();
207        return extract_quoted_value(after);
208    }
209    None
210}
211
212/// Extract a single-quoted string value.
213fn extract_quoted_value(s: &str) -> Option<String> {
214    let s = s.trim();
215    if let Some(stripped) = s.strip_prefix('\'')
216        && let Some(end) = stripped.find('\'')
217    {
218        return Some(stripped[..end].to_string());
219    }
220    None
221}
222
223/// Split SQL values, respecting quoted strings.
224fn split_sql_values(s: &str) -> Vec<String> {
225    let mut values = vec![];
226    let mut current = String::new();
227    let mut in_quote = false;
228
229    for ch in s.chars() {
230        match ch {
231            '\'' if !in_quote => {
232                in_quote = true;
233                current.push(ch);
234            }
235            '\'' if in_quote => {
236                in_quote = false;
237                current.push(ch);
238            }
239            ',' if !in_quote => {
240                values.push(current.trim().to_string());
241                current.clear();
242            }
243            _ => current.push(ch),
244        }
245    }
246
247    let trimmed = current.trim().to_string();
248    if !trimmed.is_empty() {
249        values.push(trimmed);
250    }
251    values
252}
253
254/// Remove surrounding quotes from a value string.
255fn unquote(s: &str) -> String {
256    let trimmed = s.trim();
257    if (trimmed.starts_with('\'') && trimmed.ends_with('\''))
258        || (trimmed.starts_with('"') && trimmed.ends_with('"'))
259    {
260        trimmed[1..trimmed.len() - 1].to_string()
261    } else {
262        trimmed.to_string()
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_parse_select_basic() {
272        let stmt = parse_sql("SELECT * FROM memories LIMIT 10");
273        match stmt {
274            ParsedStatement::Select(q) => {
275                assert_eq!(q.limit, 10);
276                assert_eq!(q.offset, 0);
277                assert!(q.agent_id.is_none());
278            }
279            other => panic!("Expected Select, got {:?}", other),
280        }
281    }
282
283    #[test]
284    fn test_parse_select_with_where() {
285        let stmt = parse_sql("SELECT * FROM memories WHERE agent_id = 'bot-1' LIMIT 5");
286        match stmt {
287            ParsedStatement::Select(q) => {
288                assert_eq!(q.agent_id.as_deref(), Some("bot-1"));
289                assert_eq!(q.limit, 5);
290            }
291            other => panic!("Expected Select, got {:?}", other),
292        }
293    }
294
295    #[test]
296    fn test_parse_select_with_like() {
297        let stmt = parse_sql("SELECT * FROM memories WHERE content LIKE '%hello%' LIMIT 20");
298        match stmt {
299            ParsedStatement::Select(q) => {
300                assert_eq!(q.query_text.as_deref(), Some("hello"));
301                assert_eq!(q.limit, 20);
302            }
303            other => panic!("Expected Select, got {:?}", other),
304        }
305    }
306
307    #[test]
308    fn test_parse_insert() {
309        let stmt =
310            parse_sql("INSERT INTO memories (content, importance) VALUES ('test memory', 0.8)");
311        match stmt {
312            ParsedStatement::Insert(q) => {
313                assert_eq!(q.content, "test memory");
314                assert_eq!(q.importance, Some(0.8));
315            }
316            other => panic!("Expected Insert, got {:?}", other),
317        }
318    }
319
320    #[test]
321    fn test_parse_insert_with_agent() {
322        let stmt = parse_sql(
323            "INSERT INTO memories (content, agent_id, memory_type) VALUES ('data', 'agent-1', 'episodic')",
324        );
325        match stmt {
326            ParsedStatement::Insert(q) => {
327                assert_eq!(q.content, "data");
328                assert_eq!(q.agent_id.as_deref(), Some("agent-1"));
329                assert_eq!(q.memory_type.as_deref(), Some("episodic"));
330            }
331            other => panic!("Expected Insert, got {:?}", other),
332        }
333    }
334
335    #[test]
336    fn test_parse_delete() {
337        let stmt =
338            parse_sql("DELETE FROM memories WHERE id = '550e8400-e29b-41d4-a716-446655440000'");
339        match stmt {
340            ParsedStatement::Delete(q) => {
341                assert_eq!(
342                    q.memory_id.as_deref(),
343                    Some("550e8400-e29b-41d4-a716-446655440000")
344                );
345            }
346            other => panic!("Expected Delete, got {:?}", other),
347        }
348    }
349
350    #[test]
351    fn test_parse_unsupported() {
352        let stmt = parse_sql("DROP TABLE memories");
353        assert!(matches!(stmt, ParsedStatement::Unsupported(_)));
354    }
355
356    #[test]
357    fn test_parse_select_with_offset() {
358        let stmt = parse_sql("SELECT * FROM memories LIMIT 10 OFFSET 20");
359        match stmt {
360            ParsedStatement::Select(q) => {
361                assert_eq!(q.limit, 10);
362                assert_eq!(q.offset, 20);
363            }
364            other => panic!("Expected Select, got {:?}", other),
365        }
366    }
367}