1use crate::term::{Clause, StringInterner, Term, VarId};
2use crate::tokenizer::{Token, TokenKind, Tokenizer};
3use fnv::FnvHashMap;
4
5pub struct Parser<'a> {
8 tokens: Vec<Token>,
9 pos: usize,
10 interner: &'a mut StringInterner,
11 var_map: FnvHashMap<String, VarId>,
12 next_var: VarId,
13}
14
15impl<'a> Parser<'a> {
16 pub fn parse_program(
18 input: &str,
19 interner: &mut StringInterner,
20 ) -> Result<Vec<Clause>, String> {
21 let tokens = Tokenizer::tokenize(input)?;
22 let mut parser = Parser {
23 tokens,
24 pos: 0,
25 interner,
26 var_map: FnvHashMap::default(),
27 next_var: 0,
28 };
29 let mut clauses = Vec::new();
30 while !parser.at_eof() {
31 parser.reset_vars();
32 let clause = parser.parse_clause()?;
33 clauses.push(clause);
34 }
35 Ok(clauses)
36 }
37
38 pub fn parse_query(input: &str, interner: &mut StringInterner) -> Result<Vec<Term>, String> {
41 let tokens = Tokenizer::tokenize(input)?;
42 let mut parser = Parser {
43 tokens,
44 pos: 0,
45 interner,
46 var_map: FnvHashMap::default(),
47 next_var: 0,
48 };
49 if parser.current_kind() == Some(&TokenKind::QueryOp) {
51 parser.advance();
52 }
53 let goals = parser.parse_goal_list()?;
54 if parser.current_kind() == Some(&TokenKind::Dot) {
56 parser.advance();
57 }
58 Ok(goals)
59 }
60
61 fn reset_vars(&mut self) {
62 self.var_map.clear();
63 self.next_var = 0;
64 }
65
66 fn current(&self) -> Option<&Token> {
67 self.tokens.get(self.pos)
68 }
69
70 fn current_kind(&self) -> Option<&TokenKind> {
71 self.current().map(|t| &t.kind)
72 }
73
74 fn at_eof(&self) -> bool {
75 matches!(self.current_kind(), None | Some(TokenKind::Eof))
76 }
77
78 fn advance(&mut self) -> &Token {
79 let tok = &self.tokens[self.pos];
80 self.pos += 1;
81 tok
82 }
83
84 fn expect(&mut self, kind: &TokenKind) -> Result<(), String> {
85 match self.current() {
86 Some(tok) if &tok.kind == kind => {
87 self.advance();
88 Ok(())
89 }
90 Some(tok) => Err(format!(
91 "Expected {:?}, got {:?} at line {} col {}",
92 kind, tok.kind, tok.line, tok.col
93 )),
94 None => Err(format!("Expected {:?}, got end of input", kind)),
95 }
96 }
97
98 fn parse_clause(&mut self) -> Result<Clause, String> {
99 let head = self.parse_term()?;
100 match self.current_kind() {
101 Some(TokenKind::Dot) => {
102 self.advance();
103 Ok(Clause { head, body: vec![] })
104 }
105 Some(TokenKind::Neck) => {
106 self.advance();
107 let body = self.parse_goal_list()?;
108 self.expect(&TokenKind::Dot)?;
109 Ok(Clause { head, body })
110 }
111 Some(tok) => {
112 let tok = tok.clone();
113 Err(format!(
114 "Expected '.' or ':-', got {:?} at line {} col {}",
115 tok,
116 self.current().unwrap().line,
117 self.current().unwrap().col
118 ))
119 }
120 None => Err("Unexpected end of input in clause".to_string()),
121 }
122 }
123
124 fn parse_goal_list(&mut self) -> Result<Vec<Term>, String> {
125 let body = self.parse_goal_disjunction()?;
128 Ok(vec![body])
129 }
130
131 fn parse_goal_disjunction(&mut self) -> Result<Term, String> {
133 let left = self.parse_goal_conjunction()?;
134 if self.current_kind() == Some(&TokenKind::Semicolon) {
135 self.advance();
136 let right = self.parse_goal_disjunction()?;
137 let functor = self.interner.intern(";");
138 Ok(Term::Compound {
139 functor,
140 args: vec![left, right],
141 })
142 } else {
143 Ok(left)
144 }
145 }
146
147 fn parse_goal_conjunction(&mut self) -> Result<Term, String> {
149 let first = self.parse_term()?;
150 if self.current_kind() == Some(&TokenKind::Comma) {
151 let mut goals = vec![first];
152 while self.current_kind() == Some(&TokenKind::Comma) {
153 self.advance();
154 goals.push(self.parse_term()?);
155 }
156 let comma = self.interner.intern(",");
158 let mut result = goals.pop().unwrap();
159 while let Some(g) = goals.pop() {
160 result = Term::Compound {
161 functor: comma,
162 args: vec![g, result],
163 };
164 }
165 Ok(result)
166 } else {
167 Ok(first)
168 }
169 }
170
171 fn parse_term(&mut self) -> Result<Term, String> {
173 self.parse_expr_700()
174 }
175
176 fn parse_expr_700(&mut self) -> Result<Term, String> {
178 let left = self.parse_expr_500()?;
179 if let Some(op) = self.match_op_700() {
180 let right = self.parse_expr_500()?;
181 Ok(self.build_binop(&op, left, right))
182 } else {
183 Ok(left)
184 }
185 }
186
187 fn match_op_700(&mut self) -> Option<String> {
188 let op = match self.current_kind()? {
189 TokenKind::Is => "is",
190 TokenKind::Equals => "=",
191 TokenKind::NotEquals => "\\=",
192 TokenKind::Lt => "<",
193 TokenKind::Gt => ">",
194 TokenKind::Lte => "=<",
195 TokenKind::Gte => ">=",
196 TokenKind::ArithEq => "=:=",
197 TokenKind::ArithNeq => "=\\=",
198 TokenKind::Atom(s)
199 if s == "@<" || s == "@>" || s == "@=<" || s == "@>=" || s == "=.." =>
200 {
201 let op = s.clone();
202 self.advance();
203 return Some(op);
204 }
205 _ => return None,
206 };
207 self.advance();
208 Some(op.to_string())
209 }
210
211 fn parse_expr_500(&mut self) -> Result<Term, String> {
213 let mut left = self.parse_expr_400()?;
214 loop {
215 let op = match self.current_kind() {
216 Some(TokenKind::Plus) => "+",
217 Some(TokenKind::Minus) => "-",
218 _ => break,
219 };
220 let op = op.to_string();
221 self.advance();
222 let right = self.parse_expr_400()?;
223 left = self.build_binop(&op, left, right);
224 }
225 Ok(left)
226 }
227
228 fn parse_expr_400(&mut self) -> Result<Term, String> {
230 let mut left = self.parse_primary()?;
231 loop {
232 let op = match self.current_kind() {
233 Some(TokenKind::Star) => "*",
234 Some(TokenKind::Slash) => "/",
235 Some(TokenKind::IntDiv) => "//",
236 Some(TokenKind::Mod) => "mod",
237 Some(TokenKind::Rem) => "rem",
238 _ => break,
239 };
240 let op = op.to_string();
241 self.advance();
242 let right = self.parse_primary()?;
243 left = self.build_binop(&op, left, right);
244 }
245 Ok(left)
246 }
247
248 fn build_binop(&mut self, op: &str, left: Term, right: Term) -> Term {
249 let functor = self.interner.intern(op);
250 Term::Compound {
251 functor,
252 args: vec![left, right],
253 }
254 }
255
256 fn parse_primary(&mut self) -> Result<Term, String> {
257 match self.current_kind().cloned() {
258 Some(TokenKind::Integer(n)) => {
259 self.advance();
260 Ok(Term::Integer(n))
261 }
262 Some(TokenKind::Float(f)) => {
263 self.advance();
264 Ok(Term::Float(f))
265 }
266 Some(TokenKind::Variable(ref name)) => {
267 let name = name.clone();
268 self.advance();
269 if name == "_" {
270 let id = self.next_var;
272 self.next_var += 1;
273 Ok(Term::Var(id))
274 } else if let Some(&id) = self.var_map.get(&name) {
275 Ok(Term::Var(id))
276 } else {
277 let id = self.next_var;
278 self.next_var += 1;
279 self.var_map.insert(name, id);
280 Ok(Term::Var(id))
281 }
282 }
283 Some(TokenKind::Atom(ref name)) => {
284 let name = name.clone();
285 self.advance();
286 if self.current_kind() == Some(&TokenKind::LParen) {
288 self.advance(); let args = self.parse_arg_list()?;
290 self.expect(&TokenKind::RParen)?;
291 let functor = self.interner.intern(&name);
292 Ok(Term::Compound { functor, args })
293 } else {
294 let id = self.interner.intern(&name);
295 Ok(Term::Atom(id))
296 }
297 }
298 Some(TokenKind::LParen) => {
299 self.advance();
300 let term = self.parse_paren_body()?;
301 self.expect(&TokenKind::RParen)?;
302 Ok(term)
303 }
304 Some(TokenKind::Minus) => {
305 self.advance();
306 let operand = self.parse_primary()?;
307 match operand {
309 Term::Integer(n) => Ok(Term::Integer(-n)),
310 Term::Float(f) => Ok(Term::Float(-f)),
311 _ => {
312 let functor = self.interner.intern("-");
313 Ok(Term::Compound {
314 functor,
315 args: vec![operand],
316 })
317 }
318 }
319 }
320 Some(TokenKind::LBracket) => {
321 self.advance(); self.parse_list_body()
323 }
324 Some(TokenKind::Cut) => {
325 self.advance();
326 let id = self.interner.intern("!");
327 Ok(Term::Atom(id))
328 }
329 Some(TokenKind::Not) => {
330 self.advance();
332 let goal = self.parse_term()?;
333 let functor = self.interner.intern("\\+");
334 Ok(Term::Compound {
335 functor,
336 args: vec![goal],
337 })
338 }
339 Some(ref tok) => {
340 let msg = format!(
341 "Unexpected token {:?} at line {} col {}",
342 tok,
343 self.current().unwrap().line,
344 self.current().unwrap().col
345 );
346 Err(msg)
347 }
348 None => Err("Unexpected end of input".to_string()),
349 }
350 }
351
352 fn parse_paren_body(&mut self) -> Result<Term, String> {
355 let first = self.parse_paren_comma_list()?;
356
357 if self.current_kind() == Some(&TokenKind::Arrow) {
358 self.advance();
360 let then = self.parse_paren_comma_list()?;
361 let arrow_functor = self.interner.intern("->");
362 let if_then = Term::Compound {
363 functor: arrow_functor,
364 args: vec![first, then],
365 };
366 if self.current_kind() == Some(&TokenKind::Semicolon) {
367 self.advance();
368 let else_branch = self.parse_paren_body()?;
369 let semi_functor = self.interner.intern(";");
370 Ok(Term::Compound {
371 functor: semi_functor,
372 args: vec![if_then, else_branch],
373 })
374 } else {
375 Ok(if_then)
376 }
377 } else if self.current_kind() == Some(&TokenKind::Semicolon) {
378 self.advance();
380 let right = self.parse_paren_body()?;
381 let functor = self.interner.intern(";");
382 Ok(Term::Compound {
383 functor,
384 args: vec![first, right],
385 })
386 } else {
387 Ok(first)
388 }
389 }
390
391 fn parse_paren_comma_list(&mut self) -> Result<Term, String> {
393 let first = self.parse_term()?;
394 if self.current_kind() == Some(&TokenKind::Comma) {
395 self.advance();
398 let rest = self.parse_paren_comma_list()?;
399 let functor = self.interner.intern(",");
400 Ok(Term::Compound {
401 functor,
402 args: vec![first, rest],
403 })
404 } else {
405 Ok(first)
406 }
407 }
408
409 fn parse_arg_list(&mut self) -> Result<Vec<Term>, String> {
410 let mut args = vec![self.parse_term()?];
411 while self.current_kind() == Some(&TokenKind::Comma) {
412 self.advance();
413 args.push(self.parse_term()?);
414 }
415 Ok(args)
416 }
417
418 fn parse_list_body(&mut self) -> Result<Term, String> {
419 if self.current_kind() == Some(&TokenKind::RBracket) {
421 self.advance();
422 let nil = self.interner.intern("[]");
423 return Ok(Term::Atom(nil));
424 }
425
426 let first = self.parse_term()?;
427 self.parse_list_tail(first)
428 }
429
430 fn parse_list_tail(&mut self, head: Term) -> Result<Term, String> {
431 match self.current_kind() {
432 Some(TokenKind::Comma) => {
433 self.advance();
434 let next_head = self.parse_term()?;
435 let tail = self.parse_list_tail(next_head)?;
436 Ok(Term::List {
437 head: Box::new(head),
438 tail: Box::new(tail),
439 })
440 }
441 Some(TokenKind::Pipe) => {
442 self.advance();
443 let tail = self.parse_term()?;
444 self.expect(&TokenKind::RBracket)?;
445 Ok(Term::List {
446 head: Box::new(head),
447 tail: Box::new(tail),
448 })
449 }
450 Some(TokenKind::RBracket) => {
451 self.advance();
452 let nil = self.interner.intern("[]");
453 Ok(Term::List {
454 head: Box::new(head),
455 tail: Box::new(Term::Atom(nil)),
456 })
457 }
458 _ => Err("Expected ',', '|', or ']' in list".to_string()),
459 }
460 }
461
462 pub fn var_names(&self) -> &FnvHashMap<String, VarId> {
464 &self.var_map
465 }
466
467 pub fn parse_query_with_vars(
469 input: &str,
470 interner: &mut StringInterner,
471 ) -> Result<(Vec<Term>, FnvHashMap<String, VarId>), String> {
472 let tokens = Tokenizer::tokenize(input)?;
473 let mut parser = Parser {
474 tokens,
475 pos: 0,
476 interner,
477 var_map: FnvHashMap::default(),
478 next_var: 0,
479 };
480 if parser.current_kind() == Some(&TokenKind::QueryOp) {
481 parser.advance();
482 }
483 let goals = parser.parse_goal_list()?;
484 if parser.current_kind() == Some(&TokenKind::Dot) {
485 parser.advance();
486 }
487 let vars = parser.var_map;
488 Ok((goals, vars))
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 fn parse_term(input: &str) -> (Term, StringInterner) {
497 let mut interner = StringInterner::new();
498 let goals = Parser::parse_query(input, &mut interner).unwrap();
499 assert_eq!(goals.len(), 1);
500 (goals.into_iter().next().unwrap(), interner)
501 }
502
503 fn parse_clauses(input: &str) -> (Vec<Clause>, StringInterner) {
504 let mut interner = StringInterner::new();
505 let clauses = Parser::parse_program(input, &mut interner).unwrap();
506 (clauses, interner)
507 }
508
509 #[test]
510 fn test_parse_atom() {
511 let (term, interner) = parse_term("hello");
512 match term {
513 Term::Atom(id) => assert_eq!(interner.resolve(id), "hello"),
514 _ => panic!("Expected atom"),
515 }
516 }
517
518 #[test]
519 fn test_parse_integer() {
520 let (term, _) = parse_term("42");
521 assert_eq!(term, Term::Integer(42));
522 }
523
524 #[test]
525 fn test_parse_float() {
526 let (term, _) = parse_term("3.14");
527 assert_eq!(term, Term::Float(3.14));
528 }
529
530 #[test]
531 fn test_parse_variable() {
532 let (term, _) = parse_term("X");
533 match term {
534 Term::Var(_) => {}
535 _ => panic!("Expected variable"),
536 }
537 }
538
539 #[test]
540 fn test_parse_compound() {
541 let (term, interner) = parse_term("parent(tom, mary)");
542 match term {
543 Term::Compound { functor, args } => {
544 assert_eq!(interner.resolve(functor), "parent");
545 assert_eq!(args.len(), 2);
546 }
547 _ => panic!("Expected compound"),
548 }
549 }
550
551 #[test]
552 fn test_parse_nested_compound() {
553 let (term, interner) = parse_term("outer(inner(deep(hello)))");
554 match term {
555 Term::Compound { functor, ref args } => {
556 assert_eq!(interner.resolve(functor), "outer");
557 match &args[0] {
558 Term::Compound { functor, ref args } => {
559 assert_eq!(interner.resolve(*functor), "inner");
560 match &args[0] {
561 Term::Compound { functor, ref args } => {
562 assert_eq!(interner.resolve(*functor), "deep");
563 match &args[0] {
564 Term::Atom(id) => assert_eq!(interner.resolve(*id), "hello"),
565 _ => panic!("Expected atom"),
566 }
567 }
568 _ => panic!("Expected compound"),
569 }
570 }
571 _ => panic!("Expected compound"),
572 }
573 }
574 _ => panic!("Expected compound"),
575 }
576 }
577
578 #[test]
579 fn test_parse_fact() {
580 let (clauses, interner) = parse_clauses("likes(mary, food).");
581 assert_eq!(clauses.len(), 1);
582 assert!(clauses[0].body.is_empty());
583 match &clauses[0].head {
584 Term::Compound { functor, args } => {
585 assert_eq!(interner.resolve(*functor), "likes");
586 assert_eq!(args.len(), 2);
587 }
588 _ => panic!("Expected compound"),
589 }
590 }
591
592 #[test]
593 fn test_parse_rule() {
594 let (clauses, interner) = parse_clauses("happy(X) :- likes(X, food).");
595 assert_eq!(clauses.len(), 1);
596 assert_eq!(clauses[0].body.len(), 1);
597 match &clauses[0].head {
598 Term::Compound { functor, .. } => {
599 assert_eq!(interner.resolve(*functor), "happy");
600 }
601 _ => panic!("Expected compound"),
602 }
603 }
604
605 #[test]
606 fn test_variable_scoping() {
607 let (clauses, _) = parse_clauses("foo(X, Y) :- bar(X, Y).");
609 let clause = &clauses[0];
610 if let Term::Compound {
612 args: head_args, ..
613 } = &clause.head
614 {
615 if let (Term::Var(hx), Term::Var(hy)) = (&head_args[0], &head_args[1]) {
616 if let Term::Compound {
618 args: body_args, ..
619 } = &clause.body[0]
620 {
621 if let (Term::Var(bx), Term::Var(by)) = (&body_args[0], &body_args[1]) {
622 assert_eq!(hx, bx, "X in head and body should be same var");
623 assert_eq!(hy, by, "Y in head and body should be same var");
624 assert_ne!(hx, hy, "X and Y should be different vars");
625 }
626 }
627 }
628 }
629 }
630
631 #[test]
632 fn test_operator_precedence() {
633 let (term, interner) = parse_term("2 + 3 * 4");
635 match term {
636 Term::Compound { functor, ref args } => {
637 assert_eq!(interner.resolve(functor), "+");
638 assert_eq!(args[0], Term::Integer(2));
639 match &args[1] {
640 Term::Compound { functor, ref args } => {
641 assert_eq!(interner.resolve(*functor), "*");
642 assert_eq!(args[0], Term::Integer(3));
643 assert_eq!(args[1], Term::Integer(4));
644 }
645 _ => panic!("Expected compound for 3*4"),
646 }
647 }
648 _ => panic!("Expected compound for addition"),
649 }
650 }
651
652 #[test]
653 fn test_parenthesized_expr() {
654 let (term, interner) = parse_term("(2 + 3) * 4");
656 match term {
657 Term::Compound { functor, ref args } => {
658 assert_eq!(interner.resolve(functor), "*");
659 match &args[0] {
660 Term::Compound { functor, ref args } => {
661 assert_eq!(interner.resolve(*functor), "+");
662 assert_eq!(args[0], Term::Integer(2));
663 assert_eq!(args[1], Term::Integer(3));
664 }
665 _ => panic!("Expected compound for addition"),
666 }
667 assert_eq!(args[1], Term::Integer(4));
668 }
669 _ => panic!("Expected compound for multiplication"),
670 }
671 }
672
673 #[test]
674 fn test_is_expression() {
675 let (term, interner) = parse_term("X is 2 + 3");
676 match term {
677 Term::Compound { functor, args } => {
678 assert_eq!(interner.resolve(functor), "is");
679 assert!(matches!(args[0], Term::Var(_)));
680 match &args[1] {
681 Term::Compound { functor, .. } => {
682 assert_eq!(interner.resolve(*functor), "+");
683 }
684 _ => panic!("Expected compound"),
685 }
686 }
687 _ => panic!("Expected compound"),
688 }
689 }
690
691 #[test]
692 fn test_unary_minus() {
693 let (term, _) = parse_term("- 5");
694 assert_eq!(term, Term::Integer(-5));
695 }
696
697 #[test]
698 fn test_empty_list() {
699 let (term, interner) = parse_term("[]");
700 match term {
701 Term::Atom(id) => assert_eq!(interner.resolve(id), "[]"),
702 _ => panic!("Expected empty list atom"),
703 }
704 }
705
706 #[test]
707 fn test_simple_list() {
708 let (term, interner) = parse_term("[1, 2, 3]");
709 match term {
711 Term::List { ref head, ref tail } => {
712 assert_eq!(**head, Term::Integer(1));
713 match tail.as_ref() {
714 Term::List { ref head, ref tail } => {
715 assert_eq!(**head, Term::Integer(2));
716 match tail.as_ref() {
717 Term::List { ref head, ref tail } => {
718 assert_eq!(**head, Term::Integer(3));
719 match tail.as_ref() {
720 Term::Atom(id) => assert_eq!(interner.resolve(*id), "[]"),
721 _ => panic!("Expected nil"),
722 }
723 }
724 _ => panic!("Expected list"),
725 }
726 }
727 _ => panic!("Expected list"),
728 }
729 }
730 _ => panic!("Expected list, got {:?}", term),
731 }
732 }
733
734 #[test]
735 fn test_head_tail_list() {
736 let (term, _) = parse_term("[H | T]");
737 match term {
738 Term::List { head, tail } => {
739 assert!(matches!(*head, Term::Var(_)));
740 assert!(matches!(*tail, Term::Var(_)));
741 }
742 _ => panic!("Expected list"),
743 }
744 }
745
746 #[test]
747 fn test_multiple_clauses() {
748 let (clauses, _) = parse_clauses("a. b. c.");
749 assert_eq!(clauses.len(), 3);
750 }
751
752 #[test]
753 fn test_parse_error() {
754 let mut interner = StringInterner::new();
755 let result = Parser::parse_program("invalid(((.", &mut interner);
756 assert!(result.is_err());
757 }
758
759 #[test]
760 fn test_comparison_operators() {
761 let (term, interner) = parse_term("X > 100");
762 match term {
763 Term::Compound { functor, .. } => {
764 assert_eq!(interner.resolve(functor), ">");
765 }
766 _ => panic!("Expected compound"),
767 }
768 }
769
770 #[test]
771 fn test_cut() {
772 let (clauses, interner) = parse_clauses("max(X, Y, X) :- X >= Y, !.");
773 assert_eq!(clauses[0].body.len(), 1);
775 match &clauses[0].body[0] {
776 Term::Compound { functor, args } => {
777 assert_eq!(interner.resolve(*functor), ",");
778 assert_eq!(args.len(), 2);
779 match &args[1] {
780 Term::Atom(id) => assert_eq!(interner.resolve(*id), "!"),
781 _ => panic!("Expected cut atom"),
782 }
783 }
784 _ => panic!("Expected conjunction"),
785 }
786 }
787
788 #[test]
789 fn test_negation() {
790 let (term, interner) = parse_term("\\+ foo(X)");
791 match term {
792 Term::Compound { functor, args } => {
793 assert_eq!(interner.resolve(functor), "\\+");
794 assert_eq!(args.len(), 1);
795 }
796 _ => panic!("Expected compound"),
797 }
798 }
799}