Skip to main content

prolog2/resolution/
unification.rs

1//! Unification algorithm and substitution management.
2
3use std::{
4    ops::{Deref, DerefMut},
5    usize,
6};
7
8use crate::heap::heap::{Cell, Heap, Tag};
9
10/// Substitution mapping clause `Arg` cells to heap addresses.
11///
12/// Tracks argument register bindings and direct heap-to-heap bindings
13/// produced during unification.
14#[derive(Debug, PartialEq)]
15pub struct Substitution {
16    arg_regs: [usize; 32],
17    binding_array: [(usize, usize, bool); 32], //(From, To, Complex)
18    binding_len: usize,
19}
20
21impl Deref for Substitution {
22    type Target = [(usize, usize, bool)];
23    fn deref(&self) -> &Self::Target {
24        &self.binding_array[..self.binding_len]
25    }
26}
27
28impl DerefMut for Substitution {
29    fn deref_mut(&mut self) -> &mut Self::Target {
30        &mut self.binding_array[..self.binding_len]
31    }
32}
33
34impl Default for Substitution {
35    fn default() -> Self {
36        Self {
37            arg_regs: [usize::MAX; 32],
38            binding_array: Default::default(),
39            binding_len: Default::default(),
40        }
41    }
42}
43
44impl Substitution {
45    pub fn bound(&self, addr: usize) -> Option<usize> {
46        // println!("{addr}");
47        match self.iter().find(|(a1, _, _)| *a1 == addr) {
48            Some((_, a2, _)) => match self.bound(*a2) {
49                Some(a2) => Some(a2),
50                None => Some(*a2),
51            },
52            None => None,
53        }
54    }
55
56    pub fn push(mut self, binding: (usize, usize, bool)) -> Self {
57        self.binding_array[self.binding_len] = binding;
58        self.binding_len += 1;
59        self
60    }
61
62    pub fn get_arg(&self, arg_idx: usize) -> Option<usize> {
63        if self.arg_regs[arg_idx] == usize::MAX {
64            None
65        } else {
66            Some(self.arg_regs[arg_idx])
67        }
68    }
69
70    pub fn set_arg(&mut self, arg_idx: usize, addr: usize) {
71        self.arg_regs[arg_idx] = addr;
72    }
73
74    pub fn get_bindings(&self) -> Box<[(usize, usize)]> {
75        let mut bindings = Vec::<(usize, usize)>::with_capacity(self.binding_len);
76        for i in 0..self.binding_len {
77            bindings.push((self.binding_array[i].0, self.binding_array[i].1));
78        }
79        bindings.into_boxed_slice()
80    }
81
82    // REPLACE the check_constraints method in src/resolution/unification.rs with:
83
84    /// Fully dereference an address through both heap references and substitution bindings.
85    pub(crate) fn full_deref(&self, mut addr: usize, heap: &impl Heap) -> usize {
86        loop {
87            // First, dereference through the heap
88            let heap_deref = heap.deref_addr(addr);
89
90            // Then check if there's a pending binding in the substitution
91            match self.bound(heap_deref) {
92                Some(bound_to) => {
93                    let next = heap.deref_addr(bound_to);
94                    if next == heap_deref {
95                        return heap_deref;
96                    }
97                    addr = next;
98                }
99                None => return heap_deref,
100            }
101        }
102    }
103
104    /// Check that no two constrained addresses are bound to the same final target.
105    /// This prevents different meta-variables from unifying to the same predicate symbol.
106    ///
107    /// The constraint check traces through BOTH:
108    /// 1. The heap's reference chains (via deref_addr)
109    /// 2. The substitution's pending bindings (via bound)
110    ///
111    /// Compares cell VALUES at dereferenced addresses, not the addresses themselves.
112    /// This ensures that the same constant symbol at different heap locations is
113    /// correctly detected as a duplicate.
114    pub fn check_constraints(&self, constraints: &[usize], heap: &impl Heap) -> bool {
115        // Collect (original_constraint, final_cell) for each constrained address
116        let mut constrained_targets: Vec<(usize, Cell)> = Vec::new();
117
118        for &constraint_addr in constraints {
119            let final_addr = self.full_deref(constraint_addr, heap);
120            let final_cell = heap[final_addr];
121            constrained_targets.push((constraint_addr, final_cell));
122        }
123
124        // Check if any two different constrained addresses resolve to the same cell value
125        for i in 0..constrained_targets.len() {
126            for j in (i + 1)..constrained_targets.len() {
127                if constrained_targets[i].0 != constrained_targets[j].0
128                    && constrained_targets[i].1 == constrained_targets[j].1
129                {
130                    return false;
131                }
132            }
133        }
134
135        true
136    }
137}
138
139/// Unify two terms on the heap, returning a substitution on success.
140pub fn unify(heap: &impl Heap, addr_1: usize, addr_2: usize) -> Option<Substitution> {
141    unify_rec(heap, Substitution::default(), addr_1, addr_2)
142}
143
144///Recursive unification function \
145///@addr_1: Address of program term \
146///@addr_2: Address of goal term
147fn unify_rec(
148    heap: &impl Heap,
149    mut binding: Substitution,
150    mut addr_1: usize,
151    mut addr_2: usize,
152) -> Option<Substitution> {
153    addr_1 = heap.deref_addr(addr_1);
154    addr_2 = heap.deref_addr(addr_2);
155    if heap[addr_1].0 == Tag::Ref {
156        if let Some(addr) = binding.bound(addr_1) {
157            addr_1 = addr;
158        }
159    }
160    if heap[addr_2].0 == Tag::Ref {
161        if let Some(addr) = binding.bound(addr_2) {
162            addr_2 = addr;
163        }
164    }
165
166    if addr_1 == addr_2 {
167        return Some(binding);
168    }
169
170    match (heap[addr_1].0, heap[addr_2].0) {
171        (Tag::Str, Tag::Str) => unify_rec(heap, binding, heap[addr_1].1, heap[addr_2].1),
172        (_, Tag::Str) => unify_rec(heap, binding, addr_1, heap[addr_2].1),
173        (Tag::Str, _) => unify_rec(heap, binding, heap[addr_1].1, addr_2),
174        (_, Tag::Arg) => panic!("Undefined Unification behaviour"),
175        (Tag::Arg, _) => match binding.get_arg(heap[addr_1].1) {
176            Some(addr) => unify_rec(heap, binding, addr, addr_2),
177            None => {
178                binding.set_arg(heap[addr_1].1, addr_2);
179                Some(binding)
180            }
181        },
182        (Tag::Ref, Tag::Lis | Tag::Func | Tag::Set | Tag::Tup) => {
183            Some(binding.push((addr_1, addr_2, true)))
184        }
185        (Tag::Ref, _) => Some(binding.push((addr_1, addr_2, false))),
186        (Tag::Lis | Tag::Func | Tag::Set | Tag::Tup, Tag::Ref) => {
187            Some(binding.push((addr_2, addr_1, true)))
188        }
189        (_, Tag::Ref) => Some(binding.push((addr_2, addr_1, false))),
190        (Tag::Con, Tag::Con)|(Tag::Int, Tag::Int)|(Tag::Flt, Tag::Flt) if heap[addr_1].1 == heap[addr_2].1 => Some(binding),
191        (Tag::Func, Tag::Func)|(Tag::Tup, Tag::Tup) => unify_func_or_tup(heap, binding, addr_1, addr_2),
192        (Tag::Set, Tag::Set) => unfiy_set(heap, binding, addr_1, addr_2),
193        (Tag::Lis, Tag::Lis) => unify_list(heap, binding, addr_1, addr_2),
194        (Tag::ELis, Tag::ELis) => Some(binding),
195        _ => None,
196    }
197}
198
199fn unify_func_or_tup(
200    heap: &impl Heap,
201    mut binding: Substitution,
202    addr_1: usize,
203    addr_2: usize,
204) -> Option<Substitution> {
205    if heap[addr_1].1 != heap[addr_2].1 {
206        return None;
207    };
208
209    for i in 1..heap[addr_1].1 + 1 {
210        binding = unify_rec(heap, binding, addr_1 + i, addr_2 + i)?;
211    }
212
213    Some(binding)
214}
215
216fn unfiy_set(
217    _heap: &impl Heap,
218    mut _binding: Substitution,
219    _addr_1: usize,
220    _addr_2: usize,
221) -> Option<Substitution> {
222    unimplemented!("set unification not yet supported")
223}
224
225/**Unfiy two lists together */
226fn unify_list(
227    heap: &impl Heap,
228    mut binding: Substitution,
229    addr_1: usize,
230    addr_2: usize,
231) -> Option<Substitution> {
232    // println!("List:({addr_1},{addr_2})");
233    let addr_1 = heap[addr_1].1;
234    let addr_2 = heap[addr_2].1;
235    binding = unify_rec(heap, binding, addr_1, addr_2)?;
236    unify_rec(heap, binding, addr_1 + 1, addr_2 + 1)
237}
238
239#[cfg(test)]
240mod tests {
241    use std::sync::Arc;
242
243    use super::Substitution;
244    use crate::{
245        heap::{
246            heap::Tag, query_heap::QueryHeap, symbol_db::SymbolDB
247        },
248        resolution::unification::{unify, unify_rec},
249    };
250
251    #[test]
252    fn arg_to_ref() {
253        let p = SymbolDB::set_const("p".into());
254        let a = SymbolDB::set_const("p".into());
255
256        let heap = vec![
257            (Tag::Arg, 0),
258            (Tag::Ref, 1),
259            (Tag::Ref, 2),
260            (Tag::Str, 4),
261            (Tag::Func, 2),
262            (Tag::Con, p),
263            (Tag::Con, a),
264        ];
265
266        let mut binding = unify(&heap, 0, 1).unwrap();
267        assert_eq!(binding.arg_regs[0], 1);
268        assert_eq!(binding.arg_regs[1..32], [usize::MAX; 31]);
269
270        binding = unify_rec(&heap, binding, 0, 2).unwrap();
271        assert_eq!(binding.arg_regs[0], 1);
272        assert_eq!(binding.arg_regs[1..32], [usize::MAX; 31]);
273        assert_eq!(binding.bound(1), Some(2));
274
275        binding.binding_array[0] = (0, 0, false);
276        binding.binding_len = 0;
277        binding.arg_regs[0] = 3;
278        binding = unify_rec(&heap, binding, 0, 1).unwrap();
279        assert_eq!(binding.bound(1), Some(4));
280
281        binding.binding_array[0] = (0, 0, false);
282        binding.binding_len = 0;
283        binding.arg_regs[0] = 4;
284        binding = unify_rec(&heap, binding, 0, 1).unwrap();
285        assert_eq!(binding.bound(1), Some(4));
286
287        binding.binding_array[0] = (0, 0, false);
288        binding.binding_len = 0;
289        binding.arg_regs[0] = 5;
290        binding = unify_rec(&heap, binding, 0, 1).unwrap();
291        assert_eq!(binding.bound(1), Some(5));
292    }
293
294    #[test]
295    fn arg() {
296        let p = SymbolDB::set_const("p".into());
297        let a = SymbolDB::set_const("p".into());
298
299        let heap = vec![
300            (Tag::Arg, 0),
301            (Tag::Str, 2),
302            (Tag::Func, 2),
303            (Tag::Con, p),
304            (Tag::Con, a),
305        ];
306
307        let binding = unify(&heap, 0, 1).unwrap();
308        assert_eq!(binding.get_arg(0), Some(2));
309    }
310
311    #[test]
312    fn binding_chain_ref() {
313        let p = SymbolDB::set_const("p".into());
314        let a = SymbolDB::set_const("a".into());
315
316        let heap = vec![
317            (Tag::Ref, 0),
318            (Tag::Ref, 1),
319            (Tag::Ref, 2),
320            (Tag::Str, 4),
321            (Tag::Func, 2),
322            (Tag::Con, p),
323            (Tag::Con, a),
324        ];
325
326        let mut binding = Substitution::default();
327        binding = binding.push((1, 2, false));
328
329        binding = unify_rec(&heap, binding, 0, 1).unwrap();
330        assert_eq!(binding.bound(0), Some(2));
331
332        let mut binding = Substitution::default();
333        binding = binding.push((1, 3, false));
334        binding = unify_rec(&heap, binding, 0, 1).unwrap();
335        assert_eq!(binding.bound(0), Some(4));
336
337        let mut binding = Substitution::default();
338        binding = binding.push((1, 4, false));
339        binding = unify_rec(&heap, binding, 0, 1).unwrap();
340        assert_eq!(binding.bound(0), Some(4));
341
342        let mut binding = Substitution::default();
343        binding = binding.push((1, 5, false));
344        binding = unify_rec(&heap, binding, 0, 1).unwrap();
345        assert_eq!(binding.bound(0), Some(5));
346    }
347
348    #[test]
349    fn func() {
350        let p = SymbolDB::set_const("p".into());
351        let a = SymbolDB::set_const("a".into());
352
353        let heap = vec![
354            (Tag::Func, 2),
355            (Tag::Con, p),
356            (Tag::Con, a),
357            (Tag::Tup, 2),
358            (Tag::Con, p),
359            (Tag::Con, a),
360            (Tag::Ref, 6),
361            (Tag::Lis, 8),
362            (Tag::Con, p),
363            (Tag::ELis, 0),
364        ];
365
366        assert_eq!(unify(&heap, 0, 3), None);
367        assert_eq!(unify(&heap, 0, 4), None);
368        let binding = unify(&heap, 0, 6).unwrap();
369        assert_eq!(binding.bound(6), Some(0));
370        assert_eq!(unify(&heap, 0, 7), None);
371    }
372
373    #[test]
374    fn tup() {
375        let p = SymbolDB::set_const("p".into());
376        let a = SymbolDB::set_const("a".into());
377
378        let heap = vec![
379            (Tag::Tup, 2),
380            (Tag::Con, p),
381            (Tag::Con, a),
382            (Tag::Func, 2),
383            (Tag::Con, p),
384            (Tag::Con, a),
385            (Tag::Ref, 6),
386            (Tag::Lis, 8),
387            (Tag::Con, p),
388            (Tag::ELis, 0),
389        ];
390
391        assert_eq!(unify(&heap, 0, 3), None);
392        assert_eq!(unify(&heap, 0, 4), None);
393        let binding = unify(&heap, 0, 6).unwrap();
394        assert_eq!(binding.bound(6), Some(0));
395        assert_eq!(unify(&heap, 0, 7), None);
396    }
397
398    #[test]
399    fn list() {
400        let p = SymbolDB::set_const("p".into());
401        let a = SymbolDB::set_const("a".into());
402        let b = SymbolDB::set_const("b".into());
403        let c = SymbolDB::set_const("c".into());
404        let t = SymbolDB::set_const("t".into());
405
406        let heap = vec![
407            (Tag::Lis, 1),  //0
408            (Tag::Con, a),  //1
409            (Tag::Lis, 3),  //2
410            (Tag::Con, b),  //3
411            (Tag::Lis, 5),  //4
412            (Tag::Con, c),  //5
413            (Tag::ELis, 0), //6
414            (Tag::Lis, 8),  //7
415            (Tag::Con, a),  //8
416            (Tag::Lis, 10), //9
417            (Tag::Ref, 10), //10
418            (Tag::Lis, 12), //11
419            (Tag::Ref, 12), //12
420            (Tag::ELis, 0), //13
421        ];
422
423        let binding = unify(&heap, 0, 7).unwrap();
424        assert_eq!(binding.bound(10), Some(3));
425        assert_eq!(binding.bound(12), Some(5));
426
427        let heap = vec![
428            (Tag::Lis, 1),  //0
429            (Tag::Arg, 0),  //1
430            (Tag::Lis, 3),  //2
431            (Tag::Arg, 1),  //3
432            (Tag::Lis, 5),  //4
433            (Tag::Arg, 2),  //5
434            (Tag::Con, t),  //6
435            (Tag::Lis, 8),  //7
436            (Tag::Con, a),  //8
437            (Tag::Lis, 10), //9
438            (Tag::Con, b),  //10
439            (Tag::Lis, 12), //11
440            (Tag::Ref, 12), //12
441            (Tag::Con, t),  //13
442        ];
443
444        let binding = unify(&heap, 0, 7).unwrap();
445        assert_eq!(binding.get_arg(0), Some(8));
446        assert_eq!(binding.get_arg(1), Some(10));
447        assert_eq!(binding.get_arg(2), Some(12));
448
449        let heap = vec![
450            (Tag::Lis, 1),  //0
451            (Tag::Arg, 0),  //1
452            (Tag::Lis, 3),  //2
453            (Tag::Arg, 1),  //3
454            (Tag::Lis, 5),  //4
455            (Tag::Arg, 2),  //5
456            (Tag::Arg, 3),  //6
457            (Tag::Lis, 8),  //7
458            (Tag::Ref, 8),  //8
459            (Tag::Lis, 10), //9
460            (Tag::Ref, 10), //10
461            (Tag::Lis, 12), //11
462            (Tag::Ref, 12), //12
463            (Tag::Ref, 13), //13
464        ];
465
466        let binding = unify(&heap, 0, 7).unwrap();
467        assert_eq!(binding.get_arg(0), Some(8));
468        assert_eq!(binding.get_arg(1), Some(10));
469        assert_eq!(binding.get_arg(2), Some(12));
470        assert_eq!(binding.get_arg(3), Some(13));
471
472        let heap = vec![
473            (Tag::Func, 2), //0
474            (Tag::Con, p),  //1
475            (Tag::Lis, 6),  //2
476            (Tag::Func, 2), //3
477            (Tag::Con, p),  //4
478            (Tag::Lis, 12), //5
479            (Tag::Arg, 0),  //6
480            (Tag::Lis, 8),  //7
481            (Tag::Arg, 1),  //8
482            (Tag::Lis, 10), //9
483            (Tag::Arg, 2),  //10
484            (Tag::ELis, 0), //11
485            (Tag::Ref, 12), //12
486            (Tag::Lis, 14), //13
487            (Tag::Ref, 14), //14
488            (Tag::Lis, 16), //15
489            (Tag::Ref, 16), //16
490            (Tag::ELis, 0), //17
491        ];
492
493        let binding = unify(&heap, 0, 3).unwrap();
494        assert_eq!(binding.get_arg(0), Some(12));
495        assert_eq!(binding.get_arg(1), Some(14));
496        assert_eq!(binding.get_arg(2), Some(16));
497
498        let heap = vec![
499            (Tag::Lis, 1),  //0
500            (Tag::Lis, 12), //1
501            (Tag::Lis, 3),  //2
502            (Tag::Lis, 14), //3
503            (Tag::Lis, 5),  //4
504            (Tag::Lis, 16), //5
505            (Tag::ELis, 0), //6
506            (Tag::Lis, 8),  //7
507            (Tag::Lis, 18), //8
508            (Tag::Lis, 10), //9
509            (Tag::Lis, 20), //10
510            (Tag::Ref, 11), //11
511            (Tag::Con, a),  //12
512            (Tag::ELis, 0), //13
513            (Tag::Arg, 0),  //14
514            (Tag::ELis, 0), //15
515            (Tag::Con, c),  //16
516            (Tag::ELis, 0), //17
517            (Tag::Con, a),  //18
518            (Tag::ELis, 0), //19
519            (Tag::Con, b),  //20
520            (Tag::ELis, 0), //21
521        ];
522
523        let binding = unify(&heap, 0, 7).unwrap();
524        assert_eq!(binding.get_arg(0), Some(20));
525        assert_eq!(binding.bound(11), Some(4));
526    }
527
528    #[test]
529    fn integers(){
530        let prev = SymbolDB::set_const("prev".to_string());
531        let heap = vec![(Tag::Func,3),(Tag::Con,prev),(Tag::Int,4),(Tag::Int,3)];
532        let mut heap = QueryHeap::new(Arc::new(heap), None);
533        //possible failure to deref before comparing numbers
534        heap.cells.extend(vec![
535            (Tag::Func, 3),(Tag::Ref, 5), (Tag::Int, 4), (Tag::Ref, 7)
536        ]);
537
538        let binding = unify(&heap, 0, 4).unwrap();
539        assert_eq!(binding.bound(5),Some(1));
540        assert_eq!(binding.bound(7),Some(3));
541    }
542}