Skip to main content

patch_prolog_core/
unify.rs

1use crate::term::{Term, VarId};
2use fnv::FnvHashSet;
3
4/// Vec-based substitution with trail for efficient backtracking.
5/// Bindings are stored in a Vec indexed by VarId (O(1) lookup/bind).
6/// The trail records which VarIds were bound, enabling undo on backtracking.
7#[derive(Debug, Clone)]
8pub struct Substitution {
9    bindings: Vec<Option<Term>>,
10    trail: Vec<VarId>,
11}
12
13impl Substitution {
14    pub fn new() -> Self {
15        Substitution {
16            bindings: Vec::new(),
17            trail: Vec::new(),
18        }
19    }
20
21    /// Create a substitution pre-sized for the given number of variables.
22    pub fn with_capacity(n: usize) -> Self {
23        Substitution {
24            bindings: vec![None; n],
25            trail: Vec::new(),
26        }
27    }
28
29    /// Get the current trail mark (for backtracking).
30    pub fn trail_mark(&self) -> usize {
31        self.trail.len()
32    }
33
34    /// Undo all bindings back to the given trail mark.
35    pub fn undo_to(&mut self, mark: usize) {
36        while self.trail.len() > mark {
37            let var = self.trail.pop().unwrap();
38            self.bindings[var as usize] = None;
39        }
40    }
41
42    /// Bind a variable to a term.
43    fn bind(&mut self, var: VarId, term: Term) {
44        let idx = var as usize;
45        if idx >= self.bindings.len() {
46            self.bindings.resize(idx + 1, None);
47        }
48        self.bindings[idx] = Some(term);
49        self.trail.push(var);
50    }
51
52    /// Look up a variable's binding.
53    fn lookup(&self, var: VarId) -> Option<&Term> {
54        self.bindings.get(var as usize).and_then(|b| b.as_ref())
55    }
56
57    /// Dereference: follow variable chains to their ultimate value.
58    pub fn walk(&self, term: &Term) -> Term {
59        match term {
60            Term::Var(id) => match self.lookup(*id) {
61                Some(bound) => self.walk(bound),
62                None => term.clone(),
63            },
64            _ => term.clone(),
65        }
66    }
67
68    /// Deep walk: recursively substitute all variables in a term.
69    /// Handles circular terms (from unification without occurs check) by
70    /// stopping expansion when a variable cycle is detected.
71    pub fn apply(&self, term: &Term) -> Term {
72        let mut seen = FnvHashSet::default();
73        self.apply_impl(term, &mut seen)
74    }
75
76    fn apply_impl(&self, term: &Term, seen: &mut FnvHashSet<VarId>) -> Term {
77        match term {
78            Term::Var(id) => {
79                if seen.contains(id) {
80                    // Cycle detected — return variable as-is to break infinite recursion
81                    return term.clone();
82                }
83                match self.lookup(*id) {
84                    Some(bound) => {
85                        seen.insert(*id);
86                        let result = self.apply_impl(bound, seen);
87                        seen.remove(id);
88                        result
89                    }
90                    None => term.clone(),
91                }
92            }
93            Term::Compound { functor, args } => Term::Compound {
94                functor: *functor,
95                args: args.iter().map(|a| self.apply_impl(a, seen)).collect(),
96            },
97            Term::List { head, tail } => Term::List {
98                head: Box::new(self.apply_impl(head, seen)),
99                tail: Box::new(self.apply_impl(tail, seen)),
100            },
101            _ => term.clone(),
102        }
103    }
104
105    /// Unify two terms. Returns true if unification succeeds.
106    /// On failure, bindings made during this attempt remain (caller should undo via trail).
107    pub fn unify(&mut self, t1: &Term, t2: &Term) -> bool {
108        let t1 = self.walk(t1);
109        let t2 = self.walk(t2);
110
111        match (&t1, &t2) {
112            // Both same variable
113            (Term::Var(a), Term::Var(b)) if a == b => true,
114
115            // Bind variable to the other term (ISO: no occurs check for =/2)
116            (Term::Var(id), other) | (other, Term::Var(id)) => {
117                self.bind(*id, other.clone());
118                true
119            }
120
121            // Atom equality
122            (Term::Atom(a), Term::Atom(b)) => a == b,
123
124            // Integer equality
125            (Term::Integer(a), Term::Integer(b)) => a == b,
126
127            // Float equality (use to_bits for structural equality — handles NaN)
128            (Term::Float(a), Term::Float(b)) => a.to_bits() == b.to_bits(),
129
130            // Compound: same functor and arity, then unify args pairwise
131            (
132                Term::Compound {
133                    functor: f1,
134                    args: a1,
135                },
136                Term::Compound {
137                    functor: f2,
138                    args: a2,
139                },
140            ) => {
141                if f1 != f2 || a1.len() != a2.len() {
142                    return false;
143                }
144                for (arg1, arg2) in a1.iter().zip(a2.iter()) {
145                    if !self.unify(arg1, arg2) {
146                        return false;
147                    }
148                }
149                true
150            }
151
152            // List: unify head and tail
153            (Term::List { head: h1, tail: t1 }, Term::List { head: h2, tail: t2 }) => {
154                self.unify(h1, h2) && self.unify(t1, t2)
155            }
156
157            // Anything else fails
158            _ => false,
159        }
160    }
161
162    /// Unify with occurs check (ISO 8.2.2).
163    /// Fails if binding a variable would create a circular term.
164    pub fn unify_with_occurs_check(&mut self, t1: &Term, t2: &Term) -> bool {
165        let t1 = self.walk(t1);
166        let t2 = self.walk(t2);
167
168        match (&t1, &t2) {
169            (Term::Var(a), Term::Var(b)) if a == b => true,
170            (Term::Var(id), other) | (other, Term::Var(id)) => {
171                if self.occurs_in(*id, other) {
172                    return false;
173                }
174                self.bind(*id, other.clone());
175                true
176            }
177            (Term::Atom(a), Term::Atom(b)) => a == b,
178            (Term::Integer(a), Term::Integer(b)) => a == b,
179            (Term::Float(a), Term::Float(b)) => a.to_bits() == b.to_bits(),
180            (
181                Term::Compound {
182                    functor: f1,
183                    args: a1,
184                },
185                Term::Compound {
186                    functor: f2,
187                    args: a2,
188                },
189            ) => {
190                if f1 != f2 || a1.len() != a2.len() {
191                    return false;
192                }
193                for (arg1, arg2) in a1.iter().zip(a2.iter()) {
194                    if !self.unify_with_occurs_check(arg1, arg2) {
195                        return false;
196                    }
197                }
198                true
199            }
200            (Term::List { head: h1, tail: t1 }, Term::List { head: h2, tail: t2 }) => {
201                self.unify_with_occurs_check(h1, h2) && self.unify_with_occurs_check(t1, t2)
202            }
203            _ => false,
204        }
205    }
206
207    fn occurs_in(&self, var: VarId, term: &Term) -> bool {
208        match term {
209            Term::Var(id) => {
210                if *id == var {
211                    return true;
212                }
213                match self.lookup(*id) {
214                    Some(bound) => self.occurs_in(var, bound),
215                    None => false,
216                }
217            }
218            Term::Compound { args, .. } => args.iter().any(|a| self.occurs_in(var, a)),
219            Term::List { head, tail } => self.occurs_in(var, head) || self.occurs_in(var, tail),
220            _ => false,
221        }
222    }
223}
224
225impl Default for Substitution {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::term::Term;
235
236    #[test]
237    fn test_unify_atoms() {
238        let mut sub = Substitution::new();
239        assert!(sub.unify(&Term::Atom(0), &Term::Atom(0)));
240        assert!(!sub.unify(&Term::Atom(0), &Term::Atom(1)));
241    }
242
243    #[test]
244    fn test_unify_integers() {
245        let mut sub = Substitution::new();
246        assert!(sub.unify(&Term::Integer(42), &Term::Integer(42)));
247        assert!(!sub.unify(&Term::Integer(1), &Term::Integer(2)));
248    }
249
250    #[test]
251    fn test_unify_var_to_atom() {
252        let mut sub = Substitution::new();
253        assert!(sub.unify(&Term::Var(0), &Term::Atom(1)));
254        assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(1));
255    }
256
257    #[test]
258    fn test_unify_var_to_var() {
259        let mut sub = Substitution::new();
260        assert!(sub.unify(&Term::Var(0), &Term::Var(1)));
261        // After binding, both should resolve to the same thing
262        assert!(sub.unify(&Term::Var(1), &Term::Atom(5)));
263        assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(5));
264    }
265
266    #[test]
267    fn test_unify_compound() {
268        let mut sub = Substitution::new();
269        let t1 = Term::Compound {
270            functor: 0,
271            args: vec![Term::Var(0), Term::Atom(1)],
272        };
273        let t2 = Term::Compound {
274            functor: 0,
275            args: vec![Term::Atom(2), Term::Atom(1)],
276        };
277        assert!(sub.unify(&t1, &t2));
278        assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(2));
279    }
280
281    #[test]
282    fn test_unify_compound_mismatch_functor() {
283        let mut sub = Substitution::new();
284        let t1 = Term::Compound {
285            functor: 0,
286            args: vec![Term::Atom(1)],
287        };
288        let t2 = Term::Compound {
289            functor: 1,
290            args: vec![Term::Atom(1)],
291        };
292        assert!(!sub.unify(&t1, &t2));
293    }
294
295    #[test]
296    fn test_unify_compound_mismatch_arity() {
297        let mut sub = Substitution::new();
298        let t1 = Term::Compound {
299            functor: 0,
300            args: vec![Term::Atom(1)],
301        };
302        let t2 = Term::Compound {
303            functor: 0,
304            args: vec![Term::Atom(1), Term::Atom(2)],
305        };
306        assert!(!sub.unify(&t1, &t2));
307    }
308
309    #[test]
310    fn test_no_occurs_check() {
311        let mut sub = Substitution::new();
312        // X = f(X) should succeed (ISO: =/2 does not occurs-check)
313        let t1 = Term::Var(0);
314        let t2 = Term::Compound {
315            functor: 0,
316            args: vec![Term::Var(0)],
317        };
318        assert!(sub.unify(&t1, &t2));
319    }
320
321    #[test]
322    fn test_trail_backtracking() {
323        let mut sub = Substitution::new();
324
325        let mark = sub.trail_mark();
326        assert!(sub.unify(&Term::Var(0), &Term::Atom(1)));
327        assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(1));
328
329        sub.undo_to(mark);
330        // Var should be free again
331        assert_eq!(sub.walk(&Term::Var(0)), Term::Var(0));
332    }
333
334    #[test]
335    fn test_apply() {
336        let mut sub = Substitution::new();
337        sub.unify(&Term::Var(0), &Term::Atom(5));
338        sub.unify(&Term::Var(1), &Term::Integer(42));
339
340        let term = Term::Compound {
341            functor: 0,
342            args: vec![Term::Var(0), Term::Var(1), Term::Var(2)],
343        };
344        let applied = sub.apply(&term);
345        match applied {
346            Term::Compound { args, .. } => {
347                assert_eq!(args[0], Term::Atom(5));
348                assert_eq!(args[1], Term::Integer(42));
349                assert_eq!(args[2], Term::Var(2)); // unbound
350            }
351            _ => panic!("Expected compound"),
352        }
353    }
354
355    #[test]
356    fn test_unify_list() {
357        let mut sub = Substitution::new();
358        let t1 = Term::List {
359            head: Box::new(Term::Var(0)),
360            tail: Box::new(Term::Atom(10)), // nil
361        };
362        let t2 = Term::List {
363            head: Box::new(Term::Atom(5)),
364            tail: Box::new(Term::Atom(10)),
365        };
366        assert!(sub.unify(&t1, &t2));
367        assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(5));
368    }
369
370    #[test]
371    fn test_unify_same_var() {
372        let mut sub = Substitution::new();
373        assert!(sub.unify(&Term::Var(0), &Term::Var(0)));
374    }
375
376    #[test]
377    fn test_multiple_trail_marks() {
378        let mut sub = Substitution::new();
379
380        let mark1 = sub.trail_mark();
381        sub.unify(&Term::Var(0), &Term::Atom(1));
382
383        let mark2 = sub.trail_mark();
384        sub.unify(&Term::Var(1), &Term::Atom(2));
385
386        // Undo second binding only
387        sub.undo_to(mark2);
388        assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(1));
389        assert_eq!(sub.walk(&Term::Var(1)), Term::Var(1));
390
391        // Undo first binding
392        sub.undo_to(mark1);
393        assert_eq!(sub.walk(&Term::Var(0)), Term::Var(0));
394    }
395}