1use crate::term::{StringInterner, Term};
2use crate::unify::Substitution;
3
4pub fn is_builtin(goal: &Term, interner: &StringInterner) -> bool {
6 match goal {
7 Term::Atom(id) => {
8 let name = interner.resolve(*id);
9 matches!(name, "true" | "fail" | "false" | "!" | "nl")
10 }
11 Term::Compound { functor, args } => {
12 let name = interner.resolve(*functor);
13 match (name, args.len()) {
14 ("=", 2) | ("\\=", 2) | ("unify_with_occurs_check", 2) | ("is", 2) => true,
15 ("<", 2) | (">", 2) | ("=<", 2) | (">=", 2) => true,
16 ("=:=", 2) | ("=\\=", 2) => true,
17 ("\\+", 1) => true,
18 ("var", 1) | ("nonvar", 1) | ("atom", 1) | ("number", 1) => true,
20 ("integer", 1) | ("float", 1) | ("compound", 1) | ("is_list", 1) => true,
21 (";", 2) | ("->", 2) | (",", 2) => true,
23 ("findall", 3) => true,
25 ("once", 1) | ("call", 1) => true,
27 ("atom_length", 2) | ("atom_concat", 3) | ("atom_chars", 2) => true,
29 ("write", 1) | ("writeln", 1) => true,
31 ("compare", 3) => true,
33 ("@<", 2) | ("@>", 2) | ("@=<", 2) | ("@>=", 2) => true,
34 ("functor", 3) | ("arg", 3) | ("=..", 2) => true,
36 ("between", 3) => true,
38 ("copy_term", 2) => true,
40 ("succ", 2) | ("plus", 3) => true,
42 ("msort", 2) | ("sort", 2) => true,
44 ("number_chars", 2) | ("number_codes", 2) => true,
46 _ => false,
47 }
48 }
49 _ => false,
50 }
51}
52
53#[derive(Debug)]
55pub enum BuiltinResult {
56 Success,
58 Failure,
60 Cut,
62 NegationAsFailure(Term),
64 Disjunction(Term, Term),
66 IfThenElse(Term, Term, Term),
68 IfThen(Term, Term),
70 Conjunction(Term, Term),
72 FindAll(Term, Term, Term),
74 Once(Term),
76 Call(Term),
78 AtomLength(Term, Term),
80 AtomConcat(Term, Term, Term),
82 AtomChars(Term, Term),
84 Write(Term),
86 Writeln(Term),
88 Nl,
90 Compare(Term, Term, Term),
92 Functor(Term, Term, Term),
94 Arg(Term, Term, Term),
96 Univ(Term, Term),
98 Between(Term, Term, Term),
100 CopyTerm(Term, Term),
102 Succ(Term, Term),
104 Plus(Term, Term, Term),
106 MSort(Term, Term),
108 Sort(Term, Term),
110 NumberChars(Term, Term),
112 NumberCodes(Term, Term),
114}
115
116pub fn exec_builtin(
118 goal: &Term,
119 subst: &mut Substitution,
120 interner: &StringInterner,
121) -> Result<BuiltinResult, String> {
122 match goal {
123 Term::Atom(id) => {
124 let name = interner.resolve(*id);
125 match name {
126 "true" => Ok(BuiltinResult::Success),
127 "fail" | "false" => Ok(BuiltinResult::Failure),
128 "!" => Ok(BuiltinResult::Cut),
129 "nl" => Ok(BuiltinResult::Nl),
130 _ => Err(format!("Unknown builtin atom: {}", name)),
131 }
132 }
133 Term::Compound { functor, args } => {
134 let name = interner.resolve(*functor);
135 match (name, args.len()) {
136 ("=", 2) => {
137 if subst.unify(&args[0], &args[1]) {
138 Ok(BuiltinResult::Success)
139 } else {
140 Ok(BuiltinResult::Failure)
141 }
142 }
143 ("unify_with_occurs_check", 2) => {
144 if subst.unify_with_occurs_check(&args[0], &args[1]) {
145 Ok(BuiltinResult::Success)
146 } else {
147 Ok(BuiltinResult::Failure)
148 }
149 }
150 ("\\=", 2) => {
151 let mark = subst.trail_mark();
152 if subst.unify(&args[0], &args[1]) {
153 subst.undo_to(mark);
154 Ok(BuiltinResult::Failure)
155 } else {
156 subst.undo_to(mark);
157 Ok(BuiltinResult::Success)
158 }
159 }
160 ("is", 2) => {
161 let result = eval_arith(&args[1], subst, interner)?;
162 let result_term = arith_to_term(result);
163 if subst.unify(&args[0], &result_term) {
164 Ok(BuiltinResult::Success)
165 } else {
166 Ok(BuiltinResult::Failure)
167 }
168 }
169 ("<", 2) => {
170 let l = eval_arith(&args[0], subst, interner)?;
171 let r = eval_arith(&args[1], subst, interner)?;
172 if arith_lt(&l, &r) {
173 Ok(BuiltinResult::Success)
174 } else {
175 Ok(BuiltinResult::Failure)
176 }
177 }
178 (">", 2) => {
179 let l = eval_arith(&args[0], subst, interner)?;
180 let r = eval_arith(&args[1], subst, interner)?;
181 if arith_gt(&l, &r) {
182 Ok(BuiltinResult::Success)
183 } else {
184 Ok(BuiltinResult::Failure)
185 }
186 }
187 ("=<", 2) => {
188 let l = eval_arith(&args[0], subst, interner)?;
189 let r = eval_arith(&args[1], subst, interner)?;
190 if !arith_gt(&l, &r) {
191 Ok(BuiltinResult::Success)
192 } else {
193 Ok(BuiltinResult::Failure)
194 }
195 }
196 (">=", 2) => {
197 let l = eval_arith(&args[0], subst, interner)?;
198 let r = eval_arith(&args[1], subst, interner)?;
199 if !arith_lt(&l, &r) {
200 Ok(BuiltinResult::Success)
201 } else {
202 Ok(BuiltinResult::Failure)
203 }
204 }
205 ("=:=", 2) => {
206 let l = eval_arith(&args[0], subst, interner)?;
207 let r = eval_arith(&args[1], subst, interner)?;
208 if arith_eq(&l, &r) {
209 Ok(BuiltinResult::Success)
210 } else {
211 Ok(BuiltinResult::Failure)
212 }
213 }
214 ("=\\=", 2) => {
215 let l = eval_arith(&args[0], subst, interner)?;
216 let r = eval_arith(&args[1], subst, interner)?;
217 if !arith_eq(&l, &r) {
218 Ok(BuiltinResult::Success)
219 } else {
220 Ok(BuiltinResult::Failure)
221 }
222 }
223 ("\\+", 1) => Ok(BuiltinResult::NegationAsFailure(args[0].clone())),
224 ("var", 1) => {
226 let walked = subst.walk(&args[0]);
227 if matches!(walked, Term::Var(_)) {
228 Ok(BuiltinResult::Success)
229 } else {
230 Ok(BuiltinResult::Failure)
231 }
232 }
233 ("nonvar", 1) => {
234 let walked = subst.walk(&args[0]);
235 if matches!(walked, Term::Var(_)) {
236 Ok(BuiltinResult::Failure)
237 } else {
238 Ok(BuiltinResult::Success)
239 }
240 }
241 ("atom", 1) => {
242 let walked = subst.walk(&args[0]);
243 if matches!(walked, Term::Atom(_)) {
244 Ok(BuiltinResult::Success)
245 } else {
246 Ok(BuiltinResult::Failure)
247 }
248 }
249 ("number", 1) => {
250 let walked = subst.walk(&args[0]);
251 if matches!(walked, Term::Integer(_) | Term::Float(_)) {
252 Ok(BuiltinResult::Success)
253 } else {
254 Ok(BuiltinResult::Failure)
255 }
256 }
257 ("integer", 1) => {
258 let walked = subst.walk(&args[0]);
259 if matches!(walked, Term::Integer(_)) {
260 Ok(BuiltinResult::Success)
261 } else {
262 Ok(BuiltinResult::Failure)
263 }
264 }
265 ("float", 1) => {
266 let walked = subst.walk(&args[0]);
267 if matches!(walked, Term::Float(_)) {
268 Ok(BuiltinResult::Success)
269 } else {
270 Ok(BuiltinResult::Failure)
271 }
272 }
273 ("compound", 1) => {
274 let walked = subst.walk(&args[0]);
275 if matches!(walked, Term::Compound { .. } | Term::List { .. }) {
276 Ok(BuiltinResult::Success)
277 } else {
278 Ok(BuiltinResult::Failure)
279 }
280 }
281 ("is_list", 1) => {
282 let walked = subst.apply(&args[0]);
283 if is_proper_list(&walked, interner) {
284 Ok(BuiltinResult::Success)
285 } else {
286 Ok(BuiltinResult::Failure)
287 }
288 }
289 (";", 2) => {
291 let left = subst.walk(&args[0]);
293 if let Term::Compound {
294 functor,
295 args: inner_args,
296 } = &left
297 {
298 if interner.resolve(*functor) == "->" && inner_args.len() == 2 {
299 return Ok(BuiltinResult::IfThenElse(
300 inner_args[0].clone(),
301 inner_args[1].clone(),
302 args[1].clone(),
303 ));
304 }
305 }
306 Ok(BuiltinResult::Disjunction(args[0].clone(), args[1].clone()))
308 }
309 ("->", 2) => Ok(BuiltinResult::IfThen(args[0].clone(), args[1].clone())),
310 (",", 2) => Ok(BuiltinResult::Conjunction(args[0].clone(), args[1].clone())),
311 ("findall", 3) => Ok(BuiltinResult::FindAll(
312 args[0].clone(),
313 args[1].clone(),
314 args[2].clone(),
315 )),
316 ("once", 1) => Ok(BuiltinResult::Once(args[0].clone())),
317 ("call", 1) => Ok(BuiltinResult::Call(args[0].clone())),
318 ("atom_length", 2) => {
320 Ok(BuiltinResult::AtomLength(args[0].clone(), args[1].clone()))
321 }
322 ("atom_concat", 3) => Ok(BuiltinResult::AtomConcat(
323 args[0].clone(),
324 args[1].clone(),
325 args[2].clone(),
326 )),
327 ("atom_chars", 2) => Ok(BuiltinResult::AtomChars(args[0].clone(), args[1].clone())),
328 ("write", 1) => Ok(BuiltinResult::Write(args[0].clone())),
330 ("writeln", 1) => Ok(BuiltinResult::Writeln(args[0].clone())),
331 ("compare", 3) => Ok(BuiltinResult::Compare(
333 args[0].clone(),
334 args[1].clone(),
335 args[2].clone(),
336 )),
337 ("@<", 2) => {
338 let cmp =
339 term_compare(&subst.apply(&args[0]), &subst.apply(&args[1]), interner);
340 if cmp == std::cmp::Ordering::Less {
341 Ok(BuiltinResult::Success)
342 } else {
343 Ok(BuiltinResult::Failure)
344 }
345 }
346 ("@>", 2) => {
347 let cmp =
348 term_compare(&subst.apply(&args[0]), &subst.apply(&args[1]), interner);
349 if cmp == std::cmp::Ordering::Greater {
350 Ok(BuiltinResult::Success)
351 } else {
352 Ok(BuiltinResult::Failure)
353 }
354 }
355 ("@=<", 2) => {
356 let cmp =
357 term_compare(&subst.apply(&args[0]), &subst.apply(&args[1]), interner);
358 if cmp != std::cmp::Ordering::Greater {
359 Ok(BuiltinResult::Success)
360 } else {
361 Ok(BuiltinResult::Failure)
362 }
363 }
364 ("@>=", 2) => {
365 let cmp =
366 term_compare(&subst.apply(&args[0]), &subst.apply(&args[1]), interner);
367 if cmp != std::cmp::Ordering::Less {
368 Ok(BuiltinResult::Success)
369 } else {
370 Ok(BuiltinResult::Failure)
371 }
372 }
373 ("functor", 3) => Ok(BuiltinResult::Functor(
375 args[0].clone(),
376 args[1].clone(),
377 args[2].clone(),
378 )),
379 ("arg", 3) => Ok(BuiltinResult::Arg(
380 args[0].clone(),
381 args[1].clone(),
382 args[2].clone(),
383 )),
384 ("=..", 2) => Ok(BuiltinResult::Univ(args[0].clone(), args[1].clone())),
385 ("between", 3) => Ok(BuiltinResult::Between(
387 args[0].clone(),
388 args[1].clone(),
389 args[2].clone(),
390 )),
391 ("copy_term", 2) => Ok(BuiltinResult::CopyTerm(args[0].clone(), args[1].clone())),
393 ("succ", 2) => Ok(BuiltinResult::Succ(args[0].clone(), args[1].clone())),
395 ("plus", 3) => Ok(BuiltinResult::Plus(
396 args[0].clone(),
397 args[1].clone(),
398 args[2].clone(),
399 )),
400 ("msort", 2) => Ok(BuiltinResult::MSort(args[0].clone(), args[1].clone())),
402 ("sort", 2) => Ok(BuiltinResult::Sort(args[0].clone(), args[1].clone())),
403 ("number_chars", 2) => {
405 Ok(BuiltinResult::NumberChars(args[0].clone(), args[1].clone()))
406 }
407 ("number_codes", 2) => {
408 Ok(BuiltinResult::NumberCodes(args[0].clone(), args[1].clone()))
409 }
410 _ => Err(format!("Unknown builtin: {}/{}", name, args.len())),
411 }
412 }
413 _ => Err(format!("Cannot execute as builtin: {:?}", goal)),
414 }
415}
416
417#[derive(Debug, Clone)]
419enum ArithVal {
420 Int(i64),
421 Float(f64),
422}
423
424fn arith_to_term(val: ArithVal) -> Term {
425 match val {
426 ArithVal::Int(n) => Term::Integer(n),
427 ArithVal::Float(f) => Term::Float(f),
428 }
429}
430
431fn arith_lt(a: &ArithVal, b: &ArithVal) -> bool {
432 match (a, b) {
433 (ArithVal::Int(a), ArithVal::Int(b)) => a < b,
434 (ArithVal::Float(a), ArithVal::Float(b)) => a < b,
435 (ArithVal::Int(a), ArithVal::Float(b)) => (*a as f64) < *b,
436 (ArithVal::Float(a), ArithVal::Int(b)) => *a < (*b as f64),
437 }
438}
439
440fn arith_gt(a: &ArithVal, b: &ArithVal) -> bool {
441 arith_lt(b, a)
442}
443
444fn arith_eq(a: &ArithVal, b: &ArithVal) -> bool {
445 match (a, b) {
446 (ArithVal::Int(a), ArithVal::Int(b)) => a == b,
447 (ArithVal::Float(a), ArithVal::Float(b)) => a == b,
448 (ArithVal::Int(a), ArithVal::Float(b)) => (*a as f64) == *b,
449 (ArithVal::Float(a), ArithVal::Int(b)) => *a == (*b as f64),
450 }
451}
452
453fn eval_arith(
455 term: &Term,
456 subst: &Substitution,
457 interner: &StringInterner,
458) -> Result<ArithVal, String> {
459 let term = subst.walk(term);
460 match &term {
461 Term::Integer(n) => Ok(ArithVal::Int(*n)),
462 Term::Float(f) => Ok(ArithVal::Float(*f)),
463 Term::Var(id) => Err(format!("Arithmetic error: unbound variable _{}", id)),
464 Term::Compound { functor, args } => {
465 let name = interner.resolve(*functor);
466 match (name, args.len()) {
467 ("+", 2) => {
468 let l = eval_arith(&args[0], subst, interner)?;
469 let r = eval_arith(&args[1], subst, interner)?;
470 arith_add(&l, &r)
471 }
472 ("-", 2) => {
473 let l = eval_arith(&args[0], subst, interner)?;
474 let r = eval_arith(&args[1], subst, interner)?;
475 arith_sub(&l, &r)
476 }
477 ("*", 2) => {
478 let l = eval_arith(&args[0], subst, interner)?;
479 let r = eval_arith(&args[1], subst, interner)?;
480 arith_mul(&l, &r)
481 }
482 ("/", 2) => {
483 let l = eval_arith(&args[0], subst, interner)?;
484 let r = eval_arith(&args[1], subst, interner)?;
485 arith_div(&l, &r)
486 }
487 ("//", 2) => {
488 let l = eval_arith(&args[0], subst, interner)?;
489 let r = eval_arith(&args[1], subst, interner)?;
490 arith_int_div(&l, &r)
491 }
492 ("mod", 2) => {
493 let l = eval_arith(&args[0], subst, interner)?;
494 let r = eval_arith(&args[1], subst, interner)?;
495 arith_mod(&l, &r)
496 }
497 ("rem", 2) => {
498 let l = eval_arith(&args[0], subst, interner)?;
499 let r = eval_arith(&args[1], subst, interner)?;
500 arith_rem(&l, &r)
501 }
502 ("-", 1) => {
503 let v = eval_arith(&args[0], subst, interner)?;
504 arith_neg(&v)
505 }
506 ("abs", 1) => {
507 let v = eval_arith(&args[0], subst, interner)?;
508 arith_abs(&v)
509 }
510 ("sign", 1) => {
511 let v = eval_arith(&args[0], subst, interner)?;
512 Ok(arith_sign(&v))
513 }
514 ("max", 2) => {
515 let l = eval_arith(&args[0], subst, interner)?;
516 let r = eval_arith(&args[1], subst, interner)?;
517 Ok(arith_max(&l, &r))
518 }
519 ("min", 2) => {
520 let l = eval_arith(&args[0], subst, interner)?;
521 let r = eval_arith(&args[1], subst, interner)?;
522 Ok(arith_min(&l, &r))
523 }
524 _ => Err(format!(
525 "Unknown arithmetic operator: {}/{}",
526 name,
527 args.len()
528 )),
529 }
530 }
531 _ => Err(format!("Cannot evaluate as arithmetic: {:?}", term)),
532 }
533}
534
535fn check_float(f: f64) -> Result<ArithVal, String> {
537 if f.is_nan() {
538 Err("Arithmetic error: NaN result".to_string())
539 } else if f.is_infinite() {
540 Err("Arithmetic error: Infinity result".to_string())
541 } else {
542 Ok(ArithVal::Float(f))
543 }
544}
545
546fn arith_add(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
547 match (a, b) {
548 (ArithVal::Int(a), ArithVal::Int(b)) => a
549 .checked_add(*b)
550 .map(ArithVal::Int)
551 .ok_or_else(|| "Arithmetic error: integer overflow in addition".to_string()),
552 (ArithVal::Float(a), ArithVal::Float(b)) => check_float(a + b),
553 (ArithVal::Int(a), ArithVal::Float(b)) => check_float(*a as f64 + b),
554 (ArithVal::Float(a), ArithVal::Int(b)) => check_float(a + *b as f64),
555 }
556}
557
558fn arith_sub(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
559 match (a, b) {
560 (ArithVal::Int(a), ArithVal::Int(b)) => a
561 .checked_sub(*b)
562 .map(ArithVal::Int)
563 .ok_or_else(|| "Arithmetic error: integer overflow in subtraction".to_string()),
564 (ArithVal::Float(a), ArithVal::Float(b)) => check_float(a - b),
565 (ArithVal::Int(a), ArithVal::Float(b)) => check_float(*a as f64 - b),
566 (ArithVal::Float(a), ArithVal::Int(b)) => check_float(a - *b as f64),
567 }
568}
569
570fn arith_mul(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
571 match (a, b) {
572 (ArithVal::Int(a), ArithVal::Int(b)) => a
573 .checked_mul(*b)
574 .map(ArithVal::Int)
575 .ok_or_else(|| "Arithmetic error: integer overflow in multiplication".to_string()),
576 (ArithVal::Float(a), ArithVal::Float(b)) => check_float(a * b),
577 (ArithVal::Int(a), ArithVal::Float(b)) => check_float(*a as f64 * b),
578 (ArithVal::Float(a), ArithVal::Int(b)) => check_float(a * *b as f64),
579 }
580}
581
582fn arith_div(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
583 match (a, b) {
584 (ArithVal::Int(_), ArithVal::Int(0)) => Err("Division by zero".to_string()),
585 (ArithVal::Int(a), ArithVal::Int(b)) => a
586 .checked_div(*b)
587 .map(ArithVal::Int)
588 .ok_or_else(|| "Arithmetic error: integer overflow in division".to_string()),
589 (_, ArithVal::Float(b)) if *b == 0.0 => Err("Division by zero".to_string()),
590 (ArithVal::Float(_), ArithVal::Int(0)) => Err("Division by zero".to_string()),
591 (ArithVal::Float(a), ArithVal::Float(b)) => check_float(a / b),
592 (ArithVal::Int(a), ArithVal::Float(b)) => check_float(*a as f64 / b),
593 (ArithVal::Float(a), ArithVal::Int(b)) => check_float(a / *b as f64),
594 }
595}
596
597fn arith_mod(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
598 match (a, b) {
599 (ArithVal::Int(_), ArithVal::Int(0)) => Err("Modulo by zero".to_string()),
600 (ArithVal::Int(_), ArithVal::Int(i64::MIN)) => {
601 Err("Arithmetic error: integer overflow in mod".to_string())
602 }
603 (ArithVal::Int(a), ArithVal::Int(b)) => {
604 let r = a.rem_euclid(b.abs());
608 if *b < 0 && r != 0 {
609 Ok(ArithVal::Int(r - b.abs()))
610 } else {
611 Ok(ArithVal::Int(r))
612 }
613 }
614 _ => Err("mod requires integer arguments".to_string()),
615 }
616}
617
618fn arith_int_div(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
620 match (a, b) {
621 (ArithVal::Int(_), ArithVal::Int(0)) => Err("Division by zero".to_string()),
622 (ArithVal::Int(a), ArithVal::Int(b)) => a
623 .checked_div(*b)
624 .map(ArithVal::Int)
625 .ok_or_else(|| "Arithmetic error: integer overflow in division".to_string()),
626 _ => Err("// requires integer arguments".to_string()),
627 }
628}
629
630fn arith_rem(a: &ArithVal, b: &ArithVal) -> Result<ArithVal, String> {
632 match (a, b) {
633 (ArithVal::Int(_), ArithVal::Int(0)) => Err("Remainder by zero".to_string()),
634 (ArithVal::Int(a), ArithVal::Int(b)) => a
635 .checked_rem(*b)
636 .map(ArithVal::Int)
637 .ok_or_else(|| "Arithmetic error: integer overflow in rem".to_string()),
638 _ => Err("rem requires integer arguments".to_string()),
639 }
640}
641
642fn arith_neg(a: &ArithVal) -> Result<ArithVal, String> {
643 match a {
644 ArithVal::Int(n) => n
645 .checked_neg()
646 .map(ArithVal::Int)
647 .ok_or_else(|| "Arithmetic error: integer overflow in negation".to_string()),
648 ArithVal::Float(f) => check_float(-f),
649 }
650}
651
652fn arith_abs(a: &ArithVal) -> Result<ArithVal, String> {
653 match a {
654 ArithVal::Int(n) => n
655 .checked_abs()
656 .map(ArithVal::Int)
657 .ok_or_else(|| "Arithmetic error: integer overflow in abs".to_string()),
658 ArithVal::Float(f) => check_float(f.abs()),
659 }
660}
661
662fn arith_sign(a: &ArithVal) -> ArithVal {
663 match a {
664 ArithVal::Int(n) => ArithVal::Int(n.signum()),
665 ArithVal::Float(f) => ArithVal::Float(f.signum()),
666 }
667}
668
669fn arith_max(a: &ArithVal, b: &ArithVal) -> ArithVal {
670 if arith_lt(a, b) {
671 b.clone()
672 } else {
673 a.clone()
674 }
675}
676
677fn arith_min(a: &ArithVal, b: &ArithVal) -> ArithVal {
678 if arith_lt(a, b) {
679 a.clone()
680 } else {
681 b.clone()
682 }
683}
684
685pub fn term_compare(a: &Term, b: &Term, interner: &StringInterner) -> std::cmp::Ordering {
690 use std::cmp::Ordering;
691 fn type_rank(t: &Term) -> u8 {
692 match t {
693 Term::Var(_) => 0,
694 Term::Float(_) => 1,
695 Term::Integer(_) => 1,
696 Term::Atom(_) => 2,
697 Term::List { .. } => 3,
698 Term::Compound { .. } => 3,
699 }
700 }
701
702 let ra = type_rank(a);
703 let rb = type_rank(b);
704 if ra != rb {
705 return ra.cmp(&rb);
706 }
707
708 match (a, b) {
709 (Term::Var(a), Term::Var(b)) => a.cmp(b),
710 (Term::Integer(a), Term::Integer(b)) => a.cmp(b),
711 (Term::Float(a), Term::Float(b)) => {
712 a.partial_cmp(b)
714 .unwrap_or_else(|| match (a.is_nan(), b.is_nan()) {
715 (true, true) => Ordering::Equal,
716 (true, false) => Ordering::Greater,
717 (false, true) => Ordering::Less,
718 (false, false) => unreachable!(),
719 })
720 }
721 (Term::Integer(a), Term::Float(b)) => {
722 if b.is_nan() {
724 return Ordering::Less;
725 }
726 let cmp = (*a as f64).partial_cmp(b).unwrap_or(Ordering::Less);
727 if cmp == Ordering::Equal {
728 Ordering::Greater } else {
730 cmp
731 }
732 }
733 (Term::Float(a), Term::Integer(b)) => {
734 if a.is_nan() {
736 return Ordering::Greater;
737 }
738 let cmp = a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Greater);
739 if cmp == Ordering::Equal {
740 Ordering::Less } else {
742 cmp
743 }
744 }
745 (Term::Atom(a), Term::Atom(b)) => interner.resolve(*a).cmp(interner.resolve(*b)),
746 (
747 Term::Compound {
748 functor: f1,
749 args: a1,
750 },
751 Term::Compound {
752 functor: f2,
753 args: a2,
754 },
755 ) => {
756 a1.len()
758 .cmp(&a2.len())
759 .then_with(|| interner.resolve(*f1).cmp(interner.resolve(*f2)))
760 .then_with(|| {
761 for (x, y) in a1.iter().zip(a2.iter()) {
762 let c = term_compare(x, y, interner);
763 if c != Ordering::Equal {
764 return c;
765 }
766 }
767 Ordering::Equal
768 })
769 }
770 (Term::List { .. }, Term::List { .. }) => {
771 let mut cur_a = a;
773 let mut cur_b = b;
774 loop {
775 match (cur_a, cur_b) {
776 (Term::List { head: h1, tail: t1 }, Term::List { head: h2, tail: t2 }) => {
777 let c = term_compare(h1, h2, interner);
778 if c != Ordering::Equal {
779 return c;
780 }
781 cur_a = t1;
782 cur_b = t2;
783 }
784 _ => return term_compare(cur_a, cur_b, interner),
785 }
786 }
787 }
788 (
790 Term::List { head: h, tail: t },
791 Term::Compound {
792 functor: f2,
793 args: a2,
794 },
795 ) => {
796 2usize
798 .cmp(&a2.len())
799 .then_with(|| ".".cmp(interner.resolve(*f2)))
800 .then_with(|| {
801 if a2.len() >= 1 {
802 let c = term_compare(h, &a2[0], interner);
803 if c != Ordering::Equal {
804 return c;
805 }
806 }
807 if a2.len() >= 2 {
808 return term_compare(t, &a2[1], interner);
809 }
810 Ordering::Equal
811 })
812 }
813 (
814 Term::Compound {
815 functor: f1,
816 args: a1,
817 },
818 Term::List { head: h, tail: t },
819 ) => a1
820 .len()
821 .cmp(&2usize)
822 .then_with(|| interner.resolve(*f1).cmp("."))
823 .then_with(|| {
824 if a1.len() >= 1 {
825 let c = term_compare(&a1[0], h, interner);
826 if c != Ordering::Equal {
827 return c;
828 }
829 }
830 if a1.len() >= 2 {
831 return term_compare(&a1[1], t, interner);
832 }
833 Ordering::Equal
834 }),
835 _ => unreachable!("term_compare: unhandled Term variant"),
836 }
837}
838
839pub fn collect_list(term: &Term, interner: &StringInterner) -> Option<Vec<Term>> {
841 let mut elements = Vec::new();
842 let mut current = term;
843 loop {
844 match current {
845 Term::Atom(id) if interner.resolve(*id) == "[]" => return Some(elements),
846 Term::List { head, tail } => {
847 elements.push(head.as_ref().clone());
848 current = tail;
849 }
850 _ => return None,
851 }
852 }
853}
854
855pub fn build_list(elements: Vec<Term>, interner: &StringInterner) -> Term {
857 let nil_id = interner.lookup("[]").expect("[] must be interned");
858 let mut list = Term::Atom(nil_id);
859 for elem in elements.into_iter().rev() {
860 list = Term::List {
861 head: Box::new(elem),
862 tail: Box::new(list),
863 };
864 }
865 list
866}
867
868fn is_proper_list(term: &Term, interner: &StringInterner) -> bool {
870 let mut current = term;
871 loop {
872 match current {
873 Term::Atom(id) => return interner.resolve(*id) == "[]",
874 Term::List { tail, .. } => current = tail,
875 _ => return false,
876 }
877 }
878}
879
880pub fn builtin_atom_names() -> &'static [&'static str] {
882 &["true", "fail", "false", "!", "nl"]
883}
884
885pub fn builtin_functor_names() -> &'static [(&'static str, usize)] {
886 &[
887 ("=", 2),
888 ("\\=", 2),
889 ("is", 2),
890 ("<", 2),
891 (">", 2),
892 ("=<", 2),
893 (">=", 2),
894 ("=:=", 2),
895 ("=\\=", 2),
896 ("\\+", 1),
897 ("var", 1),
898 ("nonvar", 1),
899 ("atom", 1),
900 ("number", 1),
901 ("integer", 1),
902 ("float", 1),
903 ("compound", 1),
904 ("is_list", 1),
905 (";", 2),
906 ("->", 2),
907 (",", 2),
908 ("findall", 3),
909 ("once", 1),
910 ("call", 1),
911 ("atom_length", 2),
912 ("atom_concat", 3),
913 ("atom_chars", 2),
914 ("write", 1),
915 ("writeln", 1),
916 ("compare", 3),
917 ("@<", 2),
918 ("@>", 2),
919 ("@=<", 2),
920 ("@>=", 2),
921 ("functor", 3),
922 ("arg", 3),
923 ("=..", 2),
924 ("between", 3),
925 ("copy_term", 2),
926 ("succ", 2),
927 ("plus", 3),
928 ("msort", 2),
929 ("sort", 2),
930 ("number_chars", 2),
931 ("number_codes", 2),
932 ]
933}
934
935#[cfg(test)]
936mod tests {
937 use super::*;
938 use crate::parser::Parser;
939
940 fn setup() -> StringInterner {
941 let mut i = StringInterner::new();
942 i.intern("true");
944 i.intern("fail");
945 i.intern("!");
946 i.intern("=");
947 i.intern("\\=");
948 i.intern("is");
949 i.intern("<");
950 i.intern(">");
951 i.intern("=<");
952 i.intern(">=");
953 i.intern("=:=");
954 i.intern("=\\=");
955 i.intern("\\+");
956 i.intern("+");
957 i.intern("-");
958 i.intern("*");
959 i.intern("/");
960 i.intern("mod");
961 i.intern("//");
962 i.intern("rem");
963 i
964 }
965
966 #[test]
967 fn test_is_builtin() {
968 let interner = setup();
969 let true_id = interner.lookup("true").unwrap();
970 assert!(is_builtin(&Term::Atom(true_id), &interner));
971
972 let eq_id = interner.lookup("=").unwrap();
973 let goal = Term::Compound {
974 functor: eq_id,
975 args: vec![Term::Var(0), Term::Atom(0)],
976 };
977 assert!(is_builtin(&goal, &interner));
978 }
979
980 #[test]
981 fn test_exec_true() {
982 let interner = setup();
983 let true_id = interner.lookup("true").unwrap();
984 let mut subst = Substitution::new();
985 let result = exec_builtin(&Term::Atom(true_id), &mut subst, &interner).unwrap();
986 assert!(matches!(result, BuiltinResult::Success));
987 }
988
989 #[test]
990 fn test_exec_fail() {
991 let interner = setup();
992 let fail_id = interner.lookup("fail").unwrap();
993 let mut subst = Substitution::new();
994 let result = exec_builtin(&Term::Atom(fail_id), &mut subst, &interner).unwrap();
995 assert!(matches!(result, BuiltinResult::Failure));
996 }
997
998 #[test]
999 fn test_exec_unify() {
1000 let interner = setup();
1001 let eq_id = interner.lookup("=").unwrap();
1002 let mut subst = Substitution::new();
1003 let goal = Term::Compound {
1004 functor: eq_id,
1005 args: vec![Term::Var(0), Term::Integer(42)],
1006 };
1007 let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1008 assert!(matches!(result, BuiltinResult::Success));
1009 assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(42));
1010 }
1011
1012 #[test]
1013 fn test_exec_not_unify() {
1014 let interner = setup();
1015 let neq_id = interner.lookup("\\=").unwrap();
1016 let mut subst = Substitution::new();
1017 let goal = Term::Compound {
1019 functor: neq_id,
1020 args: vec![Term::Integer(1), Term::Integer(2)],
1021 };
1022 let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1023 assert!(matches!(result, BuiltinResult::Success));
1024
1025 let goal = Term::Compound {
1027 functor: neq_id,
1028 args: vec![Term::Integer(1), Term::Integer(1)],
1029 };
1030 let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1031 assert!(matches!(result, BuiltinResult::Failure));
1032 }
1033
1034 #[test]
1035 fn test_exec_is_arithmetic() {
1036 let mut interner = setup();
1037 let goals = Parser::parse_query("X is 2 + 3 * 4", &mut interner).unwrap();
1038 let mut subst = Substitution::new();
1039 let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1040 assert!(matches!(result, BuiltinResult::Success));
1041 assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(14));
1042 }
1043
1044 #[test]
1045 fn test_exec_comparison() {
1046 let interner = setup();
1047 let lt_id = interner.lookup("<").unwrap();
1048 let mut subst = Substitution::new();
1049
1050 let goal = Term::Compound {
1052 functor: lt_id,
1053 args: vec![Term::Integer(1), Term::Integer(2)],
1054 };
1055 assert!(matches!(
1056 exec_builtin(&goal, &mut subst, &interner).unwrap(),
1057 BuiltinResult::Success
1058 ));
1059
1060 let goal = Term::Compound {
1062 functor: lt_id,
1063 args: vec![Term::Integer(2), Term::Integer(1)],
1064 };
1065 assert!(matches!(
1066 exec_builtin(&goal, &mut subst, &interner).unwrap(),
1067 BuiltinResult::Failure
1068 ));
1069 }
1070
1071 #[test]
1072 fn test_exec_cut() {
1073 let interner = setup();
1074 let cut_id = interner.lookup("!").unwrap();
1075 let mut subst = Substitution::new();
1076 let result = exec_builtin(&Term::Atom(cut_id), &mut subst, &interner).unwrap();
1077 assert!(matches!(result, BuiltinResult::Cut));
1078 }
1079
1080 #[test]
1081 fn test_type_checking_var() {
1082 let mut interner = setup();
1083 interner.intern("var");
1084 let var_id = interner.lookup("var").unwrap();
1085 let mut subst = Substitution::new();
1086 let goal = Term::Compound {
1088 functor: var_id,
1089 args: vec![Term::Var(0)],
1090 };
1091 let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1092 assert!(matches!(result, BuiltinResult::Success));
1093
1094 let goal = Term::Compound {
1096 functor: var_id,
1097 args: vec![Term::Integer(42)],
1098 };
1099 let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1100 assert!(matches!(result, BuiltinResult::Failure));
1101 }
1102
1103 #[test]
1104 fn test_type_checking_atom() {
1105 let mut interner = setup();
1106 interner.intern("atom");
1107 let atom_id = interner.lookup("atom").unwrap();
1108 let mut subst = Substitution::new();
1109 let hello = interner.intern("hello");
1110 let goal = Term::Compound {
1112 functor: atom_id,
1113 args: vec![Term::Atom(hello)],
1114 };
1115 let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1116 assert!(matches!(result, BuiltinResult::Success));
1117
1118 let goal = Term::Compound {
1120 functor: atom_id,
1121 args: vec![Term::Integer(42)],
1122 };
1123 let result = exec_builtin(&goal, &mut subst, &interner).unwrap();
1124 assert!(matches!(result, BuiltinResult::Failure));
1125 }
1126
1127 #[test]
1128 fn test_type_checking_integer() {
1129 let mut interner = setup();
1130 interner.intern("integer");
1131 let int_id = interner.lookup("integer").unwrap();
1132 let mut subst = Substitution::new();
1133 let goal = Term::Compound {
1134 functor: int_id,
1135 args: vec![Term::Integer(42)],
1136 };
1137 assert!(matches!(
1138 exec_builtin(&goal, &mut subst, &interner).unwrap(),
1139 BuiltinResult::Success
1140 ));
1141
1142 let goal = Term::Compound {
1143 functor: int_id,
1144 args: vec![Term::Float(3.14)],
1145 };
1146 assert!(matches!(
1147 exec_builtin(&goal, &mut subst, &interner).unwrap(),
1148 BuiltinResult::Failure
1149 ));
1150 }
1151
1152 #[test]
1153 fn test_type_checking_number() {
1154 let mut interner = setup();
1155 interner.intern("number");
1156 let num_id = interner.lookup("number").unwrap();
1157 let mut subst = Substitution::new();
1158 let goal = Term::Compound {
1160 functor: num_id,
1161 args: vec![Term::Integer(42)],
1162 };
1163 assert!(matches!(
1164 exec_builtin(&goal, &mut subst, &interner).unwrap(),
1165 BuiltinResult::Success
1166 ));
1167 let goal = Term::Compound {
1169 functor: num_id,
1170 args: vec![Term::Float(3.14)],
1171 };
1172 assert!(matches!(
1173 exec_builtin(&goal, &mut subst, &interner).unwrap(),
1174 BuiltinResult::Success
1175 ));
1176 }
1177
1178 #[test]
1179 fn test_exec_mod() {
1180 let mut interner = setup();
1181 let goals = Parser::parse_query("X is 10 mod 3", &mut interner).unwrap();
1182 let mut subst = Substitution::new();
1183 let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1184 assert!(matches!(result, BuiltinResult::Success));
1185 assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(1));
1186 }
1187
1188 #[test]
1189 fn test_mod_i64_min_divisor() {
1190 let result = arith_mod(&ArithVal::Int(5), &ArithVal::Int(i64::MIN));
1192 assert!(result.is_err());
1193 assert!(result.unwrap_err().contains("overflow"));
1194 }
1195
1196 #[test]
1197 fn test_mod_i64_min_dividend_neg1() {
1198 let result = arith_mod(&ArithVal::Int(i64::MIN), &ArithVal::Int(-1));
1200 match result {
1201 Ok(ArithVal::Int(0)) => {}
1202 other => panic!("Expected Ok(Int(0)), got {:?}", other),
1203 }
1204 }
1205
1206 #[test]
1207 fn test_integer_overflow_add() {
1208 let mut interner = setup();
1209 let query_str = format!("X is {} + 1", i64::MAX);
1210 let goals = Parser::parse_query(&query_str, &mut interner).unwrap();
1211 let mut subst = Substitution::new();
1212 let result = exec_builtin(&goals[0], &mut subst, &interner);
1213 assert!(result.is_err());
1214 assert!(result.unwrap_err().contains("overflow"));
1215 }
1216
1217 #[test]
1218 fn test_integer_overflow_mul() {
1219 let mut interner = setup();
1220 let query_str = format!("X is {} * 2", i64::MAX);
1221 let goals = Parser::parse_query(&query_str, &mut interner).unwrap();
1222 let mut subst = Substitution::new();
1223 let result = exec_builtin(&goals[0], &mut subst, &interner);
1224 assert!(result.is_err());
1225 assert!(result.unwrap_err().contains("overflow"));
1226 }
1227
1228 #[test]
1229 fn test_arith_abs() {
1230 let mut interner = setup();
1231 let goals = Parser::parse_query("X is abs(-5)", &mut interner).unwrap();
1232 let mut subst = Substitution::new();
1233 let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1234 assert!(matches!(result, BuiltinResult::Success));
1235 assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(5));
1236 }
1237
1238 #[test]
1239 fn test_arith_abs_positive() {
1240 let mut interner = setup();
1241 let goals = Parser::parse_query("X is abs(3)", &mut interner).unwrap();
1242 let mut subst = Substitution::new();
1243 let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1244 assert!(matches!(result, BuiltinResult::Success));
1245 assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(3));
1246 }
1247
1248 #[test]
1249 fn test_arith_sign() {
1250 let mut interner = setup();
1251 let goals = Parser::parse_query("X is sign(-42)", &mut interner).unwrap();
1252 let mut subst = Substitution::new();
1253 let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1254 assert!(matches!(result, BuiltinResult::Success));
1255 assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(-1));
1256 }
1257
1258 #[test]
1259 fn test_arith_sign_zero() {
1260 let mut interner = setup();
1261 let goals = Parser::parse_query("X is sign(0)", &mut interner).unwrap();
1262 let mut subst = Substitution::new();
1263 let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1264 assert!(matches!(result, BuiltinResult::Success));
1265 assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(0));
1266 }
1267
1268 #[test]
1269 fn test_arith_max() {
1270 let mut interner = setup();
1271 let goals = Parser::parse_query("X is max(3, 7)", &mut interner).unwrap();
1272 let mut subst = Substitution::new();
1273 let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1274 assert!(matches!(result, BuiltinResult::Success));
1275 assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(7));
1276 }
1277
1278 #[test]
1279 fn test_arith_min() {
1280 let mut interner = setup();
1281 let goals = Parser::parse_query("X is min(3, 7)", &mut interner).unwrap();
1282 let mut subst = Substitution::new();
1283 let result = exec_builtin(&goals[0], &mut subst, &interner).unwrap();
1284 assert!(matches!(result, BuiltinResult::Success));
1285 assert_eq!(subst.walk(&Term::Var(0)), Term::Integer(3));
1286 }
1287}