Skip to main content

patch_prolog_core/
unify.rs

1use crate::term::{Term, VarId};
2
3/// Vec-based substitution with trail for efficient backtracking.
4/// Bindings are stored in a Vec indexed by VarId (O(1) lookup/bind).
5/// The trail records which VarIds were bound, enabling undo on backtracking.
6#[derive(Debug, Clone)]
7pub struct Substitution {
8    bindings: Vec<Option<Term>>,
9    trail: Vec<VarId>,
10}
11
12impl Substitution {
13    pub fn new() -> Self {
14        Substitution {
15            bindings: Vec::new(),
16            trail: Vec::new(),
17        }
18    }
19
20    /// Create a substitution pre-sized for the given number of variables.
21    pub fn with_capacity(n: usize) -> Self {
22        Substitution {
23            bindings: vec![None; n],
24            trail: Vec::new(),
25        }
26    }
27
28    /// Get the current trail mark (for backtracking).
29    pub fn trail_mark(&self) -> usize {
30        self.trail.len()
31    }
32
33    /// Undo all bindings back to the given trail mark.
34    pub fn undo_to(&mut self, mark: usize) {
35        while self.trail.len() > mark {
36            let var = self.trail.pop().unwrap();
37            self.bindings[var as usize] = None;
38        }
39    }
40
41    /// Bind a variable to a term.
42    fn bind(&mut self, var: VarId, term: Term) {
43        let idx = var as usize;
44        if idx >= self.bindings.len() {
45            self.bindings.resize(idx + 1, None);
46        }
47        self.bindings[idx] = Some(term);
48        self.trail.push(var);
49    }
50
51    /// Look up a variable's binding.
52    fn lookup(&self, var: VarId) -> Option<&Term> {
53        self.bindings.get(var as usize).and_then(|b| b.as_ref())
54    }
55
56    /// Dereference: follow variable chains to their ultimate value.
57    pub fn walk(&self, term: &Term) -> Term {
58        match term {
59            Term::Var(id) => match self.lookup(*id) {
60                Some(bound) => self.walk(bound),
61                None => term.clone(),
62            },
63            _ => term.clone(),
64        }
65    }
66
67    /// Deep walk: recursively substitute all variables in a term.
68    /// Handles circular terms (from unification without occurs check) by
69    /// stopping expansion when a variable cycle is detected.
70    pub fn apply(&self, term: &Term) -> Term {
71        let mut seen = Vec::new();
72        self.apply_impl(term, &mut seen)
73    }
74
75    fn apply_impl(&self, term: &Term, seen: &mut Vec<VarId>) -> Term {
76        match term {
77            Term::Var(id) => {
78                if seen.contains(id) {
79                    // Cycle detected — return variable as-is to break infinite recursion
80                    return term.clone();
81                }
82                match self.lookup(*id) {
83                    Some(bound) => {
84                        seen.push(*id);
85                        let result = self.apply_impl(bound, seen);
86                        seen.pop();
87                        result
88                    }
89                    None => term.clone(),
90                }
91            }
92            Term::Compound { functor, args } => Term::Compound {
93                functor: *functor,
94                args: args.iter().map(|a| self.apply_impl(a, seen)).collect(),
95            },
96            Term::List { head, tail } => Term::List {
97                head: Box::new(self.apply_impl(head, seen)),
98                tail: Box::new(self.apply_impl(tail, seen)),
99            },
100            _ => term.clone(),
101        }
102    }
103
104    /// Unify two terms. Returns true if unification succeeds.
105    /// On failure, bindings made during this attempt remain (caller should undo via trail).
106    pub fn unify(&mut self, t1: &Term, t2: &Term) -> bool {
107        let t1 = self.walk(t1);
108        let t2 = self.walk(t2);
109
110        match (&t1, &t2) {
111            // Both same variable
112            (Term::Var(a), Term::Var(b)) if a == b => true,
113
114            // Bind variable to the other term (ISO: no occurs check for =/2)
115            (Term::Var(id), other) | (other, Term::Var(id)) => {
116                self.bind(*id, other.clone());
117                true
118            }
119
120            // Atom equality
121            (Term::Atom(a), Term::Atom(b)) => a == b,
122
123            // Integer equality
124            (Term::Integer(a), Term::Integer(b)) => a == b,
125
126            // Float equality (use to_bits for structural equality — handles NaN)
127            (Term::Float(a), Term::Float(b)) => a.to_bits() == b.to_bits(),
128
129            // Compound: same functor and arity, then unify args pairwise
130            (
131                Term::Compound {
132                    functor: f1,
133                    args: a1,
134                },
135                Term::Compound {
136                    functor: f2,
137                    args: a2,
138                },
139            ) => {
140                if f1 != f2 || a1.len() != a2.len() {
141                    return false;
142                }
143                for (arg1, arg2) in a1.iter().zip(a2.iter()) {
144                    if !self.unify(arg1, arg2) {
145                        return false;
146                    }
147                }
148                true
149            }
150
151            // List: unify head and tail
152            (Term::List { head: h1, tail: t1 }, Term::List { head: h2, tail: t2 }) => {
153                self.unify(h1, h2) && self.unify(t1, t2)
154            }
155
156            // Anything else fails
157            _ => false,
158        }
159    }
160
161    /// Occurs check: does the variable appear in the term?
162    /// Retained for potential unify_with_occurs_check/2 builtin.
163    #[allow(dead_code)]
164    fn occurs_in(&self, var: VarId, term: &Term) -> bool {
165        match term {
166            Term::Var(id) => {
167                if *id == var {
168                    return true;
169                }
170                match self.lookup(*id) {
171                    Some(bound) => self.occurs_in(var, bound),
172                    None => false,
173                }
174            }
175            Term::Compound { args, .. } => args.iter().any(|a| self.occurs_in(var, a)),
176            Term::List { head, tail } => self.occurs_in(var, head) || self.occurs_in(var, tail),
177            _ => false,
178        }
179    }
180}
181
182impl Default for Substitution {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::term::Term;
192
193    #[test]
194    fn test_unify_atoms() {
195        let mut sub = Substitution::new();
196        assert!(sub.unify(&Term::Atom(0), &Term::Atom(0)));
197        assert!(!sub.unify(&Term::Atom(0), &Term::Atom(1)));
198    }
199
200    #[test]
201    fn test_unify_integers() {
202        let mut sub = Substitution::new();
203        assert!(sub.unify(&Term::Integer(42), &Term::Integer(42)));
204        assert!(!sub.unify(&Term::Integer(1), &Term::Integer(2)));
205    }
206
207    #[test]
208    fn test_unify_var_to_atom() {
209        let mut sub = Substitution::new();
210        assert!(sub.unify(&Term::Var(0), &Term::Atom(1)));
211        assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(1));
212    }
213
214    #[test]
215    fn test_unify_var_to_var() {
216        let mut sub = Substitution::new();
217        assert!(sub.unify(&Term::Var(0), &Term::Var(1)));
218        // After binding, both should resolve to the same thing
219        assert!(sub.unify(&Term::Var(1), &Term::Atom(5)));
220        assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(5));
221    }
222
223    #[test]
224    fn test_unify_compound() {
225        let mut sub = Substitution::new();
226        let t1 = Term::Compound {
227            functor: 0,
228            args: vec![Term::Var(0), Term::Atom(1)],
229        };
230        let t2 = Term::Compound {
231            functor: 0,
232            args: vec![Term::Atom(2), Term::Atom(1)],
233        };
234        assert!(sub.unify(&t1, &t2));
235        assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(2));
236    }
237
238    #[test]
239    fn test_unify_compound_mismatch_functor() {
240        let mut sub = Substitution::new();
241        let t1 = Term::Compound {
242            functor: 0,
243            args: vec![Term::Atom(1)],
244        };
245        let t2 = Term::Compound {
246            functor: 1,
247            args: vec![Term::Atom(1)],
248        };
249        assert!(!sub.unify(&t1, &t2));
250    }
251
252    #[test]
253    fn test_unify_compound_mismatch_arity() {
254        let mut sub = Substitution::new();
255        let t1 = Term::Compound {
256            functor: 0,
257            args: vec![Term::Atom(1)],
258        };
259        let t2 = Term::Compound {
260            functor: 0,
261            args: vec![Term::Atom(1), Term::Atom(2)],
262        };
263        assert!(!sub.unify(&t1, &t2));
264    }
265
266    #[test]
267    fn test_no_occurs_check() {
268        let mut sub = Substitution::new();
269        // X = f(X) should succeed (ISO: =/2 does not occurs-check)
270        let t1 = Term::Var(0);
271        let t2 = Term::Compound {
272            functor: 0,
273            args: vec![Term::Var(0)],
274        };
275        assert!(sub.unify(&t1, &t2));
276    }
277
278    #[test]
279    fn test_trail_backtracking() {
280        let mut sub = Substitution::new();
281
282        let mark = sub.trail_mark();
283        assert!(sub.unify(&Term::Var(0), &Term::Atom(1)));
284        assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(1));
285
286        sub.undo_to(mark);
287        // Var should be free again
288        assert_eq!(sub.walk(&Term::Var(0)), Term::Var(0));
289    }
290
291    #[test]
292    fn test_apply() {
293        let mut sub = Substitution::new();
294        sub.unify(&Term::Var(0), &Term::Atom(5));
295        sub.unify(&Term::Var(1), &Term::Integer(42));
296
297        let term = Term::Compound {
298            functor: 0,
299            args: vec![Term::Var(0), Term::Var(1), Term::Var(2)],
300        };
301        let applied = sub.apply(&term);
302        match applied {
303            Term::Compound { args, .. } => {
304                assert_eq!(args[0], Term::Atom(5));
305                assert_eq!(args[1], Term::Integer(42));
306                assert_eq!(args[2], Term::Var(2)); // unbound
307            }
308            _ => panic!("Expected compound"),
309        }
310    }
311
312    #[test]
313    fn test_unify_list() {
314        let mut sub = Substitution::new();
315        let t1 = Term::List {
316            head: Box::new(Term::Var(0)),
317            tail: Box::new(Term::Atom(10)), // nil
318        };
319        let t2 = Term::List {
320            head: Box::new(Term::Atom(5)),
321            tail: Box::new(Term::Atom(10)),
322        };
323        assert!(sub.unify(&t1, &t2));
324        assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(5));
325    }
326
327    #[test]
328    fn test_unify_same_var() {
329        let mut sub = Substitution::new();
330        assert!(sub.unify(&Term::Var(0), &Term::Var(0)));
331    }
332
333    #[test]
334    fn test_multiple_trail_marks() {
335        let mut sub = Substitution::new();
336
337        let mark1 = sub.trail_mark();
338        sub.unify(&Term::Var(0), &Term::Atom(1));
339
340        let mark2 = sub.trail_mark();
341        sub.unify(&Term::Var(1), &Term::Atom(2));
342
343        // Undo second binding only
344        sub.undo_to(mark2);
345        assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(1));
346        assert_eq!(sub.walk(&Term::Var(1)), Term::Var(1));
347
348        // Undo first binding
349        sub.undo_to(mark1);
350        assert_eq!(sub.walk(&Term::Var(0)), Term::Var(0));
351    }
352}