1use crate::term::{Term, VarId};
2
3#[derive(Debug, Clone)]
7pub struct Substitution {
8 bindings: Vec<Option<Term>>,
9 trail: Vec<VarId>,
10}
11
12impl Substitution {
13 pub fn new() -> Self {
14 Substitution {
15 bindings: Vec::new(),
16 trail: Vec::new(),
17 }
18 }
19
20 pub fn with_capacity(n: usize) -> Self {
22 Substitution {
23 bindings: vec![None; n],
24 trail: Vec::new(),
25 }
26 }
27
28 pub fn trail_mark(&self) -> usize {
30 self.trail.len()
31 }
32
33 pub fn undo_to(&mut self, mark: usize) {
35 while self.trail.len() > mark {
36 let var = self.trail.pop().unwrap();
37 self.bindings[var as usize] = None;
38 }
39 }
40
41 fn bind(&mut self, var: VarId, term: Term) {
43 let idx = var as usize;
44 if idx >= self.bindings.len() {
45 self.bindings.resize(idx + 1, None);
46 }
47 self.bindings[idx] = Some(term);
48 self.trail.push(var);
49 }
50
51 fn lookup(&self, var: VarId) -> Option<&Term> {
53 self.bindings.get(var as usize).and_then(|b| b.as_ref())
54 }
55
56 pub fn walk(&self, term: &Term) -> Term {
58 match term {
59 Term::Var(id) => match self.lookup(*id) {
60 Some(bound) => self.walk(bound),
61 None => term.clone(),
62 },
63 _ => term.clone(),
64 }
65 }
66
67 pub fn apply(&self, term: &Term) -> Term {
71 let mut seen = Vec::new();
72 self.apply_impl(term, &mut seen)
73 }
74
75 fn apply_impl(&self, term: &Term, seen: &mut Vec<VarId>) -> Term {
76 match term {
77 Term::Var(id) => {
78 if seen.contains(id) {
79 return term.clone();
81 }
82 match self.lookup(*id) {
83 Some(bound) => {
84 seen.push(*id);
85 let result = self.apply_impl(bound, seen);
86 seen.pop();
87 result
88 }
89 None => term.clone(),
90 }
91 }
92 Term::Compound { functor, args } => Term::Compound {
93 functor: *functor,
94 args: args.iter().map(|a| self.apply_impl(a, seen)).collect(),
95 },
96 Term::List { head, tail } => Term::List {
97 head: Box::new(self.apply_impl(head, seen)),
98 tail: Box::new(self.apply_impl(tail, seen)),
99 },
100 _ => term.clone(),
101 }
102 }
103
104 pub fn unify(&mut self, t1: &Term, t2: &Term) -> bool {
107 let t1 = self.walk(t1);
108 let t2 = self.walk(t2);
109
110 match (&t1, &t2) {
111 (Term::Var(a), Term::Var(b)) if a == b => true,
113
114 (Term::Var(id), other) | (other, Term::Var(id)) => {
116 self.bind(*id, other.clone());
117 true
118 }
119
120 (Term::Atom(a), Term::Atom(b)) => a == b,
122
123 (Term::Integer(a), Term::Integer(b)) => a == b,
125
126 (Term::Float(a), Term::Float(b)) => a.to_bits() == b.to_bits(),
128
129 (
131 Term::Compound {
132 functor: f1,
133 args: a1,
134 },
135 Term::Compound {
136 functor: f2,
137 args: a2,
138 },
139 ) => {
140 if f1 != f2 || a1.len() != a2.len() {
141 return false;
142 }
143 for (arg1, arg2) in a1.iter().zip(a2.iter()) {
144 if !self.unify(arg1, arg2) {
145 return false;
146 }
147 }
148 true
149 }
150
151 (Term::List { head: h1, tail: t1 }, Term::List { head: h2, tail: t2 }) => {
153 self.unify(h1, h2) && self.unify(t1, t2)
154 }
155
156 _ => false,
158 }
159 }
160
161 #[allow(dead_code)]
164 fn occurs_in(&self, var: VarId, term: &Term) -> bool {
165 match term {
166 Term::Var(id) => {
167 if *id == var {
168 return true;
169 }
170 match self.lookup(*id) {
171 Some(bound) => self.occurs_in(var, bound),
172 None => false,
173 }
174 }
175 Term::Compound { args, .. } => args.iter().any(|a| self.occurs_in(var, a)),
176 Term::List { head, tail } => self.occurs_in(var, head) || self.occurs_in(var, tail),
177 _ => false,
178 }
179 }
180}
181
182impl Default for Substitution {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::term::Term;
192
193 #[test]
194 fn test_unify_atoms() {
195 let mut sub = Substitution::new();
196 assert!(sub.unify(&Term::Atom(0), &Term::Atom(0)));
197 assert!(!sub.unify(&Term::Atom(0), &Term::Atom(1)));
198 }
199
200 #[test]
201 fn test_unify_integers() {
202 let mut sub = Substitution::new();
203 assert!(sub.unify(&Term::Integer(42), &Term::Integer(42)));
204 assert!(!sub.unify(&Term::Integer(1), &Term::Integer(2)));
205 }
206
207 #[test]
208 fn test_unify_var_to_atom() {
209 let mut sub = Substitution::new();
210 assert!(sub.unify(&Term::Var(0), &Term::Atom(1)));
211 assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(1));
212 }
213
214 #[test]
215 fn test_unify_var_to_var() {
216 let mut sub = Substitution::new();
217 assert!(sub.unify(&Term::Var(0), &Term::Var(1)));
218 assert!(sub.unify(&Term::Var(1), &Term::Atom(5)));
220 assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(5));
221 }
222
223 #[test]
224 fn test_unify_compound() {
225 let mut sub = Substitution::new();
226 let t1 = Term::Compound {
227 functor: 0,
228 args: vec![Term::Var(0), Term::Atom(1)],
229 };
230 let t2 = Term::Compound {
231 functor: 0,
232 args: vec![Term::Atom(2), Term::Atom(1)],
233 };
234 assert!(sub.unify(&t1, &t2));
235 assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(2));
236 }
237
238 #[test]
239 fn test_unify_compound_mismatch_functor() {
240 let mut sub = Substitution::new();
241 let t1 = Term::Compound {
242 functor: 0,
243 args: vec![Term::Atom(1)],
244 };
245 let t2 = Term::Compound {
246 functor: 1,
247 args: vec![Term::Atom(1)],
248 };
249 assert!(!sub.unify(&t1, &t2));
250 }
251
252 #[test]
253 fn test_unify_compound_mismatch_arity() {
254 let mut sub = Substitution::new();
255 let t1 = Term::Compound {
256 functor: 0,
257 args: vec![Term::Atom(1)],
258 };
259 let t2 = Term::Compound {
260 functor: 0,
261 args: vec![Term::Atom(1), Term::Atom(2)],
262 };
263 assert!(!sub.unify(&t1, &t2));
264 }
265
266 #[test]
267 fn test_no_occurs_check() {
268 let mut sub = Substitution::new();
269 let t1 = Term::Var(0);
271 let t2 = Term::Compound {
272 functor: 0,
273 args: vec![Term::Var(0)],
274 };
275 assert!(sub.unify(&t1, &t2));
276 }
277
278 #[test]
279 fn test_trail_backtracking() {
280 let mut sub = Substitution::new();
281
282 let mark = sub.trail_mark();
283 assert!(sub.unify(&Term::Var(0), &Term::Atom(1)));
284 assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(1));
285
286 sub.undo_to(mark);
287 assert_eq!(sub.walk(&Term::Var(0)), Term::Var(0));
289 }
290
291 #[test]
292 fn test_apply() {
293 let mut sub = Substitution::new();
294 sub.unify(&Term::Var(0), &Term::Atom(5));
295 sub.unify(&Term::Var(1), &Term::Integer(42));
296
297 let term = Term::Compound {
298 functor: 0,
299 args: vec![Term::Var(0), Term::Var(1), Term::Var(2)],
300 };
301 let applied = sub.apply(&term);
302 match applied {
303 Term::Compound { args, .. } => {
304 assert_eq!(args[0], Term::Atom(5));
305 assert_eq!(args[1], Term::Integer(42));
306 assert_eq!(args[2], Term::Var(2)); }
308 _ => panic!("Expected compound"),
309 }
310 }
311
312 #[test]
313 fn test_unify_list() {
314 let mut sub = Substitution::new();
315 let t1 = Term::List {
316 head: Box::new(Term::Var(0)),
317 tail: Box::new(Term::Atom(10)), };
319 let t2 = Term::List {
320 head: Box::new(Term::Atom(5)),
321 tail: Box::new(Term::Atom(10)),
322 };
323 assert!(sub.unify(&t1, &t2));
324 assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(5));
325 }
326
327 #[test]
328 fn test_unify_same_var() {
329 let mut sub = Substitution::new();
330 assert!(sub.unify(&Term::Var(0), &Term::Var(0)));
331 }
332
333 #[test]
334 fn test_multiple_trail_marks() {
335 let mut sub = Substitution::new();
336
337 let mark1 = sub.trail_mark();
338 sub.unify(&Term::Var(0), &Term::Atom(1));
339
340 let mark2 = sub.trail_mark();
341 sub.unify(&Term::Var(1), &Term::Atom(2));
342
343 sub.undo_to(mark2);
345 assert_eq!(sub.walk(&Term::Var(0)), Term::Atom(1));
346 assert_eq!(sub.walk(&Term::Var(1)), Term::Var(1));
347
348 sub.undo_to(mark1);
350 assert_eq!(sub.walk(&Term::Var(0)), Term::Var(0));
351 }
352}