Skip to main content

oxilean_std/logic_programming/
functions.rs

1//! Functions for Prolog-style logic programming: unification, SLD resolution, parsing.
2
3use std::collections::HashMap;
4
5use super::types::{
6    functor_arity, LpClause, LpDatabase, LpTerm, Query, ResolutionResult, SolveConfig, Substitution,
7};
8
9// ── Variable renaming ─────────────────────────────────────────────────────────
10
11/// Rename all variables in a clause with a unique suffix to avoid name collisions.
12fn rename_clause(clause: &LpClause, stamp: usize) -> LpClause {
13    let suffix = format!("_{stamp}");
14    LpClause {
15        head: rename_term(&clause.head, &suffix),
16        body: clause
17            .body
18            .iter()
19            .map(|t| rename_term(t, &suffix))
20            .collect(),
21    }
22}
23
24fn rename_term(t: &LpTerm, suffix: &str) -> LpTerm {
25    match t {
26        LpTerm::Var(v) => LpTerm::Var(format!("{v}{suffix}")),
27        LpTerm::Atom(_) | LpTerm::Integer(_) | LpTerm::Float(_) => t.clone(),
28        LpTerm::Compound { functor, args } => LpTerm::Compound {
29            functor: functor.clone(),
30            args: args.iter().map(|a| rename_term(a, suffix)).collect(),
31        },
32        LpTerm::List(items, tail) => LpTerm::List(
33            items.iter().map(|a| rename_term(a, suffix)).collect(),
34            tail.as_ref().map(|t| Box::new(rename_term(t, suffix))),
35        ),
36    }
37}
38
39// ── Occurs check ──────────────────────────────────────────────────────────────
40
41/// Check whether variable `var` occurs in `term` under substitution `subst`.
42///
43/// Used to detect circular bindings (e.g., X = f(X)).
44pub fn occurs_check(var: &str, term: &LpTerm, subst: &Substitution) -> bool {
45    match term {
46        LpTerm::Var(v) => {
47            if v == var {
48                return true;
49            }
50            match subst.lookup(v) {
51                Some(t) => occurs_check(var, &t.clone(), subst),
52                None => false,
53            }
54        }
55        LpTerm::Atom(_) | LpTerm::Integer(_) | LpTerm::Float(_) => false,
56        LpTerm::Compound { args, .. } => args.iter().any(|a| occurs_check(var, a, subst)),
57        LpTerm::List(items, tail) => {
58            items.iter().any(|a| occurs_check(var, a, subst))
59                || tail.as_ref().map_or(false, |t| occurs_check(var, t, subst))
60        }
61    }
62}
63
64// ── Apply substitution ────────────────────────────────────────────────────────
65
66/// Apply a substitution to a term, recursively dereferencing variables.
67pub fn apply_subst(term: &LpTerm, subst: &Substitution) -> LpTerm {
68    match term {
69        LpTerm::Var(v) => match subst.lookup(v) {
70            None => term.clone(),
71            Some(t) => {
72                let t2 = t.clone();
73                // Avoid infinite recursion on self-referential vars
74                if t2 == LpTerm::Var(v.clone()) {
75                    t2
76                } else {
77                    apply_subst(&t2, subst)
78                }
79            }
80        },
81        LpTerm::Atom(_) | LpTerm::Integer(_) | LpTerm::Float(_) => term.clone(),
82        LpTerm::Compound { functor, args } => LpTerm::Compound {
83            functor: functor.clone(),
84            args: args.iter().map(|a| apply_subst(a, subst)).collect(),
85        },
86        LpTerm::List(items, tail) => LpTerm::List(
87            items.iter().map(|a| apply_subst(a, subst)).collect(),
88            tail.as_ref().map(|t| Box::new(apply_subst(t, subst))),
89        ),
90    }
91}
92
93// ── Unification ───────────────────────────────────────────────────────────────
94
95/// Compute the most general unifier of `t1` and `t2` under `subst`.
96///
97/// Returns `Some(new_subst)` on success, `None` on failure.
98pub fn unify(t1: &LpTerm, t2: &LpTerm, subst: &Substitution) -> Option<Substitution> {
99    let t1 = apply_subst(t1, subst);
100    let t2 = apply_subst(t2, subst);
101    unify_walked(&t1, &t2, subst)
102}
103
104fn unify_walked(t1: &LpTerm, t2: &LpTerm, subst: &Substitution) -> Option<Substitution> {
105    match (t1, t2) {
106        // Two identical atoms/ints/floats
107        (LpTerm::Atom(a), LpTerm::Atom(b)) if a == b => Some(subst.clone()),
108        (LpTerm::Integer(a), LpTerm::Integer(b)) if a == b => Some(subst.clone()),
109        (LpTerm::Float(a), LpTerm::Float(b)) if a == b => Some(subst.clone()),
110
111        // Variable on left
112        (LpTerm::Var(v), t) => {
113            let t = apply_subst(t, subst);
114            if let LpTerm::Var(v2) = &t {
115                if v == v2 {
116                    return Some(subst.clone());
117                }
118            }
119            // No occurs check by default; check only if needed
120            let mut new_subst = subst.clone();
121            new_subst.bind(v.clone(), t);
122            Some(new_subst)
123        }
124
125        // Variable on right
126        (t, LpTerm::Var(v)) => {
127            let t = apply_subst(t, subst);
128            let mut new_subst = subst.clone();
129            new_subst.bind(v.clone(), t);
130            Some(new_subst)
131        }
132
133        // Compound terms
134        (
135            LpTerm::Compound {
136                functor: f1,
137                args: a1,
138            },
139            LpTerm::Compound {
140                functor: f2,
141                args: a2,
142            },
143        ) => {
144            if f1 != f2 || a1.len() != a2.len() {
145                return None;
146            }
147            let mut s = subst.clone();
148            for (x, y) in a1.iter().zip(a2.iter()) {
149                s = unify(x, y, &s)?;
150            }
151            Some(s)
152        }
153
154        // Lists
155        (LpTerm::List(items1, tail1), LpTerm::List(items2, tail2)) => {
156            unify_lists(items1, tail1.as_deref(), items2, tail2.as_deref(), subst)
157        }
158
159        // Atom as nil vs empty list
160        (LpTerm::Atom(a), LpTerm::List(items, None)) if a == "[]" && items.is_empty() => {
161            Some(subst.clone())
162        }
163        (LpTerm::List(items, None), LpTerm::Atom(a)) if a == "[]" && items.is_empty() => {
164            Some(subst.clone())
165        }
166
167        _ => None,
168    }
169}
170
171fn unify_lists(
172    items1: &[LpTerm],
173    tail1: Option<&LpTerm>,
174    items2: &[LpTerm],
175    tail2: Option<&LpTerm>,
176    subst: &Substitution,
177) -> Option<Substitution> {
178    match (items1, items2) {
179        ([], []) => {
180            // Both exhausted — unify the tails
181            match (tail1, tail2) {
182                (None, None) => Some(subst.clone()),
183                (Some(t1), Some(t2)) => unify(t1, t2, subst),
184                (None, Some(t)) => unify(&LpTerm::atom("[]"), t, subst),
185                (Some(t), None) => unify(t, &LpTerm::atom("[]"), subst),
186            }
187        }
188        ([], _) => {
189            // items1 exhausted; items2 has remaining — tail1 must unify with remaining list
190            let rest = LpTerm::List(items2.to_vec(), tail2.cloned().map(|t| Box::new(t.clone())));
191            match tail1 {
192                None => None, // proper list vs longer list
193                Some(t) => unify(t, &rest, subst),
194            }
195        }
196        (_, []) => {
197            // items2 exhausted; items1 has remaining — tail2 must unify with remaining list
198            let rest = LpTerm::List(items1.to_vec(), tail1.cloned().map(|t| Box::new(t.clone())));
199            match tail2 {
200                None => None,
201                Some(t) => unify(&rest, t, subst),
202            }
203        }
204        ([h1, rest1 @ ..], [h2, rest2 @ ..]) => {
205            let s = unify(h1, h2, subst)?;
206            unify_lists(rest1, tail1, rest2, tail2, &s)
207        }
208    }
209}
210
211// ── Unification with occurs check ─────────────────────────────────────────────
212
213/// Unify with the occurs check enabled (sound but slower).
214pub fn unify_with_occurs_check(
215    t1: &LpTerm,
216    t2: &LpTerm,
217    subst: &Substitution,
218) -> Option<Substitution> {
219    let t1w = apply_subst(t1, subst);
220    let t2w = apply_subst(t2, subst);
221    unify_oc_walked(&t1w, &t2w, subst)
222}
223
224fn unify_oc_walked(t1: &LpTerm, t2: &LpTerm, subst: &Substitution) -> Option<Substitution> {
225    match (t1, t2) {
226        (LpTerm::Atom(a), LpTerm::Atom(b)) if a == b => Some(subst.clone()),
227        (LpTerm::Integer(a), LpTerm::Integer(b)) if a == b => Some(subst.clone()),
228        (LpTerm::Float(a), LpTerm::Float(b)) if a == b => Some(subst.clone()),
229
230        (LpTerm::Var(v), t) => {
231            let t = apply_subst(t, subst);
232            if let LpTerm::Var(v2) = &t {
233                if v == v2 {
234                    return Some(subst.clone());
235                }
236            }
237            if occurs_check(v, &t, subst) {
238                return None;
239            }
240            let mut s = subst.clone();
241            s.bind(v.clone(), t);
242            Some(s)
243        }
244
245        (t, LpTerm::Var(v)) => {
246            let t = apply_subst(t, subst);
247            if occurs_check(v, &t, subst) {
248                return None;
249            }
250            let mut s = subst.clone();
251            s.bind(v.clone(), t);
252            Some(s)
253        }
254
255        (
256            LpTerm::Compound {
257                functor: f1,
258                args: a1,
259            },
260            LpTerm::Compound {
261                functor: f2,
262                args: a2,
263            },
264        ) => {
265            if f1 != f2 || a1.len() != a2.len() {
266                return None;
267            }
268            let mut s = subst.clone();
269            for (x, y) in a1.iter().zip(a2.iter()) {
270                s = unify_with_occurs_check(x, y, &s)?;
271            }
272            Some(s)
273        }
274
275        (LpTerm::List(i1, t1), LpTerm::List(i2, t2)) => {
276            unify_lists(i1, t1.as_deref(), i2, t2.as_deref(), subst)
277        }
278
279        (LpTerm::Atom(a), LpTerm::List(items, None)) if a == "[]" && items.is_empty() => {
280            Some(subst.clone())
281        }
282        (LpTerm::List(items, None), LpTerm::Atom(a)) if a == "[]" && items.is_empty() => {
283            Some(subst.clone())
284        }
285
286        _ => None,
287    }
288}
289
290// ── SLD Resolution ────────────────────────────────────────────────────────────
291
292/// Collect all solutions to a query using SLD resolution.
293pub fn resolve(query: &Query, db: &LpDatabase, cfg: &SolveConfig) -> Vec<Substitution> {
294    let mut results = Vec::new();
295    let mut counter = 0usize;
296    sld_resolve(
297        query.goals.clone(),
298        &Substitution::new(),
299        db,
300        cfg,
301        0,
302        &mut counter,
303        &mut results,
304    );
305    results
306}
307
308/// Solve the query, returning the first solution or failure.
309pub fn solve_one(query: &Query, db: &LpDatabase, cfg: &SolveConfig) -> ResolutionResult {
310    let mut results = Vec::new();
311    let mut counter = 0usize;
312    let one_cfg = SolveConfig {
313        max_solutions: 1,
314        ..cfg.clone()
315    };
316    sld_resolve(
317        query.goals.clone(),
318        &Substitution::new(),
319        db,
320        &one_cfg,
321        0,
322        &mut counter,
323        &mut results,
324    );
325    match results.into_iter().next() {
326        Some(s) => ResolutionResult::Success(s),
327        None => ResolutionResult::Failure,
328    }
329}
330
331fn sld_resolve(
332    goals: Vec<LpTerm>,
333    subst: &Substitution,
334    db: &LpDatabase,
335    cfg: &SolveConfig,
336    depth: usize,
337    counter: &mut usize,
338    results: &mut Vec<Substitution>,
339) {
340    if results.len() >= cfg.max_solutions {
341        return;
342    }
343    if depth > cfg.max_depth {
344        return;
345    }
346
347    if goals.is_empty() {
348        results.push(subst.clone());
349        return;
350    }
351
352    let goal = apply_subst(&goals[0], subst);
353    let rest_goals = goals[1..].to_vec();
354
355    // Built-in predicates
356    if handle_builtin(&goal, subst, db, cfg, depth, counter, results, &rest_goals) {
357        return;
358    }
359
360    // User-defined predicates
361    let matching: Vec<LpClause> = db.matching_clauses(&goal).into_iter().cloned().collect();
362
363    for clause in &matching {
364        if results.len() >= cfg.max_solutions {
365            break;
366        }
367        *counter += 1;
368        let renamed = rename_clause(clause, *counter);
369        let unifier = if cfg.occurs_check {
370            unify_with_occurs_check(&goal, &renamed.head, subst)
371        } else {
372            unify(&goal, &renamed.head, subst)
373        };
374        if let Some(new_subst) = unifier {
375            let mut new_goals = renamed.body.clone();
376            new_goals.extend(rest_goals.clone());
377            sld_resolve(new_goals, &new_subst, db, cfg, depth + 1, counter, results);
378        }
379    }
380}
381
382/// Handle built-in predicates. Returns true if the goal was handled (even if it failed).
383fn handle_builtin(
384    goal: &LpTerm,
385    subst: &Substitution,
386    db: &LpDatabase,
387    cfg: &SolveConfig,
388    depth: usize,
389    counter: &mut usize,
390    results: &mut Vec<Substitution>,
391    rest_goals: &[LpTerm],
392) -> bool {
393    match goal {
394        // true/0 — always succeeds
395        LpTerm::Atom(a) if a == "true" => {
396            sld_resolve(
397                rest_goals.to_vec(),
398                subst,
399                db,
400                cfg,
401                depth + 1,
402                counter,
403                results,
404            );
405            true
406        }
407        // fail/0 — always fails
408        LpTerm::Atom(a) if a == "fail" || a == "false" => true,
409        // =(X, Y) — unification
410        LpTerm::Compound { functor, args } if functor == "=" && args.len() == 2 => {
411            if let Some(s) = unify(&args[0], &args[1], subst) {
412                sld_resolve(
413                    rest_goals.to_vec(),
414                    &s,
415                    db,
416                    cfg,
417                    depth + 1,
418                    counter,
419                    results,
420                );
421            }
422            true
423        }
424        // \=(X, Y) — negation of unification
425        LpTerm::Compound { functor, args } if functor == "\\=" && args.len() == 2 => {
426            if unify(&args[0], &args[1], subst).is_none() {
427                sld_resolve(
428                    rest_goals.to_vec(),
429                    subst,
430                    db,
431                    cfg,
432                    depth + 1,
433                    counter,
434                    results,
435                );
436            }
437            true
438        }
439        // is/2 — arithmetic evaluation (limited)
440        LpTerm::Compound { functor, args } if functor == "is" && args.len() == 2 => {
441            if let Some(val) = eval_arith(&args[1], subst) {
442                if let Some(s) = unify(&args[0], &val, subst) {
443                    sld_resolve(
444                        rest_goals.to_vec(),
445                        &s,
446                        db,
447                        cfg,
448                        depth + 1,
449                        counter,
450                        results,
451                    );
452                }
453            }
454            true
455        }
456        // =:=/2 — arithmetic equality
457        LpTerm::Compound { functor, args } if functor == "=:=" && args.len() == 2 => {
458            let v1 = eval_arith(&args[0], subst);
459            let v2 = eval_arith(&args[1], subst);
460            if v1 == v2 && v1.is_some() {
461                sld_resolve(
462                    rest_goals.to_vec(),
463                    subst,
464                    db,
465                    cfg,
466                    depth + 1,
467                    counter,
468                    results,
469                );
470            }
471            true
472        }
473        // </2, >/2, =</2, >=/2 — arithmetic comparison
474        LpTerm::Compound { functor, args }
475            if (functor == "<" || functor == ">" || functor == "=<" || functor == ">=")
476                && args.len() == 2 =>
477        {
478            let v1 = eval_arith_f64(&args[0], subst);
479            let v2 = eval_arith_f64(&args[1], subst);
480            let ok = match (v1, v2) {
481                (Some(a), Some(b)) => match functor.as_str() {
482                    "<" => a < b,
483                    ">" => a > b,
484                    "=<" => a <= b,
485                    ">=" => a >= b,
486                    _ => false,
487                },
488                _ => false,
489            };
490            if ok {
491                sld_resolve(
492                    rest_goals.to_vec(),
493                    subst,
494                    db,
495                    cfg,
496                    depth + 1,
497                    counter,
498                    results,
499                );
500            }
501            true
502        }
503        // not/1, \+/1 — negation as failure
504        LpTerm::Compound { functor, args }
505            if (functor == "not" || functor == "\\+") && args.len() == 1 =>
506        {
507            let inner_q = Query::single(args[0].clone());
508            let inner_cfg = SolveConfig {
509                max_solutions: 1,
510                ..cfg.clone()
511            };
512            let inner_results = resolve(&inner_q, db, &inner_cfg);
513            if inner_results.is_empty() {
514                sld_resolve(
515                    rest_goals.to_vec(),
516                    subst,
517                    db,
518                    cfg,
519                    depth + 1,
520                    counter,
521                    results,
522                );
523            }
524            true
525        }
526        // call/1 — meta-call
527        LpTerm::Compound { functor, args } if functor == "call" && args.len() == 1 => {
528            let new_goal = apply_subst(&args[0], subst);
529            let mut new_goals = vec![new_goal];
530            new_goals.extend(rest_goals.to_vec());
531            sld_resolve(new_goals, subst, db, cfg, depth + 1, counter, results);
532            true
533        }
534        _ => false,
535    }
536}
537
538/// Evaluate an arithmetic expression to an integer term.
539fn eval_arith(t: &LpTerm, subst: &Substitution) -> Option<LpTerm> {
540    let t = apply_subst(t, subst);
541    match &t {
542        LpTerm::Integer(n) => Some(LpTerm::Integer(*n)),
543        LpTerm::Float(f) => Some(LpTerm::Float(*f)),
544        LpTerm::Compound { functor, args } if args.len() == 2 => {
545            let a = eval_arith_f64(&args[0], subst)?;
546            let b = eval_arith_f64(&args[1], subst)?;
547            let result = match functor.as_str() {
548                "+" => a + b,
549                "-" => a - b,
550                "*" => a * b,
551                "/" => {
552                    if b == 0.0 {
553                        return None;
554                    }
555                    a / b
556                }
557                "mod" => {
558                    if b == 0.0 {
559                        return None;
560                    }
561                    a % b
562                }
563                "**" | "^" => a.powf(b),
564                _ => return None,
565            };
566            // Return integer if both operands were integers and result is whole
567            if result.fract() == 0.0 && functor != "/" {
568                Some(LpTerm::Integer(result as i64))
569            } else {
570                Some(LpTerm::Float(result))
571            }
572        }
573        LpTerm::Compound { functor, args } if args.len() == 1 => {
574            let a = eval_arith_f64(&args[0], subst)?;
575            let result = match functor.as_str() {
576                "abs" => a.abs(),
577                "sqrt" => a.sqrt(),
578                "floor" => a.floor(),
579                "ceiling" => a.ceil(),
580                "round" => a.round(),
581                "-" => -a,
582                _ => return None,
583            };
584            if result.fract() == 0.0 {
585                Some(LpTerm::Integer(result as i64))
586            } else {
587                Some(LpTerm::Float(result))
588            }
589        }
590        _ => None,
591    }
592}
593
594fn eval_arith_f64(t: &LpTerm, subst: &Substitution) -> Option<f64> {
595    match eval_arith(t, subst)? {
596        LpTerm::Integer(n) => Some(n as f64),
597        LpTerm::Float(f) => Some(f),
598        _ => None,
599    }
600}
601
602// ── LpDatabase methods (query_all) ────────────────────────────────────────────
603
604impl LpDatabase {
605    /// Collect all solutions to a single-goal query.
606    pub fn query_all(&self, goal: LpTerm, cfg: &SolveConfig) -> Vec<Substitution> {
607        let q = Query::single(goal);
608        resolve(&q, self, cfg)
609    }
610}
611
612// ── Term pretty-printing ──────────────────────────────────────────────────────
613
614/// Pretty-print a term to a Prolog-style string.
615pub fn term_to_string(t: &LpTerm) -> String {
616    match t {
617        LpTerm::Atom(s) => {
618            // Quote atoms that need it
619            if needs_quoting(s) {
620                format!("'{}'", s.replace('\'', "\\'"))
621            } else {
622                s.clone()
623            }
624        }
625        LpTerm::Var(v) => v.clone(),
626        LpTerm::Integer(n) => n.to_string(),
627        LpTerm::Float(f) => format!("{f}"),
628        LpTerm::Compound { functor, args } => {
629            let args_str: Vec<String> = args.iter().map(term_to_string).collect();
630            format!("{}({})", functor, args_str.join(","))
631        }
632        LpTerm::List(items, tail) => {
633            let items_str: Vec<String> = items.iter().map(term_to_string).collect();
634            let body = items_str.join(",");
635            match tail {
636                None => format!("[{body}]"),
637                Some(t) => format!("[{body}|{}]", term_to_string(t)),
638            }
639        }
640    }
641}
642
643fn needs_quoting(s: &str) -> bool {
644    if s.is_empty() {
645        return true;
646    }
647    let mut chars = s.chars();
648    let first = match chars.next() {
649        Some(c) => c,
650        None => return true,
651    };
652    // Atoms starting with lowercase letter and containing only alnum/_
653    if first.is_ascii_lowercase() && s.chars().all(|c| c.is_alphanumeric() || c == '_') {
654        return false;
655    }
656    // Operators don't need quoting
657    if s.chars().all(|c| "+-*/\\^<>=~:.?@#&".contains(c)) {
658        return false;
659    }
660    // Special atoms
661    matches!(s, "[]" | "{}" | "!" | ";" | "," | "|")
662        || s.chars().all(|c| c.is_alphanumeric() || c == '_')
663}
664
665// ── Simple term parser ────────────────────────────────────────────────────────
666
667/// Parse a simple Prolog term from a string.
668///
669/// Supports atoms, variables, integers, floats, compound terms, and lists.
670/// Does not support full operator syntax.
671pub fn parse_term(s: &str) -> Option<LpTerm> {
672    let s = s.trim();
673    if s.is_empty() {
674        return None;
675    }
676    parse_term_inner(s)
677}
678
679fn parse_term_inner(s: &str) -> Option<LpTerm> {
680    let s = s.trim();
681
682    // List
683    if s.starts_with('[') && s.ends_with(']') {
684        return parse_list(&s[1..s.len() - 1]);
685    }
686
687    // Quoted atom
688    if s.starts_with('\'') && s.ends_with('\'') && s.len() >= 2 {
689        return Some(LpTerm::Atom(s[1..s.len() - 1].replace("\\'", "'")));
690    }
691
692    // Check for compound: find the outermost '('
693    if let Some(paren_pos) = find_outer_paren(s) {
694        let functor = s[..paren_pos].trim().to_string();
695        let args_str = &s[paren_pos + 1..s.len() - 1];
696        let args = split_args(args_str)
697            .into_iter()
698            .map(|a| parse_term_inner(a.trim()))
699            .collect::<Option<Vec<_>>>()?;
700        return Some(LpTerm::Compound { functor, args });
701    }
702
703    // Integer
704    if let Ok(n) = s.parse::<i64>() {
705        return Some(LpTerm::Integer(n));
706    }
707
708    // Float
709    if let Ok(f) = s.parse::<f64>() {
710        return Some(LpTerm::Float(f));
711    }
712
713    // Variable: starts with uppercase or '_'
714    let first = s.chars().next()?;
715    if first.is_uppercase() || first == '_' {
716        return Some(LpTerm::Var(s.to_string()));
717    }
718
719    // Atom
720    Some(LpTerm::Atom(s.to_string()))
721}
722
723fn parse_list(inner: &str) -> Option<LpTerm> {
724    let inner = inner.trim();
725    if inner.is_empty() {
726        return Some(LpTerm::atom("[]"));
727    }
728
729    // Find '|' at depth 0
730    let mut depth = 0i32;
731    let mut bar_pos = None;
732    let bytes = inner.as_bytes();
733    for (i, &b) in bytes.iter().enumerate() {
734        match b {
735            b'(' | b'[' => depth += 1,
736            b')' | b']' => depth -= 1,
737            b'|' if depth == 0 => {
738                bar_pos = Some(i);
739                break;
740            }
741            _ => {}
742        }
743    }
744
745    if let Some(pos) = bar_pos {
746        let items_str = &inner[..pos];
747        let tail_str = inner[pos + 1..].trim();
748        let items = split_args(items_str)
749            .into_iter()
750            .map(|a| parse_term_inner(a.trim()))
751            .collect::<Option<Vec<_>>>()?;
752        let tail = parse_term_inner(tail_str)?;
753        Some(LpTerm::List(items, Some(Box::new(tail))))
754    } else {
755        let items = split_args(inner)
756            .into_iter()
757            .map(|a| parse_term_inner(a.trim()))
758            .collect::<Option<Vec<_>>>()?;
759        Some(LpTerm::list(items))
760    }
761}
762
763/// Find the position of the outermost '(' in a compound term like `f(...)`.
764fn find_outer_paren(s: &str) -> Option<usize> {
765    let mut depth = 0i32;
766    for (i, c) in s.char_indices() {
767        match c {
768            '(' if depth == 0 => {
769                // Must be preceded by functor name
770                if i > 0 && s.ends_with(')') {
771                    return Some(i);
772                }
773                return None;
774            }
775            '(' => depth += 1,
776            ')' => depth -= 1,
777            _ => {}
778        }
779    }
780    None
781}
782
783/// Split a comma-separated argument list, respecting parentheses and brackets.
784fn split_args(s: &str) -> Vec<&str> {
785    let mut parts = Vec::new();
786    let mut depth = 0i32;
787    let mut start = 0;
788    let bytes = s.as_bytes();
789    for (i, &b) in bytes.iter().enumerate() {
790        match b {
791            b'(' | b'[' => depth += 1,
792            b')' | b']' => depth -= 1,
793            b',' if depth == 0 => {
794                parts.push(&s[start..i]);
795                start = i + 1;
796            }
797            _ => {}
798        }
799    }
800    if start <= s.len() {
801        let tail = s[start..].trim();
802        if !tail.is_empty() {
803            parts.push(&s[start..]);
804        }
805    }
806    parts
807}
808
809/// Parse a Horn clause from a string: `head :- b1, b2.` or `head.`
810pub fn parse_clause(s: &str) -> Option<LpClause> {
811    let s = s.trim().trim_end_matches('.');
812    if let Some(pos) = s.find(":-") {
813        let head_str = s[..pos].trim();
814        let body_str = s[pos + 2..].trim();
815        let head = parse_term(head_str)?;
816        let body = split_args(body_str)
817            .into_iter()
818            .map(|a| parse_term(a.trim()))
819            .collect::<Option<Vec<_>>>()?;
820        Some(LpClause::rule(head, body))
821    } else {
822        let head = parse_term(s)?;
823        Some(LpClause::fact(head))
824    }
825}
826
827// ── Classic Prolog library predicates ────────────────────────────────────────
828
829/// Populate a database with classic Prolog predicates: member/2, append/3, reverse/3, length/2, last/2.
830pub fn load_standard_predicates(db: &mut LpDatabase) {
831    // member(X, [X|_]).
832    db.add_fact(LpTerm::compound(
833        "member",
834        vec![
835            LpTerm::var("X"),
836            LpTerm::list_with_tail(vec![LpTerm::var("X")], LpTerm::var("_T")),
837        ],
838    ));
839    // member(X, [_|T]) :- member(X, T).
840    db.add_rule(
841        LpTerm::compound(
842            "member",
843            vec![
844                LpTerm::var("X"),
845                LpTerm::list_with_tail(vec![LpTerm::var("_H")], LpTerm::var("T")),
846            ],
847        ),
848        vec![LpTerm::compound(
849            "member",
850            vec![LpTerm::var("X"), LpTerm::var("T")],
851        )],
852    );
853
854    // append([], L, L).
855    db.add_fact(LpTerm::compound(
856        "append",
857        vec![LpTerm::atom("[]"), LpTerm::var("L"), LpTerm::var("L")],
858    ));
859    // append([H|T], L, [H|R]) :- append(T, L, R).
860    db.add_rule(
861        LpTerm::compound(
862            "append",
863            vec![
864                LpTerm::list_with_tail(vec![LpTerm::var("H")], LpTerm::var("T")),
865                LpTerm::var("L"),
866                LpTerm::list_with_tail(vec![LpTerm::var("H")], LpTerm::var("R")),
867            ],
868        ),
869        vec![LpTerm::compound(
870            "append",
871            vec![LpTerm::var("T"), LpTerm::var("L"), LpTerm::var("R")],
872        )],
873    );
874
875    // reverse([], Acc, Acc).
876    db.add_fact(LpTerm::compound(
877        "reverse_acc",
878        vec![LpTerm::atom("[]"), LpTerm::var("Acc"), LpTerm::var("Acc")],
879    ));
880    // reverse_acc([H|T], Acc, Rev) :- reverse_acc(T, [H|Acc], Rev).
881    db.add_rule(
882        LpTerm::compound(
883            "reverse_acc",
884            vec![
885                LpTerm::list_with_tail(vec![LpTerm::var("H")], LpTerm::var("T")),
886                LpTerm::var("Acc"),
887                LpTerm::var("Rev"),
888            ],
889        ),
890        vec![LpTerm::compound(
891            "reverse_acc",
892            vec![
893                LpTerm::var("T"),
894                LpTerm::list_with_tail(vec![LpTerm::var("H")], LpTerm::var("Acc")),
895                LpTerm::var("Rev"),
896            ],
897        )],
898    );
899    // reverse(L, R) :- reverse_acc(L, [], R).
900    db.add_rule(
901        LpTerm::compound("reverse", vec![LpTerm::var("L"), LpTerm::var("R")]),
902        vec![LpTerm::compound(
903            "reverse_acc",
904            vec![LpTerm::var("L"), LpTerm::atom("[]"), LpTerm::var("R")],
905        )],
906    );
907
908    // length([], 0).
909    db.add_fact(LpTerm::compound(
910        "length",
911        vec![LpTerm::atom("[]"), LpTerm::Integer(0)],
912    ));
913    // length([_|T], N) :- length(T, N1), N is N1 + 1.
914    db.add_rule(
915        LpTerm::compound(
916            "length",
917            vec![
918                LpTerm::list_with_tail(vec![LpTerm::var("_H2")], LpTerm::var("T2")),
919                LpTerm::var("N"),
920            ],
921        ),
922        vec![
923            LpTerm::compound("length", vec![LpTerm::var("T2"), LpTerm::var("N1")]),
924            LpTerm::compound(
925                "is",
926                vec![
927                    LpTerm::var("N"),
928                    LpTerm::compound("+", vec![LpTerm::var("N1"), LpTerm::Integer(1)]),
929                ],
930            ),
931        ],
932    );
933
934    // last([X], X).
935    db.add_fact(LpTerm::compound(
936        "last",
937        vec![LpTerm::list(vec![LpTerm::var("X")]), LpTerm::var("X")],
938    ));
939    // last([_|T], X) :- last(T, X).
940    db.add_rule(
941        LpTerm::compound(
942            "last",
943            vec![
944                LpTerm::list_with_tail(vec![LpTerm::var("_HL")], LpTerm::var("TL")),
945                LpTerm::var("XL"),
946            ],
947        ),
948        vec![LpTerm::compound(
949            "last",
950            vec![LpTerm::var("TL"), LpTerm::var("XL")],
951        )],
952    );
953
954    // nat/1 — natural number generator (bounded by depth)
955    // nat(0).
956    db.add_fact(LpTerm::compound("nat", vec![LpTerm::Integer(0)]));
957    // nat(N) :- nat(N1), N is N1 + 1.
958    db.add_rule(
959        LpTerm::compound("nat", vec![LpTerm::var("N")]),
960        vec![
961            LpTerm::compound("nat", vec![LpTerm::var("N1")]),
962            LpTerm::compound(
963                "is",
964                vec![
965                    LpTerm::var("N"),
966                    LpTerm::compound("+", vec![LpTerm::var("N1"), LpTerm::Integer(1)]),
967                ],
968            ),
969        ],
970    );
971}
972
973// ─────────────────────────────────────────────────────────────────────────────
974// Tests
975// ─────────────────────────────────────────────────────────────────────────────
976
977#[cfg(test)]
978mod tests {
979    use super::*;
980
981    fn empty_subst() -> Substitution {
982        Substitution::new()
983    }
984
985    /// Flatten a (possibly nested) list representation into a Vec of elements.
986    fn flatten_list(t: &LpTerm) -> Vec<LpTerm> {
987        let mut result = Vec::new();
988        flatten_list_into(t, &mut result);
989        result
990    }
991
992    fn flatten_list_into(t: &LpTerm, out: &mut Vec<LpTerm>) {
993        match t {
994            LpTerm::Atom(a) if a == "[]" => {}
995            LpTerm::List(items, tail) => {
996                for item in items {
997                    out.push(item.clone());
998                }
999                if let Some(tl) = tail {
1000                    flatten_list_into(tl, out);
1001                }
1002            }
1003            _ => out.push(t.clone()),
1004        }
1005    }
1006
1007    fn default_cfg() -> SolveConfig {
1008        SolveConfig::default()
1009    }
1010
1011    fn std_db() -> LpDatabase {
1012        let mut db = LpDatabase::new();
1013        load_standard_predicates(&mut db);
1014        db
1015    }
1016
1017    // ── Unification tests ────────────────────────────────────────────────────
1018
1019    #[test]
1020    fn test_unify_atoms_equal() {
1021        let s = unify(&LpTerm::atom("foo"), &LpTerm::atom("foo"), &empty_subst());
1022        assert!(s.is_some());
1023    }
1024
1025    #[test]
1026    fn test_unify_atoms_different() {
1027        let s = unify(&LpTerm::atom("foo"), &LpTerm::atom("bar"), &empty_subst());
1028        assert!(s.is_none());
1029    }
1030
1031    #[test]
1032    fn test_unify_var_atom() {
1033        let s = unify(&LpTerm::var("X"), &LpTerm::atom("hello"), &empty_subst());
1034        assert!(s.is_some());
1035        let s = s.unwrap();
1036        assert_eq!(s.lookup("X"), Some(&LpTerm::atom("hello")));
1037    }
1038
1039    #[test]
1040    fn test_unify_compound() {
1041        let t1 = LpTerm::compound("f", vec![LpTerm::var("X"), LpTerm::Integer(1)]);
1042        let t2 = LpTerm::compound("f", vec![LpTerm::atom("a"), LpTerm::Integer(1)]);
1043        let s = unify(&t1, &t2, &empty_subst());
1044        assert!(s.is_some());
1045        let s = s.unwrap();
1046        assert_eq!(s.lookup("X"), Some(&LpTerm::atom("a")));
1047    }
1048
1049    #[test]
1050    fn test_unify_compound_arity_mismatch() {
1051        let t1 = LpTerm::compound("f", vec![LpTerm::var("X")]);
1052        let t2 = LpTerm::compound("f", vec![LpTerm::var("X"), LpTerm::var("Y")]);
1053        assert!(unify(&t1, &t2, &empty_subst()).is_none());
1054    }
1055
1056    #[test]
1057    fn test_unify_list() {
1058        let t1 = LpTerm::list(vec![LpTerm::var("X"), LpTerm::Integer(2)]);
1059        let t2 = LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]);
1060        let s = unify(&t1, &t2, &empty_subst());
1061        assert!(s.is_some());
1062        let s = s.unwrap();
1063        assert_eq!(apply_subst(&LpTerm::var("X"), &s), LpTerm::Integer(1));
1064    }
1065
1066    #[test]
1067    fn test_unify_list_different_length() {
1068        let t1 = LpTerm::list(vec![LpTerm::Integer(1)]);
1069        let t2 = LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]);
1070        assert!(unify(&t1, &t2, &empty_subst()).is_none());
1071    }
1072
1073    #[test]
1074    fn test_unify_integers() {
1075        let s = unify(&LpTerm::Integer(42), &LpTerm::Integer(42), &empty_subst());
1076        assert!(s.is_some());
1077        let s = unify(&LpTerm::Integer(1), &LpTerm::Integer(2), &empty_subst());
1078        assert!(s.is_none());
1079    }
1080
1081    // ── Apply substitution tests ─────────────────────────────────────────────
1082
1083    #[test]
1084    fn test_apply_subst_var() {
1085        let mut s = Substitution::new();
1086        s.bind("X", LpTerm::atom("hello"));
1087        assert_eq!(apply_subst(&LpTerm::var("X"), &s), LpTerm::atom("hello"));
1088    }
1089
1090    #[test]
1091    fn test_apply_subst_compound() {
1092        let mut s = Substitution::new();
1093        s.bind("X", LpTerm::Integer(5));
1094        let t = LpTerm::compound("f", vec![LpTerm::var("X"), LpTerm::Integer(1)]);
1095        let result = apply_subst(&t, &s);
1096        assert_eq!(
1097            result,
1098            LpTerm::compound("f", vec![LpTerm::Integer(5), LpTerm::Integer(1)])
1099        );
1100    }
1101
1102    // ── Occurs check tests ───────────────────────────────────────────────────
1103
1104    #[test]
1105    fn test_occurs_check_direct() {
1106        let s = empty_subst();
1107        assert!(occurs_check("X", &LpTerm::var("X"), &s));
1108    }
1109
1110    #[test]
1111    fn test_occurs_check_in_compound() {
1112        let s = empty_subst();
1113        let t = LpTerm::compound("f", vec![LpTerm::var("X")]);
1114        assert!(occurs_check("X", &t, &s));
1115    }
1116
1117    #[test]
1118    fn test_occurs_check_not_present() {
1119        let s = empty_subst();
1120        let t = LpTerm::compound("f", vec![LpTerm::var("Y")]);
1121        assert!(!occurs_check("X", &t, &s));
1122    }
1123
1124    #[test]
1125    fn test_occurs_check_prevents_circular() {
1126        let s = empty_subst();
1127        let t = LpTerm::compound("f", vec![LpTerm::var("X")]);
1128        let result = unify_with_occurs_check(&LpTerm::var("X"), &t, &s);
1129        assert!(result.is_none());
1130    }
1131
1132    // ── Resolution: member/2 ─────────────────────────────────────────────────
1133
1134    #[test]
1135    fn test_member_first() {
1136        let db = std_db();
1137        let cfg = default_cfg();
1138        let q = Query::single(LpTerm::compound(
1139            "member",
1140            vec![
1141                LpTerm::Integer(1),
1142                LpTerm::list(vec![
1143                    LpTerm::Integer(1),
1144                    LpTerm::Integer(2),
1145                    LpTerm::Integer(3),
1146                ]),
1147            ],
1148        ));
1149        let results = resolve(&q, &db, &cfg);
1150        assert!(!results.is_empty(), "member(1, [1,2,3]) should succeed");
1151    }
1152
1153    #[test]
1154    fn test_member_middle() {
1155        let db = std_db();
1156        let cfg = default_cfg();
1157        let q = Query::single(LpTerm::compound(
1158            "member",
1159            vec![
1160                LpTerm::Integer(2),
1161                LpTerm::list(vec![
1162                    LpTerm::Integer(1),
1163                    LpTerm::Integer(2),
1164                    LpTerm::Integer(3),
1165                ]),
1166            ],
1167        ));
1168        let results = resolve(&q, &db, &cfg);
1169        assert!(!results.is_empty());
1170    }
1171
1172    #[test]
1173    fn test_member_not_found() {
1174        let db = std_db();
1175        let cfg = default_cfg();
1176        let q = Query::single(LpTerm::compound(
1177            "member",
1178            vec![
1179                LpTerm::Integer(99),
1180                LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]),
1181            ],
1182        ));
1183        let results = resolve(&q, &db, &cfg);
1184        assert!(results.is_empty());
1185    }
1186
1187    #[test]
1188    fn test_member_enumerate() {
1189        let db = std_db();
1190        let cfg = default_cfg();
1191        let q = Query::single(LpTerm::compound(
1192            "member",
1193            vec![
1194                LpTerm::var("X"),
1195                LpTerm::list(vec![
1196                    LpTerm::atom("a"),
1197                    LpTerm::atom("b"),
1198                    LpTerm::atom("c"),
1199                ]),
1200            ],
1201        ));
1202        let results = resolve(&q, &db, &cfg);
1203        assert_eq!(results.len(), 3, "Should enumerate all 3 members");
1204    }
1205
1206    // ── Resolution: append/3 ─────────────────────────────────────────────────
1207
1208    #[test]
1209    fn test_append_concrete() {
1210        let db = std_db();
1211        let cfg = default_cfg();
1212        let q = Query::single(LpTerm::compound(
1213            "append",
1214            vec![
1215                LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]),
1216                LpTerm::list(vec![LpTerm::Integer(3)]),
1217                LpTerm::var("R"),
1218            ],
1219        ));
1220        let results = resolve(&q, &db, &cfg);
1221        assert_eq!(results.len(), 1);
1222        let r = apply_subst(&LpTerm::var("R"), &results[0]);
1223        let flat = flatten_list(&r);
1224        assert_eq!(
1225            flat,
1226            vec![LpTerm::Integer(1), LpTerm::Integer(2), LpTerm::Integer(3)]
1227        );
1228    }
1229
1230    #[test]
1231    fn test_append_split() {
1232        // append(X, Y, [1,2]) — find all splits
1233        let db = std_db();
1234        let cfg = default_cfg();
1235        let q = Query::single(LpTerm::compound(
1236            "append",
1237            vec![
1238                LpTerm::var("X"),
1239                LpTerm::var("Y"),
1240                LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]),
1241            ],
1242        ));
1243        let results = resolve(&q, &db, &cfg);
1244        // Should find: ([], [1,2]), ([1], [2]), ([1,2], [])
1245        assert_eq!(results.len(), 3);
1246    }
1247
1248    // ── Resolution: reverse/2 ────────────────────────────────────────────────
1249
1250    #[test]
1251    fn test_reverse() {
1252        let db = std_db();
1253        let cfg = default_cfg();
1254        let q = Query::single(LpTerm::compound(
1255            "reverse",
1256            vec![
1257                LpTerm::list(vec![
1258                    LpTerm::Integer(1),
1259                    LpTerm::Integer(2),
1260                    LpTerm::Integer(3),
1261                ]),
1262                LpTerm::var("R"),
1263            ],
1264        ));
1265        let results = resolve(&q, &db, &cfg);
1266        assert_eq!(results.len(), 1);
1267        let r = apply_subst(&LpTerm::var("R"), &results[0]);
1268        let flat = flatten_list(&r);
1269        assert_eq!(
1270            flat,
1271            vec![LpTerm::Integer(3), LpTerm::Integer(2), LpTerm::Integer(1)]
1272        );
1273    }
1274
1275    // ── Built-in: true/fail ──────────────────────────────────────────────────
1276
1277    #[test]
1278    fn test_builtin_true() {
1279        let db = LpDatabase::new();
1280        let cfg = default_cfg();
1281        let q = Query::single(LpTerm::atom("true"));
1282        let results = resolve(&q, &db, &cfg);
1283        assert_eq!(results.len(), 1);
1284    }
1285
1286    #[test]
1287    fn test_builtin_fail() {
1288        let db = LpDatabase::new();
1289        let cfg = default_cfg();
1290        let q = Query::single(LpTerm::atom("fail"));
1291        let results = resolve(&q, &db, &cfg);
1292        assert!(results.is_empty());
1293    }
1294
1295    // ── Built-in: =/2 unification ────────────────────────────────────────────
1296
1297    #[test]
1298    fn test_builtin_unify() {
1299        let db = LpDatabase::new();
1300        let cfg = default_cfg();
1301        let q = Query::single(LpTerm::compound(
1302            "=",
1303            vec![LpTerm::var("X"), LpTerm::Integer(42)],
1304        ));
1305        let results = resolve(&q, &db, &cfg);
1306        assert_eq!(results.len(), 1);
1307        let val = apply_subst(&LpTerm::var("X"), &results[0]);
1308        assert_eq!(val, LpTerm::Integer(42));
1309    }
1310
1311    // ── Built-in: is/2 arithmetic ────────────────────────────────────────────
1312
1313    #[test]
1314    fn test_builtin_is_add() {
1315        let db = LpDatabase::new();
1316        let cfg = default_cfg();
1317        let q = Query::single(LpTerm::compound(
1318            "is",
1319            vec![
1320                LpTerm::var("X"),
1321                LpTerm::compound("+", vec![LpTerm::Integer(3), LpTerm::Integer(4)]),
1322            ],
1323        ));
1324        let results = resolve(&q, &db, &cfg);
1325        assert_eq!(results.len(), 1);
1326        let val = apply_subst(&LpTerm::var("X"), &results[0]);
1327        assert_eq!(val, LpTerm::Integer(7));
1328    }
1329
1330    #[test]
1331    fn test_builtin_is_mul() {
1332        let db = LpDatabase::new();
1333        let cfg = default_cfg();
1334        let q = Query::single(LpTerm::compound(
1335            "is",
1336            vec![
1337                LpTerm::var("X"),
1338                LpTerm::compound("*", vec![LpTerm::Integer(6), LpTerm::Integer(7)]),
1339            ],
1340        ));
1341        let results = resolve(&q, &db, &cfg);
1342        assert_eq!(results.len(), 1);
1343        let val = apply_subst(&LpTerm::var("X"), &results[0]);
1344        assert_eq!(val, LpTerm::Integer(42));
1345    }
1346
1347    // ── Built-in: comparison ─────────────────────────────────────────────────
1348
1349    #[test]
1350    fn test_builtin_less_than() {
1351        let db = LpDatabase::new();
1352        let cfg = default_cfg();
1353        let q = Query::single(LpTerm::compound(
1354            "<",
1355            vec![LpTerm::Integer(3), LpTerm::Integer(5)],
1356        ));
1357        let results = resolve(&q, &db, &cfg);
1358        assert_eq!(results.len(), 1);
1359    }
1360
1361    #[test]
1362    fn test_builtin_less_than_false() {
1363        let db = LpDatabase::new();
1364        let cfg = default_cfg();
1365        let q = Query::single(LpTerm::compound(
1366            "<",
1367            vec![LpTerm::Integer(5), LpTerm::Integer(3)],
1368        ));
1369        let results = resolve(&q, &db, &cfg);
1370        assert!(results.is_empty());
1371    }
1372
1373    // ── Built-in: \+/1 negation ──────────────────────────────────────────────
1374
1375    #[test]
1376    fn test_negation_as_failure() {
1377        let db = LpDatabase::new();
1378        let cfg = default_cfg();
1379        // \+(fail) should succeed
1380        let q = Query::single(LpTerm::compound("\\+", vec![LpTerm::atom("fail")]));
1381        let results = resolve(&q, &db, &cfg);
1382        assert_eq!(results.len(), 1);
1383    }
1384
1385    #[test]
1386    fn test_negation_as_failure_fail() {
1387        let db = LpDatabase::new();
1388        let cfg = default_cfg();
1389        // \+(true) should fail
1390        let q = Query::single(LpTerm::compound("\\+", vec![LpTerm::atom("true")]));
1391        let results = resolve(&q, &db, &cfg);
1392        assert!(results.is_empty());
1393    }
1394
1395    // ── solve_one ────────────────────────────────────────────────────────────
1396
1397    #[test]
1398    fn test_solve_one_success() {
1399        let db = std_db();
1400        let cfg = default_cfg();
1401        let q = Query::single(LpTerm::compound(
1402            "member",
1403            vec![
1404                LpTerm::Integer(1),
1405                LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]),
1406            ],
1407        ));
1408        match solve_one(&q, &db, &cfg) {
1409            ResolutionResult::Success(_) => {}
1410            _ => panic!("Expected success"),
1411        }
1412    }
1413
1414    #[test]
1415    fn test_solve_one_failure() {
1416        let db = std_db();
1417        let cfg = default_cfg();
1418        let q = Query::single(LpTerm::compound(
1419            "member",
1420            vec![
1421                LpTerm::Integer(99),
1422                LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]),
1423            ],
1424        ));
1425        match solve_one(&q, &db, &cfg) {
1426            ResolutionResult::Failure => {}
1427            _ => panic!("Expected failure"),
1428        }
1429    }
1430
1431    // ── Term pretty-printing ─────────────────────────────────────────────────
1432
1433    #[test]
1434    fn test_term_to_string_atom() {
1435        assert_eq!(term_to_string(&LpTerm::atom("hello")), "hello");
1436    }
1437
1438    #[test]
1439    fn test_term_to_string_var() {
1440        assert_eq!(term_to_string(&LpTerm::var("X")), "X");
1441    }
1442
1443    #[test]
1444    fn test_term_to_string_integer() {
1445        assert_eq!(term_to_string(&LpTerm::Integer(42)), "42");
1446    }
1447
1448    #[test]
1449    fn test_term_to_string_compound() {
1450        let t = LpTerm::compound("f", vec![LpTerm::Integer(1), LpTerm::atom("a")]);
1451        assert_eq!(term_to_string(&t), "f(1,a)");
1452    }
1453
1454    #[test]
1455    fn test_term_to_string_list() {
1456        let t = LpTerm::list(vec![LpTerm::Integer(1), LpTerm::Integer(2)]);
1457        assert_eq!(term_to_string(&t), "[1,2]");
1458    }
1459
1460    // ── Simple parser ────────────────────────────────────────────────────────
1461
1462    #[test]
1463    fn test_parse_term_atom() {
1464        assert_eq!(parse_term("foo"), Some(LpTerm::atom("foo")));
1465    }
1466
1467    #[test]
1468    fn test_parse_term_var() {
1469        assert_eq!(parse_term("X"), Some(LpTerm::var("X")));
1470    }
1471
1472    #[test]
1473    fn test_parse_term_integer() {
1474        assert_eq!(parse_term("42"), Some(LpTerm::Integer(42)));
1475    }
1476
1477    #[test]
1478    fn test_parse_term_compound() {
1479        let t = parse_term("f(a,b)");
1480        assert_eq!(
1481            t,
1482            Some(LpTerm::compound(
1483                "f",
1484                vec![LpTerm::atom("a"), LpTerm::atom("b")]
1485            ))
1486        );
1487    }
1488
1489    #[test]
1490    fn test_parse_list_empty() {
1491        assert_eq!(parse_term("[]"), Some(LpTerm::atom("[]")));
1492    }
1493
1494    #[test]
1495    fn test_parse_list_items() {
1496        let t = parse_term("[1,2,3]");
1497        assert_eq!(
1498            t,
1499            Some(LpTerm::list(vec![
1500                LpTerm::Integer(1),
1501                LpTerm::Integer(2),
1502                LpTerm::Integer(3)
1503            ]))
1504        );
1505    }
1506
1507    #[test]
1508    fn test_parse_clause_fact() {
1509        let c = parse_clause("foo(a).");
1510        assert!(c.is_some());
1511        let c = c.unwrap();
1512        assert!(c.is_fact());
1513    }
1514
1515    #[test]
1516    fn test_parse_clause_rule() {
1517        let c = parse_clause("member(X,[X|_]).");
1518        assert!(c.is_some());
1519        let c = c.unwrap();
1520        // head should be member/2
1521        assert!(c.is_fact()); // no :- in this one
1522    }
1523
1524    // ── query_all convenience ────────────────────────────────────────────────
1525
1526    #[test]
1527    fn test_query_all() {
1528        let db = std_db();
1529        let cfg = default_cfg();
1530        let goal = LpTerm::compound(
1531            "member",
1532            vec![
1533                LpTerm::var("X"),
1534                LpTerm::list(vec![LpTerm::atom("a"), LpTerm::atom("b")]),
1535            ],
1536        );
1537        let results = db.query_all(goal, &cfg);
1538        assert_eq!(results.len(), 2);
1539    }
1540}