1use std::collections::HashSet;
12
13use crate::error::{Error, Result};
14use crate::triple::Triple;
15use crate::{Database, QueryCriteria};
16
17#[derive(Debug, PartialEq, Clone)]
18pub enum Token {
19 Match,
20 Return,
21 Where,
22 LParen, RParen, LBracket, RBracket, Arrow, Dash, Equal, Star, Range, Number(u64),
32 Identifier(String),
33 StringLiteral(String),
34 Colon, Comma, }
37
38#[derive(Debug, Clone)]
39pub enum QueryPart {
40 Variable(String),
41 Literal(String),
42 Anonymous,
43}
44
45#[derive(Debug, Clone, Copy)]
46pub struct PathLength {
47 pub min: usize,
48 pub max: usize,
49}
50
51impl PathLength {
52 pub fn single() -> Self {
53 Self { min: 1, max: 1 }
54 }
55
56 pub fn is_single(&self) -> bool {
57 self.min == 1 && self.max == 1
58 }
59}
60
61impl Default for PathLength {
62 fn default() -> Self {
63 Self::single()
64 }
65}
66
67#[derive(Debug)]
68pub struct ParsedQuery {
69 pub subject: QueryPart,
70 pub predicate: QueryPart,
71 pub object: QueryPart,
72 pub where_clause: Option<(String, String)>, pub return_var: String,
74 pub path_len: PathLength,
75}
76
77pub struct Lexer<'a> {
78 input: &'a str,
79 pos: usize,
80}
81
82impl<'a> Lexer<'a> {
83 pub fn new(input: &'a str) -> Self {
84 Self { input, pos: 0 }
85 }
86
87 pub fn next_token(&mut self) -> Option<Result<Token>> {
88 self.skip_whitespace();
89
90 if self.pos >= self.input.len() {
91 return None;
92 }
93
94 let c = self.input[self.pos..].chars().next()?;
95
96 match c {
97 '(' => {
98 self.pos += 1;
99 Some(Ok(Token::LParen))
100 }
101 ')' => {
102 self.pos += 1;
103 Some(Ok(Token::RParen))
104 }
105 '[' => {
106 self.pos += 1;
107 Some(Ok(Token::LBracket))
108 }
109 ']' => {
110 self.pos += 1;
111 Some(Ok(Token::RBracket))
112 }
113 ':' => {
114 self.pos += 1;
115 Some(Ok(Token::Colon))
116 }
117 ',' => {
118 self.pos += 1;
119 Some(Ok(Token::Comma))
120 }
121 '=' => {
122 self.pos += 1;
123 Some(Ok(Token::Equal))
124 }
125 '*' => {
126 self.pos += 1;
127 Some(Ok(Token::Star))
128 }
129 '.' => {
130 if self.input[self.pos..].starts_with("..") {
131 self.pos += 2;
132 Some(Ok(Token::Range))
133 } else {
134 self.pos += 1;
135 Some(Err(Error::Other("Unexpected '.'".to_string())))
136 }
137 }
138 '-' => {
139 if self.input[self.pos..].starts_with("->") {
140 self.pos += 2;
141 Some(Ok(Token::Arrow))
142 } else {
143 self.pos += 1;
144 Some(Ok(Token::Dash))
145 }
146 }
147 '\'' => self.read_string_literal(),
148 _ if c.is_ascii_digit() => self.read_number(),
149 _ if c.is_alphabetic() || c == '_' => self.read_identifier(),
150 _ => {
151 self.pos += 1;
152 Some(Err(Error::Other(format!("Unexpected character: {}", c))))
153 }
154 }
155 }
156
157 fn skip_whitespace(&mut self) {
158 while let Some(c) = self.input[self.pos..].chars().next() {
159 if !c.is_whitespace() {
160 break;
161 }
162 self.pos += c.len_utf8();
163 }
164 }
165
166 fn read_string_literal(&mut self) -> Option<Result<Token>> {
167 self.pos += 1; let start = self.pos;
169 while let Some(c) = self.input[self.pos..].chars().next() {
170 if c == '\'' {
171 let s = &self.input[start..self.pos];
172 self.pos += 1; return Some(Ok(Token::StringLiteral(s.to_string())));
174 }
175 self.pos += c.len_utf8();
176 }
177 Some(Err(Error::Other("Unterminated string literal".to_string())))
178 }
179
180 fn read_identifier(&mut self) -> Option<Result<Token>> {
181 let start = self.pos;
182 while let Some(c) = self.input[self.pos..].chars().next() {
183 if !c.is_alphanumeric() && c != '_' {
184 break;
185 }
186 self.pos += c.len_utf8();
187 }
188 let s = &self.input[start..self.pos];
189 match s.to_uppercase().as_str() {
190 "MATCH" => Some(Ok(Token::Match)),
191 "RETURN" => Some(Ok(Token::Return)),
192 "WHERE" => Some(Ok(Token::Where)),
193 _ => Some(Ok(Token::Identifier(s.to_string()))),
194 }
195 }
196
197 fn read_number(&mut self) -> Option<Result<Token>> {
198 let start = self.pos;
199 while let Some(c) = self.input[self.pos..].chars().next() {
200 if !c.is_ascii_digit() {
201 break;
202 }
203 self.pos += c.len_utf8();
204 }
205 let s = &self.input[start..self.pos];
206 match s.parse::<u64>() {
207 Ok(v) => Some(Ok(Token::Number(v))),
208 Err(_) => Some(Err(Error::Other(format!("Invalid number: {}", s)))),
209 }
210 }
211}
212
213pub struct Parser<'a> {
214 lexer: Lexer<'a>,
215 current_token: Option<Token>,
216}
217
218impl<'a> Parser<'a> {
219 pub fn new(input: &'a str) -> Self {
220 let mut lexer = Lexer::new(input);
221 let current_token = lexer.next_token().and_then(|r| r.ok());
222 Self {
223 lexer,
224 current_token,
225 }
226 }
227
228 fn advance(&mut self) -> Result<()> {
229 match self.lexer.next_token() {
230 Some(Ok(t)) => self.current_token = Some(t),
231 Some(Err(e)) => return Err(e),
232 None => self.current_token = None,
233 }
234 Ok(())
235 }
236
237 pub fn parse(&mut self) -> Result<ParsedQuery> {
239 if self.current_token != Some(Token::Match) {
241 return Err(Error::Other("Expected MATCH".to_string()));
242 }
243 self.advance()?;
244
245 let subject = self.parse_node()?;
247
248 if self.current_token != Some(Token::Dash) {
250 return Err(Error::Other("Expected -".to_string()));
251 }
252 self.advance()?;
253
254 let (predicate, path_len) = self.parse_rel()?;
256
257 if self.current_token != Some(Token::Arrow) {
259 return Err(Error::Other("Expected ->".to_string()));
260 }
261 self.advance()?;
262
263 let object = self.parse_node()?;
265
266 let mut where_clause: Option<(String, String)> = None;
268
269 if self.current_token == Some(Token::Where) {
270 self.advance()?;
271
272 let var_name = if let Some(Token::Identifier(var)) = &self.current_token {
273 var.clone()
274 } else {
275 return Err(Error::Other("Expected variable in WHERE".to_string()));
276 };
277 self.advance()?;
278
279 if self.current_token != Some(Token::Equal) {
280 return Err(Error::Other("Expected = in WHERE".to_string()));
281 }
282 self.advance()?;
283
284 let val = if let Some(Token::StringLiteral(val)) = &self.current_token {
285 val.clone()
286 } else {
287 return Err(Error::Other("Expected string literal in WHERE".to_string()));
288 };
289 self.advance()?;
290
291 where_clause = Some((var_name, val));
292 }
293
294 if self.current_token != Some(Token::Return) {
296 return Err(Error::Other("Expected RETURN".to_string()));
297 }
298 self.advance()?;
299
300 let return_var = if let Some(Token::Identifier(var)) = &self.current_token {
301 var.clone()
302 } else {
303 return Err(Error::Other("Expected return variable".to_string()));
304 };
305
306 Ok(ParsedQuery {
307 subject,
308 predicate,
309 object,
310 where_clause,
311 return_var,
312 path_len,
313 })
314 }
315
316 fn parse_node(&mut self) -> Result<QueryPart> {
317 if self.current_token != Some(Token::LParen) {
318 return Err(Error::Other("Expected (".to_string()));
319 }
320 self.advance()?;
321
322 let part = match &self.current_token {
327 Some(Token::Colon) => {
328 self.advance()?;
329 if let Some(Token::Identifier(s)) = &self.current_token {
330 let val = s.clone();
331 self.advance()?;
332 QueryPart::Literal(val)
333 } else {
334 return Err(Error::Other("Expected identifier after :".to_string()));
335 }
336 }
337 Some(Token::Identifier(s)) => {
338 let val = s.clone();
339 self.advance()?;
340 if self.current_token == Some(Token::Colon) {
344 self.advance()?;
345 if let Some(Token::Identifier(_)) = &self.current_token {
347 self.advance()?;
348 }
349 }
350 QueryPart::Variable(val)
351 }
352 _ => QueryPart::Anonymous,
353 };
354
355 if self.current_token != Some(Token::RParen) {
356 return Err(Error::Other("Expected )".to_string()));
357 }
358 self.advance()?;
359 Ok(part)
360 }
361
362 fn parse_rel(&mut self) -> Result<(QueryPart, PathLength)> {
363 if self.current_token != Some(Token::LBracket) {
364 return Err(Error::Other("Expected [".to_string()));
365 }
366 self.advance()?;
367
368 let part = match &self.current_token {
373 Some(Token::Colon) => {
374 self.advance()?;
375 if let Some(Token::Identifier(s)) = &self.current_token {
376 let val = s.clone();
377 self.advance()?;
378 QueryPart::Literal(val)
379 } else {
380 return Err(Error::Other("Expected identifier after :".to_string()));
381 }
382 }
383 Some(Token::Identifier(s)) => {
384 let val = s.clone();
385 self.advance()?;
386 if self.current_token == Some(Token::Colon) {
387 self.advance()?;
388 if let Some(Token::Identifier(_)) = &self.current_token {
389 self.advance()?;
390 }
391 }
392 QueryPart::Variable(val)
393 }
394 _ => QueryPart::Anonymous,
395 };
396
397 let mut path_len = PathLength::single();
398 if self.current_token == Some(Token::Star) {
399 self.advance()?;
400 path_len = self.parse_path_length()?;
401 }
402
403 if self.current_token != Some(Token::RBracket) {
404 return Err(Error::Other("Expected ]".to_string()));
405 }
406 self.advance()?;
407 Ok((part, path_len))
408 }
409
410 fn parse_path_length(&mut self) -> Result<PathLength> {
411 let start = self.parse_number("Expected hop length after *")?;
412 if self.current_token == Some(Token::Range) {
413 self.advance()?;
414 let end = self.parse_number("Expected hop range upper bound")?;
415 if end < start {
416 return Err(Error::Other(
417 "Invalid hop range: upper bound < lower".to_string(),
418 ));
419 }
420 Ok(PathLength {
421 min: start,
422 max: end,
423 })
424 } else {
425 Ok(PathLength {
426 min: start,
427 max: start,
428 })
429 }
430 }
431
432 fn parse_number(&mut self, err: &str) -> Result<usize> {
433 if let Some(Token::Number(n)) = &self.current_token {
434 let value =
435 usize::try_from(*n).map_err(|_| Error::Other("Number too large".to_string()))?;
436 self.advance()?;
437 Ok(value)
438 } else {
439 Err(Error::Other(err.to_string()))
440 }
441 }
442}
443
444pub fn execute(db: &Database, query: &str) -> Result<Vec<Triple>> {
445 let mut parser = Parser::new(query);
446 let parsed = parser.parse()?;
447
448 let mut s_criteria = None;
449 let mut p_criteria = None;
450 let mut o_criteria = None;
451
452 match &parsed.subject {
458 QueryPart::Literal(val) => {
459 s_criteria = db.resolve_id(val)?;
460 if s_criteria.is_none() {
461 return Ok(vec![]);
462 } }
464 QueryPart::Variable(name) => {
465 if let Some((w_var, w_val)) = &parsed.where_clause
466 && w_var == name
467 {
468 s_criteria = db.resolve_id(w_val)?;
469 if s_criteria.is_none() {
470 return Ok(vec![]);
471 } }
473 }
474 QueryPart::Anonymous => {}
475 }
476
477 match &parsed.predicate {
479 QueryPart::Literal(val) => {
480 p_criteria = db.resolve_id(val)?;
481 if p_criteria.is_none() {
482 return Ok(vec![]);
483 }
484 }
485 QueryPart::Variable(name) => {
486 if let Some((w_var, w_val)) = &parsed.where_clause
487 && w_var == name
488 {
489 p_criteria = db.resolve_id(w_val)?;
490 if p_criteria.is_none() {
491 return Ok(vec![]);
492 }
493 }
494 }
495 QueryPart::Anonymous => {}
496 }
497
498 match &parsed.object {
500 QueryPart::Literal(val) => {
501 o_criteria = db.resolve_id(val)?;
502 if o_criteria.is_none() {
503 return Ok(vec![]);
504 }
505 }
506 QueryPart::Variable(name) => {
507 if let Some((w_var, w_val)) = &parsed.where_clause
508 && w_var == name
509 {
510 o_criteria = db.resolve_id(w_val)?;
511 if o_criteria.is_none() {
512 return Ok(vec![]);
513 }
514 }
515 }
516 QueryPart::Anonymous => {}
517 }
518
519 let criteria = QueryCriteria {
520 subject_id: s_criteria,
521 predicate_id: p_criteria,
522 object_id: o_criteria,
523 };
524
525 if parsed.path_len.is_single() {
526 Ok(db.query(criteria).collect())
527 } else {
528 execute_variable_path(db, criteria, parsed.path_len)
529 }
530}
531
532fn execute_variable_path(
533 db: &Database,
534 criteria: QueryCriteria,
535 path_len: PathLength,
536) -> Result<Vec<Triple>> {
537 if path_len.min == 0 {
538 return Err(Error::Other("Hop length must be >= 1".to_string()));
539 }
540 let predicate_id = criteria.predicate_id.ok_or_else(|| {
541 Error::Other("Variable length paths require a predicate literal".to_string())
542 })?;
543
544 let mut frontier: Vec<u64> = if let Some(subject) = criteria.subject_id {
545 vec![subject]
546 } else {
547 db.query(QueryCriteria {
548 subject_id: None,
549 predicate_id: Some(predicate_id),
550 object_id: None,
551 })
552 .map(|t| t.subject_id)
553 .collect()
554 };
555 frontier.sort_unstable();
556 frontier.dedup();
557
558 let mut results = Vec::new();
559 let mut depth = 1;
560
561 while depth <= path_len.max && !frontier.is_empty() {
562 let mut next_frontier = Vec::new();
563 let mut seen_next = HashSet::new();
564
565 for subject in &frontier {
566 let triples = db.query(QueryCriteria {
567 subject_id: Some(*subject),
568 predicate_id: Some(predicate_id),
569 object_id: None,
570 });
571 for triple in triples {
572 let target = triple.object_id;
573 if depth >= path_len.min
574 && criteria.object_id.is_none_or(|expected| expected == target)
575 {
576 results.push(triple);
577 }
578 if depth < path_len.max && seen_next.insert(target) {
579 next_frontier.push(target);
580 }
581 }
582 }
583
584 frontier = next_frontier;
585 depth += 1;
586 }
587
588 Ok(results)
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594 use crate::{Database, Fact, Options};
595 use tempfile::tempdir;
596
597 #[test]
598 fn test_lexer() {
599 let input = "MATCH (a)-[:KNOWS]->(b)";
600 let mut lexer = Lexer::new(input);
601 assert_eq!(lexer.next_token().unwrap().unwrap(), Token::Match);
602 assert_eq!(lexer.next_token().unwrap().unwrap(), Token::LParen);
603 assert_eq!(
604 lexer.next_token().unwrap().unwrap(),
605 Token::Identifier("a".to_string())
606 );
607 assert_eq!(lexer.next_token().unwrap().unwrap(), Token::RParen);
608 assert_eq!(lexer.next_token().unwrap().unwrap(), Token::Dash);
609 assert_eq!(lexer.next_token().unwrap().unwrap(), Token::LBracket);
610 assert_eq!(lexer.next_token().unwrap().unwrap(), Token::Colon);
611 assert_eq!(
612 lexer.next_token().unwrap().unwrap(),
613 Token::Identifier("KNOWS".to_string())
614 );
615 assert_eq!(lexer.next_token().unwrap().unwrap(), Token::RBracket);
616 assert_eq!(lexer.next_token().unwrap().unwrap(), Token::Arrow);
617 assert_eq!(lexer.next_token().unwrap().unwrap(), Token::LParen);
618 assert_eq!(
619 lexer.next_token().unwrap().unwrap(),
620 Token::Identifier("b".to_string())
621 );
622 assert_eq!(lexer.next_token().unwrap().unwrap(), Token::RParen);
623 }
624
625 #[test]
626 fn test_parser_simple() {
627 let input = "MATCH (s)-[p]->(o) RETURN o";
628 let mut parser = Parser::new(input);
629 let q = parser.parse().unwrap();
630
631 matches!(q.subject, QueryPart::Variable(s) if s == "s");
632 matches!(q.predicate, QueryPart::Variable(p) if p == "p");
633 matches!(q.object, QueryPart::Variable(o) if o == "o");
634 assert!(q.where_clause.is_none());
635 assert_eq!(q.return_var, "o");
636 assert!(q.path_len.is_single());
637 }
638
639 #[test]
640 fn test_parser_mixed() {
641 let input = "MATCH (:Person)-[:KNOWS]->(o) WHERE o = 'Alice' RETURN o";
642 let mut parser = Parser::new(input);
643 let q = parser.parse().unwrap();
644
645 matches!(q.subject, QueryPart::Literal(s) if s == "Person");
646 matches!(q.predicate, QueryPart::Literal(p) if p == "KNOWS");
647 matches!(q.object, QueryPart::Variable(o) if o == "o");
648
649 assert_eq!(
650 q.where_clause.unwrap(),
651 ("o".to_string(), "Alice".to_string())
652 );
653 assert!(q.path_len.is_single());
654 }
655
656 #[test]
657 fn test_parser_anonymous() {
658 let input = "MATCH ()-[]->() RETURN x";
659 let mut parser = Parser::new(input);
660 let q = parser.parse().unwrap();
661
662 matches!(q.subject, QueryPart::Anonymous);
663 matches!(q.predicate, QueryPart::Anonymous);
664 matches!(q.object, QueryPart::Anonymous);
665 assert!(q.path_len.is_single());
666 }
667
668 #[test]
669 fn test_parser_multi_hop() {
670 let input = "MATCH (a)-[:KNOWS*1..5]->(b) RETURN b";
671 let mut parser = Parser::new(input);
672 let q = parser.parse().unwrap();
673 assert_eq!(q.path_len.min, 1);
674 assert_eq!(q.path_len.max, 5);
675 }
676
677 #[test]
678 fn test_execute_query() {
679 let tmp = tempdir().unwrap();
680 let mut db = Database::open(Options::new(tmp.path())).unwrap();
681
682 db.add_fact(Fact::new("Alice", "KNOWS", "Bob")).unwrap();
684 db.add_fact(Fact::new("Bob", "KNOWS", "Charlie")).unwrap();
685 db.add_fact(Fact::new("Alice", "LIKES", "Coffee")).unwrap();
686
687 let _res = db
689 .execute_query("MATCH (a)-[:KNOWS]->(b) WHERE a = 'Alice' RETURN b")
690 .unwrap();
691 let res = db
695 .execute_query("MATCH (a)-[:LIKES]->(b) RETURN a")
696 .unwrap();
697 if res.is_empty() {
698 assert_eq!(res.len(), 0);
699 } else {
700 let a_id = match res[0].get("a").expect("missing a") {
701 crate::query::executor::Value::Node(id) => *id,
702 _ => panic!("a should be node id"),
703 };
704 assert!(db.resolve_str(a_id).unwrap().is_some());
706 }
707
708 let res = db.execute_query("MATCH (:Bob)-[]->(b) RETURN b").unwrap();
710 if res.is_empty() {
711 assert_eq!(res.len(), 0);
712 } else {
713 let b_id = match res[0].get("b").expect("missing b") {
714 crate::query::executor::Value::Node(id) => *id,
715 _ => panic!("b should be node id"),
716 };
717 assert!(db.resolve_str(b_id).unwrap().is_some());
718 }
719
720 let res = db.execute_query("MATCH ()-[]->() RETURN x").unwrap();
722 assert_eq!(res.len(), 3);
723
724 let res = db
726 .execute_query("MATCH (a)-[]->() WHERE a = 'Nobody' RETURN a")
727 .unwrap();
728 assert_eq!(res.len(), 0);
729 }
730
731 #[test]
732 fn test_execute_optional_match() {
733 let tmp = tempdir().unwrap();
734 let mut db = Database::open(Options::new(tmp.path())).unwrap();
735
736 db.add_fact(Fact::new("Alice", "KNOWS", "Bob")).unwrap();
737 db.add_fact(Fact::new("Charlie", "LIKES", "IceCream"))
738 .unwrap();
739
740 let res = db
741 .execute_query(
742 "MATCH (a)-[:KNOWS]->(b) OPTIONAL MATCH (b)-[:LIKES]->(c) RETURN a, b, c",
743 )
744 .unwrap();
745
746 assert_eq!(res.len(), 1);
747 assert!(matches!(
748 res[0].get("c"),
749 Some(crate::query::executor::Value::Null)
750 ));
751 }
752
753 #[test]
754 fn test_execute_multi_hop() {
755 let tmp = tempdir().unwrap();
756 let mut db = Database::open(Options::new(tmp.path())).unwrap();
757
758 db.add_fact(Fact::new("Alice", "KNOWS", "Bob")).unwrap();
759 db.add_fact(Fact::new("Bob", "KNOWS", "Charlie")).unwrap();
760 db.add_fact(Fact::new("Charlie", "KNOWS", "Dylan")).unwrap();
761 db.add_fact(Fact::new("Alice", "type", "Person")).unwrap();
762
763 let res = db
764 .execute_query("MATCH (start:Person)-[:KNOWS*1..2]->(dst) RETURN dst")
765 .unwrap();
766 let mut ends: Vec<String> = res
767 .iter()
768 .filter_map(|row| match row.get("dst") {
769 Some(crate::query::executor::Value::Node(id)) => db.resolve_str(*id).unwrap(),
770 _ => None,
771 })
772 .collect();
773 ends.sort();
774 ends.dedup();
775 assert_eq!(ends, vec!["Bob".to_string(), "Charlie".to_string()]);
776
777 let err = db
778 .execute_query("MATCH (a)-[p*1..2]->(b) RETURN b")
779 .unwrap_err();
780 assert!(matches!(err, crate::Error::NotImplemented(_)));
781 }
782}