Skip to main content

mentedb_query/
parser.rs

1//! Hand-written recursive descent parser for MQL.
2
3use mentedb_core::edge::EdgeType;
4use mentedb_core::error::{MenteError, MenteResult};
5use mentedb_core::memory::MemoryType;
6use uuid::Uuid;
7
8use crate::ast::*;
9use crate::lexer::{Token, TokenKind};
10
11pub struct Parser<'a> {
12    tokens: &'a [Token],
13    pos: usize,
14}
15
16impl<'a> Parser<'a> {
17    pub fn new(tokens: &'a [Token]) -> Self {
18        Self { tokens, pos: 0 }
19    }
20
21    pub fn parse(tokens: &[Token]) -> MenteResult<Statement> {
22        let mut parser = Parser::new(tokens);
23        parser.parse_statement()
24    }
25
26    fn peek(&self) -> &Token {
27        &self.tokens[self.pos.min(self.tokens.len() - 1)]
28    }
29
30    fn advance(&mut self) -> &Token {
31        let tok = &self.tokens[self.pos.min(self.tokens.len() - 1)];
32        if self.pos < self.tokens.len() {
33            self.pos += 1;
34        }
35        tok
36    }
37
38    fn expect(&mut self, kind: TokenKind) -> MenteResult<&Token> {
39        let tok = self.peek();
40        if tok.kind != kind {
41            return Err(MenteError::Query(format!(
42                "expected {:?}, found {:?} ('{}') at position {}",
43                kind, tok.kind, tok.lexeme, tok.position
44            )));
45        }
46        Ok(self.advance())
47    }
48
49    fn at(&self, kind: TokenKind) -> bool {
50        self.peek().kind == kind
51    }
52
53    fn parse_statement(&mut self) -> MenteResult<Statement> {
54        match self.peek().kind {
55            TokenKind::Recall => self.parse_recall(),
56            TokenKind::Relate => self.parse_relate(),
57            TokenKind::Forget => self.parse_forget(),
58            TokenKind::Consolidate => self.parse_consolidate(),
59            TokenKind::Traverse => self.parse_traverse(),
60            _ => Err(MenteError::Query(format!(
61                "expected statement keyword, found {:?} at position {}",
62                self.peek().kind,
63                self.peek().position
64            ))),
65        }
66    }
67
68    fn parse_recall(&mut self) -> MenteResult<Statement> {
69        self.advance(); // RECALL
70
71        // Optional "memories" keyword
72        if self.at(TokenKind::Memories) {
73            self.advance();
74        }
75
76        let mut near = None;
77        let mut filters = Vec::new();
78        let mut limit = None;
79        let mut order_by = None;
80
81        // NEAR [vector]
82        if self.at(TokenKind::Near) {
83            self.advance();
84            near = Some(self.parse_vector()?);
85        }
86
87        // WHERE clause
88        if self.at(TokenKind::Where) {
89            self.advance();
90            filters = self.parse_filters()?;
91        }
92
93        // ORDER BY field
94        if self.at(TokenKind::OrderBy) {
95            self.advance();
96            // consume optional "BY"
97            if self.at(TokenKind::By) {
98                self.advance();
99            }
100            let field = self.parse_field()?;
101            let descending = false; // default ascending
102            order_by = Some(OrderBy { field, descending });
103        }
104
105        // LIMIT n
106        if self.at(TokenKind::Limit) {
107            self.advance();
108            let tok = self.advance();
109            let n: usize = tok
110                .lexeme
111                .parse()
112                .map_err(|_| MenteError::Query(format!("invalid limit value: {}", tok.lexeme)))?;
113            limit = Some(n);
114        }
115
116        Ok(Statement::Recall(RecallStatement {
117            filters,
118            near,
119            limit,
120            order_by,
121        }))
122    }
123
124    fn parse_relate(&mut self) -> MenteResult<Statement> {
125        self.advance(); // RELATE
126
127        let source = self.parse_uuid()?;
128        self.expect(TokenKind::Arrow)?;
129        let target = self.parse_uuid()?;
130        self.expect(TokenKind::As)?;
131        let edge_type = self.parse_edge_type()?;
132
133        let mut weight = None;
134        if self.at(TokenKind::With) {
135            self.advance();
136            // expect "weight = <float>"
137            self.expect(TokenKind::Identifier)?; // "weight"
138            self.expect(TokenKind::Eq)?;
139            let tok = self.advance();
140            let w: f32 = tok
141                .lexeme
142                .parse()
143                .map_err(|_| MenteError::Query(format!("invalid weight value: {}", tok.lexeme)))?;
144            weight = Some(w);
145        }
146
147        Ok(Statement::Relate(RelateStatement {
148            source,
149            target,
150            edge_type,
151            weight,
152        }))
153    }
154
155    fn parse_forget(&mut self) -> MenteResult<Statement> {
156        self.advance(); // FORGET
157        let target = self.parse_uuid()?;
158        Ok(Statement::Forget(ForgetStatement { target }))
159    }
160
161    fn parse_consolidate(&mut self) -> MenteResult<Statement> {
162        self.advance(); // CONSOLIDATE
163        let mut filters = Vec::new();
164        if self.at(TokenKind::Where) {
165            self.advance();
166            filters = self.parse_filters()?;
167        }
168        Ok(Statement::Consolidate(ConsolidateStatement { filters }))
169    }
170
171    fn parse_traverse(&mut self) -> MenteResult<Statement> {
172        self.advance(); // TRAVERSE
173        let start = self.parse_uuid()?;
174
175        self.expect(TokenKind::Depth)?;
176        let tok = self.advance();
177        let depth: usize = tok
178            .lexeme
179            .parse()
180            .map_err(|_| MenteError::Query(format!("invalid depth value: {}", tok.lexeme)))?;
181
182        let mut edge_filter = None;
183        if self.at(TokenKind::Where) {
184            self.advance();
185            // edge_type = <type>
186            self.expect(TokenKind::EdgeType)?;
187            self.expect(TokenKind::Eq)?;
188            let et = self.parse_edge_type()?;
189            edge_filter = Some(vec![et]);
190        }
191
192        Ok(Statement::Traverse(TraverseStatement {
193            start,
194            depth,
195            edge_filter,
196        }))
197    }
198
199    fn parse_filters(&mut self) -> MenteResult<Vec<Filter>> {
200        let mut filters = vec![self.parse_filter()?];
201        while self.at(TokenKind::And) {
202            self.advance();
203            filters.push(self.parse_filter()?);
204        }
205        Ok(filters)
206    }
207
208    fn parse_filter(&mut self) -> MenteResult<Filter> {
209        let field = self.parse_field()?;
210        let op = self.parse_operator()?;
211        let value = self.parse_value(&field)?;
212        Ok(Filter { field, op, value })
213    }
214
215    fn parse_field(&mut self) -> MenteResult<Field> {
216        let tok = self.advance();
217        match tok.kind {
218            TokenKind::Identifier if tok.lexeme.eq_ignore_ascii_case("content") => {
219                Ok(Field::Content)
220            }
221            TokenKind::Type => Ok(Field::Type),
222            TokenKind::Tag => Ok(Field::Tag),
223            TokenKind::Agent => Ok(Field::Agent),
224            TokenKind::Space => Ok(Field::Space),
225            TokenKind::Salience => Ok(Field::Salience),
226            TokenKind::Confidence => Ok(Field::Confidence),
227            TokenKind::Created => Ok(Field::Created),
228            TokenKind::Accessed => Ok(Field::Accessed),
229            _ => Err(MenteError::Query(format!(
230                "expected field name, found '{}' at position {}",
231                tok.lexeme, tok.position
232            ))),
233        }
234    }
235
236    fn parse_operator(&mut self) -> MenteResult<Operator> {
237        let tok = self.advance();
238        match tok.kind {
239            TokenKind::Eq => Ok(Operator::Eq),
240            TokenKind::Neq => Ok(Operator::Neq),
241            TokenKind::Gt => Ok(Operator::Gt),
242            TokenKind::Lt => Ok(Operator::Lt),
243            TokenKind::Gte => Ok(Operator::Gte),
244            TokenKind::Lte => Ok(Operator::Lte),
245            TokenKind::SimilarTo => Ok(Operator::SimilarTo),
246            _ => Err(MenteError::Query(format!(
247                "expected operator, found '{}' at position {}",
248                tok.lexeme, tok.position
249            ))),
250        }
251    }
252
253    fn parse_value(&mut self, field: &Field) -> MenteResult<Value> {
254        // For Type field, parse as MemoryType
255        if *field == Field::Type {
256            return self.parse_memory_type_value();
257        }
258
259        let tok = self.advance();
260        match tok.kind {
261            TokenKind::StringLit => {
262                // Strip surrounding quotes
263                let inner = tok.lexeme[1..tok.lexeme.len() - 1].to_string();
264                // Check if this looks like a UUID inside quotes
265                if let Ok(uuid) = inner.parse::<Uuid>() {
266                    return Ok(Value::Uuid(uuid));
267                }
268                Ok(Value::Text(inner))
269            }
270            TokenKind::IntegerLit => {
271                let n: i64 = tok
272                    .lexeme
273                    .parse()
274                    .map_err(|_| MenteError::Query(format!("invalid integer: {}", tok.lexeme)))?;
275                Ok(Value::Integer(n))
276            }
277            TokenKind::FloatLit => {
278                let n: f64 = tok
279                    .lexeme
280                    .parse()
281                    .map_err(|_| MenteError::Query(format!("invalid float: {}", tok.lexeme)))?;
282                Ok(Value::Number(n))
283            }
284            TokenKind::UuidLit => {
285                let uuid: Uuid = tok
286                    .lexeme
287                    .parse()
288                    .map_err(|_| MenteError::Query(format!("invalid UUID: {}", tok.lexeme)))?;
289                Ok(Value::Uuid(uuid))
290            }
291            TokenKind::Identifier => {
292                let lower = tok.lexeme.to_lowercase();
293                match lower.as_str() {
294                    "true" => Ok(Value::Bool(true)),
295                    "false" => Ok(Value::Bool(false)),
296                    _ => Ok(Value::Text(tok.lexeme.clone())),
297                }
298            }
299            TokenKind::LBracket => {
300                // put back and parse as vector
301                self.pos -= 1;
302                let v = self.parse_vector()?;
303                Ok(Value::Vector(v))
304            }
305            _ => Err(MenteError::Query(format!(
306                "expected value, found '{}' at position {}",
307                tok.lexeme, tok.position
308            ))),
309        }
310    }
311
312    fn parse_memory_type_value(&mut self) -> MenteResult<Value> {
313        let tok = self.advance();
314        let name = match tok.kind {
315            TokenKind::Identifier | TokenKind::StringLit => {
316                if tok.kind == TokenKind::StringLit {
317                    tok.lexeme[1..tok.lexeme.len() - 1].to_string()
318                } else {
319                    tok.lexeme.clone()
320                }
321            }
322            _ => {
323                return Err(MenteError::Query(format!(
324                    "expected memory type, found '{}' at position {}",
325                    tok.lexeme, tok.position
326                )));
327            }
328        };
329
330        let mt = match name.to_lowercase().as_str() {
331            "episodic" => MemoryType::Episodic,
332            "semantic" => MemoryType::Semantic,
333            "procedural" => MemoryType::Procedural,
334            "antipattern" | "anti_pattern" => MemoryType::AntiPattern,
335            "reasoning" => MemoryType::Reasoning,
336            "correction" => MemoryType::Correction,
337            _ => {
338                return Err(MenteError::Query(format!("unknown memory type: {}", name)));
339            }
340        };
341        Ok(Value::MemoryType(mt))
342    }
343
344    fn parse_edge_type(&mut self) -> MenteResult<EdgeType> {
345        let tok = self.advance();
346        let name = match tok.kind {
347            TokenKind::Identifier | TokenKind::StringLit => {
348                if tok.kind == TokenKind::StringLit {
349                    tok.lexeme[1..tok.lexeme.len() - 1].to_string()
350                } else {
351                    tok.lexeme.clone()
352                }
353            }
354            _ => {
355                return Err(MenteError::Query(format!(
356                    "expected edge type, found '{}' at position {}",
357                    tok.lexeme, tok.position
358                )));
359            }
360        };
361
362        match name.to_lowercase().as_str() {
363            "caused" => Ok(EdgeType::Caused),
364            "before" => Ok(EdgeType::Before),
365            "related" => Ok(EdgeType::Related),
366            "contradicts" => Ok(EdgeType::Contradicts),
367            "supports" => Ok(EdgeType::Supports),
368            "supersedes" => Ok(EdgeType::Supersedes),
369            "derived" => Ok(EdgeType::Derived),
370            "partof" | "part_of" => Ok(EdgeType::PartOf),
371            _ => Err(MenteError::Query(format!("unknown edge type: {}", name))),
372        }
373    }
374
375    fn parse_uuid(&mut self) -> MenteResult<Uuid> {
376        let tok = self.advance();
377        match tok.kind {
378            TokenKind::UuidLit => tok
379                .lexeme
380                .parse()
381                .map_err(|_| MenteError::Query(format!("invalid UUID: {}", tok.lexeme))),
382            TokenKind::StringLit => {
383                let inner = &tok.lexeme[1..tok.lexeme.len() - 1];
384                inner.parse().map_err(|_| {
385                    MenteError::Query(format!("invalid UUID in string: {}", tok.lexeme))
386                })
387            }
388            _ => Err(MenteError::Query(format!(
389                "expected UUID, found '{}' at position {}",
390                tok.lexeme, tok.position
391            ))),
392        }
393    }
394
395    fn parse_vector(&mut self) -> MenteResult<Vec<f32>> {
396        self.expect(TokenKind::LBracket)?;
397        let mut values = Vec::new();
398        if !self.at(TokenKind::RBracket) {
399            let tok = self.advance();
400            let v: f32 = tok.lexeme.parse().map_err(|_| {
401                MenteError::Query(format!("invalid float in vector: {}", tok.lexeme))
402            })?;
403            values.push(v);
404            while self.at(TokenKind::Comma) {
405                self.advance();
406                let tok = self.advance();
407                let v: f32 = tok.lexeme.parse().map_err(|_| {
408                    MenteError::Query(format!("invalid float in vector: {}", tok.lexeme))
409                })?;
410                values.push(v);
411            }
412        }
413        self.expect(TokenKind::RBracket)?;
414        Ok(values)
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use crate::lexer::tokenize;
422
423    #[test]
424    fn test_parse_recall_with_type_filter() {
425        let tokens = tokenize("RECALL memories WHERE type = episodic LIMIT 5").unwrap();
426        let stmt = Parser::parse(&tokens).unwrap();
427        match stmt {
428            Statement::Recall(r) => {
429                assert_eq!(r.filters.len(), 1);
430                assert_eq!(r.filters[0].field, Field::Type);
431                assert_eq!(r.filters[0].value, Value::MemoryType(MemoryType::Episodic));
432                assert_eq!(r.limit, Some(5));
433            }
434            _ => panic!("expected Recall"),
435        }
436    }
437
438    #[test]
439    fn test_parse_recall_similar_to() {
440        let tokens =
441            tokenize(r#"RECALL memories WHERE content ~> "database migration" LIMIT 10"#).unwrap();
442        let stmt = Parser::parse(&tokens).unwrap();
443        match stmt {
444            Statement::Recall(r) => {
445                assert_eq!(r.filters.len(), 1);
446                assert_eq!(r.filters[0].op, Operator::SimilarTo);
447                assert_eq!(r.limit, Some(10));
448            }
449            _ => panic!("expected Recall"),
450        }
451    }
452
453    #[test]
454    fn test_parse_recall_near() {
455        let tokens = tokenize("RECALL memories NEAR [0.1, 0.2, 0.3] LIMIT 10").unwrap();
456        let stmt = Parser::parse(&tokens).unwrap();
457        match stmt {
458            Statement::Recall(r) => {
459                assert_eq!(r.near, Some(vec![0.1, 0.2, 0.3]));
460                assert_eq!(r.limit, Some(10));
461            }
462            _ => panic!("expected Recall"),
463        }
464    }
465
466    #[test]
467    fn test_parse_relate() {
468        let tokens = tokenize(
469            "RELATE 550e8400-e29b-41d4-a716-446655440000 -> 660e8400-e29b-41d4-a716-446655440000 AS caused WITH weight = 0.9"
470        ).unwrap();
471        let stmt = Parser::parse(&tokens).unwrap();
472        match stmt {
473            Statement::Relate(r) => {
474                assert_eq!(r.edge_type, EdgeType::Caused);
475                assert_eq!(r.weight, Some(0.9));
476            }
477            _ => panic!("expected Relate"),
478        }
479    }
480
481    #[test]
482    fn test_parse_forget() {
483        let tokens = tokenize("FORGET 550e8400-e29b-41d4-a716-446655440000").unwrap();
484        let stmt = Parser::parse(&tokens).unwrap();
485        match stmt {
486            Statement::Forget(f) => {
487                assert_eq!(
488                    f.target,
489                    "550e8400-e29b-41d4-a716-446655440000"
490                        .parse::<Uuid>()
491                        .unwrap()
492                );
493            }
494            _ => panic!("expected Forget"),
495        }
496    }
497
498    #[test]
499    fn test_parse_consolidate() {
500        let tokens =
501            tokenize(r#"CONSOLIDATE WHERE type = episodic AND accessed < "2024-01-01""#).unwrap();
502        let stmt = Parser::parse(&tokens).unwrap();
503        match stmt {
504            Statement::Consolidate(c) => {
505                assert_eq!(c.filters.len(), 2);
506            }
507            _ => panic!("expected Consolidate"),
508        }
509    }
510
511    #[test]
512    fn test_parse_traverse() {
513        let tokens = tokenize(
514            "TRAVERSE 550e8400-e29b-41d4-a716-446655440000 DEPTH 3 WHERE edge_type = caused",
515        )
516        .unwrap();
517        let stmt = Parser::parse(&tokens).unwrap();
518        match stmt {
519            Statement::Traverse(t) => {
520                assert_eq!(t.depth, 3);
521                assert_eq!(t.edge_filter, Some(vec![EdgeType::Caused]));
522            }
523            _ => panic!("expected Traverse"),
524        }
525    }
526}