1use mentedb_core::edge::EdgeType;
4use mentedb_core::error::{MenteError, MenteResult};
5use mentedb_core::memory::MemoryType;
6use uuid::Uuid;
7
8use crate::ast::*;
9use crate::lexer::{Token, TokenKind};
10use mentedb_core::types::MemoryId;
11
12pub struct Parser<'a> {
13 tokens: &'a [Token],
14 pos: usize,
15}
16
17impl<'a> Parser<'a> {
18 pub fn new(tokens: &'a [Token]) -> Self {
19 Self { tokens, pos: 0 }
20 }
21
22 pub fn parse(tokens: &[Token]) -> MenteResult<Statement> {
23 let mut parser = Parser::new(tokens);
24 parser.parse_statement()
25 }
26
27 fn peek(&self) -> &Token {
28 &self.tokens[self.pos.min(self.tokens.len() - 1)]
29 }
30
31 fn advance(&mut self) -> &Token {
32 let tok = &self.tokens[self.pos.min(self.tokens.len() - 1)];
33 if self.pos < self.tokens.len() {
34 self.pos += 1;
35 }
36 tok
37 }
38
39 fn expect(&mut self, kind: TokenKind) -> MenteResult<&Token> {
40 let tok = self.peek();
41 if tok.kind != kind {
42 return Err(MenteError::Query(format!(
43 "expected {:?}, found {:?} ('{}') at position {}",
44 kind, tok.kind, tok.lexeme, tok.position
45 )));
46 }
47 Ok(self.advance())
48 }
49
50 fn at(&self, kind: TokenKind) -> bool {
51 self.peek().kind == kind
52 }
53
54 fn parse_statement(&mut self) -> MenteResult<Statement> {
55 match self.peek().kind {
56 TokenKind::Recall => self.parse_recall(),
57 TokenKind::Relate => self.parse_relate(),
58 TokenKind::Forget => self.parse_forget(),
59 TokenKind::Consolidate => self.parse_consolidate(),
60 TokenKind::Traverse => self.parse_traverse(),
61 _ => Err(MenteError::Query(format!(
62 "expected statement keyword, found {:?} at position {}",
63 self.peek().kind,
64 self.peek().position
65 ))),
66 }
67 }
68
69 fn parse_recall(&mut self) -> MenteResult<Statement> {
70 self.advance(); if self.at(TokenKind::Memories) {
74 self.advance();
75 }
76
77 let mut near = None;
78 let mut filters = Vec::new();
79 let mut limit = None;
80 let mut order_by = None;
81
82 if self.at(TokenKind::Near) {
84 self.advance();
85 near = Some(self.parse_vector()?);
86 }
87
88 if self.at(TokenKind::Where) {
90 self.advance();
91 filters = self.parse_filters()?;
92 }
93
94 if self.at(TokenKind::OrderBy) {
96 self.advance();
97 if self.at(TokenKind::By) {
99 self.advance();
100 }
101 let field = self.parse_field()?;
102 let descending = false; order_by = Some(OrderBy { field, descending });
104 }
105
106 if self.at(TokenKind::Limit) {
108 self.advance();
109 let tok = self.advance();
110 let n: usize = tok
111 .lexeme
112 .parse()
113 .map_err(|_| MenteError::Query(format!("invalid limit value: {}", tok.lexeme)))?;
114 limit = Some(n);
115 }
116
117 Ok(Statement::Recall(RecallStatement {
118 filters,
119 near,
120 limit,
121 order_by,
122 }))
123 }
124
125 fn parse_relate(&mut self) -> MenteResult<Statement> {
126 self.advance(); let source = self.parse_uuid()?;
129 self.expect(TokenKind::Arrow)?;
130 let target = self.parse_uuid()?;
131 self.expect(TokenKind::As)?;
132 let edge_type = self.parse_edge_type()?;
133
134 let mut weight = None;
135 if self.at(TokenKind::With) {
136 self.advance();
137 self.expect(TokenKind::Identifier)?; self.expect(TokenKind::Eq)?;
140 let tok = self.advance();
141 let w: f32 = tok
142 .lexeme
143 .parse()
144 .map_err(|_| MenteError::Query(format!("invalid weight value: {}", tok.lexeme)))?;
145 weight = Some(w);
146 }
147
148 Ok(Statement::Relate(RelateStatement {
149 source,
150 target,
151 edge_type,
152 weight,
153 }))
154 }
155
156 fn parse_forget(&mut self) -> MenteResult<Statement> {
157 self.advance(); let target = self.parse_uuid()?;
159 Ok(Statement::Forget(ForgetStatement { target }))
160 }
161
162 fn parse_consolidate(&mut self) -> MenteResult<Statement> {
163 self.advance(); let mut filters = Vec::new();
165 if self.at(TokenKind::Where) {
166 self.advance();
167 filters = self.parse_filters()?;
168 }
169 Ok(Statement::Consolidate(ConsolidateStatement { filters }))
170 }
171
172 fn parse_traverse(&mut self) -> MenteResult<Statement> {
173 self.advance(); let start = self.parse_uuid()?;
175
176 self.expect(TokenKind::Depth)?;
177 let tok = self.advance();
178 let depth: usize = tok
179 .lexeme
180 .parse()
181 .map_err(|_| MenteError::Query(format!("invalid depth value: {}", tok.lexeme)))?;
182
183 let mut edge_filter = None;
184 if self.at(TokenKind::Where) {
185 self.advance();
186 self.expect(TokenKind::EdgeType)?;
188 self.expect(TokenKind::Eq)?;
189 let et = self.parse_edge_type()?;
190 edge_filter = Some(vec![et]);
191 }
192
193 Ok(Statement::Traverse(TraverseStatement {
194 start,
195 depth,
196 edge_filter,
197 }))
198 }
199
200 fn parse_filters(&mut self) -> MenteResult<Vec<Filter>> {
201 let mut filters = vec![self.parse_filter()?];
202 while self.at(TokenKind::And) {
203 self.advance();
204 filters.push(self.parse_filter()?);
205 }
206 Ok(filters)
207 }
208
209 fn parse_filter(&mut self) -> MenteResult<Filter> {
210 let field = self.parse_field()?;
211 let op = self.parse_operator()?;
212 let value = self.parse_value(&field)?;
213 Ok(Filter { field, op, value })
214 }
215
216 fn parse_field(&mut self) -> MenteResult<Field> {
217 let tok = self.advance();
218 match tok.kind {
219 TokenKind::Identifier if tok.lexeme.eq_ignore_ascii_case("content") => {
220 Ok(Field::Content)
221 }
222 TokenKind::Type => Ok(Field::Type),
223 TokenKind::Tag => Ok(Field::Tag),
224 TokenKind::Agent => Ok(Field::Agent),
225 TokenKind::Space => Ok(Field::Space),
226 TokenKind::Salience => Ok(Field::Salience),
227 TokenKind::Confidence => Ok(Field::Confidence),
228 TokenKind::Created => Ok(Field::Created),
229 TokenKind::Accessed => Ok(Field::Accessed),
230 _ => Err(MenteError::Query(format!(
231 "expected field name, found '{}' at position {}",
232 tok.lexeme, tok.position
233 ))),
234 }
235 }
236
237 fn parse_operator(&mut self) -> MenteResult<Operator> {
238 let tok = self.advance();
239 match tok.kind {
240 TokenKind::Eq => Ok(Operator::Eq),
241 TokenKind::Neq => Ok(Operator::Neq),
242 TokenKind::Gt => Ok(Operator::Gt),
243 TokenKind::Lt => Ok(Operator::Lt),
244 TokenKind::Gte => Ok(Operator::Gte),
245 TokenKind::Lte => Ok(Operator::Lte),
246 TokenKind::SimilarTo => Ok(Operator::SimilarTo),
247 _ => Err(MenteError::Query(format!(
248 "expected operator, found '{}' at position {}",
249 tok.lexeme, tok.position
250 ))),
251 }
252 }
253
254 fn parse_value(&mut self, field: &Field) -> MenteResult<Value> {
255 if *field == Field::Type {
257 return self.parse_memory_type_value();
258 }
259
260 let tok = self.advance();
261 match tok.kind {
262 TokenKind::StringLit => {
263 let inner = tok.lexeme[1..tok.lexeme.len() - 1].to_string();
265 if let Ok(uuid) = inner.parse::<MemoryId>() {
267 return Ok(Value::Uuid(uuid.into()));
268 }
269 Ok(Value::Text(inner))
270 }
271 TokenKind::IntegerLit => {
272 let n: i64 = tok
273 .lexeme
274 .parse()
275 .map_err(|_| MenteError::Query(format!("invalid integer: {}", tok.lexeme)))?;
276 Ok(Value::Integer(n))
277 }
278 TokenKind::FloatLit => {
279 let n: f64 = tok
280 .lexeme
281 .parse()
282 .map_err(|_| MenteError::Query(format!("invalid float: {}", tok.lexeme)))?;
283 Ok(Value::Number(n))
284 }
285 TokenKind::UuidLit => {
286 let uuid: Uuid = tok
287 .lexeme
288 .parse()
289 .map_err(|_| MenteError::Query(format!("invalid UUID: {}", tok.lexeme)))?;
290 Ok(Value::Uuid(uuid))
291 }
292 TokenKind::Identifier => {
293 let lower = tok.lexeme.to_lowercase();
294 match lower.as_str() {
295 "true" => Ok(Value::Bool(true)),
296 "false" => Ok(Value::Bool(false)),
297 _ => Ok(Value::Text(tok.lexeme.clone())),
298 }
299 }
300 TokenKind::LBracket => {
301 self.pos -= 1;
303 let v = self.parse_vector()?;
304 Ok(Value::Vector(v))
305 }
306 _ => Err(MenteError::Query(format!(
307 "expected value, found '{}' at position {}",
308 tok.lexeme, tok.position
309 ))),
310 }
311 }
312
313 fn parse_memory_type_value(&mut self) -> MenteResult<Value> {
314 let tok = self.advance();
315 let name = match tok.kind {
316 TokenKind::Identifier | TokenKind::StringLit => {
317 if tok.kind == TokenKind::StringLit {
318 tok.lexeme[1..tok.lexeme.len() - 1].to_string()
319 } else {
320 tok.lexeme.clone()
321 }
322 }
323 _ => {
324 return Err(MenteError::Query(format!(
325 "expected memory type, found '{}' at position {}",
326 tok.lexeme, tok.position
327 )));
328 }
329 };
330
331 let mt = match name.to_lowercase().as_str() {
332 "episodic" => MemoryType::Episodic,
333 "semantic" => MemoryType::Semantic,
334 "procedural" => MemoryType::Procedural,
335 "antipattern" | "anti_pattern" => MemoryType::AntiPattern,
336 "reasoning" => MemoryType::Reasoning,
337 "correction" => MemoryType::Correction,
338 _ => {
339 return Err(MenteError::Query(format!("unknown memory type: {}", name)));
340 }
341 };
342 Ok(Value::MemoryType(mt))
343 }
344
345 fn parse_edge_type(&mut self) -> MenteResult<EdgeType> {
346 let tok = self.advance();
347 let name = match tok.kind {
348 TokenKind::Identifier | TokenKind::StringLit => {
349 if tok.kind == TokenKind::StringLit {
350 tok.lexeme[1..tok.lexeme.len() - 1].to_string()
351 } else {
352 tok.lexeme.clone()
353 }
354 }
355 _ => {
356 return Err(MenteError::Query(format!(
357 "expected edge type, found '{}' at position {}",
358 tok.lexeme, tok.position
359 )));
360 }
361 };
362
363 match name.to_lowercase().as_str() {
364 "caused" => Ok(EdgeType::Caused),
365 "before" => Ok(EdgeType::Before),
366 "related" => Ok(EdgeType::Related),
367 "contradicts" => Ok(EdgeType::Contradicts),
368 "supports" => Ok(EdgeType::Supports),
369 "supersedes" => Ok(EdgeType::Supersedes),
370 "derived" => Ok(EdgeType::Derived),
371 "partof" | "part_of" => Ok(EdgeType::PartOf),
372 _ => Err(MenteError::Query(format!("unknown edge type: {}", name))),
373 }
374 }
375
376 fn parse_uuid(&mut self) -> MenteResult<MemoryId> {
377 let tok = self.advance();
378 match tok.kind {
379 TokenKind::UuidLit => tok
380 .lexeme
381 .parse()
382 .map_err(|_| MenteError::Query(format!("invalid UUID: {}", tok.lexeme))),
383 TokenKind::StringLit => {
384 let inner = &tok.lexeme[1..tok.lexeme.len() - 1];
385 inner.parse().map_err(|_| {
386 MenteError::Query(format!("invalid UUID in string: {}", tok.lexeme))
387 })
388 }
389 _ => Err(MenteError::Query(format!(
390 "expected UUID, found '{}' at position {}",
391 tok.lexeme, tok.position
392 ))),
393 }
394 }
395
396 fn parse_vector(&mut self) -> MenteResult<Vec<f32>> {
397 self.expect(TokenKind::LBracket)?;
398 let mut values = Vec::new();
399 if !self.at(TokenKind::RBracket) {
400 let tok = self.advance();
401 let v: f32 = tok.lexeme.parse().map_err(|_| {
402 MenteError::Query(format!("invalid float in vector: {}", tok.lexeme))
403 })?;
404 values.push(v);
405 while self.at(TokenKind::Comma) {
406 self.advance();
407 let tok = self.advance();
408 let v: f32 = tok.lexeme.parse().map_err(|_| {
409 MenteError::Query(format!("invalid float in vector: {}", tok.lexeme))
410 })?;
411 values.push(v);
412 }
413 }
414 self.expect(TokenKind::RBracket)?;
415 Ok(values)
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use crate::lexer::tokenize;
423
424 #[test]
425 fn test_parse_recall_with_type_filter() {
426 let tokens = tokenize("RECALL memories WHERE type = episodic LIMIT 5").unwrap();
427 let stmt = Parser::parse(&tokens).unwrap();
428 match stmt {
429 Statement::Recall(r) => {
430 assert_eq!(r.filters.len(), 1);
431 assert_eq!(r.filters[0].field, Field::Type);
432 assert_eq!(r.filters[0].value, Value::MemoryType(MemoryType::Episodic));
433 assert_eq!(r.limit, Some(5));
434 }
435 _ => panic!("expected Recall"),
436 }
437 }
438
439 #[test]
440 fn test_parse_recall_similar_to() {
441 let tokens =
442 tokenize(r#"RECALL memories WHERE content ~> "database migration" LIMIT 10"#).unwrap();
443 let stmt = Parser::parse(&tokens).unwrap();
444 match stmt {
445 Statement::Recall(r) => {
446 assert_eq!(r.filters.len(), 1);
447 assert_eq!(r.filters[0].op, Operator::SimilarTo);
448 assert_eq!(r.limit, Some(10));
449 }
450 _ => panic!("expected Recall"),
451 }
452 }
453
454 #[test]
455 fn test_parse_recall_near() {
456 let tokens = tokenize("RECALL memories NEAR [0.1, 0.2, 0.3] LIMIT 10").unwrap();
457 let stmt = Parser::parse(&tokens).unwrap();
458 match stmt {
459 Statement::Recall(r) => {
460 assert_eq!(r.near, Some(vec![0.1, 0.2, 0.3]));
461 assert_eq!(r.limit, Some(10));
462 }
463 _ => panic!("expected Recall"),
464 }
465 }
466
467 #[test]
468 fn test_parse_relate() {
469 let tokens = tokenize(
470 "RELATE 550e8400-e29b-41d4-a716-446655440000 -> 660e8400-e29b-41d4-a716-446655440000 AS caused WITH weight = 0.9"
471 ).unwrap();
472 let stmt = Parser::parse(&tokens).unwrap();
473 match stmt {
474 Statement::Relate(r) => {
475 assert_eq!(r.edge_type, EdgeType::Caused);
476 assert_eq!(r.weight, Some(0.9));
477 }
478 _ => panic!("expected Relate"),
479 }
480 }
481
482 #[test]
483 fn test_parse_forget() {
484 let tokens = tokenize("FORGET 550e8400-e29b-41d4-a716-446655440000").unwrap();
485 let stmt = Parser::parse(&tokens).unwrap();
486 match stmt {
487 Statement::Forget(f) => {
488 assert_eq!(
489 f.target,
490 "550e8400-e29b-41d4-a716-446655440000"
491 .parse::<MemoryId>()
492 .unwrap()
493 );
494 }
495 _ => panic!("expected Forget"),
496 }
497 }
498
499 #[test]
500 fn test_parse_consolidate() {
501 let tokens =
502 tokenize(r#"CONSOLIDATE WHERE type = episodic AND accessed < "2024-01-01""#).unwrap();
503 let stmt = Parser::parse(&tokens).unwrap();
504 match stmt {
505 Statement::Consolidate(c) => {
506 assert_eq!(c.filters.len(), 2);
507 }
508 _ => panic!("expected Consolidate"),
509 }
510 }
511
512 #[test]
513 fn test_parse_traverse() {
514 let tokens = tokenize(
515 "TRAVERSE 550e8400-e29b-41d4-a716-446655440000 DEPTH 3 WHERE edge_type = caused",
516 )
517 .unwrap();
518 let stmt = Parser::parse(&tokens).unwrap();
519 match stmt {
520 Statement::Traverse(t) => {
521 assert_eq!(t.depth, 3);
522 assert_eq!(t.edge_filter, Some(vec![EdgeType::Caused]));
523 }
524 _ => panic!("expected Traverse"),
525 }
526 }
527}