1use crate::term::{Clause, StringInterner, Term, VarId};
2use crate::tokenizer::{Token, TokenKind, Tokenizer};
3use fnv::FnvHashMap;
4
5pub struct Parser<'a> {
8 tokens: Vec<Token>,
9 pos: usize,
10 interner: &'a mut StringInterner,
11 var_map: FnvHashMap<String, VarId>,
12 next_var: VarId,
13}
14
15impl<'a> Parser<'a> {
16 pub fn parse_program(
18 input: &str,
19 interner: &mut StringInterner,
20 ) -> Result<Vec<Clause>, String> {
21 let tokens = Tokenizer::tokenize(input)?;
22 let mut parser = Parser {
23 tokens,
24 pos: 0,
25 interner,
26 var_map: FnvHashMap::default(),
27 next_var: 0,
28 };
29 let mut clauses = Vec::new();
30 while !parser.at_eof() {
31 parser.reset_vars();
32 let clause = parser.parse_clause()?;
33 clauses.push(clause);
34 }
35 Ok(clauses)
36 }
37
38 pub fn parse_query(input: &str, interner: &mut StringInterner) -> Result<Vec<Term>, String> {
41 let tokens = Tokenizer::tokenize(input)?;
42 let mut parser = Parser {
43 tokens,
44 pos: 0,
45 interner,
46 var_map: FnvHashMap::default(),
47 next_var: 0,
48 };
49 if parser.current_kind() == Some(&TokenKind::QueryOp) {
51 parser.advance();
52 }
53 let goals = parser.parse_goal_list()?;
54 if parser.current_kind() == Some(&TokenKind::Dot) {
56 parser.advance();
57 }
58 Ok(goals)
59 }
60
61 fn reset_vars(&mut self) {
62 self.var_map.clear();
63 self.next_var = 0;
64 }
65
66 fn current(&self) -> Option<&Token> {
67 self.tokens.get(self.pos)
68 }
69
70 fn current_kind(&self) -> Option<&TokenKind> {
71 self.current().map(|t| &t.kind)
72 }
73
74 fn at_eof(&self) -> bool {
75 matches!(self.current_kind(), None | Some(TokenKind::Eof))
76 }
77
78 fn advance(&mut self) -> &Token {
79 let tok = &self.tokens[self.pos];
80 self.pos += 1;
81 tok
82 }
83
84 fn expect(&mut self, kind: &TokenKind) -> Result<(), String> {
85 match self.current() {
86 Some(tok) if &tok.kind == kind => {
87 self.advance();
88 Ok(())
89 }
90 Some(tok) => Err(format!(
91 "Expected {:?}, got {:?} at line {} col {}",
92 kind, tok.kind, tok.line, tok.col
93 )),
94 None => Err(format!("Expected {:?}, got end of input", kind)),
95 }
96 }
97
98 fn parse_clause(&mut self) -> Result<Clause, String> {
99 let head = self.parse_term()?;
100 match self.current_kind() {
101 Some(TokenKind::Dot) => {
102 self.advance();
103 Ok(Clause { head, body: vec![] })
104 }
105 Some(TokenKind::Neck) => {
106 self.advance();
107 let body = self.parse_goal_list()?;
108 self.expect(&TokenKind::Dot)?;
109 Ok(Clause { head, body })
110 }
111 Some(tok) => {
112 let tok = tok.clone();
113 Err(format!(
114 "Expected '.' or ':-', got {:?} at line {} col {}",
115 tok,
116 self.current().unwrap().line,
117 self.current().unwrap().col
118 ))
119 }
120 None => Err("Unexpected end of input in clause".to_string()),
121 }
122 }
123
124 fn parse_goal_list(&mut self) -> Result<Vec<Term>, String> {
125 let mut goals = vec![self.parse_goal_disjunction()?];
126 while self.current_kind() == Some(&TokenKind::Comma) {
127 self.advance();
128 goals.push(self.parse_goal_disjunction()?);
129 }
130 Ok(goals)
131 }
132
133 fn parse_goal_disjunction(&mut self) -> Result<Term, String> {
135 let left = self.parse_term()?;
136 if self.current_kind() == Some(&TokenKind::Semicolon) {
137 self.advance();
138 let right = self.parse_goal_disjunction()?;
139 let functor = self.interner.intern(";");
140 Ok(Term::Compound {
141 functor,
142 args: vec![left, right],
143 })
144 } else {
145 Ok(left)
146 }
147 }
148
149 fn parse_term(&mut self) -> Result<Term, String> {
151 self.parse_expr_700()
152 }
153
154 fn parse_expr_700(&mut self) -> Result<Term, String> {
156 let left = self.parse_expr_500()?;
157 if let Some(op) = self.match_op_700() {
158 let right = self.parse_expr_500()?;
159 Ok(self.build_binop(&op, left, right))
160 } else {
161 Ok(left)
162 }
163 }
164
165 fn match_op_700(&mut self) -> Option<String> {
166 let op = match self.current_kind()? {
167 TokenKind::Is => "is",
168 TokenKind::Equals => "=",
169 TokenKind::NotEquals => "\\=",
170 TokenKind::Lt => "<",
171 TokenKind::Gt => ">",
172 TokenKind::Lte => "=<",
173 TokenKind::Gte => ">=",
174 TokenKind::ArithEq => "=:=",
175 TokenKind::ArithNeq => "=\\=",
176 TokenKind::Atom(s)
177 if s == "@<" || s == "@>" || s == "@=<" || s == "@>=" || s == "=.." =>
178 {
179 let op = s.clone();
180 self.advance();
181 return Some(op);
182 }
183 _ => return None,
184 };
185 self.advance();
186 Some(op.to_string())
187 }
188
189 fn parse_expr_500(&mut self) -> Result<Term, String> {
191 let mut left = self.parse_expr_400()?;
192 loop {
193 let op = match self.current_kind() {
194 Some(TokenKind::Plus) => "+",
195 Some(TokenKind::Minus) => "-",
196 _ => break,
197 };
198 let op = op.to_string();
199 self.advance();
200 let right = self.parse_expr_400()?;
201 left = self.build_binop(&op, left, right);
202 }
203 Ok(left)
204 }
205
206 fn parse_expr_400(&mut self) -> Result<Term, String> {
208 let mut left = self.parse_primary()?;
209 loop {
210 let op = match self.current_kind() {
211 Some(TokenKind::Star) => "*",
212 Some(TokenKind::Slash) => "/",
213 Some(TokenKind::Mod) => "mod",
214 _ => break,
215 };
216 let op = op.to_string();
217 self.advance();
218 let right = self.parse_primary()?;
219 left = self.build_binop(&op, left, right);
220 }
221 Ok(left)
222 }
223
224 fn build_binop(&mut self, op: &str, left: Term, right: Term) -> Term {
225 let functor = self.interner.intern(op);
226 Term::Compound {
227 functor,
228 args: vec![left, right],
229 }
230 }
231
232 fn parse_primary(&mut self) -> Result<Term, String> {
233 match self.current_kind().cloned() {
234 Some(TokenKind::Integer(n)) => {
235 self.advance();
236 Ok(Term::Integer(n))
237 }
238 Some(TokenKind::Float(f)) => {
239 self.advance();
240 Ok(Term::Float(f))
241 }
242 Some(TokenKind::Variable(ref name)) => {
243 let name = name.clone();
244 self.advance();
245 if name == "_" {
246 let id = self.next_var;
248 self.next_var += 1;
249 Ok(Term::Var(id))
250 } else if let Some(&id) = self.var_map.get(&name) {
251 Ok(Term::Var(id))
252 } else {
253 let id = self.next_var;
254 self.next_var += 1;
255 self.var_map.insert(name, id);
256 Ok(Term::Var(id))
257 }
258 }
259 Some(TokenKind::Atom(ref name)) => {
260 let name = name.clone();
261 self.advance();
262 if self.current_kind() == Some(&TokenKind::LParen) {
264 self.advance(); let args = self.parse_arg_list()?;
266 self.expect(&TokenKind::RParen)?;
267 let functor = self.interner.intern(&name);
268 Ok(Term::Compound { functor, args })
269 } else {
270 let id = self.interner.intern(&name);
271 Ok(Term::Atom(id))
272 }
273 }
274 Some(TokenKind::LParen) => {
275 self.advance();
276 let term = self.parse_paren_body()?;
277 self.expect(&TokenKind::RParen)?;
278 Ok(term)
279 }
280 Some(TokenKind::Minus) => {
281 self.advance();
282 let operand = self.parse_primary()?;
283 match operand {
285 Term::Integer(n) => Ok(Term::Integer(-n)),
286 Term::Float(f) => Ok(Term::Float(-f)),
287 _ => {
288 let functor = self.interner.intern("-");
289 Ok(Term::Compound {
290 functor,
291 args: vec![operand],
292 })
293 }
294 }
295 }
296 Some(TokenKind::LBracket) => {
297 self.advance(); self.parse_list_body()
299 }
300 Some(TokenKind::Cut) => {
301 self.advance();
302 let id = self.interner.intern("!");
303 Ok(Term::Atom(id))
304 }
305 Some(TokenKind::Not) => {
306 self.advance();
308 let goal = self.parse_term()?;
309 let functor = self.interner.intern("\\+");
310 Ok(Term::Compound {
311 functor,
312 args: vec![goal],
313 })
314 }
315 Some(ref tok) => {
316 let msg = format!(
317 "Unexpected token {:?} at line {} col {}",
318 tok,
319 self.current().unwrap().line,
320 self.current().unwrap().col
321 );
322 Err(msg)
323 }
324 None => Err("Unexpected end of input".to_string()),
325 }
326 }
327
328 fn parse_paren_body(&mut self) -> Result<Term, String> {
331 let first = self.parse_paren_comma_list()?;
332
333 if self.current_kind() == Some(&TokenKind::Arrow) {
334 self.advance();
336 let then = self.parse_paren_comma_list()?;
337 let arrow_functor = self.interner.intern("->");
338 let if_then = Term::Compound {
339 functor: arrow_functor,
340 args: vec![first, then],
341 };
342 if self.current_kind() == Some(&TokenKind::Semicolon) {
343 self.advance();
344 let else_branch = self.parse_paren_body()?;
345 let semi_functor = self.interner.intern(";");
346 Ok(Term::Compound {
347 functor: semi_functor,
348 args: vec![if_then, else_branch],
349 })
350 } else {
351 Ok(if_then)
352 }
353 } else if self.current_kind() == Some(&TokenKind::Semicolon) {
354 self.advance();
356 let right = self.parse_paren_body()?;
357 let functor = self.interner.intern(";");
358 Ok(Term::Compound {
359 functor,
360 args: vec![first, right],
361 })
362 } else {
363 Ok(first)
364 }
365 }
366
367 fn parse_paren_comma_list(&mut self) -> Result<Term, String> {
369 let first = self.parse_term()?;
370 if self.current_kind() == Some(&TokenKind::Comma) {
371 self.advance();
374 let rest = self.parse_paren_comma_list()?;
375 let functor = self.interner.intern(",");
376 Ok(Term::Compound {
377 functor,
378 args: vec![first, rest],
379 })
380 } else {
381 Ok(first)
382 }
383 }
384
385 fn parse_arg_list(&mut self) -> Result<Vec<Term>, String> {
386 let mut args = vec![self.parse_term()?];
387 while self.current_kind() == Some(&TokenKind::Comma) {
388 self.advance();
389 args.push(self.parse_term()?);
390 }
391 Ok(args)
392 }
393
394 fn parse_list_body(&mut self) -> Result<Term, String> {
395 if self.current_kind() == Some(&TokenKind::RBracket) {
397 self.advance();
398 let nil = self.interner.intern("[]");
399 return Ok(Term::Atom(nil));
400 }
401
402 let first = self.parse_term()?;
403 self.parse_list_tail(first)
404 }
405
406 fn parse_list_tail(&mut self, head: Term) -> Result<Term, String> {
407 match self.current_kind() {
408 Some(TokenKind::Comma) => {
409 self.advance();
410 let next_head = self.parse_term()?;
411 let tail = self.parse_list_tail(next_head)?;
412 Ok(Term::List {
413 head: Box::new(head),
414 tail: Box::new(tail),
415 })
416 }
417 Some(TokenKind::Pipe) => {
418 self.advance();
419 let tail = self.parse_term()?;
420 self.expect(&TokenKind::RBracket)?;
421 Ok(Term::List {
422 head: Box::new(head),
423 tail: Box::new(tail),
424 })
425 }
426 Some(TokenKind::RBracket) => {
427 self.advance();
428 let nil = self.interner.intern("[]");
429 Ok(Term::List {
430 head: Box::new(head),
431 tail: Box::new(Term::Atom(nil)),
432 })
433 }
434 _ => Err("Expected ',', '|', or ']' in list".to_string()),
435 }
436 }
437
438 pub fn var_names(&self) -> &FnvHashMap<String, VarId> {
440 &self.var_map
441 }
442
443 pub fn parse_query_with_vars(
445 input: &str,
446 interner: &mut StringInterner,
447 ) -> Result<(Vec<Term>, FnvHashMap<String, VarId>), String> {
448 let tokens = Tokenizer::tokenize(input)?;
449 let mut parser = Parser {
450 tokens,
451 pos: 0,
452 interner,
453 var_map: FnvHashMap::default(),
454 next_var: 0,
455 };
456 if parser.current_kind() == Some(&TokenKind::QueryOp) {
457 parser.advance();
458 }
459 let goals = parser.parse_goal_list()?;
460 if parser.current_kind() == Some(&TokenKind::Dot) {
461 parser.advance();
462 }
463 let vars = parser.var_map;
464 Ok((goals, vars))
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471
472 fn parse_term(input: &str) -> (Term, StringInterner) {
473 let mut interner = StringInterner::new();
474 let goals = Parser::parse_query(input, &mut interner).unwrap();
475 assert_eq!(goals.len(), 1);
476 (goals.into_iter().next().unwrap(), interner)
477 }
478
479 fn parse_clauses(input: &str) -> (Vec<Clause>, StringInterner) {
480 let mut interner = StringInterner::new();
481 let clauses = Parser::parse_program(input, &mut interner).unwrap();
482 (clauses, interner)
483 }
484
485 #[test]
486 fn test_parse_atom() {
487 let (term, interner) = parse_term("hello");
488 match term {
489 Term::Atom(id) => assert_eq!(interner.resolve(id), "hello"),
490 _ => panic!("Expected atom"),
491 }
492 }
493
494 #[test]
495 fn test_parse_integer() {
496 let (term, _) = parse_term("42");
497 assert_eq!(term, Term::Integer(42));
498 }
499
500 #[test]
501 fn test_parse_float() {
502 let (term, _) = parse_term("3.14");
503 assert_eq!(term, Term::Float(3.14));
504 }
505
506 #[test]
507 fn test_parse_variable() {
508 let (term, _) = parse_term("X");
509 match term {
510 Term::Var(_) => {}
511 _ => panic!("Expected variable"),
512 }
513 }
514
515 #[test]
516 fn test_parse_compound() {
517 let (term, interner) = parse_term("parent(tom, mary)");
518 match term {
519 Term::Compound { functor, args } => {
520 assert_eq!(interner.resolve(functor), "parent");
521 assert_eq!(args.len(), 2);
522 }
523 _ => panic!("Expected compound"),
524 }
525 }
526
527 #[test]
528 fn test_parse_nested_compound() {
529 let (term, interner) = parse_term("outer(inner(deep(hello)))");
530 match term {
531 Term::Compound { functor, ref args } => {
532 assert_eq!(interner.resolve(functor), "outer");
533 match &args[0] {
534 Term::Compound { functor, ref args } => {
535 assert_eq!(interner.resolve(*functor), "inner");
536 match &args[0] {
537 Term::Compound { functor, ref args } => {
538 assert_eq!(interner.resolve(*functor), "deep");
539 match &args[0] {
540 Term::Atom(id) => assert_eq!(interner.resolve(*id), "hello"),
541 _ => panic!("Expected atom"),
542 }
543 }
544 _ => panic!("Expected compound"),
545 }
546 }
547 _ => panic!("Expected compound"),
548 }
549 }
550 _ => panic!("Expected compound"),
551 }
552 }
553
554 #[test]
555 fn test_parse_fact() {
556 let (clauses, interner) = parse_clauses("likes(mary, food).");
557 assert_eq!(clauses.len(), 1);
558 assert!(clauses[0].body.is_empty());
559 match &clauses[0].head {
560 Term::Compound { functor, args } => {
561 assert_eq!(interner.resolve(*functor), "likes");
562 assert_eq!(args.len(), 2);
563 }
564 _ => panic!("Expected compound"),
565 }
566 }
567
568 #[test]
569 fn test_parse_rule() {
570 let (clauses, interner) = parse_clauses("happy(X) :- likes(X, food).");
571 assert_eq!(clauses.len(), 1);
572 assert_eq!(clauses[0].body.len(), 1);
573 match &clauses[0].head {
574 Term::Compound { functor, .. } => {
575 assert_eq!(interner.resolve(*functor), "happy");
576 }
577 _ => panic!("Expected compound"),
578 }
579 }
580
581 #[test]
582 fn test_variable_scoping() {
583 let (clauses, _) = parse_clauses("foo(X, Y) :- bar(X, Y).");
585 let clause = &clauses[0];
586 if let Term::Compound {
588 args: head_args, ..
589 } = &clause.head
590 {
591 if let (Term::Var(hx), Term::Var(hy)) = (&head_args[0], &head_args[1]) {
592 if let Term::Compound {
594 args: body_args, ..
595 } = &clause.body[0]
596 {
597 if let (Term::Var(bx), Term::Var(by)) = (&body_args[0], &body_args[1]) {
598 assert_eq!(hx, bx, "X in head and body should be same var");
599 assert_eq!(hy, by, "Y in head and body should be same var");
600 assert_ne!(hx, hy, "X and Y should be different vars");
601 }
602 }
603 }
604 }
605 }
606
607 #[test]
608 fn test_operator_precedence() {
609 let (term, interner) = parse_term("2 + 3 * 4");
611 match term {
612 Term::Compound { functor, ref args } => {
613 assert_eq!(interner.resolve(functor), "+");
614 assert_eq!(args[0], Term::Integer(2));
615 match &args[1] {
616 Term::Compound { functor, ref args } => {
617 assert_eq!(interner.resolve(*functor), "*");
618 assert_eq!(args[0], Term::Integer(3));
619 assert_eq!(args[1], Term::Integer(4));
620 }
621 _ => panic!("Expected compound for 3*4"),
622 }
623 }
624 _ => panic!("Expected compound for addition"),
625 }
626 }
627
628 #[test]
629 fn test_parenthesized_expr() {
630 let (term, interner) = parse_term("(2 + 3) * 4");
632 match term {
633 Term::Compound { functor, ref args } => {
634 assert_eq!(interner.resolve(functor), "*");
635 match &args[0] {
636 Term::Compound { functor, ref args } => {
637 assert_eq!(interner.resolve(*functor), "+");
638 assert_eq!(args[0], Term::Integer(2));
639 assert_eq!(args[1], Term::Integer(3));
640 }
641 _ => panic!("Expected compound for addition"),
642 }
643 assert_eq!(args[1], Term::Integer(4));
644 }
645 _ => panic!("Expected compound for multiplication"),
646 }
647 }
648
649 #[test]
650 fn test_is_expression() {
651 let (term, interner) = parse_term("X is 2 + 3");
652 match term {
653 Term::Compound { functor, args } => {
654 assert_eq!(interner.resolve(functor), "is");
655 assert!(matches!(args[0], Term::Var(_)));
656 match &args[1] {
657 Term::Compound { functor, .. } => {
658 assert_eq!(interner.resolve(*functor), "+");
659 }
660 _ => panic!("Expected compound"),
661 }
662 }
663 _ => panic!("Expected compound"),
664 }
665 }
666
667 #[test]
668 fn test_unary_minus() {
669 let (term, _) = parse_term("- 5");
670 assert_eq!(term, Term::Integer(-5));
671 }
672
673 #[test]
674 fn test_empty_list() {
675 let (term, interner) = parse_term("[]");
676 match term {
677 Term::Atom(id) => assert_eq!(interner.resolve(id), "[]"),
678 _ => panic!("Expected empty list atom"),
679 }
680 }
681
682 #[test]
683 fn test_simple_list() {
684 let (term, interner) = parse_term("[1, 2, 3]");
685 match term {
687 Term::List { ref head, ref tail } => {
688 assert_eq!(**head, Term::Integer(1));
689 match tail.as_ref() {
690 Term::List { ref head, ref tail } => {
691 assert_eq!(**head, Term::Integer(2));
692 match tail.as_ref() {
693 Term::List { ref head, ref tail } => {
694 assert_eq!(**head, Term::Integer(3));
695 match tail.as_ref() {
696 Term::Atom(id) => assert_eq!(interner.resolve(*id), "[]"),
697 _ => panic!("Expected nil"),
698 }
699 }
700 _ => panic!("Expected list"),
701 }
702 }
703 _ => panic!("Expected list"),
704 }
705 }
706 _ => panic!("Expected list, got {:?}", term),
707 }
708 }
709
710 #[test]
711 fn test_head_tail_list() {
712 let (term, _) = parse_term("[H | T]");
713 match term {
714 Term::List { head, tail } => {
715 assert!(matches!(*head, Term::Var(_)));
716 assert!(matches!(*tail, Term::Var(_)));
717 }
718 _ => panic!("Expected list"),
719 }
720 }
721
722 #[test]
723 fn test_multiple_clauses() {
724 let (clauses, _) = parse_clauses("a. b. c.");
725 assert_eq!(clauses.len(), 3);
726 }
727
728 #[test]
729 fn test_parse_error() {
730 let mut interner = StringInterner::new();
731 let result = Parser::parse_program("invalid(((.", &mut interner);
732 assert!(result.is_err());
733 }
734
735 #[test]
736 fn test_comparison_operators() {
737 let (term, interner) = parse_term("X > 100");
738 match term {
739 Term::Compound { functor, .. } => {
740 assert_eq!(interner.resolve(functor), ">");
741 }
742 _ => panic!("Expected compound"),
743 }
744 }
745
746 #[test]
747 fn test_cut() {
748 let (clauses, interner) = parse_clauses("max(X, Y, X) :- X >= Y, !.");
749 assert_eq!(clauses[0].body.len(), 2);
750 match &clauses[0].body[1] {
751 Term::Atom(id) => assert_eq!(interner.resolve(*id), "!"),
752 _ => panic!("Expected cut atom"),
753 }
754 }
755
756 #[test]
757 fn test_negation() {
758 let (term, interner) = parse_term("\\+ foo(X)");
759 match term {
760 Term::Compound { functor, args } => {
761 assert_eq!(interner.resolve(functor), "\\+");
762 assert_eq!(args.len(), 1);
763 }
764 _ => panic!("Expected compound"),
765 }
766 }
767}