1use crate::term::{Term, VarId};
2use fnv::FnvHashSet;
3
4#[derive(Debug, Clone)]
8pub struct Substitution {
9 bindings: Vec<Option<Term>>,
10 trail: Vec<VarId>,
11}
12
13impl Substitution {
14 pub fn new() -> Self {
15 Substitution {
16 bindings: Vec::new(),
17 trail: Vec::new(),
18 }
19 }
20
21 pub fn with_capacity(n: usize) -> Self {
23 Substitution {
24 bindings: vec![None; n],
25 trail: Vec::new(),
26 }
27 }
28
29 pub fn trail_mark(&self) -> usize {
31 self.trail.len()
32 }
33
34 pub fn undo_to(&mut self, mark: usize) {
36 while self.trail.len() > mark {
37 let var = self.trail.pop().unwrap();
38 self.bindings[var as usize] = None;
39 }
40 }
41
42 fn bind(&mut self, var: VarId, term: Term) {
44 let idx = var as usize;
45 if idx >= self.bindings.len() {
46 self.bindings.resize(idx + 1, None);
47 }
48 self.bindings[idx] = Some(term);
49 self.trail.push(var);
50 }
51
52 fn lookup(&self, var: VarId) -> Option<&Term> {
54 self.bindings.get(var as usize).and_then(|b| b.as_ref())
55 }
56
57 pub fn walk(&self, term: &Term) -> Term {
59 match term {
60 Term::Var(id) => match self.lookup(*id) {
61 Some(bound) => self.walk(bound),
62 None => term.clone(),
63 },
64 _ => term.clone(),
65 }
66 }
67
68 pub fn apply(&self, term: &Term) -> Term {
72 let mut seen = FnvHashSet::default();
73 self.apply_impl(term, &mut seen)
74 }
75
76 fn apply_impl(&self, term: &Term, seen: &mut FnvHashSet<VarId>) -> Term {
77 match term {
78 Term::Var(id) => {
79 if seen.contains(id) {
80 return term.clone();
82 }
83 match self.lookup(*id) {
84 Some(bound) => {
85 seen.insert(*id);
86 let result = self.apply_impl(bound, seen);
87 seen.remove(id);
88 result
89 }
90 None => term.clone(),
91 }
92 }
93 Term::Compound { functor, args } => Term::Compound {
94 functor: *functor,
95 args: args.iter().map(|a| self.apply_impl(a, seen)).collect(),
96 },
97 Term::List { head, tail } => Term::List {
98 head: Box::new(self.apply_impl(head, seen)),
99 tail: Box::new(self.apply_impl(tail, seen)),
100 },
101 _ => term.clone(),
102 }
103 }
104
105 pub fn unify(&mut self, t1: &Term, t2: &Term) -> bool {
108 let t1 = self.walk(t1);
109 let t2 = self.walk(t2);
110
111 match (&t1, &t2) {
112 (Term::Var(a), Term::Var(b)) if a == b => true,
114
115 (Term::Var(id), other) | (other, Term::Var(id)) => {
117 self.bind(*id, other.clone());
118 true
119 }
120
121 (Term::Atom(a), Term::Atom(b)) => a == b,
123
124 (Term::Integer(a), Term::Integer(b)) => a == b,
126
127 (Term::Float(a), Term::Float(b)) => a.to_bits() == b.to_bits(),
129
130 (
132 Term::Compound {
133 functor: f1,
134 args: a1,
135 },
136 Term::Compound {
137 functor: f2,
138 args: a2,
139 },
140 ) => {
141 if f1 != f2 || a1.len() != a2.len() {
142 return false;
143 }
144 for (arg1, arg2) in a1.iter().zip(a2.iter()) {
145 if !self.unify(arg1, arg2) {
146 return false;
147 }
148 }
149 true
150 }
151
152 (Term::List { head: h1, tail: t1 }, Term::List { head: h2, tail: t2 }) => {
154 self.unify(h1, h2) && self.unify(t1, t2)
155 }
156
157 _ => false,
159 }
160 }
161
162 pub fn unify_with_occurs_check(&mut self, t1: &Term, t2: &Term) -> bool {
165 let t1 = self.walk(t1);
166 let t2 = self.walk(t2);
167
168 match (&t1, &t2) {
169 (Term::Var(a), Term::Var(b)) if a == b => true,
170 (Term::Var(id), other) | (other, Term::Var(id)) => {
171 if self.occurs_in(*id, other) {
172 return false;
173 }
174 self.bind(*id, other.clone());
175 true
176 }
177 (Term::Atom(a), Term::Atom(b)) => a == b,
178 (Term::Integer(a), Term::Integer(b)) => a == b,
179 (Term::Float(a), Term::Float(b)) => a.to_bits() == b.to_bits(),
180 (
181 Term::Compound {
182 functor: f1,
183 args: a1,
184 },
185 Term::Compound {
186 functor: f2,
187 args: a2,
188 },
189 ) => {
190 if f1 != f2 || a1.len() != a2.len() {
191 return false;
192 }
193 for (arg1, arg2) in a1.iter().zip(a2.iter()) {
194 if !self.unify_with_occurs_check(arg1, arg2) {
195 return false;
196 }
197 }
198 true
199 }
200 (Term::List { head: h1, tail: t1 }, Term::List { head: h2, tail: t2 }) => {
201 self.unify_with_occurs_check(h1, h2) && self.unify_with_occurs_check(t1, t2)
202 }
203 _ => false,
204 }
205 }
206
207 fn occurs_in(&self, var: VarId, term: &Term) -> bool {
208 match term {
209 Term::Var(id) => {
210 if *id == var {
211 return true;
212 }
213 match self.lookup(*id) {
214 Some(bound) => self.occurs_in(var, bound),
215 None => false,
216 }
217 }
218 Term::Compound { args, .. } => args.iter().any(|a| self.occurs_in(var, a)),
219 Term::List { head, tail } => self.occurs_in(var, head) || self.occurs_in(var, tail),
220 _ => false,
221 }
222 }
223}
224
225impl Default for Substitution {
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::term::Term;
235
236 #[test]
237 fn test_unify_atoms() {
238 let mut sub = Substitution::new();
239 assert!(sub.unify(&Term::Atom(0), &Term::Atom(0)));
240 assert!(!sub.unify(&Term::Atom(0), &Term::Atom(1)));
241 }
242
243 #[test]
244 fn test_unify_integers() {
245 let mut sub = Substitution::new();
246 assert!(sub.unify(&Term::Integer(42), &Term::Integer(42)));
247 assert!(!sub.unify(&Term::Integer(1), &Term::Integer(2)));
248 }
249
250 #[test]
251 fn test_unify_var_to_atom() {
252 let mut sub = Substitution::new();
253 assert!(sub.unify(&Term::Var(0), &Term::Atom(1)));
254 assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(1));
255 }
256
257 #[test]
258 fn test_unify_var_to_var() {
259 let mut sub = Substitution::new();
260 assert!(sub.unify(&Term::Var(0), &Term::Var(1)));
261 assert!(sub.unify(&Term::Var(1), &Term::Atom(5)));
263 assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(5));
264 }
265
266 #[test]
267 fn test_unify_compound() {
268 let mut sub = Substitution::new();
269 let t1 = Term::Compound {
270 functor: 0,
271 args: vec![Term::Var(0), Term::Atom(1)],
272 };
273 let t2 = Term::Compound {
274 functor: 0,
275 args: vec![Term::Atom(2), Term::Atom(1)],
276 };
277 assert!(sub.unify(&t1, &t2));
278 assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(2));
279 }
280
281 #[test]
282 fn test_unify_compound_mismatch_functor() {
283 let mut sub = Substitution::new();
284 let t1 = Term::Compound {
285 functor: 0,
286 args: vec![Term::Atom(1)],
287 };
288 let t2 = Term::Compound {
289 functor: 1,
290 args: vec![Term::Atom(1)],
291 };
292 assert!(!sub.unify(&t1, &t2));
293 }
294
295 #[test]
296 fn test_unify_compound_mismatch_arity() {
297 let mut sub = Substitution::new();
298 let t1 = Term::Compound {
299 functor: 0,
300 args: vec![Term::Atom(1)],
301 };
302 let t2 = Term::Compound {
303 functor: 0,
304 args: vec![Term::Atom(1), Term::Atom(2)],
305 };
306 assert!(!sub.unify(&t1, &t2));
307 }
308
309 #[test]
310 fn test_no_occurs_check() {
311 let mut sub = Substitution::new();
312 let t1 = Term::Var(0);
314 let t2 = Term::Compound {
315 functor: 0,
316 args: vec![Term::Var(0)],
317 };
318 assert!(sub.unify(&t1, &t2));
319 }
320
321 #[test]
322 fn test_trail_backtracking() {
323 let mut sub = Substitution::new();
324
325 let mark = sub.trail_mark();
326 assert!(sub.unify(&Term::Var(0), &Term::Atom(1)));
327 assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(1));
328
329 sub.undo_to(mark);
330 assert_eq!(sub.walk(&Term::Var(0)), Term::Var(0));
332 }
333
334 #[test]
335 fn test_apply() {
336 let mut sub = Substitution::new();
337 sub.unify(&Term::Var(0), &Term::Atom(5));
338 sub.unify(&Term::Var(1), &Term::Integer(42));
339
340 let term = Term::Compound {
341 functor: 0,
342 args: vec![Term::Var(0), Term::Var(1), Term::Var(2)],
343 };
344 let applied = sub.apply(&term);
345 match applied {
346 Term::Compound { args, .. } => {
347 assert_eq!(args[0], Term::Atom(5));
348 assert_eq!(args[1], Term::Integer(42));
349 assert_eq!(args[2], Term::Var(2)); }
351 _ => panic!("Expected compound"),
352 }
353 }
354
355 #[test]
356 fn test_unify_list() {
357 let mut sub = Substitution::new();
358 let t1 = Term::List {
359 head: Box::new(Term::Var(0)),
360 tail: Box::new(Term::Atom(10)), };
362 let t2 = Term::List {
363 head: Box::new(Term::Atom(5)),
364 tail: Box::new(Term::Atom(10)),
365 };
366 assert!(sub.unify(&t1, &t2));
367 assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(5));
368 }
369
370 #[test]
371 fn test_unify_same_var() {
372 let mut sub = Substitution::new();
373 assert!(sub.unify(&Term::Var(0), &Term::Var(0)));
374 }
375
376 #[test]
377 fn test_multiple_trail_marks() {
378 let mut sub = Substitution::new();
379
380 let mark1 = sub.trail_mark();
381 sub.unify(&Term::Var(0), &Term::Atom(1));
382
383 let mark2 = sub.trail_mark();
384 sub.unify(&Term::Var(1), &Term::Atom(2));
385
386 sub.undo_to(mark2);
388 assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(1));
389 assert_eq!(sub.walk(&Term::Var(1)), Term::Var(1));
390
391 sub.undo_to(mark1);
393 assert_eq!(sub.walk(&Term::Var(0)), Term::Var(0));
394 }
395}