1use std::{
4 mem::MaybeUninit,
5 ops::{Deref, DerefMut},
6 usize,
7};
8
9use crate::heap::heap::{Cell, Heap, Tag};
10
11#[derive(Debug, PartialEq)]
16pub struct Substitution {
17 arg_regs: [usize; 32],
18 binding_array: [(usize, usize, bool); 32], binding_len: usize,
20}
21
22impl Deref for Substitution {
23 type Target = [(usize, usize, bool)];
24 fn deref(&self) -> &Self::Target {
25 &self.binding_array[..self.binding_len]
26 }
27}
28
29impl DerefMut for Substitution {
30 fn deref_mut(&mut self) -> &mut Self::Target {
31 &mut self.binding_array[..self.binding_len]
32 }
33}
34
35impl Default for Substitution {
36 fn default() -> Self {
37 Self {
38 arg_regs: [usize::MAX; 32],
39 binding_array: Default::default(),
40 binding_len: Default::default(),
41 }
42 }
43}
44
45impl Substitution {
46 pub fn bound(&self, addr: usize) -> Option<usize> {
47 match self.iter().find(|(a1, _, _)| *a1 == addr) {
49 Some((_, a2, _)) => match self.bound(*a2) {
50 Some(a2) => Some(a2),
51 None => Some(*a2),
52 },
53 None => None,
54 }
55 }
56
57 pub fn push(mut self, binding: (usize, usize, bool)) -> Self {
58 self.binding_array[self.binding_len] = binding;
59 self.binding_len += 1;
60 self
61 }
62
63 pub fn get_arg(&self, arg_idx: usize) -> Option<usize> {
64 if self.arg_regs[arg_idx] == usize::MAX {
65 None
66 } else {
67 Some(self.arg_regs[arg_idx])
68 }
69 }
70
71 pub fn set_arg(&mut self, arg_idx: usize, addr: usize) {
72 self.arg_regs[arg_idx] = addr;
73 }
74
75 pub fn get_bindings(&self) -> Box<[(usize, usize)]> {
76 let mut bindings = Vec::<(usize, usize)>::with_capacity(self.binding_len);
77 for i in 0..self.binding_len {
78 bindings.push((self.binding_array[i].0, self.binding_array[i].1));
79 }
80 bindings.into_boxed_slice()
81 }
82
83 pub(crate) fn full_deref(&self, mut addr: usize, heap: &impl Heap) -> usize {
85 loop {
86 let heap_deref = heap.deref_addr(addr);
88
89 match self.bound(heap_deref) {
91 Some(bound_to) => {
92 let next = heap.deref_addr(bound_to);
93 if next == heap_deref {
94 return heap_deref;
95 }
96 addr = next;
97 }
98 None => return heap_deref,
99 }
100 }
101 }
102
103 pub fn check_constraints(&self, constraints: &[usize], heap: &impl Heap) -> bool {
114 const STACK_CAP: usize = 8;
115 let len = constraints.len();
116
117 if len <= STACK_CAP {
118 let mut buf: [MaybeUninit<(usize, Cell)>; STACK_CAP] =
120 unsafe { MaybeUninit::uninit().assume_init() };
121 for i in 0..len {
122 let addr = constraints[i];
123 let cell = heap[self.full_deref(addr, heap)];
124 buf[i] = MaybeUninit::new((addr, cell));
125 }
126 for i in 0..len {
127 let (addr_i, cell_i) = unsafe { buf[i].assume_init() };
128 for j in (i + 1)..len {
129 let (addr_j, cell_j) = unsafe { buf[j].assume_init() };
130 if addr_i != addr_j && cell_i == cell_j {
131 return false;
132 }
133 }
134 }
135 true
136 } else {
137 let targets: Vec<(usize, Cell)> = constraints
139 .iter()
140 .map(|&addr| (addr, heap[self.full_deref(addr, heap)]))
141 .collect();
142 for i in 0..targets.len() {
143 for j in (i + 1)..targets.len() {
144 if targets[i].0 != targets[j].0 && targets[i].1 == targets[j].1 {
145 return false;
146 }
147 }
148 }
149 true
150 }
151 }
152}
153
154pub fn unify(heap: &impl Heap, addr_1: usize, addr_2: usize) -> Option<Substitution> {
156 unify_rec(heap, Substitution::default(), addr_1, addr_2)
157}
158
159fn unify_rec(
163 heap: &impl Heap,
164 mut binding: Substitution,
165 mut addr_1: usize,
166 mut addr_2: usize,
167) -> Option<Substitution> {
168 addr_1 = heap.deref_addr(addr_1);
169 addr_2 = heap.deref_addr(addr_2);
170 if heap[addr_1].0 == Tag::Ref {
171 if let Some(addr) = binding.bound(addr_1) {
172 addr_1 = addr;
173 }
174 }
175 if heap[addr_2].0 == Tag::Ref {
176 if let Some(addr) = binding.bound(addr_2) {
177 addr_2 = addr;
178 }
179 }
180
181 if addr_1 == addr_2 {
182 return Some(binding);
183 }
184
185 match (heap[addr_1].0, heap[addr_2].0) {
186 (Tag::Str, Tag::Str) => unify_rec(heap, binding, heap[addr_1].1, heap[addr_2].1),
187 (_, Tag::Str) => unify_rec(heap, binding, addr_1, heap[addr_2].1),
188 (Tag::Str, _) => unify_rec(heap, binding, heap[addr_1].1, addr_2),
189 (_, Tag::Arg) => panic!("Undefined Unification behaviour"),
190 (Tag::Arg, _) => match binding.get_arg(heap[addr_1].1) {
191 Some(addr) => unify_rec(heap, binding, addr, addr_2),
192 None => {
193 binding.set_arg(heap[addr_1].1, addr_2);
194 Some(binding)
195 }
196 },
197 (Tag::Ref, Tag::Lis | Tag::Func | Tag::Set | Tag::Tup) => {
198 Some(binding.push((addr_1, addr_2, true)))
199 }
200 (Tag::Ref, _) => Some(binding.push((addr_1, addr_2, false))),
201 (Tag::Lis | Tag::Func | Tag::Set | Tag::Tup, Tag::Ref) => {
202 Some(binding.push((addr_2, addr_1, true)))
203 }
204 (_, Tag::Ref) => Some(binding.push((addr_2, addr_1, false))),
205 (Tag::Con, Tag::Con) | (Tag::Int, Tag::Int) | (Tag::Flt, Tag::Flt)
206 if heap[addr_1].1 == heap[addr_2].1 =>
207 {
208 Some(binding)
209 }
210 (Tag::Func, Tag::Func) | (Tag::Tup, Tag::Tup) => {
211 unify_func_or_tup(heap, binding, addr_1, addr_2)
212 }
213 (Tag::Set, Tag::Set) => unfiy_set(heap, binding, addr_1, addr_2),
214 (Tag::Lis, Tag::Lis) => unify_list(heap, binding, addr_1, addr_2),
215 (Tag::ELis, Tag::ELis) => Some(binding),
216 _ => None,
217 }
218}
219
220fn unify_func_or_tup(
221 heap: &impl Heap,
222 mut binding: Substitution,
223 addr_1: usize,
224 addr_2: usize,
225) -> Option<Substitution> {
226 if heap[addr_1].1 != heap[addr_2].1 {
227 return None;
228 };
229
230 for i in 1..heap[addr_1].1 + 1 {
231 binding = unify_rec(heap, binding, addr_1 + i, addr_2 + i)?;
232 }
233
234 Some(binding)
235}
236
237fn unfiy_set(
238 _heap: &impl Heap,
239 mut _binding: Substitution,
240 _addr_1: usize,
241 _addr_2: usize,
242) -> Option<Substitution> {
243 unimplemented!("set unification not yet supported")
244}
245
246fn unify_list(
248 heap: &impl Heap,
249 mut binding: Substitution,
250 addr_1: usize,
251 addr_2: usize,
252) -> Option<Substitution> {
253 let addr_1 = heap[addr_1].1;
255 let addr_2 = heap[addr_2].1;
256 binding = unify_rec(heap, binding, addr_1, addr_2)?;
257 unify_rec(heap, binding, addr_1 + 1, addr_2 + 1)
258}
259
260#[cfg(test)]
261mod tests {
262 use super::Substitution;
263 use crate::{
264 heap::{heap::Tag, query_heap::QueryHeap, symbol_db::SymbolDB},
265 resolution::unification::{unify, unify_rec},
266 };
267
268 #[test]
269 fn arg_to_ref() {
270 let p = SymbolDB::set_const("p".into());
271 let a = SymbolDB::set_const("p".into());
272
273 let heap = vec![
274 (Tag::Arg, 0),
275 (Tag::Ref, 1),
276 (Tag::Ref, 2),
277 (Tag::Str, 4),
278 (Tag::Func, 2),
279 (Tag::Con, p),
280 (Tag::Con, a),
281 ];
282
283 let mut binding = unify(&heap, 0, 1).unwrap();
284 assert_eq!(binding.arg_regs[0], 1);
285 assert_eq!(binding.arg_regs[1..32], [usize::MAX; 31]);
286
287 binding = unify_rec(&heap, binding, 0, 2).unwrap();
288 assert_eq!(binding.arg_regs[0], 1);
289 assert_eq!(binding.arg_regs[1..32], [usize::MAX; 31]);
290 assert_eq!(binding.bound(1), Some(2));
291
292 binding.binding_array[0] = (0, 0, false);
293 binding.binding_len = 0;
294 binding.arg_regs[0] = 3;
295 binding = unify_rec(&heap, binding, 0, 1).unwrap();
296 assert_eq!(binding.bound(1), Some(4));
297
298 binding.binding_array[0] = (0, 0, false);
299 binding.binding_len = 0;
300 binding.arg_regs[0] = 4;
301 binding = unify_rec(&heap, binding, 0, 1).unwrap();
302 assert_eq!(binding.bound(1), Some(4));
303
304 binding.binding_array[0] = (0, 0, false);
305 binding.binding_len = 0;
306 binding.arg_regs[0] = 5;
307 binding = unify_rec(&heap, binding, 0, 1).unwrap();
308 assert_eq!(binding.bound(1), Some(5));
309 }
310
311 #[test]
312 fn arg() {
313 let p = SymbolDB::set_const("p".into());
314 let a = SymbolDB::set_const("p".into());
315
316 let heap = vec![
317 (Tag::Arg, 0),
318 (Tag::Str, 2),
319 (Tag::Func, 2),
320 (Tag::Con, p),
321 (Tag::Con, a),
322 ];
323
324 let binding = unify(&heap, 0, 1).unwrap();
325 assert_eq!(binding.get_arg(0), Some(2));
326 }
327
328 #[test]
329 fn binding_chain_ref() {
330 let p = SymbolDB::set_const("p".into());
331 let a = SymbolDB::set_const("a".into());
332
333 let heap = vec![
334 (Tag::Ref, 0),
335 (Tag::Ref, 1),
336 (Tag::Ref, 2),
337 (Tag::Str, 4),
338 (Tag::Func, 2),
339 (Tag::Con, p),
340 (Tag::Con, a),
341 ];
342
343 let mut binding = Substitution::default();
344 binding = binding.push((1, 2, false));
345
346 binding = unify_rec(&heap, binding, 0, 1).unwrap();
347 assert_eq!(binding.bound(0), Some(2));
348
349 let mut binding = Substitution::default();
350 binding = binding.push((1, 3, false));
351 binding = unify_rec(&heap, binding, 0, 1).unwrap();
352 assert_eq!(binding.bound(0), Some(4));
353
354 let mut binding = Substitution::default();
355 binding = binding.push((1, 4, false));
356 binding = unify_rec(&heap, binding, 0, 1).unwrap();
357 assert_eq!(binding.bound(0), Some(4));
358
359 let mut binding = Substitution::default();
360 binding = binding.push((1, 5, false));
361 binding = unify_rec(&heap, binding, 0, 1).unwrap();
362 assert_eq!(binding.bound(0), Some(5));
363 }
364
365 #[test]
366 fn func() {
367 let p = SymbolDB::set_const("p".into());
368 let a = SymbolDB::set_const("a".into());
369
370 let heap = vec![
371 (Tag::Func, 2),
372 (Tag::Con, p),
373 (Tag::Con, a),
374 (Tag::Tup, 2),
375 (Tag::Con, p),
376 (Tag::Con, a),
377 (Tag::Ref, 6),
378 (Tag::Lis, 8),
379 (Tag::Con, p),
380 (Tag::ELis, 0),
381 ];
382
383 assert_eq!(unify(&heap, 0, 3), None);
384 assert_eq!(unify(&heap, 0, 4), None);
385 let binding = unify(&heap, 0, 6).unwrap();
386 assert_eq!(binding.bound(6), Some(0));
387 assert_eq!(unify(&heap, 0, 7), None);
388 }
389
390 #[test]
391 fn tup() {
392 let p = SymbolDB::set_const("p".into());
393 let a = SymbolDB::set_const("a".into());
394
395 let heap = vec![
396 (Tag::Tup, 2),
397 (Tag::Con, p),
398 (Tag::Con, a),
399 (Tag::Func, 2),
400 (Tag::Con, p),
401 (Tag::Con, a),
402 (Tag::Ref, 6),
403 (Tag::Lis, 8),
404 (Tag::Con, p),
405 (Tag::ELis, 0),
406 ];
407
408 assert_eq!(unify(&heap, 0, 3), None);
409 assert_eq!(unify(&heap, 0, 4), None);
410 let binding = unify(&heap, 0, 6).unwrap();
411 assert_eq!(binding.bound(6), Some(0));
412 assert_eq!(unify(&heap, 0, 7), None);
413 }
414
415 #[test]
416 fn list() {
417 let p = SymbolDB::set_const("p".into());
418 let a = SymbolDB::set_const("a".into());
419 let b = SymbolDB::set_const("b".into());
420 let c = SymbolDB::set_const("c".into());
421 let t = SymbolDB::set_const("t".into());
422
423 let heap = vec![
424 (Tag::Lis, 1), (Tag::Con, a), (Tag::Lis, 3), (Tag::Con, b), (Tag::Lis, 5), (Tag::Con, c), (Tag::ELis, 0), (Tag::Lis, 8), (Tag::Con, a), (Tag::Lis, 10), (Tag::Ref, 10), (Tag::Lis, 12), (Tag::Ref, 12), (Tag::ELis, 0), ];
439
440 let binding = unify(&heap, 0, 7).unwrap();
441 assert_eq!(binding.bound(10), Some(3));
442 assert_eq!(binding.bound(12), Some(5));
443
444 let heap = vec![
445 (Tag::Lis, 1), (Tag::Arg, 0), (Tag::Lis, 3), (Tag::Arg, 1), (Tag::Lis, 5), (Tag::Arg, 2), (Tag::Con, t), (Tag::Lis, 8), (Tag::Con, a), (Tag::Lis, 10), (Tag::Con, b), (Tag::Lis, 12), (Tag::Ref, 12), (Tag::Con, t), ];
460
461 let binding = unify(&heap, 0, 7).unwrap();
462 assert_eq!(binding.get_arg(0), Some(8));
463 assert_eq!(binding.get_arg(1), Some(10));
464 assert_eq!(binding.get_arg(2), Some(12));
465
466 let heap = vec![
467 (Tag::Lis, 1), (Tag::Arg, 0), (Tag::Lis, 3), (Tag::Arg, 1), (Tag::Lis, 5), (Tag::Arg, 2), (Tag::Arg, 3), (Tag::Lis, 8), (Tag::Ref, 8), (Tag::Lis, 10), (Tag::Ref, 10), (Tag::Lis, 12), (Tag::Ref, 12), (Tag::Ref, 13), ];
482
483 let binding = unify(&heap, 0, 7).unwrap();
484 assert_eq!(binding.get_arg(0), Some(8));
485 assert_eq!(binding.get_arg(1), Some(10));
486 assert_eq!(binding.get_arg(2), Some(12));
487 assert_eq!(binding.get_arg(3), Some(13));
488
489 let heap = vec![
490 (Tag::Func, 2), (Tag::Con, p), (Tag::Lis, 6), (Tag::Func, 2), (Tag::Con, p), (Tag::Lis, 12), (Tag::Arg, 0), (Tag::Lis, 8), (Tag::Arg, 1), (Tag::Lis, 10), (Tag::Arg, 2), (Tag::ELis, 0), (Tag::Ref, 12), (Tag::Lis, 14), (Tag::Ref, 14), (Tag::Lis, 16), (Tag::Ref, 16), (Tag::ELis, 0), ];
509
510 let binding = unify(&heap, 0, 3).unwrap();
511 assert_eq!(binding.get_arg(0), Some(12));
512 assert_eq!(binding.get_arg(1), Some(14));
513 assert_eq!(binding.get_arg(2), Some(16));
514
515 let heap = vec![
516 (Tag::Lis, 1), (Tag::Lis, 12), (Tag::Lis, 3), (Tag::Lis, 14), (Tag::Lis, 5), (Tag::Lis, 16), (Tag::ELis, 0), (Tag::Lis, 8), (Tag::Lis, 18), (Tag::Lis, 10), (Tag::Lis, 20), (Tag::Ref, 11), (Tag::Con, a), (Tag::ELis, 0), (Tag::Arg, 0), (Tag::ELis, 0), (Tag::Con, c), (Tag::ELis, 0), (Tag::Con, a), (Tag::ELis, 0), (Tag::Con, b), (Tag::ELis, 0), ];
539
540 let binding = unify(&heap, 0, 7).unwrap();
541 assert_eq!(binding.get_arg(0), Some(20));
542 assert_eq!(binding.bound(11), Some(4));
543 }
544
545 #[test]
546 fn integers() {
547 let prev = SymbolDB::set_const("prev".to_string());
548 let prog = vec![
549 (Tag::Func, 3),
550 (Tag::Con, prev),
551 (Tag::Int, 4),
552 (Tag::Int, 3),
553 ];
554 let mut heap = QueryHeap::new(&prog, None);
555 heap.cells.extend(vec![
557 (Tag::Func, 3),
558 (Tag::Ref, 5),
559 (Tag::Int, 4),
560 (Tag::Ref, 7),
561 ]);
562
563 let binding = unify(&heap, 0, 4).unwrap();
564 assert_eq!(binding.bound(5), Some(1));
565 assert_eq!(binding.bound(7), Some(3));
566 }
567}