Skip to main content

patch_prolog_core/
builtins.rs

1use crate::term::{StringInterner, Term};
2use crate::unify::Substitution;
3
4/// Check if a goal is a built-in predicate.
5pub fn is_builtin(goal: &Term, interner: &StringInterner) -> bool {
6    match goal {
7        Term::Atom(id) => {
8            let name = interner.resolve(*id);
9            matches!(name, "true" | "fail" | "false" | "!" | "nl")
10        }
11        Term::Compound { functor, args } => {
12            let name = interner.resolve(*functor);
13            match (name, args.len()) {
14                ("=", 2) | ("\\=", 2) | ("is", 2) => true,
15                ("<", 2) | (">", 2) | ("=<", 2) | (">=", 2) => true,
16                ("=:=", 2) | ("=\\=", 2) => true,
17                ("\\+", 1) => true,
18                // Type-checking predicates
19                ("var", 1) | ("nonvar", 1) | ("atom", 1) | ("number", 1) => true,
20                ("integer", 1) | ("float", 1) | ("compound", 1) | ("is_list", 1) => true,
21                // Control flow
22                (";", 2) | ("->", 2) | (",", 2) => true,
23                // Solution collection
24                ("findall", 3) => true,
25                // Meta-call
26                ("once", 1) | ("call", 1) => true,
27                // Atom/string predicates
28                ("atom_length", 2) | ("atom_concat", 3) | ("atom_chars", 2) => true,
29                // I/O predicates
30                ("write", 1) | ("writeln", 1) => true,
31                // Term ordering
32                ("compare", 3) => true,
33                ("@<", 2) | ("@>", 2) | ("@=<", 2) | ("@>=", 2) => true,
34                // Term introspection
35                ("functor", 3) | ("arg", 3) | ("=..", 2) => true,
36                // Integer enumeration
37                ("between", 3) => true,
38                // Term copying
39                ("copy_term", 2) => true,
40                // Peano arithmetic
41                ("succ", 2) | ("plus", 3) => true,
42                // List sorting
43                ("msort", 2) | ("sort", 2) => true,
44                // Number/string conversion
45                ("number_chars", 2) | ("number_codes", 2) => true,
46                _ => false,
47            }
48        }
49        _ => false,
50    }
51}
52
53/// Result of executing a builtin.
54#[derive(Debug)]
55pub enum BuiltinResult {
56    /// The builtin succeeded (substitution may have been modified).
57    Success,
58    /// The builtin failed.
59    Failure,
60    /// Cut: succeed and signal cut to the solver.
61    Cut,
62    /// Negation-as-failure: the solver needs to try the inner goal.
63    NegationAsFailure(Term),
64    /// Disjunction: try left, then right on backtracking.
65    Disjunction(Term, Term),
66    /// If-then-else: ;(->(Cond, Then), Else)
67    IfThenElse(Term, Term, Term),
68    /// If-then (no else): ->(Cond, Then)
69    IfThen(Term, Term),
70    /// Conjunction: ','(A, B) — flatten into goal list.
71    Conjunction(Term, Term),
72    /// findall/3: Template, Goal, Result list.
73    FindAll(Term, Term, Term),
74    /// once/1: solve goal, take first solution only.
75    Once(Term),
76    /// call/1: execute a term as a goal.
77    Call(Term),
78    /// atom_length/2: atom, length
79    AtomLength(Term, Term),
80    /// atom_concat/3: atom1, atom2, result
81    AtomConcat(Term, Term, Term),
82    /// atom_chars/2: atom, char list
83    AtomChars(Term, Term),
84    /// write/1: write term to stdout (no newline).
85    Write(Term),
86    /// writeln/1: write term to stdout with newline.
87    Writeln(Term),
88    /// nl/0: write newline to stdout.
89    Nl,
90    /// compare/3: Order, Term1, Term2 — standard term ordering.
91    Compare(Term, Term, Term),
92    /// functor/3: Term, Name, Arity.
93    Functor(Term, Term, Term),
94    /// arg/3: N, Term, Arg.
95    Arg(Term, Term, Term),
96    /// =../2: Term, List (univ).
97    Univ(Term, Term),
98    /// between/3: Low, High, X — integer enumeration.
99    Between(Term, Term, Term),
100    /// copy_term/2: Original, Copy.
101    CopyTerm(Term, Term),
102    /// succ/2: X, S — successor relation.
103    Succ(Term, Term),
104    /// plus/3: X, Y, Z — addition relation.
105    Plus(Term, Term, Term),
106    /// msort/2: List, Sorted.
107    MSort(Term, Term),
108    /// sort/2: List, Sorted.
109    Sort(Term, Term),
110    /// number_chars/2: Number, Chars.
111    NumberChars(Term, Term),
112    /// number_codes/2: Number, Codes.
113    NumberCodes(Term, Term),
114}
115
116/// Execute a built-in predicate.
117pub fn exec_builtin(
118    goal: &Term,
119    subst: &mut Substitution,
120    interner: &StringInterner,
121) -> Result<BuiltinResult, String> {
122    match goal {
123        Term::Atom(id) => {
124            let name = interner.resolve(*id);
125            match name {
126                "true" => Ok(BuiltinResult::Success),
127                "fail" | "false" => Ok(BuiltinResult::Failure),
128                "!" => Ok(BuiltinResult::Cut),
129                "nl" => Ok(BuiltinResult::Nl),
130                _ => Err(format!("Unknown builtin atom: {}", name)),
131            }
132        }
133        Term::Compound { functor, args } => {
134            let name = interner.resolve(*functor);
135            match (name, args.len()) {
136                ("=", 2) => {
137                    if subst.unify(&args[0], &args[1]) {
138                        Ok(BuiltinResult::Success)
139                    } else {
140                        Ok(BuiltinResult::Failure)
141                    }
142                }
143                ("\\=", 2) => {
144                    let mark = subst.trail_mark();
145                    if subst.unify(&args[0], &args[1]) {
146                        subst.undo_to(mark);
147                        Ok(BuiltinResult::Failure)
148                    } else {
149                        subst.undo_to(mark);
150                        Ok(BuiltinResult::Success)
151                    }
152                }
153                ("is", 2) => {
154                    let result = eval_arith(&args[1], subst, interner)?;
155                    let result_term = arith_to_term(result);
156                    if subst.unify(&args[0], &result_term) {
157                        Ok(BuiltinResult::Success)
158                    } else {
159                        Ok(BuiltinResult::Failure)
160                    }
161                }
162                ("<", 2) => {
163                    let l = eval_arith(&args[0], subst, interner)?;
164                    let r = eval_arith(&args[1], subst, interner)?;
165                    if arith_lt(&l, &r) {
166                        Ok(BuiltinResult::Success)
167                    } else {
168                        Ok(BuiltinResult::Failure)
169                    }
170                }
171                (">", 2) => {
172                    let l = eval_arith(&args[0], subst, interner)?;
173                    let r = eval_arith(&args[1], subst, interner)?;
174                    if arith_gt(&l, &r) {
175                        Ok(BuiltinResult::Success)
176                    } else {
177                        Ok(BuiltinResult::Failure)
178                    }
179                }
180                ("=<", 2) => {
181                    let l = eval_arith(&args[0], subst, interner)?;
182                    let r = eval_arith(&args[1], subst, interner)?;
183                    if !arith_gt(&l, &r) {
184                        Ok(BuiltinResult::Success)
185                    } else {
186                        Ok(BuiltinResult::Failure)
187                    }
188                }
189                (">=", 2) => {
190                    let l = eval_arith(&args[0], subst, interner)?;
191                    let r = eval_arith(&args[1], subst, interner)?;
192                    if !arith_lt(&l, &r) {
193                        Ok(BuiltinResult::Success)
194                    } else {
195                        Ok(BuiltinResult::Failure)
196                    }
197                }
198                ("=:=", 2) => {
199                    let l = eval_arith(&args[0], subst, interner)?;
200                    let r = eval_arith(&args[1], subst, interner)?;
201                    if arith_eq(&l, &r) {
202                        Ok(BuiltinResult::Success)
203                    } else {
204                        Ok(BuiltinResult::Failure)
205                    }
206                }
207                ("=\\=", 2) => {
208                    let l = eval_arith(&args[0], subst, interner)?;
209                    let r = eval_arith(&args[1], subst, interner)?;
210                    if !arith_eq(&l, &r) {
211                        Ok(BuiltinResult::Success)
212                    } else {
213                        Ok(BuiltinResult::Failure)
214                    }
215                }
216                ("\\+", 1) => Ok(BuiltinResult::NegationAsFailure(args[0].clone())),
217                // Type-checking predicates
218                ("var", 1) => {
219                    let walked = subst.walk(&args[0]);
220                    if matches!(walked, Term::Var(_)) {
221                        Ok(BuiltinResult::Success)
222                    } else {
223                        Ok(BuiltinResult::Failure)
224                    }
225                }
226                ("nonvar", 1) => {
227                    let walked = subst.walk(&args[0]);
228                    if matches!(walked, Term::Var(_)) {
229                        Ok(BuiltinResult::Failure)
230                    } else {
231                        Ok(BuiltinResult::Success)
232                    }
233                }
234                ("atom", 1) => {
235                    let walked = subst.walk(&args[0]);
236                    if matches!(walked, Term::Atom(_)) {
237                        Ok(BuiltinResult::Success)
238                    } else {
239                        Ok(BuiltinResult::Failure)
240                    }
241                }
242                ("number", 1) => {
243                    let walked = subst.walk(&args[0]);
244                    if matches!(walked, Term::Integer(_) | Term::Float(_)) {
245                        Ok(BuiltinResult::Success)
246                    } else {
247                        Ok(BuiltinResult::Failure)
248                    }
249                }
250                ("integer", 1) => {
251                    let walked = subst.walk(&args[0]);
252                    if matches!(walked, Term::Integer(_)) {
253                        Ok(BuiltinResult::Success)
254                    } else {
255                        Ok(BuiltinResult::Failure)
256                    }
257                }
258                ("float", 1) => {
259                    let walked = subst.walk(&args[0]);
260                    if matches!(walked, Term::Float(_)) {
261                        Ok(BuiltinResult::Success)
262                    } else {
263                        Ok(BuiltinResult::Failure)
264                    }
265                }
266                ("compound", 1) => {
267                    let walked = subst.walk(&args[0]);
268                    if matches!(walked, Term::Compound { .. } | Term::List { .. }) {
269                        Ok(BuiltinResult::Success)
270                    } else {
271                        Ok(BuiltinResult::Failure)
272                    }
273                }
274                ("is_list", 1) => {
275                    let walked = subst.apply(&args[0]);
276                    if is_proper_list(&walked, interner) {
277                        Ok(BuiltinResult::Success)
278                    } else {
279                        Ok(BuiltinResult::Failure)
280                    }
281                }
282                // Control flow
283                (";", 2) => {
284                    // Check if left arg is ->(Cond, Then) => if-then-else
285                    let left = subst.walk(&args[0]);
286                    if let Term::Compound {
287                        functor,
288                        args: inner_args,
289                    } = &left
290                    {
291                        if interner.resolve(*functor) == "->" && inner_args.len() == 2 {
292                            return Ok(BuiltinResult::IfThenElse(
293                                inner_args[0].clone(),
294                                inner_args[1].clone(),
295                                args[1].clone(),
296                            ));
297                        }
298                    }
299                    // Plain disjunction
300                    Ok(BuiltinResult::Disjunction(args[0].clone(), args[1].clone()))
301                }
302                ("->", 2) => Ok(BuiltinResult::IfThen(args[0].clone(), args[1].clone())),
303                (",", 2) => Ok(BuiltinResult::Conjunction(args[0].clone(), args[1].clone())),
304                ("findall", 3) => Ok(BuiltinResult::FindAll(
305                    args[0].clone(),
306                    args[1].clone(),
307                    args[2].clone(),
308                )),
309                ("once", 1) => Ok(BuiltinResult::Once(args[0].clone())),
310                ("call", 1) => Ok(BuiltinResult::Call(args[0].clone())),
311                // Atom/string predicates
312                ("atom_length", 2) => {
313                    Ok(BuiltinResult::AtomLength(args[0].clone(), args[1].clone()))
314                }
315                ("atom_concat", 3) => Ok(BuiltinResult::AtomConcat(
316                    args[0].clone(),
317                    args[1].clone(),
318                    args[2].clone(),
319                )),
320                ("atom_chars", 2) => Ok(BuiltinResult::AtomChars(args[0].clone(), args[1].clone())),
321                // I/O
322                ("write", 1) => Ok(BuiltinResult::Write(args[0].clone())),
323                ("writeln", 1) => Ok(BuiltinResult::Writeln(args[0].clone())),
324                // Term ordering
325                ("compare", 3) => Ok(BuiltinResult::Compare(
326                    args[0].clone(),
327                    args[1].clone(),
328                    args[2].clone(),
329                )),
330                ("@<", 2) => {
331                    let cmp =
332                        term_compare(&subst.apply(&args[0]), &subst.apply(&args[1]), interner);
333                    if cmp == std::cmp::Ordering::Less {
334                        Ok(BuiltinResult::Success)
335                    } else {
336                        Ok(BuiltinResult::Failure)
337                    }
338                }
339                ("@>", 2) => {
340                    let cmp =
341                        term_compare(&subst.apply(&args[0]), &subst.apply(&args[1]), interner);
342                    if cmp == std::cmp::Ordering::Greater {
343                        Ok(BuiltinResult::Success)
344                    } else {
345                        Ok(BuiltinResult::Failure)
346                    }
347                }
348                ("@=<", 2) => {
349                    let cmp =
350                        term_compare(&subst.apply(&args[0]), &subst.apply(&args[1]), interner);
351                    if cmp != std::cmp::Ordering::Greater {
352                        Ok(BuiltinResult::Success)
353                    } else {
354                        Ok(BuiltinResult::Failure)
355                    }
356                }
357                ("@>=", 2) => {
358                    let cmp =
359                        term_compare(&subst.apply(&args[0]), &subst.apply(&args[1]), interner);
360                    if cmp != std::cmp::Ordering::Less {
361                        Ok(BuiltinResult::Success)
362                    } else {
363                        Ok(BuiltinResult::Failure)
364                    }
365                }
366                // Term introspection
367                ("functor", 3) => Ok(BuiltinResult::Functor(
368                    args[0].clone(),
369                    args[1].clone(),
370                    args[2].clone(),
371                )),
372                ("arg", 3) => Ok(BuiltinResult::Arg(
373                    args[0].clone(),
374                    args[1].clone(),
375                    args[2].clone(),
376                )),
377                ("=..", 2) => Ok(BuiltinResult::Univ(args[0].clone(), args[1].clone())),
378                // Integer enumeration
379                ("between", 3) => Ok(BuiltinResult::Between(
380                    args[0].clone(),
381                    args[1].clone(),
382                    args[2].clone(),
383                )),
384                // Term copying
385                ("copy_term", 2) => Ok(BuiltinResult::CopyTerm(args[0].clone(), args[1].clone())),
386                // Peano arithmetic
387                ("succ", 2) => Ok(BuiltinResult::Succ(args[0].clone(), args[1].clone())),
388                ("plus", 3) => Ok(BuiltinResult::Plus(
389                    args[0].clone(),
390                    args[1].clone(),
391                    args[2].clone(),
392                )),
393                // List sorting
394                ("msort", 2) => Ok(BuiltinResult::MSort(args[0].clone(), args[1].clone())),
395                ("sort", 2) => Ok(BuiltinResult::Sort(args[0].clone(), args[1].clone())),
396                // Number/string conversion
397                ("number_chars", 2) => {
398                    Ok(BuiltinResult::NumberChars(args[0].clone(), args[1].clone()))
399                }
400                ("number_codes", 2) => {
401                    Ok(BuiltinResult::NumberCodes(args[0].clone(), args[1].clone()))
402                }
403                _ => Err(format!("Unknown builtin: {}/{}", name, args.len())),
404            }
405        }
406        _ => Err(format!("Cannot execute as builtin: {:?}", goal)),
407    }
408}
409
410/// Arithmetic value: either integer or float.
411#[derive(Debug, Clone)]
412enum ArithVal {
413    Int(i64),
414    Float(f64),
415}
416
417fn arith_to_term(val: ArithVal) -> Term {
418    match val {
419        ArithVal::Int(n) => Term::Integer(n),
420        ArithVal::Float(f) => Term::Float(f),
421    }
422}
423
424fn arith_lt(a: &ArithVal, b: &ArithVal) -> bool {
425    match (a, b) {
426        (ArithVal::Int(a), ArithVal::Int(b)) => a < b,
427        (ArithVal::Float(a), ArithVal::Float(b)) => a < b,
428        (ArithVal::Int(a), ArithVal::Float(b)) => (*a as f64) < *b,
429        (ArithVal::Float(a), ArithVal::Int(b)) => *a < (*b as f64),
430    }
431}
432
433fn arith_gt(a: &ArithVal, b: &ArithVal) -> bool {
434    arith_lt(b, a)
435}
436
437fn arith_eq(a: &ArithVal, b: &ArithVal) -> bool {
438    match (a, b) {
439        (ArithVal::Int(a), ArithVal::Int(b)) => a == b,
440        (ArithVal::Float(a), ArithVal::Float(b)) => a == b,
441        (ArithVal::Int(a), ArithVal::Float(b)) => (*a as f64) == *b,
442        (ArithVal::Float(a), ArithVal::Int(b)) => *a == (*b as f64),
443    }
444}
445
446/// Evaluate an arithmetic expression.
447fn eval_arith(
448    term: &Term,
449    subst: &Substitution,
450    interner: &StringInterner,
451) -> Result<ArithVal, String> {
452    let term = subst.walk(term);
453    match &term {
454        Term::Integer(n) => Ok(ArithVal::Int(*n)),
455        Term::Float(f) => Ok(ArithVal::Float(*f)),
456        Term::Var(id) => Err(format!("Arithmetic error: unbound variable _{}", id)),
457        Term::Compound { functor, args } => {
458            let name = interner.resolve(*functor);
459            match (name, args.len()) {
460                ("+", 2) => {
461                    let l = eval_arith(&args[0], subst, interner)?;
462                    let r = eval_arith(&args[1], subst, interner)?;
463                    arith_add(&l, &r)
464                }
465                ("-", 2) => {
466                    let l = eval_arith(&args[0], subst, interner)?;
467                    let r = eval_arith(&args[1], subst, interner)?;
468                    arith_sub(&l, &r)
469                }
470                ("*", 2) => {
471                    let l = eval_arith(&args[0], subst, interner)?;
472                    let r = eval_arith(&args[1], subst, interner)?;
473                    arith_mul(&l, &r)
474                }
475                ("/", 2) => {
476                    let l = eval_arith(&args[0], subst, interner)?;
477                    let r = eval_arith(&args[1], subst, interner)?;
478                    arith_div(&l, &r)
479                }
480                ("mod", 2) => {
481                    let l = eval_arith(&args[0], subst, interner)?;
482                    let r = eval_arith(&args[1], subst, interner)?;
483                    arith_mod(&l, &r)
484                }
485                ("-", 1) => {
486                    let v = eval_arith(&args[0], subst, interner)?;
487                    arith_neg(&v)
488                }
489                ("abs", 1) => {
490                    let v = eval_arith(&args[0], subst, interner)?;
491                    arith_abs(&v)
492                }
493                ("sign", 1) => {
494                    let v = eval_arith(&args[0], subst, interner)?;
495                    Ok(arith_sign(&v))
496                }
497                ("max", 2) => {
498                    let l = eval_arith(&args[0], subst, interner)?;
499                    let r = eval_arith(&args[1], subst, interner)?;
500                    Ok(arith_max(&l, &r))
501                }
502                ("min", 2) => {
503                    let l = eval_arith(&args[0], subst, interner)?;
504                    let r = eval_arith(&args[1], subst, interner)?;
505                    Ok(arith_min(&l, &r))
506                }
507                _ => Err(format!(
508                    "Unknown arithmetic operator: {}/{}",
509                    name,
510                    args.len()
511                )),
512            }
513        }
514        _ => Err(format!("Cannot evaluate as arithmetic: {:?}", term)),
515    }
516}
517
518/// Check a float result for NaN or Infinity, returning an error if detected.
519fn check_float(f: f64) -> Result<ArithVal, String> {
520    if f.is_nan() {
521        Err("Arithmetic error: NaN result".to_string())
522    } else if f.is_infinite() {
523        Err("Arithmetic error: Infinity result".to_string())
524    } else {
525        Ok(ArithVal::Float(f))
526    }
527}
528
529fn arith_add(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
530    match (a, b) {
531        (ArithVal::Int(a), ArithVal::Int(b)) => a
532            .checked_add(*b)
533            .map(ArithVal::Int)
534            .ok_or_else(|| "Arithmetic error: integer overflow in addition".to_string()),
535        (ArithVal::Float(a), ArithVal::Float(b)) => check_float(a + b),
536        (ArithVal::Int(a), ArithVal::Float(b)) => check_float(*a as f64 + b),
537        (ArithVal::Float(a), ArithVal::Int(b)) => check_float(a + *b as f64),
538    }
539}
540
541fn arith_sub(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
542    match (a, b) {
543        (ArithVal::Int(a), ArithVal::Int(b)) => a
544            .checked_sub(*b)
545            .map(ArithVal::Int)
546            .ok_or_else(|| "Arithmetic error: integer overflow in subtraction".to_string()),
547        (ArithVal::Float(a), ArithVal::Float(b)) => check_float(a - b),
548        (ArithVal::Int(a), ArithVal::Float(b)) => check_float(*a as f64 - b),
549        (ArithVal::Float(a), ArithVal::Int(b)) => check_float(a - *b as f64),
550    }
551}
552
553fn arith_mul(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
554    match (a, b) {
555        (ArithVal::Int(a), ArithVal::Int(b)) => a
556            .checked_mul(*b)
557            .map(ArithVal::Int)
558            .ok_or_else(|| "Arithmetic error: integer overflow in multiplication".to_string()),
559        (ArithVal::Float(a), ArithVal::Float(b)) => check_float(a * b),
560        (ArithVal::Int(a), ArithVal::Float(b)) => check_float(*a as f64 * b),
561        (ArithVal::Float(a), ArithVal::Int(b)) => check_float(a * *b as f64),
562    }
563}
564
565fn arith_div(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
566    match (a, b) {
567        (ArithVal::Int(_), ArithVal::Int(0)) => Err("Division by zero".to_string()),
568        (ArithVal::Int(a), ArithVal::Int(b)) => a
569            .checked_div(*b)
570            .map(ArithVal::Int)
571            .ok_or_else(|| "Arithmetic error: integer overflow in division".to_string()),
572        (_, ArithVal::Float(b)) if *b == 0.0 => Err("Division by zero".to_string()),
573        (ArithVal::Float(_), ArithVal::Int(0)) => Err("Division by zero".to_string()),
574        (ArithVal::Float(a), ArithVal::Float(b)) => check_float(a / b),
575        (ArithVal::Int(a), ArithVal::Float(b)) => check_float(*a as f64 / b),
576        (ArithVal::Float(a), ArithVal::Int(b)) => check_float(a / *b as f64),
577    }
578}
579
580fn arith_mod(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
581    match (a, b) {
582        (ArithVal::Int(_), ArithVal::Int(0)) => Err("Modulo by zero".to_string()),
583        (ArithVal::Int(_), ArithVal::Int(i64::MIN)) => {
584            Err("Arithmetic error: integer overflow in mod".to_string())
585        }
586        (ArithVal::Int(a), ArithVal::Int(b)) => {
587            // ISO Prolog mod: result has the sign of the divisor
588            // X mod Y = X - floor(X/Y) * Y
589            // b.abs() is safe here because we excluded i64::MIN above
590            let r = a.rem_euclid(b.abs());
591            if *b < 0 && r != 0 {
592                Ok(ArithVal::Int(r - b.abs()))
593            } else {
594                Ok(ArithVal::Int(r))
595            }
596        }
597        _ => Err("mod requires integer arguments".to_string()),
598    }
599}
600
601fn arith_neg(a: &ArithVal) -> Result<ArithVal, String> {
602    match a {
603        ArithVal::Int(n) => n
604            .checked_neg()
605            .map(ArithVal::Int)
606            .ok_or_else(|| "Arithmetic error: integer overflow in negation".to_string()),
607        ArithVal::Float(f) => check_float(-f),
608    }
609}
610
611fn arith_abs(a: &ArithVal) -> Result<ArithVal, String> {
612    match a {
613        ArithVal::Int(n) => n
614            .checked_abs()
615            .map(ArithVal::Int)
616            .ok_or_else(|| "Arithmetic error: integer overflow in abs".to_string()),
617        ArithVal::Float(f) => check_float(f.abs()),
618    }
619}
620
621fn arith_sign(a: &ArithVal) -> ArithVal {
622    match a {
623        ArithVal::Int(n) => ArithVal::Int(n.signum()),
624        ArithVal::Float(f) => ArithVal::Float(f.signum()),
625    }
626}
627
628fn arith_max(a: &ArithVal, b: &ArithVal) -> ArithVal {
629    if arith_lt(a, b) {
630        b.clone()
631    } else {
632        a.clone()
633    }
634}
635
636fn arith_min(a: &ArithVal, b: &ArithVal) -> ArithVal {
637    if arith_lt(a, b) {
638        a.clone()
639    } else {
640        b.clone()
641    }
642}
643
644/// Standard order of terms (ISO Prolog):
645/// Variables < Numbers < Atoms < Compound terms
646/// Within numbers: by value. Within atoms: alphabetical.
647/// Within compounds: by arity, then functor name, then arguments left-to-right.
648pub fn term_compare(a: &Term, b: &Term, interner: &StringInterner) -> std::cmp::Ordering {
649    use std::cmp::Ordering;
650    fn type_rank(t: &Term) -> u8 {
651        match t {
652            Term::Var(_) => 0,
653            Term::Float(_) => 1,
654            Term::Integer(_) => 1,
655            Term::Atom(_) => 2,
656            Term::List { .. } => 3,
657            Term::Compound { .. } => 3,
658        }
659    }
660
661    let ra = type_rank(a);
662    let rb = type_rank(b);
663    if ra != rb {
664        return ra.cmp(&rb);
665    }
666
667    match (a, b) {
668        (Term::Var(a), Term::Var(b)) => a.cmp(b),
669        (Term::Integer(a), Term::Integer(b)) => a.cmp(b),
670        (Term::Float(a), Term::Float(b)) => {
671            // NaN sorts after all other floats (deterministic total order)
672            a.partial_cmp(b)
673                .unwrap_or_else(|| match (a.is_nan(), b.is_nan()) {
674                    (true, true) => Ordering::Equal,
675                    (true, false) => Ordering::Greater,
676                    (false, true) => Ordering::Less,
677                    (false, false) => unreachable!(),
678                })
679        }
680        (Term::Integer(a), Term::Float(b)) => {
681            // NaN sorts after everything; ISO: float < integer when same value
682            if b.is_nan() {
683                return Ordering::Less;
684            }
685            let cmp = (*a as f64).partial_cmp(b).unwrap_or(Ordering::Less);
686            if cmp == Ordering::Equal {
687                Ordering::Greater // integer > float for same value (ISO 8.4.2.1)
688            } else {
689                cmp
690            }
691        }
692        (Term::Float(a), Term::Integer(b)) => {
693            // NaN sorts after everything; ISO: float < integer when same value
694            if a.is_nan() {
695                return Ordering::Greater;
696            }
697            let cmp = a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Greater);
698            if cmp == Ordering::Equal {
699                Ordering::Less // float < integer for same value (ISO 8.4.2.1)
700            } else {
701                cmp
702            }
703        }
704        (Term::Atom(a), Term::Atom(b)) => interner.resolve(*a).cmp(interner.resolve(*b)),
705        (
706            Term::Compound {
707                functor: f1,
708                args: a1,
709            },
710            Term::Compound {
711                functor: f2,
712                args: a2,
713            },
714        ) => {
715            // Compare by arity first, then functor name, then args
716            a1.len()
717                .cmp(&a2.len())
718                .then_with(|| interner.resolve(*f1).cmp(interner.resolve(*f2)))
719                .then_with(|| {
720                    for (x, y) in a1.iter().zip(a2.iter()) {
721                        let c = term_compare(x, y, interner);
722                        if c != Ordering::Equal {
723                            return c;
724                        }
725                    }
726                    Ordering::Equal
727                })
728        }
729        (Term::List { .. }, Term::List { .. }) => {
730            // Iterative list comparison to avoid stack overflow on long lists
731            let mut cur_a = a;
732            let mut cur_b = b;
733            loop {
734                match (cur_a, cur_b) {
735                    (Term::List { head: h1, tail: t1 }, Term::List { head: h2, tail: t2 }) => {
736                        let c = term_compare(h1, h2, interner);
737                        if c != Ordering::Equal {
738                            return c;
739                        }
740                        cur_a = t1;
741                        cur_b = t2;
742                    }
743                    _ => return term_compare(cur_a, cur_b, interner),
744                }
745            }
746        }
747        // List vs Compound: lists are .(H,T) which is arity 2
748        (
749            Term::List { head: h, tail: t },
750            Term::Compound {
751                functor: f2,
752                args: a2,
753            },
754        ) => {
755            // List is ./2; compare arity, then functor ".", then args
756            2usize
757                .cmp(&a2.len())
758                .then_with(|| ".".cmp(interner.resolve(*f2)))
759                .then_with(|| {
760                    if a2.len() >= 1 {
761                        let c = term_compare(h, &a2[0], interner);
762                        if c != Ordering::Equal {
763                            return c;
764                        }
765                    }
766                    if a2.len() >= 2 {
767                        return term_compare(t, &a2[1], interner);
768                    }
769                    Ordering::Equal
770                })
771        }
772        (
773            Term::Compound {
774                functor: f1,
775                args: a1,
776            },
777            Term::List { head: h, tail: t },
778        ) => a1
779            .len()
780            .cmp(&2usize)
781            .then_with(|| interner.resolve(*f1).cmp("."))
782            .then_with(|| {
783                if a1.len() >= 1 {
784                    let c = term_compare(&a1[0], h, interner);
785                    if c != Ordering::Equal {
786                        return c;
787                    }
788                }
789                if a1.len() >= 2 {
790                    return term_compare(&a1[1], t, interner);
791                }
792                Ordering::Equal
793            }),
794        _ => unreachable!("term_compare: unhandled Term variant"),
795    }
796}
797
798/// Collect list elements from a term. Returns None if not a proper list.
799pub fn collect_list(term: &Term, interner: &StringInterner) -> Option<Vec<Term>> {
800    let mut elements = Vec::new();
801    let mut current = term;
802    loop {
803        match current {
804            Term::Atom(id) if interner.resolve(*id) == "[]" => return Some(elements),
805            Term::List { head, tail } => {
806                elements.push(head.as_ref().clone());
807                current = tail;
808            }
809            _ => return None,
810        }
811    }
812}
813
814/// Build a list term from elements.
815pub fn build_list(elements: Vec<Term>, interner: &StringInterner) -> Term {
816    let nil_id = interner.lookup("[]").expect("[] must be interned");
817    let mut list = Term::Atom(nil_id);
818    for elem in elements.into_iter().rev() {
819        list = Term::List {
820            head: Box::new(elem),
821            tail: Box::new(list),
822        };
823    }
824    list
825}
826
827/// Check if a term is a proper list (ends with []).
828fn is_proper_list(term: &Term, interner: &StringInterner) -> bool {
829    let mut current = term;
830    loop {
831        match current {
832            Term::Atom(id) => return interner.resolve(*id) == "[]",
833            Term::List { tail, .. } => current = tail,
834            _ => return false,
835        }
836    }
837}
838
839/// Helper: check if a goal atom name matches a known builtin name.
840pub fn builtin_atom_names() -> &'static [&'static str] {
841    &["true", "fail", "false", "!", "nl"]
842}
843
844pub fn builtin_functor_names() -> &'static [(&'static str, usize)] {
845    &[
846        ("=", 2),
847        ("\\=", 2),
848        ("is", 2),
849        ("<", 2),
850        (">", 2),
851        ("=<", 2),
852        (">=", 2),
853        ("=:=", 2),
854        ("=\\=", 2),
855        ("\\+", 1),
856        ("var", 1),
857        ("nonvar", 1),
858        ("atom", 1),
859        ("number", 1),
860        ("integer", 1),
861        ("float", 1),
862        ("compound", 1),
863        ("is_list", 1),
864        (";", 2),
865        ("->", 2),
866        (",", 2),
867        ("findall", 3),
868        ("once", 1),
869        ("call", 1),
870        ("atom_length", 2),
871        ("atom_concat", 3),
872        ("atom_chars", 2),
873        ("write", 1),
874        ("writeln", 1),
875        ("compare", 3),
876        ("@<", 2),
877        ("@>", 2),
878        ("@=<", 2),
879        ("@>=", 2),
880        ("functor", 3),
881        ("arg", 3),
882        ("=..", 2),
883        ("between", 3),
884        ("copy_term", 2),
885        ("succ", 2),
886        ("plus", 3),
887        ("msort", 2),
888        ("sort", 2),
889        ("number_chars", 2),
890        ("number_codes", 2),
891    ]
892}
893
894#[cfg(test)]
895mod tests {
896    use super::*;
897    use crate::parser::Parser;
898
899    fn setup() -> StringInterner {
900        let mut i = StringInterner::new();
901        // Pre-intern common atoms
902        i.intern("true");
903        i.intern("fail");
904        i.intern("!");
905        i.intern("=");
906        i.intern("\\=");
907        i.intern("is");
908        i.intern("<");
909        i.intern(">");
910        i.intern("=<");
911        i.intern(">=");
912        i.intern("=:=");
913        i.intern("=\\=");
914        i.intern("\\+");
915        i.intern("+");
916        i.intern("-");
917        i.intern("*");
918        i.intern("/");
919        i.intern("mod");
920        i
921    }
922
923    #[test]
924    fn test_is_builtin() {
925        let interner = setup();
926        let true_id = interner.lookup("true").unwrap();
927        assert!(is_builtin(&Term::Atom(true_id), &interner));
928
929        let eq_id = interner.lookup("=").unwrap();
930        let goal = Term::Compound {
931            functor: eq_id,
932            args: vec![Term::Var(0), Term::Atom(0)],
933        };
934        assert!(is_builtin(&goal, &interner));
935    }
936
937    #[test]
938    fn test_exec_true() {
939        let interner = setup();
940        let true_id = interner.lookup("true").unwrap();
941        let mut subst = Substitution::new();
942        let result = exec_builtin(&Term::Atom(true_id), &mut subst, &interner).unwrap();
943        assert!(matches!(result, BuiltinResult::Success));
944    }
945
946    #[test]
947    fn test_exec_fail() {
948        let interner = setup();
949        let fail_id = interner.lookup("fail").unwrap();
950        let mut subst = Substitution::new();
951        let result = exec_builtin(&Term::Atom(fail_id), &mut subst, &interner).unwrap();
952        assert!(matches!(result, BuiltinResult::Failure));
953    }
954
955    #[test]
956    fn test_exec_unify() {
957        let interner = setup();
958        let eq_id = interner.lookup("=").unwrap();
959        let mut subst = Substitution::new();
960        let goal = Term::Compound {
961            functor: eq_id,
962            args: vec![Term::Var(0), Term::Integer(42)],
963        };
964        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
965        assert!(matches!(result, BuiltinResult::Success));
966        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(42));
967    }
968
969    #[test]
970    fn test_exec_not_unify() {
971        let interner = setup();
972        let neq_id = interner.lookup("\\=").unwrap();
973        let mut subst = Substitution::new();
974        // 1 \= 2 should succeed
975        let goal = Term::Compound {
976            functor: neq_id,
977            args: vec![Term::Integer(1), Term::Integer(2)],
978        };
979        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
980        assert!(matches!(result, BuiltinResult::Success));
981
982        // 1 \= 1 should fail
983        let goal = Term::Compound {
984            functor: neq_id,
985            args: vec![Term::Integer(1), Term::Integer(1)],
986        };
987        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
988        assert!(matches!(result, BuiltinResult::Failure));
989    }
990
991    #[test]
992    fn test_exec_is_arithmetic() {
993        let mut interner = setup();
994        let goals = Parser::parse_query("X is 2 + 3 * 4", &mut interner).unwrap();
995        let mut subst = Substitution::new();
996        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
997        assert!(matches!(result, BuiltinResult::Success));
998        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(14));
999    }
1000
1001    #[test]
1002    fn test_exec_comparison() {
1003        let interner = setup();
1004        let lt_id = interner.lookup("<").unwrap();
1005        let mut subst = Substitution::new();
1006
1007        // 1 < 2 should succeed
1008        let goal = Term::Compound {
1009            functor: lt_id,
1010            args: vec![Term::Integer(1), Term::Integer(2)],
1011        };
1012        assert!(matches!(
1013            exec_builtin(&goal, &mut subst, &interner).unwrap(),
1014            BuiltinResult::Success
1015        ));
1016
1017        // 2 < 1 should fail
1018        let goal = Term::Compound {
1019            functor: lt_id,
1020            args: vec![Term::Integer(2), Term::Integer(1)],
1021        };
1022        assert!(matches!(
1023            exec_builtin(&goal, &mut subst, &interner).unwrap(),
1024            BuiltinResult::Failure
1025        ));
1026    }
1027
1028    #[test]
1029    fn test_exec_cut() {
1030        let interner = setup();
1031        let cut_id = interner.lookup("!").unwrap();
1032        let mut subst = Substitution::new();
1033        let result = exec_builtin(&Term::Atom(cut_id), &mut subst, &interner).unwrap();
1034        assert!(matches!(result, BuiltinResult::Cut));
1035    }
1036
1037    #[test]
1038    fn test_type_checking_var() {
1039        let mut interner = setup();
1040        interner.intern("var");
1041        let var_id = interner.lookup("var").unwrap();
1042        let mut subst = Substitution::new();
1043        // var(X) where X is unbound should succeed
1044        let goal = Term::Compound {
1045            functor: var_id,
1046            args: vec![Term::Var(0)],
1047        };
1048        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1049        assert!(matches!(result, BuiltinResult::Success));
1050
1051        // var(42) should fail
1052        let goal = Term::Compound {
1053            functor: var_id,
1054            args: vec![Term::Integer(42)],
1055        };
1056        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1057        assert!(matches!(result, BuiltinResult::Failure));
1058    }
1059
1060    #[test]
1061    fn test_type_checking_atom() {
1062        let mut interner = setup();
1063        interner.intern("atom");
1064        let atom_id = interner.lookup("atom").unwrap();
1065        let mut subst = Substitution::new();
1066        let hello = interner.intern("hello");
1067        // atom(hello) should succeed
1068        let goal = Term::Compound {
1069            functor: atom_id,
1070            args: vec![Term::Atom(hello)],
1071        };
1072        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1073        assert!(matches!(result, BuiltinResult::Success));
1074
1075        // atom(42) should fail
1076        let goal = Term::Compound {
1077            functor: atom_id,
1078            args: vec![Term::Integer(42)],
1079        };
1080        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1081        assert!(matches!(result, BuiltinResult::Failure));
1082    }
1083
1084    #[test]
1085    fn test_type_checking_integer() {
1086        let mut interner = setup();
1087        interner.intern("integer");
1088        let int_id = interner.lookup("integer").unwrap();
1089        let mut subst = Substitution::new();
1090        let goal = Term::Compound {
1091            functor: int_id,
1092            args: vec![Term::Integer(42)],
1093        };
1094        assert!(matches!(
1095            exec_builtin(&goal, &mut subst, &interner).unwrap(),
1096            BuiltinResult::Success
1097        ));
1098
1099        let goal = Term::Compound {
1100            functor: int_id,
1101            args: vec![Term::Float(3.14)],
1102        };
1103        assert!(matches!(
1104            exec_builtin(&goal, &mut subst, &interner).unwrap(),
1105            BuiltinResult::Failure
1106        ));
1107    }
1108
1109    #[test]
1110    fn test_type_checking_number() {
1111        let mut interner = setup();
1112        interner.intern("number");
1113        let num_id = interner.lookup("number").unwrap();
1114        let mut subst = Substitution::new();
1115        // number(42) should succeed
1116        let goal = Term::Compound {
1117            functor: num_id,
1118            args: vec![Term::Integer(42)],
1119        };
1120        assert!(matches!(
1121            exec_builtin(&goal, &mut subst, &interner).unwrap(),
1122            BuiltinResult::Success
1123        ));
1124        // number(3.14) should succeed
1125        let goal = Term::Compound {
1126            functor: num_id,
1127            args: vec![Term::Float(3.14)],
1128        };
1129        assert!(matches!(
1130            exec_builtin(&goal, &mut subst, &interner).unwrap(),
1131            BuiltinResult::Success
1132        ));
1133    }
1134
1135    #[test]
1136    fn test_exec_mod() {
1137        let mut interner = setup();
1138        let goals = Parser::parse_query("X is 10 mod 3", &mut interner).unwrap();
1139        let mut subst = Substitution::new();
1140        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1141        assert!(matches!(result, BuiltinResult::Success));
1142        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(1));
1143    }
1144
1145    #[test]
1146    fn test_mod_i64_min_divisor() {
1147        // arith_mod with i64::MIN divisor should error, not panic from .abs()
1148        let result = arith_mod(&ArithVal::Int(5), &ArithVal::Int(i64::MIN));
1149        assert!(result.is_err());
1150        assert!(result.unwrap_err().contains("overflow"));
1151    }
1152
1153    #[test]
1154    fn test_mod_i64_min_dividend_neg1() {
1155        // i64::MIN mod -1 should be 0 (rem_euclid handles this correctly)
1156        let result = arith_mod(&ArithVal::Int(i64::MIN), &ArithVal::Int(-1));
1157        match result {
1158            Ok(ArithVal::Int(0)) => {}
1159            other => panic!("Expected Ok(Int(0)), got {:?}", other),
1160        }
1161    }
1162
1163    #[test]
1164    fn test_integer_overflow_add() {
1165        let mut interner = setup();
1166        let query_str = format!("X is {} + 1", i64::MAX);
1167        let goals = Parser::parse_query(&query_str, &mut interner).unwrap();
1168        let mut subst = Substitution::new();
1169        let result = exec_builtin(&goals[0], &mut subst, &interner);
1170        assert!(result.is_err());
1171        assert!(result.unwrap_err().contains("overflow"));
1172    }
1173
1174    #[test]
1175    fn test_integer_overflow_mul() {
1176        let mut interner = setup();
1177        let query_str = format!("X is {} * 2", i64::MAX);
1178        let goals = Parser::parse_query(&query_str, &mut interner).unwrap();
1179        let mut subst = Substitution::new();
1180        let result = exec_builtin(&goals[0], &mut subst, &interner);
1181        assert!(result.is_err());
1182        assert!(result.unwrap_err().contains("overflow"));
1183    }
1184
1185    #[test]
1186    fn test_arith_abs() {
1187        let mut interner = setup();
1188        let goals = Parser::parse_query("X is abs(-5)", &mut interner).unwrap();
1189        let mut subst = Substitution::new();
1190        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1191        assert!(matches!(result, BuiltinResult::Success));
1192        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(5));
1193    }
1194
1195    #[test]
1196    fn test_arith_abs_positive() {
1197        let mut interner = setup();
1198        let goals = Parser::parse_query("X is abs(3)", &mut interner).unwrap();
1199        let mut subst = Substitution::new();
1200        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1201        assert!(matches!(result, BuiltinResult::Success));
1202        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(3));
1203    }
1204
1205    #[test]
1206    fn test_arith_sign() {
1207        let mut interner = setup();
1208        let goals = Parser::parse_query("X is sign(-42)", &mut interner).unwrap();
1209        let mut subst = Substitution::new();
1210        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1211        assert!(matches!(result, BuiltinResult::Success));
1212        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(-1));
1213    }
1214
1215    #[test]
1216    fn test_arith_sign_zero() {
1217        let mut interner = setup();
1218        let goals = Parser::parse_query("X is sign(0)", &mut interner).unwrap();
1219        let mut subst = Substitution::new();
1220        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1221        assert!(matches!(result, BuiltinResult::Success));
1222        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(0));
1223    }
1224
1225    #[test]
1226    fn test_arith_max() {
1227        let mut interner = setup();
1228        let goals = Parser::parse_query("X is max(3, 7)", &mut interner).unwrap();
1229        let mut subst = Substitution::new();
1230        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1231        assert!(matches!(result, BuiltinResult::Success));
1232        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(7));
1233    }
1234
1235    #[test]
1236    fn test_arith_min() {
1237        let mut interner = setup();
1238        let goals = Parser::parse_query("X is min(3, 7)", &mut interner).unwrap();
1239        let mut subst = Substitution::new();
1240        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1241        assert!(matches!(result, BuiltinResult::Success));
1242        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(3));
1243    }
1244}