Skip to main content

mentedb_query/
parser.rs

1//! Hand-written recursive descent parser for MQL.
2
3use mentedb_core::edge::EdgeType;
4use mentedb_core::error::{MenteError, MenteResult};
5use mentedb_core::memory::MemoryType;
6use uuid::Uuid;
7
8use crate::ast::*;
9use crate::lexer::{Token, TokenKind};
10use mentedb_core::types::MemoryId;
11
12pub struct Parser<'a> {
13    tokens: &'a [Token],
14    pos: usize,
15}
16
17impl<'a> Parser<'a> {
18    pub fn new(tokens: &'a [Token]) -> Self {
19        Self { tokens, pos: 0 }
20    }
21
22    pub fn parse(tokens: &[Token]) -> MenteResult<Statement> {
23        let mut parser = Parser::new(tokens);
24        parser.parse_statement()
25    }
26
27    fn peek(&self) -> &Token {
28        &self.tokens[self.pos.min(self.tokens.len() - 1)]
29    }
30
31    fn advance(&mut self) -> &Token {
32        let tok = &self.tokens[self.pos.min(self.tokens.len() - 1)];
33        if self.pos < self.tokens.len() {
34            self.pos += 1;
35        }
36        tok
37    }
38
39    fn expect(&mut self, kind: TokenKind) -> MenteResult<&Token> {
40        let tok = self.peek();
41        if tok.kind != kind {
42            return Err(MenteError::Query(format!(
43                "expected {:?}, found {:?} ('{}') at position {}",
44                kind, tok.kind, tok.lexeme, tok.position
45            )));
46        }
47        Ok(self.advance())
48    }
49
50    fn at(&self, kind: TokenKind) -> bool {
51        self.peek().kind == kind
52    }
53
54    fn parse_statement(&mut self) -> MenteResult<Statement> {
55        match self.peek().kind {
56            TokenKind::Recall => self.parse_recall(),
57            TokenKind::Relate => self.parse_relate(),
58            TokenKind::Forget => self.parse_forget(),
59            TokenKind::Consolidate => self.parse_consolidate(),
60            TokenKind::Traverse => self.parse_traverse(),
61            _ => Err(MenteError::Query(format!(
62                "expected statement keyword, found {:?} at position {}",
63                self.peek().kind,
64                self.peek().position
65            ))),
66        }
67    }
68
69    fn parse_recall(&mut self) -> MenteResult<Statement> {
70        self.advance(); // RECALL
71
72        // Optional "memories" keyword
73        if self.at(TokenKind::Memories) {
74            self.advance();
75        }
76
77        let mut near = None;
78        let mut filters = Vec::new();
79        let mut limit = None;
80        let mut order_by = None;
81
82        // NEAR [vector]
83        if self.at(TokenKind::Near) {
84            self.advance();
85            near = Some(self.parse_vector()?);
86        }
87
88        // WHERE clause
89        if self.at(TokenKind::Where) {
90            self.advance();
91            filters = self.parse_filters()?;
92        }
93
94        // ORDER BY field
95        if self.at(TokenKind::OrderBy) {
96            self.advance();
97            // consume optional "BY"
98            if self.at(TokenKind::By) {
99                self.advance();
100            }
101            let field = self.parse_field()?;
102            let descending = false; // default ascending
103            order_by = Some(OrderBy { field, descending });
104        }
105
106        // LIMIT n
107        if self.at(TokenKind::Limit) {
108            self.advance();
109            let tok = self.advance();
110            let n: usize = tok
111                .lexeme
112                .parse()
113                .map_err(|_| MenteError::Query(format!("invalid limit value: {}", tok.lexeme)))?;
114            limit = Some(n);
115        }
116
117        Ok(Statement::Recall(RecallStatement {
118            filters,
119            near,
120            limit,
121            order_by,
122        }))
123    }
124
125    fn parse_relate(&mut self) -> MenteResult<Statement> {
126        self.advance(); // RELATE
127
128        let source = self.parse_uuid()?;
129        self.expect(TokenKind::Arrow)?;
130        let target = self.parse_uuid()?;
131        self.expect(TokenKind::As)?;
132        let edge_type = self.parse_edge_type()?;
133
134        let mut weight = None;
135        if self.at(TokenKind::With) {
136            self.advance();
137            // expect "weight = <float>"
138            self.expect(TokenKind::Identifier)?; // "weight"
139            self.expect(TokenKind::Eq)?;
140            let tok = self.advance();
141            let w: f32 = tok
142                .lexeme
143                .parse()
144                .map_err(|_| MenteError::Query(format!("invalid weight value: {}", tok.lexeme)))?;
145            weight = Some(w);
146        }
147
148        Ok(Statement::Relate(RelateStatement {
149            source,
150            target,
151            edge_type,
152            weight,
153        }))
154    }
155
156    fn parse_forget(&mut self) -> MenteResult<Statement> {
157        self.advance(); // FORGET
158        let target = self.parse_uuid()?;
159        Ok(Statement::Forget(ForgetStatement { target }))
160    }
161
162    fn parse_consolidate(&mut self) -> MenteResult<Statement> {
163        self.advance(); // CONSOLIDATE
164        let mut filters = Vec::new();
165        if self.at(TokenKind::Where) {
166            self.advance();
167            filters = self.parse_filters()?;
168        }
169        Ok(Statement::Consolidate(ConsolidateStatement { filters }))
170    }
171
172    fn parse_traverse(&mut self) -> MenteResult<Statement> {
173        self.advance(); // TRAVERSE
174        let start = self.parse_uuid()?;
175
176        self.expect(TokenKind::Depth)?;
177        let tok = self.advance();
178        let depth: usize = tok
179            .lexeme
180            .parse()
181            .map_err(|_| MenteError::Query(format!("invalid depth value: {}", tok.lexeme)))?;
182
183        let mut edge_filter = None;
184        if self.at(TokenKind::Where) {
185            self.advance();
186            // edge_type = <type>
187            self.expect(TokenKind::EdgeType)?;
188            self.expect(TokenKind::Eq)?;
189            let et = self.parse_edge_type()?;
190            edge_filter = Some(vec![et]);
191        }
192
193        Ok(Statement::Traverse(TraverseStatement {
194            start,
195            depth,
196            edge_filter,
197        }))
198    }
199
200    fn parse_filters(&mut self) -> MenteResult<Vec<Filter>> {
201        let mut filters = vec![self.parse_filter()?];
202        while self.at(TokenKind::And) {
203            self.advance();
204            filters.push(self.parse_filter()?);
205        }
206        Ok(filters)
207    }
208
209    fn parse_filter(&mut self) -> MenteResult<Filter> {
210        let field = self.parse_field()?;
211        let op = self.parse_operator()?;
212        let value = self.parse_value(&field)?;
213        Ok(Filter { field, op, value })
214    }
215
216    fn parse_field(&mut self) -> MenteResult<Field> {
217        let tok = self.advance();
218        match tok.kind {
219            TokenKind::Identifier if tok.lexeme.eq_ignore_ascii_case("content") => {
220                Ok(Field::Content)
221            }
222            TokenKind::Type => Ok(Field::Type),
223            TokenKind::Tag => Ok(Field::Tag),
224            TokenKind::Agent => Ok(Field::Agent),
225            TokenKind::Space => Ok(Field::Space),
226            TokenKind::Salience => Ok(Field::Salience),
227            TokenKind::Confidence => Ok(Field::Confidence),
228            TokenKind::Created => Ok(Field::Created),
229            TokenKind::Accessed => Ok(Field::Accessed),
230            _ => Err(MenteError::Query(format!(
231                "expected field name, found '{}' at position {}",
232                tok.lexeme, tok.position
233            ))),
234        }
235    }
236
237    fn parse_operator(&mut self) -> MenteResult<Operator> {
238        let tok = self.advance();
239        match tok.kind {
240            TokenKind::Eq => Ok(Operator::Eq),
241            TokenKind::Neq => Ok(Operator::Neq),
242            TokenKind::Gt => Ok(Operator::Gt),
243            TokenKind::Lt => Ok(Operator::Lt),
244            TokenKind::Gte => Ok(Operator::Gte),
245            TokenKind::Lte => Ok(Operator::Lte),
246            TokenKind::SimilarTo => Ok(Operator::SimilarTo),
247            _ => Err(MenteError::Query(format!(
248                "expected operator, found '{}' at position {}",
249                tok.lexeme, tok.position
250            ))),
251        }
252    }
253
254    fn parse_value(&mut self, field: &Field) -> MenteResult<Value> {
255        // For Type field, parse as MemoryType
256        if *field == Field::Type {
257            return self.parse_memory_type_value();
258        }
259
260        let tok = self.advance();
261        match tok.kind {
262            TokenKind::StringLit => {
263                // Strip surrounding quotes
264                let inner = tok.lexeme[1..tok.lexeme.len() - 1].to_string();
265                // Check if this looks like a UUID inside quotes
266                if let Ok(uuid) = inner.parse::<MemoryId>() {
267                    return Ok(Value::Uuid(uuid.into()));
268                }
269                Ok(Value::Text(inner))
270            }
271            TokenKind::IntegerLit => {
272                let n: i64 = tok
273                    .lexeme
274                    .parse()
275                    .map_err(|_| MenteError::Query(format!("invalid integer: {}", tok.lexeme)))?;
276                Ok(Value::Integer(n))
277            }
278            TokenKind::FloatLit => {
279                let n: f64 = tok
280                    .lexeme
281                    .parse()
282                    .map_err(|_| MenteError::Query(format!("invalid float: {}", tok.lexeme)))?;
283                Ok(Value::Number(n))
284            }
285            TokenKind::UuidLit => {
286                let uuid: Uuid = tok
287                    .lexeme
288                    .parse()
289                    .map_err(|_| MenteError::Query(format!("invalid UUID: {}", tok.lexeme)))?;
290                Ok(Value::Uuid(uuid))
291            }
292            TokenKind::Identifier => {
293                let lower = tok.lexeme.to_lowercase();
294                match lower.as_str() {
295                    "true" => Ok(Value::Bool(true)),
296                    "false" => Ok(Value::Bool(false)),
297                    _ => Ok(Value::Text(tok.lexeme.clone())),
298                }
299            }
300            TokenKind::LBracket => {
301                // put back and parse as vector
302                self.pos -= 1;
303                let v = self.parse_vector()?;
304                Ok(Value::Vector(v))
305            }
306            _ => Err(MenteError::Query(format!(
307                "expected value, found '{}' at position {}",
308                tok.lexeme, tok.position
309            ))),
310        }
311    }
312
313    fn parse_memory_type_value(&mut self) -> MenteResult<Value> {
314        let tok = self.advance();
315        let name = match tok.kind {
316            TokenKind::Identifier | TokenKind::StringLit => {
317                if tok.kind == TokenKind::StringLit {
318                    tok.lexeme[1..tok.lexeme.len() - 1].to_string()
319                } else {
320                    tok.lexeme.clone()
321                }
322            }
323            _ => {
324                return Err(MenteError::Query(format!(
325                    "expected memory type, found '{}' at position {}",
326                    tok.lexeme, tok.position
327                )));
328            }
329        };
330
331        let mt = match name.to_lowercase().as_str() {
332            "episodic" => MemoryType::Episodic,
333            "semantic" => MemoryType::Semantic,
334            "procedural" => MemoryType::Procedural,
335            "antipattern" | "anti_pattern" => MemoryType::AntiPattern,
336            "reasoning" => MemoryType::Reasoning,
337            "correction" => MemoryType::Correction,
338            _ => {
339                return Err(MenteError::Query(format!("unknown memory type: {}", name)));
340            }
341        };
342        Ok(Value::MemoryType(mt))
343    }
344
345    fn parse_edge_type(&mut self) -> MenteResult<EdgeType> {
346        let tok = self.advance();
347        let name = match tok.kind {
348            TokenKind::Identifier | TokenKind::StringLit => {
349                if tok.kind == TokenKind::StringLit {
350                    tok.lexeme[1..tok.lexeme.len() - 1].to_string()
351                } else {
352                    tok.lexeme.clone()
353                }
354            }
355            _ => {
356                return Err(MenteError::Query(format!(
357                    "expected edge type, found '{}' at position {}",
358                    tok.lexeme, tok.position
359                )));
360            }
361        };
362
363        match name.to_lowercase().as_str() {
364            "caused" => Ok(EdgeType::Caused),
365            "before" => Ok(EdgeType::Before),
366            "related" => Ok(EdgeType::Related),
367            "contradicts" => Ok(EdgeType::Contradicts),
368            "supports" => Ok(EdgeType::Supports),
369            "supersedes" => Ok(EdgeType::Supersedes),
370            "derived" => Ok(EdgeType::Derived),
371            "partof" | "part_of" => Ok(EdgeType::PartOf),
372            _ => Err(MenteError::Query(format!("unknown edge type: {}", name))),
373        }
374    }
375
376    fn parse_uuid(&mut self) -> MenteResult<MemoryId> {
377        let tok = self.advance();
378        match tok.kind {
379            TokenKind::UuidLit => tok
380                .lexeme
381                .parse()
382                .map_err(|_| MenteError::Query(format!("invalid UUID: {}", tok.lexeme))),
383            TokenKind::StringLit => {
384                let inner = &tok.lexeme[1..tok.lexeme.len() - 1];
385                inner.parse().map_err(|_| {
386                    MenteError::Query(format!("invalid UUID in string: {}", tok.lexeme))
387                })
388            }
389            _ => Err(MenteError::Query(format!(
390                "expected UUID, found '{}' at position {}",
391                tok.lexeme, tok.position
392            ))),
393        }
394    }
395
396    fn parse_vector(&mut self) -> MenteResult<Vec<f32>> {
397        self.expect(TokenKind::LBracket)?;
398        let mut values = Vec::new();
399        if !self.at(TokenKind::RBracket) {
400            let tok = self.advance();
401            let v: f32 = tok.lexeme.parse().map_err(|_| {
402                MenteError::Query(format!("invalid float in vector: {}", tok.lexeme))
403            })?;
404            values.push(v);
405            while self.at(TokenKind::Comma) {
406                self.advance();
407                let tok = self.advance();
408                let v: f32 = tok.lexeme.parse().map_err(|_| {
409                    MenteError::Query(format!("invalid float in vector: {}", tok.lexeme))
410                })?;
411                values.push(v);
412            }
413        }
414        self.expect(TokenKind::RBracket)?;
415        Ok(values)
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use crate::lexer::tokenize;
423
424    #[test]
425    fn test_parse_recall_with_type_filter() {
426        let tokens = tokenize("RECALL memories WHERE type = episodic LIMIT 5").unwrap();
427        let stmt = Parser::parse(&tokens).unwrap();
428        match stmt {
429            Statement::Recall(r) => {
430                assert_eq!(r.filters.len(), 1);
431                assert_eq!(r.filters[0].field, Field::Type);
432                assert_eq!(r.filters[0].value, Value::MemoryType(MemoryType::Episodic));
433                assert_eq!(r.limit, Some(5));
434            }
435            _ => panic!("expected Recall"),
436        }
437    }
438
439    #[test]
440    fn test_parse_recall_similar_to() {
441        let tokens =
442            tokenize(r#"RECALL memories WHERE content ~> "database migration" LIMIT 10"#).unwrap();
443        let stmt = Parser::parse(&tokens).unwrap();
444        match stmt {
445            Statement::Recall(r) => {
446                assert_eq!(r.filters.len(), 1);
447                assert_eq!(r.filters[0].op, Operator::SimilarTo);
448                assert_eq!(r.limit, Some(10));
449            }
450            _ => panic!("expected Recall"),
451        }
452    }
453
454    #[test]
455    fn test_parse_recall_near() {
456        let tokens = tokenize("RECALL memories NEAR [0.1, 0.2, 0.3] LIMIT 10").unwrap();
457        let stmt = Parser::parse(&tokens).unwrap();
458        match stmt {
459            Statement::Recall(r) => {
460                assert_eq!(r.near, Some(vec![0.1, 0.2, 0.3]));
461                assert_eq!(r.limit, Some(10));
462            }
463            _ => panic!("expected Recall"),
464        }
465    }
466
467    #[test]
468    fn test_parse_relate() {
469        let tokens = tokenize(
470            "RELATE 550e8400-e29b-41d4-a716-446655440000 -> 660e8400-e29b-41d4-a716-446655440000 AS caused WITH weight = 0.9"
471        ).unwrap();
472        let stmt = Parser::parse(&tokens).unwrap();
473        match stmt {
474            Statement::Relate(r) => {
475                assert_eq!(r.edge_type, EdgeType::Caused);
476                assert_eq!(r.weight, Some(0.9));
477            }
478            _ => panic!("expected Relate"),
479        }
480    }
481
482    #[test]
483    fn test_parse_forget() {
484        let tokens = tokenize("FORGET 550e8400-e29b-41d4-a716-446655440000").unwrap();
485        let stmt = Parser::parse(&tokens).unwrap();
486        match stmt {
487            Statement::Forget(f) => {
488                assert_eq!(
489                    f.target,
490                    "550e8400-e29b-41d4-a716-446655440000"
491                        .parse::<MemoryId>()
492                        .unwrap()
493                );
494            }
495            _ => panic!("expected Forget"),
496        }
497    }
498
499    #[test]
500    fn test_parse_consolidate() {
501        let tokens =
502            tokenize(r#"CONSOLIDATE WHERE type = episodic AND accessed < "2024-01-01""#).unwrap();
503        let stmt = Parser::parse(&tokens).unwrap();
504        match stmt {
505            Statement::Consolidate(c) => {
506                assert_eq!(c.filters.len(), 2);
507            }
508            _ => panic!("expected Consolidate"),
509        }
510    }
511
512    #[test]
513    fn test_parse_traverse() {
514        let tokens = tokenize(
515            "TRAVERSE 550e8400-e29b-41d4-a716-446655440000 DEPTH 3 WHERE edge_type = caused",
516        )
517        .unwrap();
518        let stmt = Parser::parse(&tokens).unwrap();
519        match stmt {
520            Statement::Traverse(t) => {
521                assert_eq!(t.depth, 3);
522                assert_eq!(t.edge_filter, Some(vec![EdgeType::Caused]));
523            }
524            _ => panic!("expected Traverse"),
525        }
526    }
527}