Skip to main content

prolog2/program/
predicate_table.rs

1use std::{
2    cmp::Ordering,
3    ops::{Deref, DerefMut},
4};
5
6use crate::predicate_modules::PredicateFunction;
7
8use super::clause::Clause;
9
10/// A `(symbol_id, arity)` pair identifying a predicate.
11pub(crate) type SymbolArity = (usize, usize);
12
13/// A predicate is either a set of compiled clauses or a built-in function.
14#[allow(unpredictable_function_pointer_comparisons)]
15#[derive(PartialEq, Eq, Debug, Clone)]
16pub enum Predicate {
17    /// A native Rust predicate function.
18    Function(PredicateFunction),
19    /// One or more compiled Prolog clauses.
20    Clauses(Box<[Clause]>),
21}
22
23/// Internal entry in the predicate table.
24#[derive(PartialEq, Eq, Debug)]
25pub struct PredicateEntry {
26    symbol_arity: SymbolArity,
27    predicate: Predicate,
28}
29
30/// The program's predicate table.
31///
32/// Maps `(symbol, arity)` pairs to predicates (clause sets or built-in functions).
33/// Also tracks which predicates are designated as body predicates for MIL learning.
34#[derive(Debug, PartialEq)]
35pub struct PredicateTable {
36    predicates: Vec<PredicateEntry>,
37    body_list: Vec<usize>,
38}
39
40//Return type for binary search of predicate keys
41#[derive(Debug, PartialEq, Eq)]
42enum FindReturn {
43    Index(usize),
44    InsertPos(usize),
45}
46
47impl PredicateTable {
48    pub fn new() -> Self {
49        PredicateTable {
50            predicates: vec![],
51            body_list: vec![],
52        }
53    }
54
55    //Performs a binary search of the ordered predicate table.
56    fn find_predicate(&self, symbol_arity: SymbolArity) -> FindReturn {
57        let mut lb: usize = 0;
58        let mut ub: usize = self.len();
59        let mut mid: usize;
60
61        while ub > lb {
62            mid = (lb + ub) / 2;
63            match symbol_arity.cmp(&self[mid].symbol_arity) {
64                Ordering::Less => ub = mid,
65                Ordering::Equal => return FindReturn::Index(mid),
66                Ordering::Greater => lb = mid + 1,
67            }
68        }
69        FindReturn::InsertPos(lb)
70    }
71
72    //Inserts a new predicate function to the table
73    pub fn insert_predicate_function(
74        &mut self,
75        symbol_arity: SymbolArity,
76        predicate_fn: PredicateFunction,
77    ) -> Result<(), &str> {
78        match self.find_predicate(symbol_arity) {
79            FindReturn::Index(idx) => match &mut self[idx].predicate {
80                Predicate::Function(old_predicate_fn) => {
81                    *old_predicate_fn = predicate_fn;
82                    Ok(())
83                }
84                _ => Err("Cannot insert predicate function to clause predicate"),
85            },
86            FindReturn::InsertPos(insert_idx) => {
87                self.insert(
88                    insert_idx,
89                    PredicateEntry {
90                        symbol_arity,
91                        predicate: Predicate::Function(predicate_fn),
92                    },
93                );
94                Ok(())
95            }
96        }
97    }
98
99    //Adds a clause to an existing enrty or creates a new entry with a single clause
100    pub fn add_clause_to_predicate(
101        &mut self,
102        clause: Clause,
103        symbol_arity: SymbolArity,
104    ) -> Result<(), &str> {
105        match self.find_predicate(symbol_arity) {
106            FindReturn::Index(idx) => match &mut self.get_mut(idx).unwrap().predicate {
107                Predicate::Function(_) => return Err("Cannot add clause to function predicate"),
108                Predicate::Clauses(clauses) => {
109                    *clauses = [&**clauses, &[clause]].concat().into_boxed_slice();
110                }
111            },
112            FindReturn::InsertPos(insert_idx) => {
113                self.insert(
114                    insert_idx,
115                    PredicateEntry {
116                        symbol_arity,
117                        predicate: Predicate::Clauses(Box::new([clause])),
118                    },
119                );
120            }
121        };
122        Ok(())
123    }
124
125    //Get predicate by SymbolArity key
126    pub fn get_predicate(&self, symbol_arity: SymbolArity) -> Option<&Predicate> {
127        match self.find_predicate(symbol_arity) {
128            FindReturn::Index(i) => Some(&self[i].predicate),
129            FindReturn::InsertPos(_) => None,
130        }
131    }
132
133    pub fn get_variable_clauses(&self, arity: usize) -> Option<&Box<[Clause]>> {
134        match self.find_predicate((0, arity)) {
135            FindReturn::Index(i) => match &self[i].predicate {
136                Predicate::Clauses(clauses) => Some(clauses),
137                _ => None,
138            },
139            _ => None,
140        }
141    }
142
143    //Remove predicate by SymbolArity key, if clause predicate return the range to remove from clause table
144    pub fn _remove_predicate(&mut self, symbol_arity: SymbolArity) {
145        if let FindReturn::Index(predicate_idx) = self.find_predicate(symbol_arity) {
146            if let Predicate::Clauses(_clauses) = self.remove(predicate_idx).predicate {
147                self.body_list.retain(|i| *i != predicate_idx);
148            }
149            for i in &mut self.body_list {
150                if *i > predicate_idx {
151                    println!("{i}");
152                    *i -= 1;
153                }
154            }
155        }
156    }
157
158    //Remove or add entry index from the body predicate list
159    pub fn set_body(&mut self, symbol_arity: SymbolArity, value: bool) -> Result<(), &str> {
160        match self.find_predicate(symbol_arity) {
161            FindReturn::Index(idx) => {
162                let predicate = &mut self[idx];
163                if matches!(predicate.predicate, Predicate::Function(_)) {
164                    Err("Can't set predicate function to body")
165                } else {
166                    if value == false {
167                        self.body_list.retain(|&idx2| idx != idx2);
168                    } else {
169                        self.body_list.push(idx);
170                    }
171                    Ok(())
172                }
173            }
174            _ => Ok(()), //Err("Can't set non existing predicate to body"),
175        }
176    }
177
178    //Get all clause index ranges from entry indexes in the body_list
179    pub fn get_body_clauses(&self, arity: usize) -> impl Iterator<Item = &Clause> {
180        self.body_list
181            .iter()
182            .filter_map(move |&idx| {
183                if self[idx].symbol_arity.1 != arity {
184                    return None;
185                }
186                if let Predicate::Clauses(pred_clauses) = &self[idx].predicate {
187                    Some(pred_clauses.iter())
188                } else {
189                    None
190                }
191            })
192            .flatten()
193    }
194}
195
196impl Deref for PredicateTable {
197    type Target = Vec<PredicateEntry>;
198
199    fn deref(&self) -> &Self::Target {
200        &self.predicates
201    }
202}
203
204impl DerefMut for PredicateTable {
205    fn deref_mut(&mut self) -> &mut Self::Target {
206        &mut self.predicates
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::{super::clause::Clause, Predicate, PredicateEntry, PredicateTable};
213    use crate::{
214        heap::{query_heap::QueryHeap, symbol_db::SymbolDB},
215        predicate_modules::PredReturn,
216        program::{hypothesis::Hypothesis, predicate_table::FindReturn},
217        Config,
218    };
219
220    fn pred_fn_placeholder(
221        _heap: &mut QueryHeap,
222        _hypothesis: &mut Hypothesis,
223        _goal: usize,
224        _predicate_table: &PredicateTable,
225        _config: Config,
226    ) -> PredReturn {
227        PredReturn::True
228    }
229
230    /// Build a test predicate table with four entries sorted by symbol_arity.
231    /// Returns (table, p, q, pred_func) so tests can use the actual symbol IDs.
232    fn setup() -> (PredicateTable, usize, usize, usize) {
233        let p = SymbolDB::set_const("p".into());
234        let q = SymbolDB::set_const("q".into());
235        let pred_func = SymbolDB::set_const("func".into());
236
237        let p_entry = PredicateEntry {
238            symbol_arity: (p, 2),
239            predicate: Predicate::Clauses(Box::new([
240                Clause::new(vec![15, 19], None, None),
241                Clause::new(vec![23, 27], None, None),
242            ])),
243        };
244        let q_entry = PredicateEntry {
245            symbol_arity: (q, 2),
246            predicate: Predicate::Clauses(Box::new([
247                Clause::new(vec![31, 35], None, None),
248                Clause::new(vec![39, 43], None, None),
249            ])),
250        };
251        let func_entry = PredicateEntry {
252            symbol_arity: (pred_func, 2),
253            predicate: Predicate::Function(pred_fn_placeholder),
254        };
255        let zero_entry = PredicateEntry {
256            symbol_arity: (0, 2),
257            predicate: Predicate::Clauses(Box::new([
258                Clause::new(vec![0, 3], Some(vec![0, 1]), None),
259                Clause::new(vec![7, 11], Some(vec![0]), None),
260            ])),
261        };
262
263        let mut predicates = vec![zero_entry, p_entry, q_entry, func_entry];
264        predicates.sort_by_key(|e| e.symbol_arity);
265
266        // body_list should point to the index of (p, 2) after sorting
267        let p_idx = predicates
268            .iter()
269            .position(|e| e.symbol_arity == (p, 2))
270            .unwrap();
271
272        (
273            PredicateTable {
274                predicates,
275                body_list: vec![p_idx],
276            },
277            p,
278            q,
279            pred_func,
280        )
281    }
282
283    #[test]
284    fn find_predicate() {
285        let (pred_table, p, _q, _pred_func) = setup();
286
287        let symbol = SymbolDB::set_const("find_predicate_test_symbol".into());
288        let p_idx = pred_table
289            .iter()
290            .position(|e| e.symbol_arity == (p, 2))
291            .unwrap();
292
293        assert_eq!(pred_table.find_predicate((0, 1)), FindReturn::InsertPos(0));
294        assert_eq!(pred_table.find_predicate((p, 2)), FindReturn::Index(p_idx));
295
296        // A symbol larger than all entries should go at the end
297        assert_eq!(
298            pred_table.find_predicate((symbol, 2)),
299            if symbol > pred_table.last().unwrap().symbol_arity.0 {
300                FindReturn::InsertPos(pred_table.len())
301            } else {
302                pred_table.find_predicate((symbol, 2))
303            }
304        );
305
306        // Same symbol, different arity should get an insert position
307        assert_eq!(
308            pred_table.find_predicate((p, 1)),
309            FindReturn::InsertPos(p_idx)
310        );
311
312        let pred_table = PredicateTable {
313            predicates: vec![],
314            body_list: vec![],
315        };
316
317        assert_eq!(pred_table.find_predicate((50, 2)), FindReturn::InsertPos(0));
318    }
319
320    #[test]
321    fn get_predicate() {
322        let (pred_table, p, _q, _pred_func) = setup();
323
324        assert_eq!(pred_table.get_predicate((p, 3)), None);
325        assert_eq!(
326            pred_table.get_predicate((p, 2)),
327            Some(&Predicate::Clauses(Box::new([
328                Clause::new(vec![15, 19], None, None),
329                Clause::new(vec![23, 27], None, None),
330            ])))
331        );
332    }
333
334    #[test]
335    fn insert_predicate_function() {
336        let (mut pred_table, p, _q, pred_func) = setup();
337
338        assert_eq!(
339            pred_table.insert_predicate_function((p, 2), pred_fn_placeholder),
340            Err("Cannot insert predicate function to clause predicate")
341        );
342
343        pred_table
344            .insert_predicate_function((pred_func, 3), pred_fn_placeholder)
345            .unwrap();
346        assert_eq!(
347            pred_table.get_predicate((pred_func, 3)),
348            Some(&Predicate::Function(pred_fn_placeholder))
349        );
350    }
351
352    #[test]
353    fn add_clause_to_predicate() {
354        let (mut pred_table, p, _q, pred_func) = setup();
355        let r = SymbolDB::set_const("r".into());
356
357        pred_table
358            .add_clause_to_predicate(Clause::new(vec![], Some(vec![]), None), (p, 2))
359            .unwrap();
360        pred_table
361            .add_clause_to_predicate(Clause::new(vec![], Some(vec![]), None), (r, 2))
362            .unwrap();
363        assert_eq!(
364            pred_table
365                .add_clause_to_predicate(Clause::new(vec![], Some(vec![]), None), (pred_func, 2)),
366            Err("Cannot add clause to function predicate")
367        );
368
369        assert_eq!(
370            pred_table.get_predicate((p, 2)),
371            Some(&Predicate::Clauses(Box::new([
372                Clause::new(vec![15, 19], None, None),
373                Clause::new(vec![23, 27], None, None),
374                Clause::new(vec![], Some(vec![]), None)
375            ])))
376        );
377        assert_eq!(
378            pred_table.get_predicate((r, 2)),
379            Some(&Predicate::Clauses(Box::new([Clause::new(
380                vec![],
381                Some(vec![]),
382                None
383            )])))
384        );
385    }
386
387    #[test]
388    fn remove_predicate() {
389        // Test removing p
390        let (mut pred_table, p, _q, _pred_func) = setup();
391        let len_before = pred_table.len();
392        pred_table._remove_predicate((p, 2));
393        assert_eq!(pred_table.len(), len_before - 1);
394        assert_eq!(pred_table.get_predicate((p, 2)), None);
395        // body_list should be cleared since p was the body predicate
396        assert!(
397            pred_table.body_list.is_empty()
398                || pred_table
399                    .body_list
400                    .iter()
401                    .all(|&idx| pred_table[idx].symbol_arity != (p, 2))
402        );
403
404        // Test removing q
405        let (mut pred_table, _p, q, _pred_func) = setup();
406        let len_before = pred_table.len();
407        pred_table._remove_predicate((q, 2));
408        assert_eq!(pred_table.len(), len_before - 1);
409        assert_eq!(pred_table.get_predicate((q, 2)), None);
410
411        // Test removing entry at symbol 0
412        let (mut pred_table, p, q, _pred_func) = setup();
413        let len_before = pred_table.len();
414        pred_table._remove_predicate((0, 2));
415        assert_eq!(pred_table.len(), len_before - 1);
416        assert_eq!(pred_table.get_predicate((0, 2)), None);
417        // p and q should still be present
418        assert!(pred_table.get_predicate((p, 2)).is_some());
419        assert!(pred_table.get_predicate((q, 2)).is_some());
420    }
421
422    #[test]
423    fn set_body() {
424        let (mut pred_table, p, q, _pred_func) = setup();
425
426        // Remove p from body list
427        pred_table.set_body((p, 2), false).unwrap();
428        // Add q to body list
429        pred_table.set_body((q, 2), true).unwrap();
430
431        let q_idx = pred_table
432            .iter()
433            .position(|e| e.symbol_arity == (q, 2))
434            .unwrap();
435        assert_eq!(pred_table.body_list, [q_idx]);
436    }
437
438    #[test]
439    fn get_body_clauses() {
440        let (mut pred_table, _p, q, _pred_func) = setup();
441
442        // No body clauses for arity 1
443        let empty: Vec<&Clause> = pred_table.get_body_clauses(1).collect();
444        assert!(empty.is_empty());
445
446        // Initially p is the body predicate (set in setup)
447        let body2: Vec<&Clause> = pred_table.get_body_clauses(2).collect();
448        assert_eq!(
449            body2,
450            vec![
451                &Clause::new(vec![15, 19], None, None),
452                &Clause::new(vec![23, 27], None, None),
453            ]
454        );
455
456        // Add q as body predicate too
457        pred_table.set_body((q, 2), true).unwrap();
458
459        let body2_ext: Vec<&Clause> = pred_table.get_body_clauses(2).collect();
460        assert_eq!(body2_ext.len(), 4);
461        // Should contain both p's and q's clauses
462        assert!(body2_ext.contains(&&Clause::new(vec![15, 19], None, None)));
463        assert!(body2_ext.contains(&&Clause::new(vec![23, 27], None, None)));
464        assert!(body2_ext.contains(&&Clause::new(vec![31, 35], None, None)));
465        assert!(body2_ext.contains(&&Clause::new(vec![39, 43], None, None)));
466    }
467}