1use mentedb_core::edge::EdgeType;
4use mentedb_core::error::{MenteError, MenteResult};
5use mentedb_core::memory::MemoryType;
6use uuid::Uuid;
7
8use crate::ast::*;
9use crate::lexer::{Token, TokenKind};
10
11pub struct Parser<'a> {
12 tokens: &'a [Token],
13 pos: usize,
14}
15
16impl<'a> Parser<'a> {
17 pub fn new(tokens: &'a [Token]) -> Self {
18 Self { tokens, pos: 0 }
19 }
20
21 pub fn parse(tokens: &[Token]) -> MenteResult<Statement> {
22 let mut parser = Parser::new(tokens);
23 parser.parse_statement()
24 }
25
26 fn peek(&self) -> &Token {
27 &self.tokens[self.pos.min(self.tokens.len() - 1)]
28 }
29
30 fn advance(&mut self) -> &Token {
31 let tok = &self.tokens[self.pos.min(self.tokens.len() - 1)];
32 if self.pos < self.tokens.len() {
33 self.pos += 1;
34 }
35 tok
36 }
37
38 fn expect(&mut self, kind: TokenKind) -> MenteResult<&Token> {
39 let tok = self.peek();
40 if tok.kind != kind {
41 return Err(MenteError::Query(format!(
42 "expected {:?}, found {:?} ('{}') at position {}",
43 kind, tok.kind, tok.lexeme, tok.position
44 )));
45 }
46 Ok(self.advance())
47 }
48
49 fn at(&self, kind: TokenKind) -> bool {
50 self.peek().kind == kind
51 }
52
53 fn parse_statement(&mut self) -> MenteResult<Statement> {
54 match self.peek().kind {
55 TokenKind::Recall => self.parse_recall(),
56 TokenKind::Relate => self.parse_relate(),
57 TokenKind::Forget => self.parse_forget(),
58 TokenKind::Consolidate => self.parse_consolidate(),
59 TokenKind::Traverse => self.parse_traverse(),
60 _ => Err(MenteError::Query(format!(
61 "expected statement keyword, found {:?} at position {}",
62 self.peek().kind,
63 self.peek().position
64 ))),
65 }
66 }
67
68 fn parse_recall(&mut self) -> MenteResult<Statement> {
69 self.advance(); if self.at(TokenKind::Memories) {
73 self.advance();
74 }
75
76 let mut near = None;
77 let mut filters = Vec::new();
78 let mut limit = None;
79 let mut order_by = None;
80
81 if self.at(TokenKind::Near) {
83 self.advance();
84 near = Some(self.parse_vector()?);
85 }
86
87 if self.at(TokenKind::Where) {
89 self.advance();
90 filters = self.parse_filters()?;
91 }
92
93 if self.at(TokenKind::OrderBy) {
95 self.advance();
96 if self.at(TokenKind::By) {
98 self.advance();
99 }
100 let field = self.parse_field()?;
101 let descending = false; order_by = Some(OrderBy { field, descending });
103 }
104
105 if self.at(TokenKind::Limit) {
107 self.advance();
108 let tok = self.advance();
109 let n: usize = tok
110 .lexeme
111 .parse()
112 .map_err(|_| MenteError::Query(format!("invalid limit value: {}", tok.lexeme)))?;
113 limit = Some(n);
114 }
115
116 Ok(Statement::Recall(RecallStatement {
117 filters,
118 near,
119 limit,
120 order_by,
121 }))
122 }
123
124 fn parse_relate(&mut self) -> MenteResult<Statement> {
125 self.advance(); let source = self.parse_uuid()?;
128 self.expect(TokenKind::Arrow)?;
129 let target = self.parse_uuid()?;
130 self.expect(TokenKind::As)?;
131 let edge_type = self.parse_edge_type()?;
132
133 let mut weight = None;
134 if self.at(TokenKind::With) {
135 self.advance();
136 self.expect(TokenKind::Identifier)?; self.expect(TokenKind::Eq)?;
139 let tok = self.advance();
140 let w: f32 = tok
141 .lexeme
142 .parse()
143 .map_err(|_| MenteError::Query(format!("invalid weight value: {}", tok.lexeme)))?;
144 weight = Some(w);
145 }
146
147 Ok(Statement::Relate(RelateStatement {
148 source,
149 target,
150 edge_type,
151 weight,
152 }))
153 }
154
155 fn parse_forget(&mut self) -> MenteResult<Statement> {
156 self.advance(); let target = self.parse_uuid()?;
158 Ok(Statement::Forget(ForgetStatement { target }))
159 }
160
161 fn parse_consolidate(&mut self) -> MenteResult<Statement> {
162 self.advance(); let mut filters = Vec::new();
164 if self.at(TokenKind::Where) {
165 self.advance();
166 filters = self.parse_filters()?;
167 }
168 Ok(Statement::Consolidate(ConsolidateStatement { filters }))
169 }
170
171 fn parse_traverse(&mut self) -> MenteResult<Statement> {
172 self.advance(); let start = self.parse_uuid()?;
174
175 self.expect(TokenKind::Depth)?;
176 let tok = self.advance();
177 let depth: usize = tok
178 .lexeme
179 .parse()
180 .map_err(|_| MenteError::Query(format!("invalid depth value: {}", tok.lexeme)))?;
181
182 let mut edge_filter = None;
183 if self.at(TokenKind::Where) {
184 self.advance();
185 self.expect(TokenKind::EdgeType)?;
187 self.expect(TokenKind::Eq)?;
188 let et = self.parse_edge_type()?;
189 edge_filter = Some(vec![et]);
190 }
191
192 Ok(Statement::Traverse(TraverseStatement {
193 start,
194 depth,
195 edge_filter,
196 }))
197 }
198
199 fn parse_filters(&mut self) -> MenteResult<Vec<Filter>> {
200 let mut filters = vec![self.parse_filter()?];
201 while self.at(TokenKind::And) {
202 self.advance();
203 filters.push(self.parse_filter()?);
204 }
205 Ok(filters)
206 }
207
208 fn parse_filter(&mut self) -> MenteResult<Filter> {
209 let field = self.parse_field()?;
210 let op = self.parse_operator()?;
211 let value = self.parse_value(&field)?;
212 Ok(Filter { field, op, value })
213 }
214
215 fn parse_field(&mut self) -> MenteResult<Field> {
216 let tok = self.advance();
217 match tok.kind {
218 TokenKind::Identifier if tok.lexeme.eq_ignore_ascii_case("content") => {
219 Ok(Field::Content)
220 }
221 TokenKind::Type => Ok(Field::Type),
222 TokenKind::Tag => Ok(Field::Tag),
223 TokenKind::Agent => Ok(Field::Agent),
224 TokenKind::Space => Ok(Field::Space),
225 TokenKind::Salience => Ok(Field::Salience),
226 TokenKind::Confidence => Ok(Field::Confidence),
227 TokenKind::Created => Ok(Field::Created),
228 TokenKind::Accessed => Ok(Field::Accessed),
229 _ => Err(MenteError::Query(format!(
230 "expected field name, found '{}' at position {}",
231 tok.lexeme, tok.position
232 ))),
233 }
234 }
235
236 fn parse_operator(&mut self) -> MenteResult<Operator> {
237 let tok = self.advance();
238 match tok.kind {
239 TokenKind::Eq => Ok(Operator::Eq),
240 TokenKind::Neq => Ok(Operator::Neq),
241 TokenKind::Gt => Ok(Operator::Gt),
242 TokenKind::Lt => Ok(Operator::Lt),
243 TokenKind::Gte => Ok(Operator::Gte),
244 TokenKind::Lte => Ok(Operator::Lte),
245 TokenKind::SimilarTo => Ok(Operator::SimilarTo),
246 _ => Err(MenteError::Query(format!(
247 "expected operator, found '{}' at position {}",
248 tok.lexeme, tok.position
249 ))),
250 }
251 }
252
253 fn parse_value(&mut self, field: &Field) -> MenteResult<Value> {
254 if *field == Field::Type {
256 return self.parse_memory_type_value();
257 }
258
259 let tok = self.advance();
260 match tok.kind {
261 TokenKind::StringLit => {
262 let inner = tok.lexeme[1..tok.lexeme.len() - 1].to_string();
264 if let Ok(uuid) = inner.parse::<Uuid>() {
266 return Ok(Value::Uuid(uuid));
267 }
268 Ok(Value::Text(inner))
269 }
270 TokenKind::IntegerLit => {
271 let n: i64 = tok
272 .lexeme
273 .parse()
274 .map_err(|_| MenteError::Query(format!("invalid integer: {}", tok.lexeme)))?;
275 Ok(Value::Integer(n))
276 }
277 TokenKind::FloatLit => {
278 let n: f64 = tok
279 .lexeme
280 .parse()
281 .map_err(|_| MenteError::Query(format!("invalid float: {}", tok.lexeme)))?;
282 Ok(Value::Number(n))
283 }
284 TokenKind::UuidLit => {
285 let uuid: Uuid = tok
286 .lexeme
287 .parse()
288 .map_err(|_| MenteError::Query(format!("invalid UUID: {}", tok.lexeme)))?;
289 Ok(Value::Uuid(uuid))
290 }
291 TokenKind::Identifier => {
292 let lower = tok.lexeme.to_lowercase();
293 match lower.as_str() {
294 "true" => Ok(Value::Bool(true)),
295 "false" => Ok(Value::Bool(false)),
296 _ => Ok(Value::Text(tok.lexeme.clone())),
297 }
298 }
299 TokenKind::LBracket => {
300 self.pos -= 1;
302 let v = self.parse_vector()?;
303 Ok(Value::Vector(v))
304 }
305 _ => Err(MenteError::Query(format!(
306 "expected value, found '{}' at position {}",
307 tok.lexeme, tok.position
308 ))),
309 }
310 }
311
312 fn parse_memory_type_value(&mut self) -> MenteResult<Value> {
313 let tok = self.advance();
314 let name = match tok.kind {
315 TokenKind::Identifier | TokenKind::StringLit => {
316 if tok.kind == TokenKind::StringLit {
317 tok.lexeme[1..tok.lexeme.len() - 1].to_string()
318 } else {
319 tok.lexeme.clone()
320 }
321 }
322 _ => {
323 return Err(MenteError::Query(format!(
324 "expected memory type, found '{}' at position {}",
325 tok.lexeme, tok.position
326 )));
327 }
328 };
329
330 let mt = match name.to_lowercase().as_str() {
331 "episodic" => MemoryType::Episodic,
332 "semantic" => MemoryType::Semantic,
333 "procedural" => MemoryType::Procedural,
334 "antipattern" | "anti_pattern" => MemoryType::AntiPattern,
335 "reasoning" => MemoryType::Reasoning,
336 "correction" => MemoryType::Correction,
337 _ => {
338 return Err(MenteError::Query(format!("unknown memory type: {}", name)));
339 }
340 };
341 Ok(Value::MemoryType(mt))
342 }
343
344 fn parse_edge_type(&mut self) -> MenteResult<EdgeType> {
345 let tok = self.advance();
346 let name = match tok.kind {
347 TokenKind::Identifier | TokenKind::StringLit => {
348 if tok.kind == TokenKind::StringLit {
349 tok.lexeme[1..tok.lexeme.len() - 1].to_string()
350 } else {
351 tok.lexeme.clone()
352 }
353 }
354 _ => {
355 return Err(MenteError::Query(format!(
356 "expected edge type, found '{}' at position {}",
357 tok.lexeme, tok.position
358 )));
359 }
360 };
361
362 match name.to_lowercase().as_str() {
363 "caused" => Ok(EdgeType::Caused),
364 "before" => Ok(EdgeType::Before),
365 "related" => Ok(EdgeType::Related),
366 "contradicts" => Ok(EdgeType::Contradicts),
367 "supports" => Ok(EdgeType::Supports),
368 "supersedes" => Ok(EdgeType::Supersedes),
369 "derived" => Ok(EdgeType::Derived),
370 "partof" | "part_of" => Ok(EdgeType::PartOf),
371 _ => Err(MenteError::Query(format!("unknown edge type: {}", name))),
372 }
373 }
374
375 fn parse_uuid(&mut self) -> MenteResult<Uuid> {
376 let tok = self.advance();
377 match tok.kind {
378 TokenKind::UuidLit => tok
379 .lexeme
380 .parse()
381 .map_err(|_| MenteError::Query(format!("invalid UUID: {}", tok.lexeme))),
382 TokenKind::StringLit => {
383 let inner = &tok.lexeme[1..tok.lexeme.len() - 1];
384 inner.parse().map_err(|_| {
385 MenteError::Query(format!("invalid UUID in string: {}", tok.lexeme))
386 })
387 }
388 _ => Err(MenteError::Query(format!(
389 "expected UUID, found '{}' at position {}",
390 tok.lexeme, tok.position
391 ))),
392 }
393 }
394
395 fn parse_vector(&mut self) -> MenteResult<Vec<f32>> {
396 self.expect(TokenKind::LBracket)?;
397 let mut values = Vec::new();
398 if !self.at(TokenKind::RBracket) {
399 let tok = self.advance();
400 let v: f32 = tok.lexeme.parse().map_err(|_| {
401 MenteError::Query(format!("invalid float in vector: {}", tok.lexeme))
402 })?;
403 values.push(v);
404 while self.at(TokenKind::Comma) {
405 self.advance();
406 let tok = self.advance();
407 let v: f32 = tok.lexeme.parse().map_err(|_| {
408 MenteError::Query(format!("invalid float in vector: {}", tok.lexeme))
409 })?;
410 values.push(v);
411 }
412 }
413 self.expect(TokenKind::RBracket)?;
414 Ok(values)
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use crate::lexer::tokenize;
422
423 #[test]
424 fn test_parse_recall_with_type_filter() {
425 let tokens = tokenize("RECALL memories WHERE type = episodic LIMIT 5").unwrap();
426 let stmt = Parser::parse(&tokens).unwrap();
427 match stmt {
428 Statement::Recall(r) => {
429 assert_eq!(r.filters.len(), 1);
430 assert_eq!(r.filters[0].field, Field::Type);
431 assert_eq!(r.filters[0].value, Value::MemoryType(MemoryType::Episodic));
432 assert_eq!(r.limit, Some(5));
433 }
434 _ => panic!("expected Recall"),
435 }
436 }
437
438 #[test]
439 fn test_parse_recall_similar_to() {
440 let tokens =
441 tokenize(r#"RECALL memories WHERE content ~> "database migration" LIMIT 10"#).unwrap();
442 let stmt = Parser::parse(&tokens).unwrap();
443 match stmt {
444 Statement::Recall(r) => {
445 assert_eq!(r.filters.len(), 1);
446 assert_eq!(r.filters[0].op, Operator::SimilarTo);
447 assert_eq!(r.limit, Some(10));
448 }
449 _ => panic!("expected Recall"),
450 }
451 }
452
453 #[test]
454 fn test_parse_recall_near() {
455 let tokens = tokenize("RECALL memories NEAR [0.1, 0.2, 0.3] LIMIT 10").unwrap();
456 let stmt = Parser::parse(&tokens).unwrap();
457 match stmt {
458 Statement::Recall(r) => {
459 assert_eq!(r.near, Some(vec![0.1, 0.2, 0.3]));
460 assert_eq!(r.limit, Some(10));
461 }
462 _ => panic!("expected Recall"),
463 }
464 }
465
466 #[test]
467 fn test_parse_relate() {
468 let tokens = tokenize(
469 "RELATE 550e8400-e29b-41d4-a716-446655440000 -> 660e8400-e29b-41d4-a716-446655440000 AS caused WITH weight = 0.9"
470 ).unwrap();
471 let stmt = Parser::parse(&tokens).unwrap();
472 match stmt {
473 Statement::Relate(r) => {
474 assert_eq!(r.edge_type, EdgeType::Caused);
475 assert_eq!(r.weight, Some(0.9));
476 }
477 _ => panic!("expected Relate"),
478 }
479 }
480
481 #[test]
482 fn test_parse_forget() {
483 let tokens = tokenize("FORGET 550e8400-e29b-41d4-a716-446655440000").unwrap();
484 let stmt = Parser::parse(&tokens).unwrap();
485 match stmt {
486 Statement::Forget(f) => {
487 assert_eq!(
488 f.target,
489 "550e8400-e29b-41d4-a716-446655440000"
490 .parse::<Uuid>()
491 .unwrap()
492 );
493 }
494 _ => panic!("expected Forget"),
495 }
496 }
497
498 #[test]
499 fn test_parse_consolidate() {
500 let tokens =
501 tokenize(r#"CONSOLIDATE WHERE type = episodic AND accessed < "2024-01-01""#).unwrap();
502 let stmt = Parser::parse(&tokens).unwrap();
503 match stmt {
504 Statement::Consolidate(c) => {
505 assert_eq!(c.filters.len(), 2);
506 }
507 _ => panic!("expected Consolidate"),
508 }
509 }
510
511 #[test]
512 fn test_parse_traverse() {
513 let tokens = tokenize(
514 "TRAVERSE 550e8400-e29b-41d4-a716-446655440000 DEPTH 3 WHERE edge_type = caused",
515 )
516 .unwrap();
517 let stmt = Parser::parse(&tokens).unwrap();
518 match stmt {
519 Statement::Traverse(t) => {
520 assert_eq!(t.depth, 3);
521 assert_eq!(t.edge_filter, Some(vec![EdgeType::Caused]));
522 }
523 _ => panic!("expected Traverse"),
524 }
525 }
526}