1#[derive(Debug, Clone, PartialEq)]
9pub enum ParsedStatement {
10 Select(SelectQuery),
12 Insert(InsertQuery),
14 Delete(DeleteQuery),
16 Unsupported(String),
18}
19
20#[derive(Debug, Clone, PartialEq)]
22pub struct SelectQuery {
23 pub agent_id: Option<String>,
25 pub query_text: Option<String>,
27 pub limit: usize,
29 pub offset: usize,
31}
32
33#[derive(Debug, Clone, PartialEq)]
35pub struct InsertQuery {
36 pub content: String,
37 pub agent_id: Option<String>,
38 pub importance: Option<f32>,
39 pub memory_type: Option<String>,
40 pub tags: Vec<String>,
41}
42
43#[derive(Debug, Clone, PartialEq)]
45pub struct DeleteQuery {
46 pub memory_id: Option<String>,
48 pub agent_id: Option<String>,
50}
51
52pub fn parse_sql(sql: &str) -> ParsedStatement {
59 let trimmed = sql.trim().trim_end_matches(';');
60 let upper = trimmed.to_uppercase();
61
62 if upper.starts_with("SELECT") {
63 parse_select(trimmed)
64 } else if upper.starts_with("INSERT") {
65 parse_insert(trimmed)
66 } else if upper.starts_with("DELETE") {
67 parse_delete(trimmed)
68 } else {
69 ParsedStatement::Unsupported(trimmed.to_string())
70 }
71}
72
73fn parse_select(sql: &str) -> ParsedStatement {
74 let upper = sql.to_uppercase();
75 let mut query = SelectQuery {
76 agent_id: None,
77 query_text: None,
78 limit: 50,
79 offset: 0,
80 };
81
82 if let Some(pos) = upper.find("LIMIT") {
84 let after = &sql[pos + 5..].trim();
85 if let Some(num_str) = after.split_whitespace().next()
86 && let Ok(n) = num_str.parse::<usize>()
87 {
88 query.limit = n;
89 }
90 }
91
92 if let Some(pos) = upper.find("OFFSET") {
94 let after = &sql[pos + 6..].trim();
95 if let Some(num_str) = after.split_whitespace().next()
96 && let Ok(n) = num_str.parse::<usize>()
97 {
98 query.offset = n;
99 }
100 }
101
102 if let Some(agent_id) = extract_string_condition(&upper, sql, "AGENT_ID") {
104 query.agent_id = Some(agent_id);
105 }
106
107 if let Some(pos) = upper.find("CONTENT LIKE") {
109 let after = &sql[pos + 12..].trim();
110 if let Some(value) = extract_quoted_value(after) {
111 let clean = value.trim_matches('%').to_string();
113 if !clean.is_empty() {
114 query.query_text = Some(clean);
115 }
116 }
117 }
118
119 ParsedStatement::Select(query)
120}
121
122fn parse_insert(sql: &str) -> ParsedStatement {
123 let upper = sql.to_uppercase();
125
126 let cols_start = match upper.find('(') {
127 Some(p) => p,
128 None => return ParsedStatement::Unsupported(sql.to_string()),
129 };
130 let cols_end = match upper[cols_start..].find(')') {
131 Some(p) => cols_start + p,
132 None => return ParsedStatement::Unsupported(sql.to_string()),
133 };
134
135 let values_marker = match upper[cols_end..].find("VALUES") {
136 Some(p) => cols_end + p,
137 None => return ParsedStatement::Unsupported(sql.to_string()),
138 };
139
140 let vals_start = match upper[values_marker..].find('(') {
141 Some(p) => values_marker + p,
142 None => return ParsedStatement::Unsupported(sql.to_string()),
143 };
144 let vals_end = match sql[vals_start..].rfind(')') {
145 Some(p) => vals_start + p,
146 None => return ParsedStatement::Unsupported(sql.to_string()),
147 };
148
149 let columns: Vec<String> = sql[cols_start + 1..cols_end]
150 .split(',')
151 .map(|c| c.trim().to_uppercase())
152 .collect();
153
154 let values: Vec<String> = split_sql_values(&sql[vals_start + 1..vals_end]);
155
156 let mut insert = InsertQuery {
157 content: String::new(),
158 agent_id: None,
159 importance: None,
160 memory_type: None,
161 tags: vec![],
162 };
163
164 for (i, col) in columns.iter().enumerate() {
165 if i >= values.len() {
166 break;
167 }
168 let val = unquote(&values[i]);
169 match col.as_str() {
170 "CONTENT" => insert.content = val,
171 "AGENT_ID" => insert.agent_id = Some(val),
172 "IMPORTANCE" => insert.importance = val.parse().ok(),
173 "MEMORY_TYPE" => insert.memory_type = Some(val),
174 _ => {}
175 }
176 }
177
178 if insert.content.is_empty() {
179 return ParsedStatement::Unsupported(sql.to_string());
180 }
181
182 ParsedStatement::Insert(insert)
183}
184
185fn parse_delete(sql: &str) -> ParsedStatement {
186 let upper = sql.to_uppercase();
187 let mut delete = DeleteQuery {
188 memory_id: None,
189 agent_id: None,
190 };
191
192 if let Some(id) = extract_string_condition(&upper, sql, "ID") {
193 delete.memory_id = Some(id);
194 }
195 if let Some(agent_id) = extract_string_condition(&upper, sql, "AGENT_ID") {
196 delete.agent_id = Some(agent_id);
197 }
198
199 ParsedStatement::Delete(delete)
200}
201
202fn extract_string_condition(upper: &str, original: &str, column: &str) -> Option<String> {
204 let pattern = format!("{column} =");
205 if let Some(pos) = upper.find(&pattern) {
206 let after = &original[pos + pattern.len()..].trim_start();
207 return extract_quoted_value(after);
208 }
209 None
210}
211
212fn extract_quoted_value(s: &str) -> Option<String> {
214 let s = s.trim();
215 if let Some(stripped) = s.strip_prefix('\'')
216 && let Some(end) = stripped.find('\'')
217 {
218 return Some(stripped[..end].to_string());
219 }
220 None
221}
222
223fn split_sql_values(s: &str) -> Vec<String> {
225 let mut values = vec![];
226 let mut current = String::new();
227 let mut in_quote = false;
228
229 for ch in s.chars() {
230 match ch {
231 '\'' if !in_quote => {
232 in_quote = true;
233 current.push(ch);
234 }
235 '\'' if in_quote => {
236 in_quote = false;
237 current.push(ch);
238 }
239 ',' if !in_quote => {
240 values.push(current.trim().to_string());
241 current.clear();
242 }
243 _ => current.push(ch),
244 }
245 }
246
247 let trimmed = current.trim().to_string();
248 if !trimmed.is_empty() {
249 values.push(trimmed);
250 }
251 values
252}
253
254fn unquote(s: &str) -> String {
256 let trimmed = s.trim();
257 if (trimmed.starts_with('\'') && trimmed.ends_with('\''))
258 || (trimmed.starts_with('"') && trimmed.ends_with('"'))
259 {
260 trimmed[1..trimmed.len() - 1].to_string()
261 } else {
262 trimmed.to_string()
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_parse_select_basic() {
272 let stmt = parse_sql("SELECT * FROM memories LIMIT 10");
273 match stmt {
274 ParsedStatement::Select(q) => {
275 assert_eq!(q.limit, 10);
276 assert_eq!(q.offset, 0);
277 assert!(q.agent_id.is_none());
278 }
279 other => panic!("Expected Select, got {:?}", other),
280 }
281 }
282
283 #[test]
284 fn test_parse_select_with_where() {
285 let stmt = parse_sql("SELECT * FROM memories WHERE agent_id = 'bot-1' LIMIT 5");
286 match stmt {
287 ParsedStatement::Select(q) => {
288 assert_eq!(q.agent_id.as_deref(), Some("bot-1"));
289 assert_eq!(q.limit, 5);
290 }
291 other => panic!("Expected Select, got {:?}", other),
292 }
293 }
294
295 #[test]
296 fn test_parse_select_with_like() {
297 let stmt = parse_sql("SELECT * FROM memories WHERE content LIKE '%hello%' LIMIT 20");
298 match stmt {
299 ParsedStatement::Select(q) => {
300 assert_eq!(q.query_text.as_deref(), Some("hello"));
301 assert_eq!(q.limit, 20);
302 }
303 other => panic!("Expected Select, got {:?}", other),
304 }
305 }
306
307 #[test]
308 fn test_parse_insert() {
309 let stmt =
310 parse_sql("INSERT INTO memories (content, importance) VALUES ('test memory', 0.8)");
311 match stmt {
312 ParsedStatement::Insert(q) => {
313 assert_eq!(q.content, "test memory");
314 assert_eq!(q.importance, Some(0.8));
315 }
316 other => panic!("Expected Insert, got {:?}", other),
317 }
318 }
319
320 #[test]
321 fn test_parse_insert_with_agent() {
322 let stmt = parse_sql(
323 "INSERT INTO memories (content, agent_id, memory_type) VALUES ('data', 'agent-1', 'episodic')",
324 );
325 match stmt {
326 ParsedStatement::Insert(q) => {
327 assert_eq!(q.content, "data");
328 assert_eq!(q.agent_id.as_deref(), Some("agent-1"));
329 assert_eq!(q.memory_type.as_deref(), Some("episodic"));
330 }
331 other => panic!("Expected Insert, got {:?}", other),
332 }
333 }
334
335 #[test]
336 fn test_parse_delete() {
337 let stmt =
338 parse_sql("DELETE FROM memories WHERE id = '550e8400-e29b-41d4-a716-446655440000'");
339 match stmt {
340 ParsedStatement::Delete(q) => {
341 assert_eq!(
342 q.memory_id.as_deref(),
343 Some("550e8400-e29b-41d4-a716-446655440000")
344 );
345 }
346 other => panic!("Expected Delete, got {:?}", other),
347 }
348 }
349
350 #[test]
351 fn test_parse_unsupported() {
352 let stmt = parse_sql("DROP TABLE memories");
353 assert!(matches!(stmt, ParsedStatement::Unsupported(_)));
354 }
355
356 #[test]
357 fn test_parse_select_with_offset() {
358 let stmt = parse_sql("SELECT * FROM memories LIMIT 10 OFFSET 20");
359 match stmt {
360 ParsedStatement::Select(q) => {
361 assert_eq!(q.limit, 10);
362 assert_eq!(q.offset, 20);
363 }
364 other => panic!("Expected Select, got {:?}", other),
365 }
366 }
367}