1use std::collections::HashMap;
4
5use super::types::{
6 functor_arity, LpClause, LpDatabase, LpTerm, Query, ResolutionResult, SolveConfig, Substitution,
7};
8
9fn rename_clause(clause: &LpClause, stamp: usize) -> LpClause {
13 let suffix = format!("_{stamp}");
14 LpClause {
15 head: rename_term(&clause.head, &suffix),
16 body: clause
17 .body
18 .iter()
19 .map(|t| rename_term(t, &suffix))
20 .collect(),
21 }
22}
23
24fn rename_term(t: &LpTerm, suffix: &str) -> LpTerm {
25 match t {
26 LpTerm::Var(v) => LpTerm::Var(format!("{v}{suffix}")),
27 LpTerm::Atom(_) | LpTerm::Integer(_) | LpTerm::Float(_) => t.clone(),
28 LpTerm::Compound { functor, args } => LpTerm::Compound {
29 functor: functor.clone(),
30 args: args.iter().map(|a| rename_term(a, suffix)).collect(),
31 },
32 LpTerm::List(items, tail) => LpTerm::List(
33 items.iter().map(|a| rename_term(a, suffix)).collect(),
34 tail.as_ref().map(|t| Box::new(rename_term(t, suffix))),
35 ),
36 }
37}
38
39pub fn occurs_check(var: &str, term: &LpTerm, subst: &Substitution) -> bool {
45 match term {
46 LpTerm::Var(v) => {
47 if v == var {
48 return true;
49 }
50 match subst.lookup(v) {
51 Some(t) => occurs_check(var, &t.clone(), subst),
52 None => false,
53 }
54 }
55 LpTerm::Atom(_) | LpTerm::Integer(_) | LpTerm::Float(_) => false,
56 LpTerm::Compound { args, .. } => args.iter().any(|a| occurs_check(var, a, subst)),
57 LpTerm::List(items, tail) => {
58 items.iter().any(|a| occurs_check(var, a, subst))
59 || tail.as_ref().map_or(false, |t| occurs_check(var, t, subst))
60 }
61 }
62}
63
64pub fn apply_subst(term: &LpTerm, subst: &Substitution) -> LpTerm {
68 match term {
69 LpTerm::Var(v) => match subst.lookup(v) {
70 None => term.clone(),
71 Some(t) => {
72 let t2 = t.clone();
73 if t2 == LpTerm::Var(v.clone()) {
75 t2
76 } else {
77 apply_subst(&t2, subst)
78 }
79 }
80 },
81 LpTerm::Atom(_) | LpTerm::Integer(_) | LpTerm::Float(_) => term.clone(),
82 LpTerm::Compound { functor, args } => LpTerm::Compound {
83 functor: functor.clone(),
84 args: args.iter().map(|a| apply_subst(a, subst)).collect(),
85 },
86 LpTerm::List(items, tail) => LpTerm::List(
87 items.iter().map(|a| apply_subst(a, subst)).collect(),
88 tail.as_ref().map(|t| Box::new(apply_subst(t, subst))),
89 ),
90 }
91}
92
93pub fn unify(t1: &LpTerm, t2: &LpTerm, subst: &Substitution) -> Option<Substitution> {
99 let t1 = apply_subst(t1, subst);
100 let t2 = apply_subst(t2, subst);
101 unify_walked(&t1, &t2, subst)
102}
103
104fn unify_walked(t1: &LpTerm, t2: &LpTerm, subst: &Substitution) -> Option<Substitution> {
105 match (t1, t2) {
106 (LpTerm::Atom(a), LpTerm::Atom(b)) if a == b => Some(subst.clone()),
108 (LpTerm::Integer(a), LpTerm::Integer(b)) if a == b => Some(subst.clone()),
109 (LpTerm::Float(a), LpTerm::Float(b)) if a == b => Some(subst.clone()),
110
111 (LpTerm::Var(v), t) => {
113 let t = apply_subst(t, subst);
114 if let LpTerm::Var(v2) = &t {
115 if v == v2 {
116 return Some(subst.clone());
117 }
118 }
119 let mut new_subst = subst.clone();
121 new_subst.bind(v.clone(), t);
122 Some(new_subst)
123 }
124
125 (t, LpTerm::Var(v)) => {
127 let t = apply_subst(t, subst);
128 let mut new_subst = subst.clone();
129 new_subst.bind(v.clone(), t);
130 Some(new_subst)
131 }
132
133 (
135 LpTerm::Compound {
136 functor: f1,
137 args: a1,
138 },
139 LpTerm::Compound {
140 functor: f2,
141 args: a2,
142 },
143 ) => {
144 if f1 != f2 || a1.len() != a2.len() {
145 return None;
146 }
147 let mut s = subst.clone();
148 for (x, y) in a1.iter().zip(a2.iter()) {
149 s = unify(x, y, &s)?;
150 }
151 Some(s)
152 }
153
154 (LpTerm::List(items1, tail1), LpTerm::List(items2, tail2)) => {
156 unify_lists(items1, tail1.as_deref(), items2, tail2.as_deref(), subst)
157 }
158
159 (LpTerm::Atom(a), LpTerm::List(items, None)) if a == "[]" && items.is_empty() => {
161 Some(subst.clone())
162 }
163 (LpTerm::List(items, None), LpTerm::Atom(a)) if a == "[]" && items.is_empty() => {
164 Some(subst.clone())
165 }
166
167 _ => None,
168 }
169}
170
171fn unify_lists(
172 items1: &[LpTerm],
173 tail1: Option<&LpTerm>,
174 items2: &[LpTerm],
175 tail2: Option<&LpTerm>,
176 subst: &Substitution,
177) -> Option<Substitution> {
178 match (items1, items2) {
179 ([], []) => {
180 match (tail1, tail2) {
182 (None, None) => Some(subst.clone()),
183 (Some(t1), Some(t2)) => unify(t1, t2, subst),
184 (None, Some(t)) => unify(&LpTerm::atom("[]"), t, subst),
185 (Some(t), None) => unify(t, &LpTerm::atom("[]"), subst),
186 }
187 }
188 ([], _) => {
189 let rest = LpTerm::List(items2.to_vec(), tail2.cloned().map(|t| Box::new(t.clone())));
191 match tail1 {
192 None => None, Some(t) => unify(t, &rest, subst),
194 }
195 }
196 (_, []) => {
197 let rest = LpTerm::List(items1.to_vec(), tail1.cloned().map(|t| Box::new(t.clone())));
199 match tail2 {
200 None => None,
201 Some(t) => unify(&rest, t, subst),
202 }
203 }
204 ([h1, rest1 @ ..], [h2, rest2 @ ..]) => {
205 let s = unify(h1, h2, subst)?;
206 unify_lists(rest1, tail1, rest2, tail2, &s)
207 }
208 }
209}
210
211pub fn unify_with_occurs_check(
215 t1: &LpTerm,
216 t2: &LpTerm,
217 subst: &Substitution,
218) -> Option<Substitution> {
219 let t1w = apply_subst(t1, subst);
220 let t2w = apply_subst(t2, subst);
221 unify_oc_walked(&t1w, &t2w, subst)
222}
223
224fn unify_oc_walked(t1: &LpTerm, t2: &LpTerm, subst: &Substitution) -> Option<Substitution> {
225 match (t1, t2) {
226 (LpTerm::Atom(a), LpTerm::Atom(b)) if a == b => Some(subst.clone()),
227 (LpTerm::Integer(a), LpTerm::Integer(b)) if a == b => Some(subst.clone()),
228 (LpTerm::Float(a), LpTerm::Float(b)) if a == b => Some(subst.clone()),
229
230 (LpTerm::Var(v), t) => {
231 let t = apply_subst(t, subst);
232 if let LpTerm::Var(v2) = &t {
233 if v == v2 {
234 return Some(subst.clone());
235 }
236 }
237 if occurs_check(v, &t, subst) {
238 return None;
239 }
240 let mut s = subst.clone();
241 s.bind(v.clone(), t);
242 Some(s)
243 }
244
245 (t, LpTerm::Var(v)) => {
246 let t = apply_subst(t, subst);
247 if occurs_check(v, &t, subst) {
248 return None;
249 }
250 let mut s = subst.clone();
251 s.bind(v.clone(), t);
252 Some(s)
253 }
254
255 (
256 LpTerm::Compound {
257 functor: f1,
258 args: a1,
259 },
260 LpTerm::Compound {
261 functor: f2,
262 args: a2,
263 },
264 ) => {
265 if f1 != f2 || a1.len() != a2.len() {
266 return None;
267 }
268 let mut s = subst.clone();
269 for (x, y) in a1.iter().zip(a2.iter()) {
270 s = unify_with_occurs_check(x, y, &s)?;
271 }
272 Some(s)
273 }
274
275 (LpTerm::List(i1, t1), LpTerm::List(i2, t2)) => {
276 unify_lists(i1, t1.as_deref(), i2, t2.as_deref(), subst)
277 }
278
279 (LpTerm::Atom(a), LpTerm::List(items, None)) if a == "[]" && items.is_empty() => {
280 Some(subst.clone())
281 }
282 (LpTerm::List(items, None), LpTerm::Atom(a)) if a == "[]" && items.is_empty() => {
283 Some(subst.clone())
284 }
285
286 _ => None,
287 }
288}
289
290pub fn resolve(query: &Query, db: &LpDatabase, cfg: &SolveConfig) -> Vec<Substitution> {
294 let mut results = Vec::new();
295 let mut counter = 0usize;
296 sld_resolve(
297 query.goals.clone(),
298 &Substitution::new(),
299 db,
300 cfg,
301 0,
302 &mut counter,
303 &mut results,
304 );
305 results
306}
307
308pub fn solve_one(query: &Query, db: &LpDatabase, cfg: &SolveConfig) -> ResolutionResult {
310 let mut results = Vec::new();
311 let mut counter = 0usize;
312 let one_cfg = SolveConfig {
313 max_solutions: 1,
314 ..cfg.clone()
315 };
316 sld_resolve(
317 query.goals.clone(),
318 &Substitution::new(),
319 db,
320 &one_cfg,
321 0,
322 &mut counter,
323 &mut results,
324 );
325 match results.into_iter().next() {
326 Some(s) => ResolutionResult::Success(s),
327 None => ResolutionResult::Failure,
328 }
329}
330
331fn sld_resolve(
332 goals: Vec<LpTerm>,
333 subst: &Substitution,
334 db: &LpDatabase,
335 cfg: &SolveConfig,
336 depth: usize,
337 counter: &mut usize,
338 results: &mut Vec<Substitution>,
339) {
340 if results.len() >= cfg.max_solutions {
341 return;
342 }
343 if depth > cfg.max_depth {
344 return;
345 }
346
347 if goals.is_empty() {
348 results.push(subst.clone());
349 return;
350 }
351
352 let goal = apply_subst(&goals[0], subst);
353 let rest_goals = goals[1..].to_vec();
354
355 if handle_builtin(&goal, subst, db, cfg, depth, counter, results, &rest_goals) {
357 return;
358 }
359
360 let matching: Vec<LpClause> = db.matching_clauses(&goal).into_iter().cloned().collect();
362
363 for clause in &matching {
364 if results.len() >= cfg.max_solutions {
365 break;
366 }
367 *counter += 1;
368 let renamed = rename_clause(clause, *counter);
369 let unifier = if cfg.occurs_check {
370 unify_with_occurs_check(&goal, &renamed.head, subst)
371 } else {
372 unify(&goal, &renamed.head, subst)
373 };
374 if let Some(new_subst) = unifier {
375 let mut new_goals = renamed.body.clone();
376 new_goals.extend(rest_goals.clone());
377 sld_resolve(new_goals, &new_subst, db, cfg, depth + 1, counter, results);
378 }
379 }
380}
381
382fn handle_builtin(
384 goal: &LpTerm,
385 subst: &Substitution,
386 db: &LpDatabase,
387 cfg: &SolveConfig,
388 depth: usize,
389 counter: &mut usize,
390 results: &mut Vec<Substitution>,
391 rest_goals: &[LpTerm],
392) -> bool {
393 match goal {
394 LpTerm::Atom(a) if a == "true" => {
396 sld_resolve(
397 rest_goals.to_vec(),
398 subst,
399 db,
400 cfg,
401 depth + 1,
402 counter,
403 results,
404 );
405 true
406 }
407 LpTerm::Atom(a) if a == "fail" || a == "false" => true,
409 LpTerm::Compound { functor, args } if functor == "=" && args.len() == 2 => {
411 if let Some(s) = unify(&args[0], &args[1], subst) {
412 sld_resolve(
413 rest_goals.to_vec(),
414 &s,
415 db,
416 cfg,
417 depth + 1,
418 counter,
419 results,
420 );
421 }
422 true
423 }
424 LpTerm::Compound { functor, args } if functor == "\\=" && args.len() == 2 => {
426 if unify(&args[0], &args[1], subst).is_none() {
427 sld_resolve(
428 rest_goals.to_vec(),
429 subst,
430 db,
431 cfg,
432 depth + 1,
433 counter,
434 results,
435 );
436 }
437 true
438 }
439 LpTerm::Compound { functor, args } if functor == "is" && args.len() == 2 => {
441 if let Some(val) = eval_arith(&args[1], subst) {
442 if let Some(s) = unify(&args[0], &val, subst) {
443 sld_resolve(
444 rest_goals.to_vec(),
445 &s,
446 db,
447 cfg,
448 depth + 1,
449 counter,
450 results,
451 );
452 }
453 }
454 true
455 }
456 LpTerm::Compound { functor, args } if functor == "=:=" && args.len() == 2 => {
458 let v1 = eval_arith(&args[0], subst);
459 let v2 = eval_arith(&args[1], subst);
460 if v1 == v2 && v1.is_some() {
461 sld_resolve(
462 rest_goals.to_vec(),
463 subst,
464 db,
465 cfg,
466 depth + 1,
467 counter,
468 results,
469 );
470 }
471 true
472 }
473 LpTerm::Compound { functor, args }
475 if (functor == "<" || functor == ">" || functor == "=<" || functor == ">=")
476 && args.len() == 2 =>
477 {
478 let v1 = eval_arith_f64(&args[0], subst);
479 let v2 = eval_arith_f64(&args[1], subst);
480 let ok = match (v1, v2) {
481 (Some(a), Some(b)) => match functor.as_str() {
482 "<" => a < b,
483 ">" => a > b,
484 "=<" => a <= b,
485 ">=" => a >= b,
486 _ => false,
487 },
488 _ => false,
489 };
490 if ok {
491 sld_resolve(
492 rest_goals.to_vec(),
493 subst,
494 db,
495 cfg,
496 depth + 1,
497 counter,
498 results,
499 );
500 }
501 true
502 }
503 LpTerm::Compound { functor, args }
505 if (functor == "not" || functor == "\\+") && args.len() == 1 =>
506 {
507 let inner_q = Query::single(args[0].clone());
508 let inner_cfg = SolveConfig {
509 max_solutions: 1,
510 ..cfg.clone()
511 };
512 let inner_results = resolve(&inner_q, db, &inner_cfg);
513 if inner_results.is_empty() {
514 sld_resolve(
515 rest_goals.to_vec(),
516 subst,
517 db,
518 cfg,
519 depth + 1,
520 counter,
521 results,
522 );
523 }
524 true
525 }
526 LpTerm::Compound { functor, args } if functor == "call" && args.len() == 1 => {
528 let new_goal = apply_subst(&args[0], subst);
529 let mut new_goals = vec![new_goal];
530 new_goals.extend(rest_goals.to_vec());
531 sld_resolve(new_goals, subst, db, cfg, depth + 1, counter, results);
532 true
533 }
534 _ => false,
535 }
536}
537
538fn eval_arith(t: &LpTerm, subst: &Substitution) -> Option<LpTerm> {
540 let t = apply_subst(t, subst);
541 match &t {
542 LpTerm::Integer(n) => Some(LpTerm::Integer(*n)),
543 LpTerm::Float(f) => Some(LpTerm::Float(*f)),
544 LpTerm::Compound { functor, args } if args.len() == 2 => {
545 let a = eval_arith_f64(&args[0], subst)?;
546 let b = eval_arith_f64(&args[1], subst)?;
547 let result = match functor.as_str() {
548 "+" => a + b,
549 "-" => a - b,
550 "*" => a * b,
551 "/" => {
552 if b == 0.0 {
553 return None;
554 }
555 a / b
556 }
557 "mod" => {
558 if b == 0.0 {
559 return None;
560 }
561 a % b
562 }
563 "**" | "^" => a.powf(b),
564 _ => return None,
565 };
566 if result.fract() == 0.0 && functor != "/" {
568 Some(LpTerm::Integer(result as i64))
569 } else {
570 Some(LpTerm::Float(result))
571 }
572 }
573 LpTerm::Compound { functor, args } if args.len() == 1 => {
574 let a = eval_arith_f64(&args[0], subst)?;
575 let result = match functor.as_str() {
576 "abs" => a.abs(),
577 "sqrt" => a.sqrt(),
578 "floor" => a.floor(),
579 "ceiling" => a.ceil(),
580 "round" => a.round(),
581 "-" => -a,
582 _ => return None,
583 };
584 if result.fract() == 0.0 {
585 Some(LpTerm::Integer(result as i64))
586 } else {
587 Some(LpTerm::Float(result))
588 }
589 }
590 _ => None,
591 }
592}
593
594fn eval_arith_f64(t: &LpTerm, subst: &Substitution) -> Option<f64> {
595 match eval_arith(t, subst)? {
596 LpTerm::Integer(n) => Some(n as f64),
597 LpTerm::Float(f) => Some(f),
598 _ => None,
599 }
600}
601
602impl LpDatabase {
605 pub fn query_all(&self, goal: LpTerm, cfg: &SolveConfig) -> Vec<Substitution> {
607 let q = Query::single(goal);
608 resolve(&q, self, cfg)
609 }
610}
611
612pub fn term_to_string(t: &LpTerm) -> String {
616 match t {
617 LpTerm::Atom(s) => {
618 if needs_quoting(s) {
620 format!("'{}'", s.replace('\'', "\\'"))
621 } else {
622 s.clone()
623 }
624 }
625 LpTerm::Var(v) => v.clone(),
626 LpTerm::Integer(n) => n.to_string(),
627 LpTerm::Float(f) => format!("{f}"),
628 LpTerm::Compound { functor, args } => {
629 let args_str: Vec<String> = args.iter().map(term_to_string).collect();
630 format!("{}({})", functor, args_str.join(","))
631 }
632 LpTerm::List(items, tail) => {
633 let items_str: Vec<String> = items.iter().map(term_to_string).collect();
634 let body = items_str.join(",");
635 match tail {
636 None => format!("[{body}]"),
637 Some(t) => format!("[{body}|{}]", term_to_string(t)),
638 }
639 }
640 }
641}
642
643fn needs_quoting(s: &str) -> bool {
644 if s.is_empty() {
645 return true;
646 }
647 let mut chars = s.chars();
648 let first = match chars.next() {
649 Some(c) => c,
650 None => return true,
651 };
652 if first.is_ascii_lowercase() && s.chars().all(|c| c.is_alphanumeric() || c == '_') {
654 return false;
655 }
656 if s.chars().all(|c| "+-*/\\^<>=~:.?@#&".contains(c)) {
658 return false;
659 }
660 matches!(s, "[]" | "{}" | "!" | ";" | "," | "|")
662 || s.chars().all(|c| c.is_alphanumeric() || c == '_')
663}
664
665pub fn parse_term(s: &str) -> Option<LpTerm> {
672 let s = s.trim();
673 if s.is_empty() {
674 return None;
675 }
676 parse_term_inner(s)
677}
678
679fn parse_term_inner(s: &str) -> Option<LpTerm> {
680 let s = s.trim();
681
682 if s.starts_with('[') && s.ends_with(']') {
684 return parse_list(&s[1..s.len() - 1]);
685 }
686
687 if s.starts_with('\'') && s.ends_with('\'') && s.len() >= 2 {
689 return Some(LpTerm::Atom(s[1..s.len() - 1].replace("\\'", "'")));
690 }
691
692 if let Some(paren_pos) = find_outer_paren(s) {
694 let functor = s[..paren_pos].trim().to_string();
695 let args_str = &s[paren_pos + 1..s.len() - 1];
696 let args = split_args(args_str)
697 .into_iter()
698 .map(|a| parse_term_inner(a.trim()))
699 .collect::<Option<Vec<_>>>()?;
700 return Some(LpTerm::Compound { functor, args });
701 }
702
703 if let Ok(n) = s.parse::<i64>() {
705 return Some(LpTerm::Integer(n));
706 }
707
708 if let Ok(f) = s.parse::<f64>() {
710 return Some(LpTerm::Float(f));
711 }
712
713 let first = s.chars().next()?;
715 if first.is_uppercase() || first == '_' {
716 return Some(LpTerm::Var(s.to_string()));
717 }
718
719 Some(LpTerm::Atom(s.to_string()))
721}
722
723fn parse_list(inner: &str) -> Option<LpTerm> {
724 let inner = inner.trim();
725 if inner.is_empty() {
726 return Some(LpTerm::atom("[]"));
727 }
728
729 let mut depth = 0i32;
731 let mut bar_pos = None;
732 let bytes = inner.as_bytes();
733 for (i, &b) in bytes.iter().enumerate() {
734 match b {
735 b'(' | b'[' => depth += 1,
736 b')' | b']' => depth -= 1,
737 b'|' if depth == 0 => {
738 bar_pos = Some(i);
739 break;
740 }
741 _ => {}
742 }
743 }
744
745 if let Some(pos) = bar_pos {
746 let items_str = &inner[..pos];
747 let tail_str = inner[pos + 1..].trim();
748 let items = split_args(items_str)
749 .into_iter()
750 .map(|a| parse_term_inner(a.trim()))
751 .collect::<Option<Vec<_>>>()?;
752 let tail = parse_term_inner(tail_str)?;
753 Some(LpTerm::List(items, Some(Box::new(tail))))
754 } else {
755 let items = split_args(inner)
756 .into_iter()
757 .map(|a| parse_term_inner(a.trim()))
758 .collect::<Option<Vec<_>>>()?;
759 Some(LpTerm::list(items))
760 }
761}
762
763fn find_outer_paren(s: &str) -> Option<usize> {
765 let mut depth = 0i32;
766 for (i, c) in s.char_indices() {
767 match c {
768 '(' if depth == 0 => {
769 if i > 0 && s.ends_with(')') {
771 return Some(i);
772 }
773 return None;
774 }
775 '(' => depth += 1,
776 ')' => depth -= 1,
777 _ => {}
778 }
779 }
780 None
781}
782
783fn split_args(s: &str) -> Vec<&str> {
785 let mut parts = Vec::new();
786 let mut depth = 0i32;
787 let mut start = 0;
788 let bytes = s.as_bytes();
789 for (i, &b) in bytes.iter().enumerate() {
790 match b {
791 b'(' | b'[' => depth += 1,
792 b')' | b']' => depth -= 1,
793 b',' if depth == 0 => {
794 parts.push(&s[start..i]);
795 start = i + 1;
796 }
797 _ => {}
798 }
799 }
800 if start <= s.len() {
801 let tail = s[start..].trim();
802 if !tail.is_empty() {
803 parts.push(&s[start..]);
804 }
805 }
806 parts
807}
808
809pub fn parse_clause(s: &str) -> Option<LpClause> {
811 let s = s.trim().trim_end_matches('.');
812 if let Some(pos) = s.find(":-") {
813 let head_str = s[..pos].trim();
814 let body_str = s[pos + 2..].trim();
815 let head = parse_term(head_str)?;
816 let body = split_args(body_str)
817 .into_iter()
818 .map(|a| parse_term(a.trim()))
819 .collect::<Option<Vec<_>>>()?;
820 Some(LpClause::rule(head, body))
821 } else {
822 let head = parse_term(s)?;
823 Some(LpClause::fact(head))
824 }
825}
826
827pub fn load_standard_predicates(db: &mut LpDatabase) {
831 db.add_fact(LpTerm::compound(
833 "member",
834 vec![
835 LpTerm::var("X"),
836 LpTerm::list_with_tail(vec![LpTerm::var("X")], LpTerm::var("_T")),
837 ],
838 ));
839 db.add_rule(
841 LpTerm::compound(
842 "member",
843 vec![
844 LpTerm::var("X"),
845 LpTerm::list_with_tail(vec![LpTerm::var("_H")], LpTerm::var("T")),
846 ],
847 ),
848 vec![LpTerm::compound(
849 "member",
850 vec![LpTerm::var("X"), LpTerm::var("T")],
851 )],
852 );
853
854 db.add_fact(LpTerm::compound(
856 "append",
857 vec![LpTerm::atom("[]"), LpTerm::var("L"), LpTerm::var("L")],
858 ));
859 db.add_rule(
861 LpTerm::compound(
862 "append",
863 vec![
864 LpTerm::list_with_tail(vec![LpTerm::var("H")], LpTerm::var("T")),
865 LpTerm::var("L"),
866 LpTerm::list_with_tail(vec![LpTerm::var("H")], LpTerm::var("R")),
867 ],
868 ),
869 vec![LpTerm::compound(
870 "append",
871 vec![LpTerm::var("T"), LpTerm::var("L"), LpTerm::var("R")],
872 )],
873 );
874
875 db.add_fact(LpTerm::compound(
877 "reverse_acc",
878 vec![LpTerm::atom("[]"), LpTerm::var("Acc"), LpTerm::var("Acc")],
879 ));
880 db.add_rule(
882 LpTerm::compound(
883 "reverse_acc",
884 vec![
885 LpTerm::list_with_tail(vec![LpTerm::var("H")], LpTerm::var("T")),
886 LpTerm::var("Acc"),
887 LpTerm::var("Rev"),
888 ],
889 ),
890 vec![LpTerm::compound(
891 "reverse_acc",
892 vec![
893 LpTerm::var("T"),
894 LpTerm::list_with_tail(vec![LpTerm::var("H")], LpTerm::var("Acc")),
895 LpTerm::var("Rev"),
896 ],
897 )],
898 );
899 db.add_rule(
901 LpTerm::compound("reverse", vec![LpTerm::var("L"), LpTerm::var("R")]),
902 vec![LpTerm::compound(
903 "reverse_acc",
904 vec![LpTerm::var("L"), LpTerm::atom("[]"), LpTerm::var("R")],
905 )],
906 );
907
908 db.add_fact(LpTerm::compound(
910 "length",
911 vec![LpTerm::atom("[]"), LpTerm::Integer(0)],
912 ));
913 db.add_rule(
915 LpTerm::compound(
916 "length",
917 vec![
918 LpTerm::list_with_tail(vec![LpTerm::var("_H2")], LpTerm::var("T2")),
919 LpTerm::var("N"),
920 ],
921 ),
922 vec![
923 LpTerm::compound("length", vec![LpTerm::var("T2"), LpTerm::var("N1")]),
924 LpTerm::compound(
925 "is",
926 vec![
927 LpTerm::var("N"),
928 LpTerm::compound("+", vec![LpTerm::var("N1"), LpTerm::Integer(1)]),
929 ],
930 ),
931 ],
932 );
933
934 db.add_fact(LpTerm::compound(
936 "last",
937 vec![LpTerm::list(vec![LpTerm::var("X")]), LpTerm::var("X")],
938 ));
939 db.add_rule(
941 LpTerm::compound(
942 "last",
943 vec![
944 LpTerm::list_with_tail(vec![LpTerm::var("_HL")], LpTerm::var("TL")),
945 LpTerm::var("XL"),
946 ],
947 ),
948 vec![LpTerm::compound(
949 "last",
950 vec![LpTerm::var("TL"), LpTerm::var("XL")],
951 )],
952 );
953
954 db.add_fact(LpTerm::compound("nat", vec![LpTerm::Integer(0)]));
957 db.add_rule(
959 LpTerm::compound("nat", vec![LpTerm::var("N")]),
960 vec![
961 LpTerm::compound("nat", vec![LpTerm::var("N1")]),
962 LpTerm::compound(
963 "is",
964 vec![
965 LpTerm::var("N"),
966 LpTerm::compound("+", vec![LpTerm::var("N1"), LpTerm::Integer(1)]),
967 ],
968 ),
969 ],
970 );
971}
972
973#[cfg(test)]
978mod tests {
979 use super::*;
980
981 fn empty_subst() -> Substitution {
982 Substitution::new()
983 }
984
985 fn flatten_list(t: &LpTerm) -> Vec<LpTerm> {
987 let mut result = Vec::new();
988 flatten_list_into(t, &mut result);
989 result
990 }
991
992 fn flatten_list_into(t: &LpTerm, out: &mut Vec<LpTerm>) {
993 match t {
994 LpTerm::Atom(a) if a == "[]" => {}
995 LpTerm::List(items, tail) => {
996 for item in items {
997 out.push(item.clone());
998 }
999 if let Some(tl) = tail {
1000 flatten_list_into(tl, out);
1001 }
1002 }
1003 _ => out.push(t.clone()),
1004 }
1005 }
1006
1007 fn default_cfg() -> SolveConfig {
1008 SolveConfig::default()
1009 }
1010
1011 fn std_db() -> LpDatabase {
1012 let mut db = LpDatabase::new();
1013 load_standard_predicates(&mut db);
1014 db
1015 }
1016
1017 #[test]
1020 fn test_unify_atoms_equal() {
1021 let s = unify(&LpTerm::atom("foo"), &LpTerm::atom("foo"), &empty_subst());
1022 assert!(s.is_some());
1023 }
1024
1025 #[test]
1026 fn test_unify_atoms_different() {
1027 let s = unify(&LpTerm::atom("foo"), &LpTerm::atom("bar"), &empty_subst());
1028 assert!(s.is_none());
1029 }
1030
1031 #[test]
1032 fn test_unify_var_atom() {
1033 let s = unify(&LpTerm::var("X"), &LpTerm::atom("hello"), &empty_subst());
1034 assert!(s.is_some());
1035 let s = s.unwrap();
1036 assert_eq!(s.lookup("X"), Some(&LpTerm::atom("hello")));
1037 }
1038
1039 #[test]
1040 fn test_unify_compound() {
1041 let t1 = LpTerm::compound("f", vec![LpTerm::var("X"), LpTerm::Integer(1)]);
1042 let t2 = LpTerm::compound("f", vec![LpTerm::atom("a"), LpTerm::Integer(1)]);
1043 let s = unify(&t1, &t2, &empty_subst());
1044 assert!(s.is_some());
1045 let s = s.unwrap();
1046 assert_eq!(s.lookup("X"), Some(&LpTerm::atom("a")));
1047 }
1048
1049 #[test]
1050 fn test_unify_compound_arity_mismatch() {
1051 let t1 = LpTerm::compound("f", vec![LpTerm::var("X")]);
1052 let t2 = LpTerm::compound("f", vec![LpTerm::var("X"), LpTerm::var("Y")]);
1053 assert!(unify(&t1, &t2, &empty_subst()).is_none());
1054 }
1055
1056 #[test]
1057 fn test_unify_list() {
1058 let t1 = LpTerm::list(vec![LpTerm::var("X"), LpTerm::Integer(2)]);
1059 let t2 = LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]);
1060 let s = unify(&t1, &t2, &empty_subst());
1061 assert!(s.is_some());
1062 let s = s.unwrap();
1063 assert_eq!(apply_subst(&LpTerm::var("X"), &s), LpTerm::Integer(1));
1064 }
1065
1066 #[test]
1067 fn test_unify_list_different_length() {
1068 let t1 = LpTerm::list(vec![LpTerm::Integer(1)]);
1069 let t2 = LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]);
1070 assert!(unify(&t1, &t2, &empty_subst()).is_none());
1071 }
1072
1073 #[test]
1074 fn test_unify_integers() {
1075 let s = unify(&LpTerm::Integer(42), &LpTerm::Integer(42), &empty_subst());
1076 assert!(s.is_some());
1077 let s = unify(&LpTerm::Integer(1), &LpTerm::Integer(2), &empty_subst());
1078 assert!(s.is_none());
1079 }
1080
1081 #[test]
1084 fn test_apply_subst_var() {
1085 let mut s = Substitution::new();
1086 s.bind("X", LpTerm::atom("hello"));
1087 assert_eq!(apply_subst(&LpTerm::var("X"), &s), LpTerm::atom("hello"));
1088 }
1089
1090 #[test]
1091 fn test_apply_subst_compound() {
1092 let mut s = Substitution::new();
1093 s.bind("X", LpTerm::Integer(5));
1094 let t = LpTerm::compound("f", vec![LpTerm::var("X"), LpTerm::Integer(1)]);
1095 let result = apply_subst(&t, &s);
1096 assert_eq!(
1097 result,
1098 LpTerm::compound("f", vec![LpTerm::Integer(5), LpTerm::Integer(1)])
1099 );
1100 }
1101
1102 #[test]
1105 fn test_occurs_check_direct() {
1106 let s = empty_subst();
1107 assert!(occurs_check("X", &LpTerm::var("X"), &s));
1108 }
1109
1110 #[test]
1111 fn test_occurs_check_in_compound() {
1112 let s = empty_subst();
1113 let t = LpTerm::compound("f", vec![LpTerm::var("X")]);
1114 assert!(occurs_check("X", &t, &s));
1115 }
1116
1117 #[test]
1118 fn test_occurs_check_not_present() {
1119 let s = empty_subst();
1120 let t = LpTerm::compound("f", vec![LpTerm::var("Y")]);
1121 assert!(!occurs_check("X", &t, &s));
1122 }
1123
1124 #[test]
1125 fn test_occurs_check_prevents_circular() {
1126 let s = empty_subst();
1127 let t = LpTerm::compound("f", vec![LpTerm::var("X")]);
1128 let result = unify_with_occurs_check(&LpTerm::var("X"), &t, &s);
1129 assert!(result.is_none());
1130 }
1131
1132 #[test]
1135 fn test_member_first() {
1136 let db = std_db();
1137 let cfg = default_cfg();
1138 let q = Query::single(LpTerm::compound(
1139 "member",
1140 vec![
1141 LpTerm::Integer(1),
1142 LpTerm::list(vec![
1143 LpTerm::Integer(1),
1144 LpTerm::Integer(2),
1145 LpTerm::Integer(3),
1146 ]),
1147 ],
1148 ));
1149 let results = resolve(&q, &db, &cfg);
1150 assert!(!results.is_empty(), "member(1, [1,2,3]) should succeed");
1151 }
1152
1153 #[test]
1154 fn test_member_middle() {
1155 let db = std_db();
1156 let cfg = default_cfg();
1157 let q = Query::single(LpTerm::compound(
1158 "member",
1159 vec![
1160 LpTerm::Integer(2),
1161 LpTerm::list(vec![
1162 LpTerm::Integer(1),
1163 LpTerm::Integer(2),
1164 LpTerm::Integer(3),
1165 ]),
1166 ],
1167 ));
1168 let results = resolve(&q, &db, &cfg);
1169 assert!(!results.is_empty());
1170 }
1171
1172 #[test]
1173 fn test_member_not_found() {
1174 let db = std_db();
1175 let cfg = default_cfg();
1176 let q = Query::single(LpTerm::compound(
1177 "member",
1178 vec![
1179 LpTerm::Integer(99),
1180 LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]),
1181 ],
1182 ));
1183 let results = resolve(&q, &db, &cfg);
1184 assert!(results.is_empty());
1185 }
1186
1187 #[test]
1188 fn test_member_enumerate() {
1189 let db = std_db();
1190 let cfg = default_cfg();
1191 let q = Query::single(LpTerm::compound(
1192 "member",
1193 vec![
1194 LpTerm::var("X"),
1195 LpTerm::list(vec![
1196 LpTerm::atom("a"),
1197 LpTerm::atom("b"),
1198 LpTerm::atom("c"),
1199 ]),
1200 ],
1201 ));
1202 let results = resolve(&q, &db, &cfg);
1203 assert_eq!(results.len(), 3, "Should enumerate all 3 members");
1204 }
1205
1206 #[test]
1209 fn test_append_concrete() {
1210 let db = std_db();
1211 let cfg = default_cfg();
1212 let q = Query::single(LpTerm::compound(
1213 "append",
1214 vec![
1215 LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]),
1216 LpTerm::list(vec![LpTerm::Integer(3)]),
1217 LpTerm::var("R"),
1218 ],
1219 ));
1220 let results = resolve(&q, &db, &cfg);
1221 assert_eq!(results.len(), 1);
1222 let r = apply_subst(&LpTerm::var("R"), &results[0]);
1223 let flat = flatten_list(&r);
1224 assert_eq!(
1225 flat,
1226 vec![LpTerm::Integer(1), LpTerm::Integer(2), LpTerm::Integer(3)]
1227 );
1228 }
1229
1230 #[test]
1231 fn test_append_split() {
1232 let db = std_db();
1234 let cfg = default_cfg();
1235 let q = Query::single(LpTerm::compound(
1236 "append",
1237 vec![
1238 LpTerm::var("X"),
1239 LpTerm::var("Y"),
1240 LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]),
1241 ],
1242 ));
1243 let results = resolve(&q, &db, &cfg);
1244 assert_eq!(results.len(), 3);
1246 }
1247
1248 #[test]
1251 fn test_reverse() {
1252 let db = std_db();
1253 let cfg = default_cfg();
1254 let q = Query::single(LpTerm::compound(
1255 "reverse",
1256 vec![
1257 LpTerm::list(vec![
1258 LpTerm::Integer(1),
1259 LpTerm::Integer(2),
1260 LpTerm::Integer(3),
1261 ]),
1262 LpTerm::var("R"),
1263 ],
1264 ));
1265 let results = resolve(&q, &db, &cfg);
1266 assert_eq!(results.len(), 1);
1267 let r = apply_subst(&LpTerm::var("R"), &results[0]);
1268 let flat = flatten_list(&r);
1269 assert_eq!(
1270 flat,
1271 vec![LpTerm::Integer(3), LpTerm::Integer(2), LpTerm::Integer(1)]
1272 );
1273 }
1274
1275 #[test]
1278 fn test_builtin_true() {
1279 let db = LpDatabase::new();
1280 let cfg = default_cfg();
1281 let q = Query::single(LpTerm::atom("true"));
1282 let results = resolve(&q, &db, &cfg);
1283 assert_eq!(results.len(), 1);
1284 }
1285
1286 #[test]
1287 fn test_builtin_fail() {
1288 let db = LpDatabase::new();
1289 let cfg = default_cfg();
1290 let q = Query::single(LpTerm::atom("fail"));
1291 let results = resolve(&q, &db, &cfg);
1292 assert!(results.is_empty());
1293 }
1294
1295 #[test]
1298 fn test_builtin_unify() {
1299 let db = LpDatabase::new();
1300 let cfg = default_cfg();
1301 let q = Query::single(LpTerm::compound(
1302 "=",
1303 vec![LpTerm::var("X"), LpTerm::Integer(42)],
1304 ));
1305 let results = resolve(&q, &db, &cfg);
1306 assert_eq!(results.len(), 1);
1307 let val = apply_subst(&LpTerm::var("X"), &results[0]);
1308 assert_eq!(val, LpTerm::Integer(42));
1309 }
1310
1311 #[test]
1314 fn test_builtin_is_add() {
1315 let db = LpDatabase::new();
1316 let cfg = default_cfg();
1317 let q = Query::single(LpTerm::compound(
1318 "is",
1319 vec![
1320 LpTerm::var("X"),
1321 LpTerm::compound("+", vec![LpTerm::Integer(3), LpTerm::Integer(4)]),
1322 ],
1323 ));
1324 let results = resolve(&q, &db, &cfg);
1325 assert_eq!(results.len(), 1);
1326 let val = apply_subst(&LpTerm::var("X"), &results[0]);
1327 assert_eq!(val, LpTerm::Integer(7));
1328 }
1329
1330 #[test]
1331 fn test_builtin_is_mul() {
1332 let db = LpDatabase::new();
1333 let cfg = default_cfg();
1334 let q = Query::single(LpTerm::compound(
1335 "is",
1336 vec![
1337 LpTerm::var("X"),
1338 LpTerm::compound("*", vec![LpTerm::Integer(6), LpTerm::Integer(7)]),
1339 ],
1340 ));
1341 let results = resolve(&q, &db, &cfg);
1342 assert_eq!(results.len(), 1);
1343 let val = apply_subst(&LpTerm::var("X"), &results[0]);
1344 assert_eq!(val, LpTerm::Integer(42));
1345 }
1346
1347 #[test]
1350 fn test_builtin_less_than() {
1351 let db = LpDatabase::new();
1352 let cfg = default_cfg();
1353 let q = Query::single(LpTerm::compound(
1354 "<",
1355 vec![LpTerm::Integer(3), LpTerm::Integer(5)],
1356 ));
1357 let results = resolve(&q, &db, &cfg);
1358 assert_eq!(results.len(), 1);
1359 }
1360
1361 #[test]
1362 fn test_builtin_less_than_false() {
1363 let db = LpDatabase::new();
1364 let cfg = default_cfg();
1365 let q = Query::single(LpTerm::compound(
1366 "<",
1367 vec![LpTerm::Integer(5), LpTerm::Integer(3)],
1368 ));
1369 let results = resolve(&q, &db, &cfg);
1370 assert!(results.is_empty());
1371 }
1372
1373 #[test]
1376 fn test_negation_as_failure() {
1377 let db = LpDatabase::new();
1378 let cfg = default_cfg();
1379 let q = Query::single(LpTerm::compound("\\+", vec![LpTerm::atom("fail")]));
1381 let results = resolve(&q, &db, &cfg);
1382 assert_eq!(results.len(), 1);
1383 }
1384
1385 #[test]
1386 fn test_negation_as_failure_fail() {
1387 let db = LpDatabase::new();
1388 let cfg = default_cfg();
1389 let q = Query::single(LpTerm::compound("\\+", vec![LpTerm::atom("true")]));
1391 let results = resolve(&q, &db, &cfg);
1392 assert!(results.is_empty());
1393 }
1394
1395 #[test]
1398 fn test_solve_one_success() {
1399 let db = std_db();
1400 let cfg = default_cfg();
1401 let q = Query::single(LpTerm::compound(
1402 "member",
1403 vec![
1404 LpTerm::Integer(1),
1405 LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]),
1406 ],
1407 ));
1408 match solve_one(&q, &db, &cfg) {
1409 ResolutionResult::Success(_) => {}
1410 _ => panic!("Expected success"),
1411 }
1412 }
1413
1414 #[test]
1415 fn test_solve_one_failure() {
1416 let db = std_db();
1417 let cfg = default_cfg();
1418 let q = Query::single(LpTerm::compound(
1419 "member",
1420 vec![
1421 LpTerm::Integer(99),
1422 LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]),
1423 ],
1424 ));
1425 match solve_one(&q, &db, &cfg) {
1426 ResolutionResult::Failure => {}
1427 _ => panic!("Expected failure"),
1428 }
1429 }
1430
1431 #[test]
1434 fn test_term_to_string_atom() {
1435 assert_eq!(term_to_string(&LpTerm::atom("hello")), "hello");
1436 }
1437
1438 #[test]
1439 fn test_term_to_string_var() {
1440 assert_eq!(term_to_string(&LpTerm::var("X")), "X");
1441 }
1442
1443 #[test]
1444 fn test_term_to_string_integer() {
1445 assert_eq!(term_to_string(&LpTerm::Integer(42)), "42");
1446 }
1447
1448 #[test]
1449 fn test_term_to_string_compound() {
1450 let t = LpTerm::compound("f", vec![LpTerm::Integer(1), LpTerm::atom("a")]);
1451 assert_eq!(term_to_string(&t), "f(1,a)");
1452 }
1453
1454 #[test]
1455 fn test_term_to_string_list() {
1456 let t = LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]);
1457 assert_eq!(term_to_string(&t), "[1,2]");
1458 }
1459
1460 #[test]
1463 fn test_parse_term_atom() {
1464 assert_eq!(parse_term("foo"), Some(LpTerm::atom("foo")));
1465 }
1466
1467 #[test]
1468 fn test_parse_term_var() {
1469 assert_eq!(parse_term("X"), Some(LpTerm::var("X")));
1470 }
1471
1472 #[test]
1473 fn test_parse_term_integer() {
1474 assert_eq!(parse_term("42"), Some(LpTerm::Integer(42)));
1475 }
1476
1477 #[test]
1478 fn test_parse_term_compound() {
1479 let t = parse_term("f(a,b)");
1480 assert_eq!(
1481 t,
1482 Some(LpTerm::compound(
1483 "f",
1484 vec![LpTerm::atom("a"), LpTerm::atom("b")]
1485 ))
1486 );
1487 }
1488
1489 #[test]
1490 fn test_parse_list_empty() {
1491 assert_eq!(parse_term("[]"), Some(LpTerm::atom("[]")));
1492 }
1493
1494 #[test]
1495 fn test_parse_list_items() {
1496 let t = parse_term("[1,2,3]");
1497 assert_eq!(
1498 t,
1499 Some(LpTerm::list(vec![
1500 LpTerm::Integer(1),
1501 LpTerm::Integer(2),
1502 LpTerm::Integer(3)
1503 ]))
1504 );
1505 }
1506
1507 #[test]
1508 fn test_parse_clause_fact() {
1509 let c = parse_clause("foo(a).");
1510 assert!(c.is_some());
1511 let c = c.unwrap();
1512 assert!(c.is_fact());
1513 }
1514
1515 #[test]
1516 fn test_parse_clause_rule() {
1517 let c = parse_clause("member(X,[X|_]).");
1518 assert!(c.is_some());
1519 let c = c.unwrap();
1520 assert!(c.is_fact()); }
1523
1524 #[test]
1527 fn test_query_all() {
1528 let db = std_db();
1529 let cfg = default_cfg();
1530 let goal = LpTerm::compound(
1531 "member",
1532 vec![
1533 LpTerm::var("X"),
1534 LpTerm::list(vec![LpTerm::atom("a"), LpTerm::atom("b")]),
1535 ],
1536 );
1537 let results = db.query_all(goal, &cfg);
1538 assert_eq!(results.len(), 2);
1539 }
1540}