Skip to main content

patch_prolog_core/
builtins.rs

1use crate::term::{StringInterner, Term};
2use crate::unify::Substitution;
3
4/// Check if a goal is a built-in predicate.
5pub fn is_builtin(goal: &Term, interner: &StringInterner) -> bool {
6    match goal {
7        Term::Atom(id) => {
8            let name = interner.resolve(*id);
9            matches!(name, "true" | "fail" | "false" | "!" | "nl")
10        }
11        Term::Compound { functor, args } => {
12            let name = interner.resolve(*functor);
13            match (name, args.len()) {
14                ("=", 2) | ("\\=", 2) | ("unify_with_occurs_check", 2) | ("is", 2) => true,
15                ("<", 2) | (">", 2) | ("=<", 2) | (">=", 2) => true,
16                ("=:=", 2) | ("=\\=", 2) => true,
17                ("\\+", 1) => true,
18                // Type-checking predicates
19                ("var", 1) | ("nonvar", 1) | ("atom", 1) | ("number", 1) => true,
20                ("integer", 1) | ("float", 1) | ("compound", 1) | ("is_list", 1) => true,
21                // Control flow
22                (";", 2) | ("->", 2) | (",", 2) => true,
23                // Solution collection
24                ("findall", 3) => true,
25                // Meta-call
26                ("once", 1) | ("call", 1) => true,
27                // Atom/string predicates
28                ("atom_length", 2) | ("atom_concat", 3) | ("atom_chars", 2) => true,
29                // I/O predicates
30                ("write", 1) | ("writeln", 1) => true,
31                // Term ordering
32                ("compare", 3) => true,
33                ("@<", 2) | ("@>", 2) | ("@=<", 2) | ("@>=", 2) => true,
34                // Term introspection
35                ("functor", 3) | ("arg", 3) | ("=..", 2) => true,
36                // Integer enumeration
37                ("between", 3) => true,
38                // Term copying
39                ("copy_term", 2) => true,
40                // Peano arithmetic
41                ("succ", 2) | ("plus", 3) => true,
42                // List sorting
43                ("msort", 2) | ("sort", 2) => true,
44                // Number/string conversion
45                ("number_chars", 2) | ("number_codes", 2) => true,
46                _ => false,
47            }
48        }
49        _ => false,
50    }
51}
52
53/// Result of executing a builtin.
54#[derive(Debug)]
55pub enum BuiltinResult {
56    /// The builtin succeeded (substitution may have been modified).
57    Success,
58    /// The builtin failed.
59    Failure,
60    /// Cut: succeed and signal cut to the solver.
61    Cut,
62    /// Negation-as-failure: the solver needs to try the inner goal.
63    NegationAsFailure(Term),
64    /// Disjunction: try left, then right on backtracking.
65    Disjunction(Term, Term),
66    /// If-then-else: ;(->(Cond, Then), Else)
67    IfThenElse(Term, Term, Term),
68    /// If-then (no else): ->(Cond, Then)
69    IfThen(Term, Term),
70    /// Conjunction: ','(A, B) — flatten into goal list.
71    Conjunction(Term, Term),
72    /// findall/3: Template, Goal, Result list.
73    FindAll(Term, Term, Term),
74    /// once/1: solve goal, take first solution only.
75    Once(Term),
76    /// call/1: execute a term as a goal.
77    Call(Term),
78    /// atom_length/2: atom, length
79    AtomLength(Term, Term),
80    /// atom_concat/3: atom1, atom2, result
81    AtomConcat(Term, Term, Term),
82    /// atom_chars/2: atom, char list
83    AtomChars(Term, Term),
84    /// write/1: write term to stdout (no newline).
85    Write(Term),
86    /// writeln/1: write term to stdout with newline.
87    Writeln(Term),
88    /// nl/0: write newline to stdout.
89    Nl,
90    /// compare/3: Order, Term1, Term2 — standard term ordering.
91    Compare(Term, Term, Term),
92    /// functor/3: Term, Name, Arity.
93    Functor(Term, Term, Term),
94    /// arg/3: N, Term, Arg.
95    Arg(Term, Term, Term),
96    /// =../2: Term, List (univ).
97    Univ(Term, Term),
98    /// between/3: Low, High, X — integer enumeration.
99    Between(Term, Term, Term),
100    /// copy_term/2: Original, Copy.
101    CopyTerm(Term, Term),
102    /// succ/2: X, S — successor relation.
103    Succ(Term, Term),
104    /// plus/3: X, Y, Z — addition relation.
105    Plus(Term, Term, Term),
106    /// msort/2: List, Sorted.
107    MSort(Term, Term),
108    /// sort/2: List, Sorted.
109    Sort(Term, Term),
110    /// number_chars/2: Number, Chars.
111    NumberChars(Term, Term),
112    /// number_codes/2: Number, Codes.
113    NumberCodes(Term, Term),
114}
115
116/// Execute a built-in predicate.
117pub fn exec_builtin(
118    goal: &Term,
119    subst: &mut Substitution,
120    interner: &StringInterner,
121) -> Result<BuiltinResult, String> {
122    match goal {
123        Term::Atom(id) => {
124            let name = interner.resolve(*id);
125            match name {
126                "true" => Ok(BuiltinResult::Success),
127                "fail" | "false" => Ok(BuiltinResult::Failure),
128                "!" => Ok(BuiltinResult::Cut),
129                "nl" => Ok(BuiltinResult::Nl),
130                _ => Err(format!("Unknown builtin atom: {}", name)),
131            }
132        }
133        Term::Compound { functor, args } => {
134            let name = interner.resolve(*functor);
135            match (name, args.len()) {
136                ("=", 2) => {
137                    if subst.unify(&args[0], &args[1]) {
138                        Ok(BuiltinResult::Success)
139                    } else {
140                        Ok(BuiltinResult::Failure)
141                    }
142                }
143                ("unify_with_occurs_check", 2) => {
144                    if subst.unify_with_occurs_check(&args[0], &args[1]) {
145                        Ok(BuiltinResult::Success)
146                    } else {
147                        Ok(BuiltinResult::Failure)
148                    }
149                }
150                ("\\=", 2) => {
151                    let mark = subst.trail_mark();
152                    if subst.unify(&args[0], &args[1]) {
153                        subst.undo_to(mark);
154                        Ok(BuiltinResult::Failure)
155                    } else {
156                        subst.undo_to(mark);
157                        Ok(BuiltinResult::Success)
158                    }
159                }
160                ("is", 2) => {
161                    let result = eval_arith(&args[1], subst, interner)?;
162                    let result_term = arith_to_term(result);
163                    if subst.unify(&args[0], &result_term) {
164                        Ok(BuiltinResult::Success)
165                    } else {
166                        Ok(BuiltinResult::Failure)
167                    }
168                }
169                ("<", 2) => {
170                    let l = eval_arith(&args[0], subst, interner)?;
171                    let r = eval_arith(&args[1], subst, interner)?;
172                    if arith_lt(&l, &r) {
173                        Ok(BuiltinResult::Success)
174                    } else {
175                        Ok(BuiltinResult::Failure)
176                    }
177                }
178                (">", 2) => {
179                    let l = eval_arith(&args[0], subst, interner)?;
180                    let r = eval_arith(&args[1], subst, interner)?;
181                    if arith_gt(&l, &r) {
182                        Ok(BuiltinResult::Success)
183                    } else {
184                        Ok(BuiltinResult::Failure)
185                    }
186                }
187                ("=<", 2) => {
188                    let l = eval_arith(&args[0], subst, interner)?;
189                    let r = eval_arith(&args[1], subst, interner)?;
190                    if !arith_gt(&l, &r) {
191                        Ok(BuiltinResult::Success)
192                    } else {
193                        Ok(BuiltinResult::Failure)
194                    }
195                }
196                (">=", 2) => {
197                    let l = eval_arith(&args[0], subst, interner)?;
198                    let r = eval_arith(&args[1], subst, interner)?;
199                    if !arith_lt(&l, &r) {
200                        Ok(BuiltinResult::Success)
201                    } else {
202                        Ok(BuiltinResult::Failure)
203                    }
204                }
205                ("=:=", 2) => {
206                    let l = eval_arith(&args[0], subst, interner)?;
207                    let r = eval_arith(&args[1], subst, interner)?;
208                    if arith_eq(&l, &r) {
209                        Ok(BuiltinResult::Success)
210                    } else {
211                        Ok(BuiltinResult::Failure)
212                    }
213                }
214                ("=\\=", 2) => {
215                    let l = eval_arith(&args[0], subst, interner)?;
216                    let r = eval_arith(&args[1], subst, interner)?;
217                    if !arith_eq(&l, &r) {
218                        Ok(BuiltinResult::Success)
219                    } else {
220                        Ok(BuiltinResult::Failure)
221                    }
222                }
223                ("\\+", 1) => Ok(BuiltinResult::NegationAsFailure(args[0].clone())),
224                // Type-checking predicates
225                ("var", 1) => {
226                    let walked = subst.walk(&args[0]);
227                    if matches!(walked, Term::Var(_)) {
228                        Ok(BuiltinResult::Success)
229                    } else {
230                        Ok(BuiltinResult::Failure)
231                    }
232                }
233                ("nonvar", 1) => {
234                    let walked = subst.walk(&args[0]);
235                    if matches!(walked, Term::Var(_)) {
236                        Ok(BuiltinResult::Failure)
237                    } else {
238                        Ok(BuiltinResult::Success)
239                    }
240                }
241                ("atom", 1) => {
242                    let walked = subst.walk(&args[0]);
243                    if matches!(walked, Term::Atom(_)) {
244                        Ok(BuiltinResult::Success)
245                    } else {
246                        Ok(BuiltinResult::Failure)
247                    }
248                }
249                ("number", 1) => {
250                    let walked = subst.walk(&args[0]);
251                    if matches!(walked, Term::Integer(_) | Term::Float(_)) {
252                        Ok(BuiltinResult::Success)
253                    } else {
254                        Ok(BuiltinResult::Failure)
255                    }
256                }
257                ("integer", 1) => {
258                    let walked = subst.walk(&args[0]);
259                    if matches!(walked, Term::Integer(_)) {
260                        Ok(BuiltinResult::Success)
261                    } else {
262                        Ok(BuiltinResult::Failure)
263                    }
264                }
265                ("float", 1) => {
266                    let walked = subst.walk(&args[0]);
267                    if matches!(walked, Term::Float(_)) {
268                        Ok(BuiltinResult::Success)
269                    } else {
270                        Ok(BuiltinResult::Failure)
271                    }
272                }
273                ("compound", 1) => {
274                    let walked = subst.walk(&args[0]);
275                    if matches!(walked, Term::Compound { .. } | Term::List { .. }) {
276                        Ok(BuiltinResult::Success)
277                    } else {
278                        Ok(BuiltinResult::Failure)
279                    }
280                }
281                ("is_list", 1) => {
282                    let walked = subst.apply(&args[0]);
283                    if is_proper_list(&walked, interner) {
284                        Ok(BuiltinResult::Success)
285                    } else {
286                        Ok(BuiltinResult::Failure)
287                    }
288                }
289                // Control flow
290                (";", 2) => {
291                    // Check if left arg is ->(Cond, Then) => if-then-else
292                    let left = subst.walk(&args[0]);
293                    if let Term::Compound {
294                        functor,
295                        args: inner_args,
296                    } = &left
297                    {
298                        if interner.resolve(*functor) == "->" && inner_args.len() == 2 {
299                            return Ok(BuiltinResult::IfThenElse(
300                                inner_args[0].clone(),
301                                inner_args[1].clone(),
302                                args[1].clone(),
303                            ));
304                        }
305                    }
306                    // Plain disjunction
307                    Ok(BuiltinResult::Disjunction(args[0].clone(), args[1].clone()))
308                }
309                ("->", 2) => Ok(BuiltinResult::IfThen(args[0].clone(), args[1].clone())),
310                (",", 2) => Ok(BuiltinResult::Conjunction(args[0].clone(), args[1].clone())),
311                ("findall", 3) => Ok(BuiltinResult::FindAll(
312                    args[0].clone(),
313                    args[1].clone(),
314                    args[2].clone(),
315                )),
316                ("once", 1) => Ok(BuiltinResult::Once(args[0].clone())),
317                ("call", 1) => Ok(BuiltinResult::Call(args[0].clone())),
318                // Atom/string predicates
319                ("atom_length", 2) => {
320                    Ok(BuiltinResult::AtomLength(args[0].clone(), args[1].clone()))
321                }
322                ("atom_concat", 3) => Ok(BuiltinResult::AtomConcat(
323                    args[0].clone(),
324                    args[1].clone(),
325                    args[2].clone(),
326                )),
327                ("atom_chars", 2) => Ok(BuiltinResult::AtomChars(args[0].clone(), args[1].clone())),
328                // I/O
329                ("write", 1) => Ok(BuiltinResult::Write(args[0].clone())),
330                ("writeln", 1) => Ok(BuiltinResult::Writeln(args[0].clone())),
331                // Term ordering
332                ("compare", 3) => Ok(BuiltinResult::Compare(
333                    args[0].clone(),
334                    args[1].clone(),
335                    args[2].clone(),
336                )),
337                ("@<", 2) => {
338                    let cmp =
339                        term_compare(&subst.apply(&args[0]), &subst.apply(&args[1]), interner);
340                    if cmp == std::cmp::Ordering::Less {
341                        Ok(BuiltinResult::Success)
342                    } else {
343                        Ok(BuiltinResult::Failure)
344                    }
345                }
346                ("@>", 2) => {
347                    let cmp =
348                        term_compare(&subst.apply(&args[0]), &subst.apply(&args[1]), interner);
349                    if cmp == std::cmp::Ordering::Greater {
350                        Ok(BuiltinResult::Success)
351                    } else {
352                        Ok(BuiltinResult::Failure)
353                    }
354                }
355                ("@=<", 2) => {
356                    let cmp =
357                        term_compare(&subst.apply(&args[0]), &subst.apply(&args[1]), interner);
358                    if cmp != std::cmp::Ordering::Greater {
359                        Ok(BuiltinResult::Success)
360                    } else {
361                        Ok(BuiltinResult::Failure)
362                    }
363                }
364                ("@>=", 2) => {
365                    let cmp =
366                        term_compare(&subst.apply(&args[0]), &subst.apply(&args[1]), interner);
367                    if cmp != std::cmp::Ordering::Less {
368                        Ok(BuiltinResult::Success)
369                    } else {
370                        Ok(BuiltinResult::Failure)
371                    }
372                }
373                // Term introspection
374                ("functor", 3) => Ok(BuiltinResult::Functor(
375                    args[0].clone(),
376                    args[1].clone(),
377                    args[2].clone(),
378                )),
379                ("arg", 3) => Ok(BuiltinResult::Arg(
380                    args[0].clone(),
381                    args[1].clone(),
382                    args[2].clone(),
383                )),
384                ("=..", 2) => Ok(BuiltinResult::Univ(args[0].clone(), args[1].clone())),
385                // Integer enumeration
386                ("between", 3) => Ok(BuiltinResult::Between(
387                    args[0].clone(),
388                    args[1].clone(),
389                    args[2].clone(),
390                )),
391                // Term copying
392                ("copy_term", 2) => Ok(BuiltinResult::CopyTerm(args[0].clone(), args[1].clone())),
393                // Peano arithmetic
394                ("succ", 2) => Ok(BuiltinResult::Succ(args[0].clone(), args[1].clone())),
395                ("plus", 3) => Ok(BuiltinResult::Plus(
396                    args[0].clone(),
397                    args[1].clone(),
398                    args[2].clone(),
399                )),
400                // List sorting
401                ("msort", 2) => Ok(BuiltinResult::MSort(args[0].clone(), args[1].clone())),
402                ("sort", 2) => Ok(BuiltinResult::Sort(args[0].clone(), args[1].clone())),
403                // Number/string conversion
404                ("number_chars", 2) => {
405                    Ok(BuiltinResult::NumberChars(args[0].clone(), args[1].clone()))
406                }
407                ("number_codes", 2) => {
408                    Ok(BuiltinResult::NumberCodes(args[0].clone(), args[1].clone()))
409                }
410                _ => Err(format!("Unknown builtin: {}/{}", name, args.len())),
411            }
412        }
413        _ => Err(format!("Cannot execute as builtin: {:?}", goal)),
414    }
415}
416
417/// Arithmetic value: either integer or float.
418#[derive(Debug, Clone)]
419enum ArithVal {
420    Int(i64),
421    Float(f64),
422}
423
424fn arith_to_term(val: ArithVal) -> Term {
425    match val {
426        ArithVal::Int(n) => Term::Integer(n),
427        ArithVal::Float(f) => Term::Float(f),
428    }
429}
430
431fn arith_lt(a: &ArithVal, b: &ArithVal) -> bool {
432    match (a, b) {
433        (ArithVal::Int(a), ArithVal::Int(b)) => a < b,
434        (ArithVal::Float(a), ArithVal::Float(b)) => a < b,
435        (ArithVal::Int(a), ArithVal::Float(b)) => (*a as f64) < *b,
436        (ArithVal::Float(a), ArithVal::Int(b)) => *a < (*b as f64),
437    }
438}
439
440fn arith_gt(a: &ArithVal, b: &ArithVal) -> bool {
441    arith_lt(b, a)
442}
443
444fn arith_eq(a: &ArithVal, b: &ArithVal) -> bool {
445    match (a, b) {
446        (ArithVal::Int(a), ArithVal::Int(b)) => a == b,
447        (ArithVal::Float(a), ArithVal::Float(b)) => a == b,
448        (ArithVal::Int(a), ArithVal::Float(b)) => (*a as f64) == *b,
449        (ArithVal::Float(a), ArithVal::Int(b)) => *a == (*b as f64),
450    }
451}
452
453/// Evaluate an arithmetic expression.
454fn eval_arith(
455    term: &Term,
456    subst: &Substitution,
457    interner: &StringInterner,
458) -> Result<ArithVal, String> {
459    let term = subst.walk(term);
460    match &term {
461        Term::Integer(n) => Ok(ArithVal::Int(*n)),
462        Term::Float(f) => Ok(ArithVal::Float(*f)),
463        Term::Var(id) => Err(format!("Arithmetic error: unbound variable _{}", id)),
464        Term::Compound { functor, args } => {
465            let name = interner.resolve(*functor);
466            match (name, args.len()) {
467                ("+", 2) => {
468                    let l = eval_arith(&args[0], subst, interner)?;
469                    let r = eval_arith(&args[1], subst, interner)?;
470                    arith_add(&l, &r)
471                }
472                ("-", 2) => {
473                    let l = eval_arith(&args[0], subst, interner)?;
474                    let r = eval_arith(&args[1], subst, interner)?;
475                    arith_sub(&l, &r)
476                }
477                ("*", 2) => {
478                    let l = eval_arith(&args[0], subst, interner)?;
479                    let r = eval_arith(&args[1], subst, interner)?;
480                    arith_mul(&l, &r)
481                }
482                ("/", 2) => {
483                    let l = eval_arith(&args[0], subst, interner)?;
484                    let r = eval_arith(&args[1], subst, interner)?;
485                    arith_div(&l, &r)
486                }
487                ("//", 2) => {
488                    let l = eval_arith(&args[0], subst, interner)?;
489                    let r = eval_arith(&args[1], subst, interner)?;
490                    arith_int_div(&l, &r)
491                }
492                ("mod", 2) => {
493                    let l = eval_arith(&args[0], subst, interner)?;
494                    let r = eval_arith(&args[1], subst, interner)?;
495                    arith_mod(&l, &r)
496                }
497                ("rem", 2) => {
498                    let l = eval_arith(&args[0], subst, interner)?;
499                    let r = eval_arith(&args[1], subst, interner)?;
500                    arith_rem(&l, &r)
501                }
502                ("-", 1) => {
503                    let v = eval_arith(&args[0], subst, interner)?;
504                    arith_neg(&v)
505                }
506                ("abs", 1) => {
507                    let v = eval_arith(&args[0], subst, interner)?;
508                    arith_abs(&v)
509                }
510                ("sign", 1) => {
511                    let v = eval_arith(&args[0], subst, interner)?;
512                    Ok(arith_sign(&v))
513                }
514                ("max", 2) => {
515                    let l = eval_arith(&args[0], subst, interner)?;
516                    let r = eval_arith(&args[1], subst, interner)?;
517                    Ok(arith_max(&l, &r))
518                }
519                ("min", 2) => {
520                    let l = eval_arith(&args[0], subst, interner)?;
521                    let r = eval_arith(&args[1], subst, interner)?;
522                    Ok(arith_min(&l, &r))
523                }
524                _ => Err(format!(
525                    "Unknown arithmetic operator: {}/{}",
526                    name,
527                    args.len()
528                )),
529            }
530        }
531        _ => Err(format!("Cannot evaluate as arithmetic: {:?}", term)),
532    }
533}
534
535/// Check a float result for NaN or Infinity, returning an error if detected.
536fn check_float(f: f64) -> Result<ArithVal, String> {
537    if f.is_nan() {
538        Err("Arithmetic error: NaN result".to_string())
539    } else if f.is_infinite() {
540        Err("Arithmetic error: Infinity result".to_string())
541    } else {
542        Ok(ArithVal::Float(f))
543    }
544}
545
546fn arith_add(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
547    match (a, b) {
548        (ArithVal::Int(a), ArithVal::Int(b)) => a
549            .checked_add(*b)
550            .map(ArithVal::Int)
551            .ok_or_else(|| "Arithmetic error: integer overflow in addition".to_string()),
552        (ArithVal::Float(a), ArithVal::Float(b)) => check_float(a + b),
553        (ArithVal::Int(a), ArithVal::Float(b)) => check_float(*a as f64 + b),
554        (ArithVal::Float(a), ArithVal::Int(b)) => check_float(a + *b as f64),
555    }
556}
557
558fn arith_sub(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
559    match (a, b) {
560        (ArithVal::Int(a), ArithVal::Int(b)) => a
561            .checked_sub(*b)
562            .map(ArithVal::Int)
563            .ok_or_else(|| "Arithmetic error: integer overflow in subtraction".to_string()),
564        (ArithVal::Float(a), ArithVal::Float(b)) => check_float(a - b),
565        (ArithVal::Int(a), ArithVal::Float(b)) => check_float(*a as f64 - b),
566        (ArithVal::Float(a), ArithVal::Int(b)) => check_float(a - *b as f64),
567    }
568}
569
570fn arith_mul(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
571    match (a, b) {
572        (ArithVal::Int(a), ArithVal::Int(b)) => a
573            .checked_mul(*b)
574            .map(ArithVal::Int)
575            .ok_or_else(|| "Arithmetic error: integer overflow in multiplication".to_string()),
576        (ArithVal::Float(a), ArithVal::Float(b)) => check_float(a * b),
577        (ArithVal::Int(a), ArithVal::Float(b)) => check_float(*a as f64 * b),
578        (ArithVal::Float(a), ArithVal::Int(b)) => check_float(a * *b as f64),
579    }
580}
581
582fn arith_div(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
583    match (a, b) {
584        (ArithVal::Int(_), ArithVal::Int(0)) => Err("Division by zero".to_string()),
585        (ArithVal::Int(a), ArithVal::Int(b)) => a
586            .checked_div(*b)
587            .map(ArithVal::Int)
588            .ok_or_else(|| "Arithmetic error: integer overflow in division".to_string()),
589        (_, ArithVal::Float(b)) if *b == 0.0 => Err("Division by zero".to_string()),
590        (ArithVal::Float(_), ArithVal::Int(0)) => Err("Division by zero".to_string()),
591        (ArithVal::Float(a), ArithVal::Float(b)) => check_float(a / b),
592        (ArithVal::Int(a), ArithVal::Float(b)) => check_float(*a as f64 / b),
593        (ArithVal::Float(a), ArithVal::Int(b)) => check_float(a / *b as f64),
594    }
595}
596
597fn arith_mod(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
598    match (a, b) {
599        (ArithVal::Int(_), ArithVal::Int(0)) => Err("Modulo by zero".to_string()),
600        (ArithVal::Int(_), ArithVal::Int(i64::MIN)) => {
601            Err("Arithmetic error: integer overflow in mod".to_string())
602        }
603        (ArithVal::Int(a), ArithVal::Int(b)) => {
604            // ISO Prolog mod: result has the sign of the divisor
605            // X mod Y = X - floor(X/Y) * Y
606            // b.abs() is safe here because we excluded i64::MIN above
607            let r = a.rem_euclid(b.abs());
608            if *b < 0 && r != 0 {
609                Ok(ArithVal::Int(r - b.abs()))
610            } else {
611                Ok(ArithVal::Int(r))
612            }
613        }
614        _ => Err("mod requires integer arguments".to_string()),
615    }
616}
617
618/// ISO `//` — truncating integer division (integers only)
619fn arith_int_div(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
620    match (a, b) {
621        (ArithVal::Int(_), ArithVal::Int(0)) => Err("Division by zero".to_string()),
622        (ArithVal::Int(a), ArithVal::Int(b)) => a
623            .checked_div(*b)
624            .map(ArithVal::Int)
625            .ok_or_else(|| "Arithmetic error: integer overflow in division".to_string()),
626        _ => Err("// requires integer arguments".to_string()),
627    }
628}
629
630/// ISO `rem` — truncating remainder (sign follows dividend)
631fn arith_rem(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
632    match (a, b) {
633        (ArithVal::Int(_), ArithVal::Int(0)) => Err("Remainder by zero".to_string()),
634        (ArithVal::Int(a), ArithVal::Int(b)) => a
635            .checked_rem(*b)
636            .map(ArithVal::Int)
637            .ok_or_else(|| "Arithmetic error: integer overflow in rem".to_string()),
638        _ => Err("rem requires integer arguments".to_string()),
639    }
640}
641
642fn arith_neg(a: &ArithVal) -> Result<ArithVal, String> {
643    match a {
644        ArithVal::Int(n) => n
645            .checked_neg()
646            .map(ArithVal::Int)
647            .ok_or_else(|| "Arithmetic error: integer overflow in negation".to_string()),
648        ArithVal::Float(f) => check_float(-f),
649    }
650}
651
652fn arith_abs(a: &ArithVal) -> Result<ArithVal, String> {
653    match a {
654        ArithVal::Int(n) => n
655            .checked_abs()
656            .map(ArithVal::Int)
657            .ok_or_else(|| "Arithmetic error: integer overflow in abs".to_string()),
658        ArithVal::Float(f) => check_float(f.abs()),
659    }
660}
661
662fn arith_sign(a: &ArithVal) -> ArithVal {
663    match a {
664        ArithVal::Int(n) => ArithVal::Int(n.signum()),
665        ArithVal::Float(f) => ArithVal::Float(f.signum()),
666    }
667}
668
669fn arith_max(a: &ArithVal, b: &ArithVal) -> ArithVal {
670    if arith_lt(a, b) {
671        b.clone()
672    } else {
673        a.clone()
674    }
675}
676
677fn arith_min(a: &ArithVal, b: &ArithVal) -> ArithVal {
678    if arith_lt(a, b) {
679        a.clone()
680    } else {
681        b.clone()
682    }
683}
684
685/// Standard order of terms (ISO Prolog):
686/// Variables < Numbers < Atoms < Compound terms
687/// Within numbers: by value. Within atoms: alphabetical.
688/// Within compounds: by arity, then functor name, then arguments left-to-right.
689pub fn term_compare(a: &Term, b: &Term, interner: &StringInterner) -> std::cmp::Ordering {
690    use std::cmp::Ordering;
691    fn type_rank(t: &Term) -> u8 {
692        match t {
693            Term::Var(_) => 0,
694            Term::Float(_) => 1,
695            Term::Integer(_) => 1,
696            Term::Atom(_) => 2,
697            Term::List { .. } => 3,
698            Term::Compound { .. } => 3,
699        }
700    }
701
702    let ra = type_rank(a);
703    let rb = type_rank(b);
704    if ra != rb {
705        return ra.cmp(&rb);
706    }
707
708    match (a, b) {
709        (Term::Var(a), Term::Var(b)) => a.cmp(b),
710        (Term::Integer(a), Term::Integer(b)) => a.cmp(b),
711        (Term::Float(a), Term::Float(b)) => {
712            // NaN sorts after all other floats (deterministic total order)
713            a.partial_cmp(b)
714                .unwrap_or_else(|| match (a.is_nan(), b.is_nan()) {
715                    (true, true) => Ordering::Equal,
716                    (true, false) => Ordering::Greater,
717                    (false, true) => Ordering::Less,
718                    (false, false) => unreachable!(),
719                })
720        }
721        (Term::Integer(a), Term::Float(b)) => {
722            // NaN sorts after everything; ISO: float < integer when same value
723            if b.is_nan() {
724                return Ordering::Less;
725            }
726            let cmp = (*a as f64).partial_cmp(b).unwrap_or(Ordering::Less);
727            if cmp == Ordering::Equal {
728                Ordering::Greater // integer > float for same value (ISO 8.4.2.1)
729            } else {
730                cmp
731            }
732        }
733        (Term::Float(a), Term::Integer(b)) => {
734            // NaN sorts after everything; ISO: float < integer when same value
735            if a.is_nan() {
736                return Ordering::Greater;
737            }
738            let cmp = a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Greater);
739            if cmp == Ordering::Equal {
740                Ordering::Less // float < integer for same value (ISO 8.4.2.1)
741            } else {
742                cmp
743            }
744        }
745        (Term::Atom(a), Term::Atom(b)) => interner.resolve(*a).cmp(interner.resolve(*b)),
746        (
747            Term::Compound {
748                functor: f1,
749                args: a1,
750            },
751            Term::Compound {
752                functor: f2,
753                args: a2,
754            },
755        ) => {
756            // Compare by arity first, then functor name, then args
757            a1.len()
758                .cmp(&a2.len())
759                .then_with(|| interner.resolve(*f1).cmp(interner.resolve(*f2)))
760                .then_with(|| {
761                    for (x, y) in a1.iter().zip(a2.iter()) {
762                        let c = term_compare(x, y, interner);
763                        if c != Ordering::Equal {
764                            return c;
765                        }
766                    }
767                    Ordering::Equal
768                })
769        }
770        (Term::List { .. }, Term::List { .. }) => {
771            // Iterative list comparison to avoid stack overflow on long lists
772            let mut cur_a = a;
773            let mut cur_b = b;
774            loop {
775                match (cur_a, cur_b) {
776                    (Term::List { head: h1, tail: t1 }, Term::List { head: h2, tail: t2 }) => {
777                        let c = term_compare(h1, h2, interner);
778                        if c != Ordering::Equal {
779                            return c;
780                        }
781                        cur_a = t1;
782                        cur_b = t2;
783                    }
784                    _ => return term_compare(cur_a, cur_b, interner),
785                }
786            }
787        }
788        // List vs Compound: lists are .(H,T) which is arity 2
789        (
790            Term::List { head: h, tail: t },
791            Term::Compound {
792                functor: f2,
793                args: a2,
794            },
795        ) => {
796            // List is ./2; compare arity, then functor ".", then args
797            2usize
798                .cmp(&a2.len())
799                .then_with(|| ".".cmp(interner.resolve(*f2)))
800                .then_with(|| {
801                    if a2.len() >= 1 {
802                        let c = term_compare(h, &a2[0], interner);
803                        if c != Ordering::Equal {
804                            return c;
805                        }
806                    }
807                    if a2.len() >= 2 {
808                        return term_compare(t, &a2[1], interner);
809                    }
810                    Ordering::Equal
811                })
812        }
813        (
814            Term::Compound {
815                functor: f1,
816                args: a1,
817            },
818            Term::List { head: h, tail: t },
819        ) => a1
820            .len()
821            .cmp(&2usize)
822            .then_with(|| interner.resolve(*f1).cmp("."))
823            .then_with(|| {
824                if a1.len() >= 1 {
825                    let c = term_compare(&a1[0], h, interner);
826                    if c != Ordering::Equal {
827                        return c;
828                    }
829                }
830                if a1.len() >= 2 {
831                    return term_compare(&a1[1], t, interner);
832                }
833                Ordering::Equal
834            }),
835        _ => unreachable!("term_compare: unhandled Term variant"),
836    }
837}
838
839/// Collect list elements from a term. Returns None if not a proper list.
840pub fn collect_list(term: &Term, interner: &StringInterner) -> Option<Vec<Term>> {
841    let mut elements = Vec::new();
842    let mut current = term;
843    loop {
844        match current {
845            Term::Atom(id) if interner.resolve(*id) == "[]" => return Some(elements),
846            Term::List { head, tail } => {
847                elements.push(head.as_ref().clone());
848                current = tail;
849            }
850            _ => return None,
851        }
852    }
853}
854
855/// Build a list term from elements.
856pub fn build_list(elements: Vec<Term>, interner: &StringInterner) -> Term {
857    let nil_id = interner.lookup("[]").expect("[] must be interned");
858    let mut list = Term::Atom(nil_id);
859    for elem in elements.into_iter().rev() {
860        list = Term::List {
861            head: Box::new(elem),
862            tail: Box::new(list),
863        };
864    }
865    list
866}
867
868/// Check if a term is a proper list (ends with []).
869fn is_proper_list(term: &Term, interner: &StringInterner) -> bool {
870    let mut current = term;
871    loop {
872        match current {
873            Term::Atom(id) => return interner.resolve(*id) == "[]",
874            Term::List { tail, .. } => current = tail,
875            _ => return false,
876        }
877    }
878}
879
880/// Helper: check if a goal atom name matches a known builtin name.
881pub fn builtin_atom_names() -> &'static [&'static str] {
882    &["true", "fail", "false", "!", "nl"]
883}
884
885pub fn builtin_functor_names() -> &'static [(&'static str, usize)] {
886    &[
887        ("=", 2),
888        ("\\=", 2),
889        ("is", 2),
890        ("<", 2),
891        (">", 2),
892        ("=<", 2),
893        (">=", 2),
894        ("=:=", 2),
895        ("=\\=", 2),
896        ("\\+", 1),
897        ("var", 1),
898        ("nonvar", 1),
899        ("atom", 1),
900        ("number", 1),
901        ("integer", 1),
902        ("float", 1),
903        ("compound", 1),
904        ("is_list", 1),
905        (";", 2),
906        ("->", 2),
907        (",", 2),
908        ("findall", 3),
909        ("once", 1),
910        ("call", 1),
911        ("atom_length", 2),
912        ("atom_concat", 3),
913        ("atom_chars", 2),
914        ("write", 1),
915        ("writeln", 1),
916        ("compare", 3),
917        ("@<", 2),
918        ("@>", 2),
919        ("@=<", 2),
920        ("@>=", 2),
921        ("functor", 3),
922        ("arg", 3),
923        ("=..", 2),
924        ("between", 3),
925        ("copy_term", 2),
926        ("succ", 2),
927        ("plus", 3),
928        ("msort", 2),
929        ("sort", 2),
930        ("number_chars", 2),
931        ("number_codes", 2),
932    ]
933}
934
935#[cfg(test)]
936mod tests {
937    use super::*;
938    use crate::parser::Parser;
939
940    fn setup() -> StringInterner {
941        let mut i = StringInterner::new();
942        // Pre-intern common atoms
943        i.intern("true");
944        i.intern("fail");
945        i.intern("!");
946        i.intern("=");
947        i.intern("\\=");
948        i.intern("is");
949        i.intern("<");
950        i.intern(">");
951        i.intern("=<");
952        i.intern(">=");
953        i.intern("=:=");
954        i.intern("=\\=");
955        i.intern("\\+");
956        i.intern("+");
957        i.intern("-");
958        i.intern("*");
959        i.intern("/");
960        i.intern("mod");
961        i.intern("//");
962        i.intern("rem");
963        i
964    }
965
966    #[test]
967    fn test_is_builtin() {
968        let interner = setup();
969        let true_id = interner.lookup("true").unwrap();
970        assert!(is_builtin(&Term::Atom(true_id), &interner));
971
972        let eq_id = interner.lookup("=").unwrap();
973        let goal = Term::Compound {
974            functor: eq_id,
975            args: vec![Term::Var(0), Term::Atom(0)],
976        };
977        assert!(is_builtin(&goal, &interner));
978    }
979
980    #[test]
981    fn test_exec_true() {
982        let interner = setup();
983        let true_id = interner.lookup("true").unwrap();
984        let mut subst = Substitution::new();
985        let result = exec_builtin(&Term::Atom(true_id), &mut subst, &interner).unwrap();
986        assert!(matches!(result, BuiltinResult::Success));
987    }
988
989    #[test]
990    fn test_exec_fail() {
991        let interner = setup();
992        let fail_id = interner.lookup("fail").unwrap();
993        let mut subst = Substitution::new();
994        let result = exec_builtin(&Term::Atom(fail_id), &mut subst, &interner).unwrap();
995        assert!(matches!(result, BuiltinResult::Failure));
996    }
997
998    #[test]
999    fn test_exec_unify() {
1000        let interner = setup();
1001        let eq_id = interner.lookup("=").unwrap();
1002        let mut subst = Substitution::new();
1003        let goal = Term::Compound {
1004            functor: eq_id,
1005            args: vec![Term::Var(0), Term::Integer(42)],
1006        };
1007        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1008        assert!(matches!(result, BuiltinResult::Success));
1009        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(42));
1010    }
1011
1012    #[test]
1013    fn test_exec_not_unify() {
1014        let interner = setup();
1015        let neq_id = interner.lookup("\\=").unwrap();
1016        let mut subst = Substitution::new();
1017        // 1 \= 2 should succeed
1018        let goal = Term::Compound {
1019            functor: neq_id,
1020            args: vec![Term::Integer(1), Term::Integer(2)],
1021        };
1022        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1023        assert!(matches!(result, BuiltinResult::Success));
1024
1025        // 1 \= 1 should fail
1026        let goal = Term::Compound {
1027            functor: neq_id,
1028            args: vec![Term::Integer(1), Term::Integer(1)],
1029        };
1030        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1031        assert!(matches!(result, BuiltinResult::Failure));
1032    }
1033
1034    #[test]
1035    fn test_exec_is_arithmetic() {
1036        let mut interner = setup();
1037        let goals = Parser::parse_query("X is 2 + 3 * 4", &mut interner).unwrap();
1038        let mut subst = Substitution::new();
1039        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1040        assert!(matches!(result, BuiltinResult::Success));
1041        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(14));
1042    }
1043
1044    #[test]
1045    fn test_exec_comparison() {
1046        let interner = setup();
1047        let lt_id = interner.lookup("<").unwrap();
1048        let mut subst = Substitution::new();
1049
1050        // 1 < 2 should succeed
1051        let goal = Term::Compound {
1052            functor: lt_id,
1053            args: vec![Term::Integer(1), Term::Integer(2)],
1054        };
1055        assert!(matches!(
1056            exec_builtin(&goal, &mut subst, &interner).unwrap(),
1057            BuiltinResult::Success
1058        ));
1059
1060        // 2 < 1 should fail
1061        let goal = Term::Compound {
1062            functor: lt_id,
1063            args: vec![Term::Integer(2), Term::Integer(1)],
1064        };
1065        assert!(matches!(
1066            exec_builtin(&goal, &mut subst, &interner).unwrap(),
1067            BuiltinResult::Failure
1068        ));
1069    }
1070
1071    #[test]
1072    fn test_exec_cut() {
1073        let interner = setup();
1074        let cut_id = interner.lookup("!").unwrap();
1075        let mut subst = Substitution::new();
1076        let result = exec_builtin(&Term::Atom(cut_id), &mut subst, &interner).unwrap();
1077        assert!(matches!(result, BuiltinResult::Cut));
1078    }
1079
1080    #[test]
1081    fn test_type_checking_var() {
1082        let mut interner = setup();
1083        interner.intern("var");
1084        let var_id = interner.lookup("var").unwrap();
1085        let mut subst = Substitution::new();
1086        // var(X) where X is unbound should succeed
1087        let goal = Term::Compound {
1088            functor: var_id,
1089            args: vec![Term::Var(0)],
1090        };
1091        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1092        assert!(matches!(result, BuiltinResult::Success));
1093
1094        // var(42) should fail
1095        let goal = Term::Compound {
1096            functor: var_id,
1097            args: vec![Term::Integer(42)],
1098        };
1099        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1100        assert!(matches!(result, BuiltinResult::Failure));
1101    }
1102
1103    #[test]
1104    fn test_type_checking_atom() {
1105        let mut interner = setup();
1106        interner.intern("atom");
1107        let atom_id = interner.lookup("atom").unwrap();
1108        let mut subst = Substitution::new();
1109        let hello = interner.intern("hello");
1110        // atom(hello) should succeed
1111        let goal = Term::Compound {
1112            functor: atom_id,
1113            args: vec![Term::Atom(hello)],
1114        };
1115        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1116        assert!(matches!(result, BuiltinResult::Success));
1117
1118        // atom(42) should fail
1119        let goal = Term::Compound {
1120            functor: atom_id,
1121            args: vec![Term::Integer(42)],
1122        };
1123        let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1124        assert!(matches!(result, BuiltinResult::Failure));
1125    }
1126
1127    #[test]
1128    fn test_type_checking_integer() {
1129        let mut interner = setup();
1130        interner.intern("integer");
1131        let int_id = interner.lookup("integer").unwrap();
1132        let mut subst = Substitution::new();
1133        let goal = Term::Compound {
1134            functor: int_id,
1135            args: vec![Term::Integer(42)],
1136        };
1137        assert!(matches!(
1138            exec_builtin(&goal, &mut subst, &interner).unwrap(),
1139            BuiltinResult::Success
1140        ));
1141
1142        let goal = Term::Compound {
1143            functor: int_id,
1144            args: vec![Term::Float(3.14)],
1145        };
1146        assert!(matches!(
1147            exec_builtin(&goal, &mut subst, &interner).unwrap(),
1148            BuiltinResult::Failure
1149        ));
1150    }
1151
1152    #[test]
1153    fn test_type_checking_number() {
1154        let mut interner = setup();
1155        interner.intern("number");
1156        let num_id = interner.lookup("number").unwrap();
1157        let mut subst = Substitution::new();
1158        // number(42) should succeed
1159        let goal = Term::Compound {
1160            functor: num_id,
1161            args: vec![Term::Integer(42)],
1162        };
1163        assert!(matches!(
1164            exec_builtin(&goal, &mut subst, &interner).unwrap(),
1165            BuiltinResult::Success
1166        ));
1167        // number(3.14) should succeed
1168        let goal = Term::Compound {
1169            functor: num_id,
1170            args: vec![Term::Float(3.14)],
1171        };
1172        assert!(matches!(
1173            exec_builtin(&goal, &mut subst, &interner).unwrap(),
1174            BuiltinResult::Success
1175        ));
1176    }
1177
1178    #[test]
1179    fn test_exec_mod() {
1180        let mut interner = setup();
1181        let goals = Parser::parse_query("X is 10 mod 3", &mut interner).unwrap();
1182        let mut subst = Substitution::new();
1183        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1184        assert!(matches!(result, BuiltinResult::Success));
1185        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(1));
1186    }
1187
1188    #[test]
1189    fn test_mod_i64_min_divisor() {
1190        // arith_mod with i64::MIN divisor should error, not panic from .abs()
1191        let result = arith_mod(&ArithVal::Int(5), &ArithVal::Int(i64::MIN));
1192        assert!(result.is_err());
1193        assert!(result.unwrap_err().contains("overflow"));
1194    }
1195
1196    #[test]
1197    fn test_mod_i64_min_dividend_neg1() {
1198        // i64::MIN mod -1 should be 0 (rem_euclid handles this correctly)
1199        let result = arith_mod(&ArithVal::Int(i64::MIN), &ArithVal::Int(-1));
1200        match result {
1201            Ok(ArithVal::Int(0)) => {}
1202            other => panic!("Expected Ok(Int(0)), got {:?}", other),
1203        }
1204    }
1205
1206    #[test]
1207    fn test_integer_overflow_add() {
1208        let mut interner = setup();
1209        let query_str = format!("X is {} + 1", i64::MAX);
1210        let goals = Parser::parse_query(&query_str, &mut interner).unwrap();
1211        let mut subst = Substitution::new();
1212        let result = exec_builtin(&goals[0], &mut subst, &interner);
1213        assert!(result.is_err());
1214        assert!(result.unwrap_err().contains("overflow"));
1215    }
1216
1217    #[test]
1218    fn test_integer_overflow_mul() {
1219        let mut interner = setup();
1220        let query_str = format!("X is {} * 2", i64::MAX);
1221        let goals = Parser::parse_query(&query_str, &mut interner).unwrap();
1222        let mut subst = Substitution::new();
1223        let result = exec_builtin(&goals[0], &mut subst, &interner);
1224        assert!(result.is_err());
1225        assert!(result.unwrap_err().contains("overflow"));
1226    }
1227
1228    #[test]
1229    fn test_arith_abs() {
1230        let mut interner = setup();
1231        let goals = Parser::parse_query("X is abs(-5)", &mut interner).unwrap();
1232        let mut subst = Substitution::new();
1233        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1234        assert!(matches!(result, BuiltinResult::Success));
1235        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(5));
1236    }
1237
1238    #[test]
1239    fn test_arith_abs_positive() {
1240        let mut interner = setup();
1241        let goals = Parser::parse_query("X is abs(3)", &mut interner).unwrap();
1242        let mut subst = Substitution::new();
1243        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1244        assert!(matches!(result, BuiltinResult::Success));
1245        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(3));
1246    }
1247
1248    #[test]
1249    fn test_arith_sign() {
1250        let mut interner = setup();
1251        let goals = Parser::parse_query("X is sign(-42)", &mut interner).unwrap();
1252        let mut subst = Substitution::new();
1253        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1254        assert!(matches!(result, BuiltinResult::Success));
1255        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(-1));
1256    }
1257
1258    #[test]
1259    fn test_arith_sign_zero() {
1260        let mut interner = setup();
1261        let goals = Parser::parse_query("X is sign(0)", &mut interner).unwrap();
1262        let mut subst = Substitution::new();
1263        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1264        assert!(matches!(result, BuiltinResult::Success));
1265        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(0));
1266    }
1267
1268    #[test]
1269    fn test_arith_max() {
1270        let mut interner = setup();
1271        let goals = Parser::parse_query("X is max(3, 7)", &mut interner).unwrap();
1272        let mut subst = Substitution::new();
1273        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1274        assert!(matches!(result, BuiltinResult::Success));
1275        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(7));
1276    }
1277
1278    #[test]
1279    fn test_arith_min() {
1280        let mut interner = setup();
1281        let goals = Parser::parse_query("X is min(3, 7)", &mut interner).unwrap();
1282        let mut subst = Substitution::new();
1283        let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1284        assert!(matches!(result, BuiltinResult::Success));
1285        assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(3));
1286    }
1287}