Skip to main content

prolog2/resolution/
unification.rs

1//! Unification algorithm and substitution management.
2
3use std::{
4    mem::MaybeUninit,
5    ops::{Deref, DerefMut},
6    usize,
7};
8
9use crate::heap::heap::{Cell, Heap, Tag};
10
11/// Substitution mapping clause `Arg` cells to heap addresses.
12///
13/// Tracks argument register bindings and direct heap-to-heap bindings
14/// produced during unification.
15#[derive(Debug, PartialEq)]
16pub struct Substitution {
17    arg_regs: [usize; 32],
18    binding_array: [(usize, usize, bool); 32], //(From, To, Complex)
19    binding_len: usize,
20}
21
22impl Deref for Substitution {
23    type Target = [(usize, usize, bool)];
24    fn deref(&self) -> &Self::Target {
25        &self.binding_array[..self.binding_len]
26    }
27}
28
29impl DerefMut for Substitution {
30    fn deref_mut(&mut self) -> &mut Self::Target {
31        &mut self.binding_array[..self.binding_len]
32    }
33}
34
35impl Default for Substitution {
36    fn default() -> Self {
37        Self {
38            arg_regs: [usize::MAX; 32],
39            binding_array: Default::default(),
40            binding_len: Default::default(),
41        }
42    }
43}
44
45impl Substitution {
46    pub fn bound(&self, addr: usize) -> Option<usize> {
47        // println!("{addr}");
48        match self.iter().find(|(a1, _, _)| *a1 == addr) {
49            Some((_, a2, _)) => match self.bound(*a2) {
50                Some(a2) => Some(a2),
51                None => Some(*a2),
52            },
53            None => None,
54        }
55    }
56
57    pub fn push(mut self, binding: (usize, usize, bool)) -> Self {
58        self.binding_array[self.binding_len] = binding;
59        self.binding_len += 1;
60        self
61    }
62
63    pub fn get_arg(&self, arg_idx: usize) -> Option<usize> {
64        if self.arg_regs[arg_idx] == usize::MAX {
65            None
66        } else {
67            Some(self.arg_regs[arg_idx])
68        }
69    }
70
71    pub fn set_arg(&mut self, arg_idx: usize, addr: usize) {
72        self.arg_regs[arg_idx] = addr;
73    }
74
75    pub fn get_bindings(&self) -> Box<[(usize, usize)]> {
76        let mut bindings = Vec::<(usize, usize)>::with_capacity(self.binding_len);
77        for i in 0..self.binding_len {
78            bindings.push((self.binding_array[i].0, self.binding_array[i].1));
79        }
80        bindings.into_boxed_slice()
81    }
82
83    /// Fully dereference an address through both heap references and substitution bindings.
84    pub(crate) fn full_deref(&self, mut addr: usize, heap: &impl Heap) -> usize {
85        loop {
86            // First, dereference through the heap
87            let heap_deref = heap.deref_addr(addr);
88
89            // Then check if there's a pending binding in the substitution
90            match self.bound(heap_deref) {
91                Some(bound_to) => {
92                    let next = heap.deref_addr(bound_to);
93                    if next == heap_deref {
94                        return heap_deref;
95                    }
96                    addr = next;
97                }
98                None => return heap_deref,
99            }
100        }
101    }
102
103    /// Check that no two constrained addresses are bound to the same final target.
104    /// This prevents different meta-variables from unifying to the same predicate symbol.
105    ///
106    /// The constraint check traces through BOTH:
107    /// 1. The heap's reference chains (via deref_addr)
108    /// 2. The substitution's pending bindings (via bound)
109    ///
110    /// Compares cell VALUES at dereferenced addresses, not the addresses themselves.
111    /// This ensures that the same constant symbol at different heap locations is
112    /// correctly detected as a duplicate.
113    pub fn check_constraints(&self, constraints: &[usize], heap: &impl Heap) -> bool {
114        const STACK_CAP: usize = 8;
115        let len = constraints.len();
116
117        if len <= STACK_CAP {
118            // Stack-allocated path: use MaybeUninit to avoid zeroing unused slots
119            let mut buf: [MaybeUninit<(usize, Cell)>; STACK_CAP] =
120                unsafe { MaybeUninit::uninit().assume_init() };
121            for i in 0..len {
122                let addr = constraints[i];
123                let cell = heap[self.full_deref(addr, heap)];
124                buf[i] = MaybeUninit::new((addr, cell));
125            }
126            for i in 0..len {
127                let (addr_i, cell_i) = unsafe { buf[i].assume_init() };
128                for j in (i + 1)..len {
129                    let (addr_j, cell_j) = unsafe { buf[j].assume_init() };
130                    if addr_i != addr_j && cell_i == cell_j {
131                        return false;
132                    }
133                }
134            }
135            true
136        } else {
137            // Fallback for very large constraint sets
138            let targets: Vec<(usize, Cell)> = constraints
139                .iter()
140                .map(|&addr| (addr, heap[self.full_deref(addr, heap)]))
141                .collect();
142            for i in 0..targets.len() {
143                for j in (i + 1)..targets.len() {
144                    if targets[i].0 != targets[j].0 && targets[i].1 == targets[j].1 {
145                        return false;
146                    }
147                }
148            }
149            true
150        }
151    }
152}
153
154/// Unify two terms on the heap, returning a substitution on success.
155pub fn unify(heap: &impl Heap, addr_1: usize, addr_2: usize) -> Option<Substitution> {
156    unify_rec(heap, Substitution::default(), addr_1, addr_2)
157}
158
159///Recursive unification function \
160///@addr_1: Address of program term \
161///@addr_2: Address of goal term
162fn unify_rec(
163    heap: &impl Heap,
164    mut binding: Substitution,
165    mut addr_1: usize,
166    mut addr_2: usize,
167) -> Option<Substitution> {
168    addr_1 = heap.deref_addr(addr_1);
169    addr_2 = heap.deref_addr(addr_2);
170    if heap[addr_1].0 == Tag::Ref {
171        if let Some(addr) = binding.bound(addr_1) {
172            addr_1 = addr;
173        }
174    }
175    if heap[addr_2].0 == Tag::Ref {
176        if let Some(addr) = binding.bound(addr_2) {
177            addr_2 = addr;
178        }
179    }
180
181    if addr_1 == addr_2 {
182        return Some(binding);
183    }
184
185    match (heap[addr_1].0, heap[addr_2].0) {
186        (Tag::Str, Tag::Str) => unify_rec(heap, binding, heap[addr_1].1, heap[addr_2].1),
187        (_, Tag::Str) => unify_rec(heap, binding, addr_1, heap[addr_2].1),
188        (Tag::Str, _) => unify_rec(heap, binding, heap[addr_1].1, addr_2),
189        (_, Tag::Arg) => panic!("Undefined Unification behaviour"),
190        (Tag::Arg, _) => match binding.get_arg(heap[addr_1].1) {
191            Some(addr) => unify_rec(heap, binding, addr, addr_2),
192            None => {
193                binding.set_arg(heap[addr_1].1, addr_2);
194                Some(binding)
195            }
196        },
197        (Tag::Ref, Tag::Lis | Tag::Func | Tag::Set | Tag::Tup) => {
198            Some(binding.push((addr_1, addr_2, true)))
199        }
200        (Tag::Ref, _) => Some(binding.push((addr_1, addr_2, false))),
201        (Tag::Lis | Tag::Func | Tag::Set | Tag::Tup, Tag::Ref) => {
202            Some(binding.push((addr_2, addr_1, true)))
203        }
204        (_, Tag::Ref) => Some(binding.push((addr_2, addr_1, false))),
205        (Tag::Con, Tag::Con) | (Tag::Int, Tag::Int) | (Tag::Flt, Tag::Flt)
206            if heap[addr_1].1 == heap[addr_2].1 =>
207        {
208            Some(binding)
209        }
210        (Tag::Func, Tag::Func) | (Tag::Tup, Tag::Tup) => {
211            unify_func_or_tup(heap, binding, addr_1, addr_2)
212        }
213        (Tag::Set, Tag::Set) => unfiy_set(heap, binding, addr_1, addr_2),
214        (Tag::Lis, Tag::Lis) => unify_list(heap, binding, addr_1, addr_2),
215        (Tag::ELis, Tag::ELis) => Some(binding),
216        _ => None,
217    }
218}
219
220fn unify_func_or_tup(
221    heap: &impl Heap,
222    mut binding: Substitution,
223    addr_1: usize,
224    addr_2: usize,
225) -> Option<Substitution> {
226    if heap[addr_1].1 != heap[addr_2].1 {
227        return None;
228    };
229
230    for i in 1..heap[addr_1].1 + 1 {
231        binding = unify_rec(heap, binding, addr_1 + i, addr_2 + i)?;
232    }
233
234    Some(binding)
235}
236
237fn unfiy_set(
238    _heap: &impl Heap,
239    mut _binding: Substitution,
240    _addr_1: usize,
241    _addr_2: usize,
242) -> Option<Substitution> {
243    unimplemented!("set unification not yet supported")
244}
245
246/**Unfiy two lists together */
247fn unify_list(
248    heap: &impl Heap,
249    mut binding: Substitution,
250    addr_1: usize,
251    addr_2: usize,
252) -> Option<Substitution> {
253    // println!("List:({addr_1},{addr_2})");
254    let addr_1 = heap[addr_1].1;
255    let addr_2 = heap[addr_2].1;
256    binding = unify_rec(heap, binding, addr_1, addr_2)?;
257    unify_rec(heap, binding, addr_1 + 1, addr_2 + 1)
258}
259
260#[cfg(test)]
261mod tests {
262    use super::Substitution;
263    use crate::{
264        heap::{heap::Tag, query_heap::QueryHeap, symbol_db::SymbolDB},
265        resolution::unification::{unify, unify_rec},
266    };
267
268    #[test]
269    fn arg_to_ref() {
270        let p = SymbolDB::set_const("p".into());
271        let a = SymbolDB::set_const("p".into());
272
273        let heap = vec![
274            (Tag::Arg, 0),
275            (Tag::Ref, 1),
276            (Tag::Ref, 2),
277            (Tag::Str, 4),
278            (Tag::Func, 2),
279            (Tag::Con, p),
280            (Tag::Con, a),
281        ];
282
283        let mut binding = unify(&heap, 0, 1).unwrap();
284        assert_eq!(binding.arg_regs[0], 1);
285        assert_eq!(binding.arg_regs[1..32], [usize::MAX; 31]);
286
287        binding = unify_rec(&heap, binding, 0, 2).unwrap();
288        assert_eq!(binding.arg_regs[0], 1);
289        assert_eq!(binding.arg_regs[1..32], [usize::MAX; 31]);
290        assert_eq!(binding.bound(1), Some(2));
291
292        binding.binding_array[0] = (0, 0, false);
293        binding.binding_len = 0;
294        binding.arg_regs[0] = 3;
295        binding = unify_rec(&heap, binding, 0, 1).unwrap();
296        assert_eq!(binding.bound(1), Some(4));
297
298        binding.binding_array[0] = (0, 0, false);
299        binding.binding_len = 0;
300        binding.arg_regs[0] = 4;
301        binding = unify_rec(&heap, binding, 0, 1).unwrap();
302        assert_eq!(binding.bound(1), Some(4));
303
304        binding.binding_array[0] = (0, 0, false);
305        binding.binding_len = 0;
306        binding.arg_regs[0] = 5;
307        binding = unify_rec(&heap, binding, 0, 1).unwrap();
308        assert_eq!(binding.bound(1), Some(5));
309    }
310
311    #[test]
312    fn arg() {
313        let p = SymbolDB::set_const("p".into());
314        let a = SymbolDB::set_const("p".into());
315
316        let heap = vec![
317            (Tag::Arg, 0),
318            (Tag::Str, 2),
319            (Tag::Func, 2),
320            (Tag::Con, p),
321            (Tag::Con, a),
322        ];
323
324        let binding = unify(&heap, 0, 1).unwrap();
325        assert_eq!(binding.get_arg(0), Some(2));
326    }
327
328    #[test]
329    fn binding_chain_ref() {
330        let p = SymbolDB::set_const("p".into());
331        let a = SymbolDB::set_const("a".into());
332
333        let heap = vec![
334            (Tag::Ref, 0),
335            (Tag::Ref, 1),
336            (Tag::Ref, 2),
337            (Tag::Str, 4),
338            (Tag::Func, 2),
339            (Tag::Con, p),
340            (Tag::Con, a),
341        ];
342
343        let mut binding = Substitution::default();
344        binding = binding.push((1, 2, false));
345
346        binding = unify_rec(&heap, binding, 0, 1).unwrap();
347        assert_eq!(binding.bound(0), Some(2));
348
349        let mut binding = Substitution::default();
350        binding = binding.push((1, 3, false));
351        binding = unify_rec(&heap, binding, 0, 1).unwrap();
352        assert_eq!(binding.bound(0), Some(4));
353
354        let mut binding = Substitution::default();
355        binding = binding.push((1, 4, false));
356        binding = unify_rec(&heap, binding, 0, 1).unwrap();
357        assert_eq!(binding.bound(0), Some(4));
358
359        let mut binding = Substitution::default();
360        binding = binding.push((1, 5, false));
361        binding = unify_rec(&heap, binding, 0, 1).unwrap();
362        assert_eq!(binding.bound(0), Some(5));
363    }
364
365    #[test]
366    fn func() {
367        let p = SymbolDB::set_const("p".into());
368        let a = SymbolDB::set_const("a".into());
369
370        let heap = vec![
371            (Tag::Func, 2),
372            (Tag::Con, p),
373            (Tag::Con, a),
374            (Tag::Tup, 2),
375            (Tag::Con, p),
376            (Tag::Con, a),
377            (Tag::Ref, 6),
378            (Tag::Lis, 8),
379            (Tag::Con, p),
380            (Tag::ELis, 0),
381        ];
382
383        assert_eq!(unify(&heap, 0, 3), None);
384        assert_eq!(unify(&heap, 0, 4), None);
385        let binding = unify(&heap, 0, 6).unwrap();
386        assert_eq!(binding.bound(6), Some(0));
387        assert_eq!(unify(&heap, 0, 7), None);
388    }
389
390    #[test]
391    fn tup() {
392        let p = SymbolDB::set_const("p".into());
393        let a = SymbolDB::set_const("a".into());
394
395        let heap = vec![
396            (Tag::Tup, 2),
397            (Tag::Con, p),
398            (Tag::Con, a),
399            (Tag::Func, 2),
400            (Tag::Con, p),
401            (Tag::Con, a),
402            (Tag::Ref, 6),
403            (Tag::Lis, 8),
404            (Tag::Con, p),
405            (Tag::ELis, 0),
406        ];
407
408        assert_eq!(unify(&heap, 0, 3), None);
409        assert_eq!(unify(&heap, 0, 4), None);
410        let binding = unify(&heap, 0, 6).unwrap();
411        assert_eq!(binding.bound(6), Some(0));
412        assert_eq!(unify(&heap, 0, 7), None);
413    }
414
415    #[test]
416    fn list() {
417        let p = SymbolDB::set_const("p".into());
418        let a = SymbolDB::set_const("a".into());
419        let b = SymbolDB::set_const("b".into());
420        let c = SymbolDB::set_const("c".into());
421        let t = SymbolDB::set_const("t".into());
422
423        let heap = vec![
424            (Tag::Lis, 1),  //0
425            (Tag::Con, a),  //1
426            (Tag::Lis, 3),  //2
427            (Tag::Con, b),  //3
428            (Tag::Lis, 5),  //4
429            (Tag::Con, c),  //5
430            (Tag::ELis, 0), //6
431            (Tag::Lis, 8),  //7
432            (Tag::Con, a),  //8
433            (Tag::Lis, 10), //9
434            (Tag::Ref, 10), //10
435            (Tag::Lis, 12), //11
436            (Tag::Ref, 12), //12
437            (Tag::ELis, 0), //13
438        ];
439
440        let binding = unify(&heap, 0, 7).unwrap();
441        assert_eq!(binding.bound(10), Some(3));
442        assert_eq!(binding.bound(12), Some(5));
443
444        let heap = vec![
445            (Tag::Lis, 1),  //0
446            (Tag::Arg, 0),  //1
447            (Tag::Lis, 3),  //2
448            (Tag::Arg, 1),  //3
449            (Tag::Lis, 5),  //4
450            (Tag::Arg, 2),  //5
451            (Tag::Con, t),  //6
452            (Tag::Lis, 8),  //7
453            (Tag::Con, a),  //8
454            (Tag::Lis, 10), //9
455            (Tag::Con, b),  //10
456            (Tag::Lis, 12), //11
457            (Tag::Ref, 12), //12
458            (Tag::Con, t),  //13
459        ];
460
461        let binding = unify(&heap, 0, 7).unwrap();
462        assert_eq!(binding.get_arg(0), Some(8));
463        assert_eq!(binding.get_arg(1), Some(10));
464        assert_eq!(binding.get_arg(2), Some(12));
465
466        let heap = vec![
467            (Tag::Lis, 1),  //0
468            (Tag::Arg, 0),  //1
469            (Tag::Lis, 3),  //2
470            (Tag::Arg, 1),  //3
471            (Tag::Lis, 5),  //4
472            (Tag::Arg, 2),  //5
473            (Tag::Arg, 3),  //6
474            (Tag::Lis, 8),  //7
475            (Tag::Ref, 8),  //8
476            (Tag::Lis, 10), //9
477            (Tag::Ref, 10), //10
478            (Tag::Lis, 12), //11
479            (Tag::Ref, 12), //12
480            (Tag::Ref, 13), //13
481        ];
482
483        let binding = unify(&heap, 0, 7).unwrap();
484        assert_eq!(binding.get_arg(0), Some(8));
485        assert_eq!(binding.get_arg(1), Some(10));
486        assert_eq!(binding.get_arg(2), Some(12));
487        assert_eq!(binding.get_arg(3), Some(13));
488
489        let heap = vec![
490            (Tag::Func, 2), //0
491            (Tag::Con, p),  //1
492            (Tag::Lis, 6),  //2
493            (Tag::Func, 2), //3
494            (Tag::Con, p),  //4
495            (Tag::Lis, 12), //5
496            (Tag::Arg, 0),  //6
497            (Tag::Lis, 8),  //7
498            (Tag::Arg, 1),  //8
499            (Tag::Lis, 10), //9
500            (Tag::Arg, 2),  //10
501            (Tag::ELis, 0), //11
502            (Tag::Ref, 12), //12
503            (Tag::Lis, 14), //13
504            (Tag::Ref, 14), //14
505            (Tag::Lis, 16), //15
506            (Tag::Ref, 16), //16
507            (Tag::ELis, 0), //17
508        ];
509
510        let binding = unify(&heap, 0, 3).unwrap();
511        assert_eq!(binding.get_arg(0), Some(12));
512        assert_eq!(binding.get_arg(1), Some(14));
513        assert_eq!(binding.get_arg(2), Some(16));
514
515        let heap = vec![
516            (Tag::Lis, 1),  //0
517            (Tag::Lis, 12), //1
518            (Tag::Lis, 3),  //2
519            (Tag::Lis, 14), //3
520            (Tag::Lis, 5),  //4
521            (Tag::Lis, 16), //5
522            (Tag::ELis, 0), //6
523            (Tag::Lis, 8),  //7
524            (Tag::Lis, 18), //8
525            (Tag::Lis, 10), //9
526            (Tag::Lis, 20), //10
527            (Tag::Ref, 11), //11
528            (Tag::Con, a),  //12
529            (Tag::ELis, 0), //13
530            (Tag::Arg, 0),  //14
531            (Tag::ELis, 0), //15
532            (Tag::Con, c),  //16
533            (Tag::ELis, 0), //17
534            (Tag::Con, a),  //18
535            (Tag::ELis, 0), //19
536            (Tag::Con, b),  //20
537            (Tag::ELis, 0), //21
538        ];
539
540        let binding = unify(&heap, 0, 7).unwrap();
541        assert_eq!(binding.get_arg(0), Some(20));
542        assert_eq!(binding.bound(11), Some(4));
543    }
544
545    #[test]
546    fn integers() {
547        let prev = SymbolDB::set_const("prev".to_string());
548        let prog = vec![
549            (Tag::Func, 3),
550            (Tag::Con, prev),
551            (Tag::Int, 4),
552            (Tag::Int, 3),
553        ];
554        let mut heap = QueryHeap::new(&prog, None);
555        //possible failure to deref before comparing numbers
556        heap.cells.extend(vec![
557            (Tag::Func, 3),
558            (Tag::Ref, 5),
559            (Tag::Int, 4),
560            (Tag::Ref, 7),
561        ]);
562
563        let binding = unify(&heap, 0, 4).unwrap();
564        assert_eq!(binding.bound(5), Some(1));
565        assert_eq!(binding.bound(7), Some(3));
566    }
567}