Skip to main content

prolog2/program/
predicate_table.rs

1use std::{
2    cmp::Ordering,
3    ops::{Deref, DerefMut},
4};
5
6use crate::predicate_modules::PredicateFunction;
7
8use super::clause::Clause;
9
10/// A `(symbol_id, arity)` pair identifying a predicate.
11pub(crate) type SymbolArity = (usize, usize);
12
13/// A predicate is either a set of compiled clauses or a built-in function.
14#[derive(PartialEq, Eq, Debug, Clone)]
15pub enum Predicate {
16    /// A native Rust predicate function.
17    Function(PredicateFunction),
18    /// One or more compiled Prolog clauses.
19    Clauses(Box<[Clause]>),
20}
21
22/// Internal entry in the predicate table.
23#[derive(PartialEq, Eq, Debug)]
24pub struct PredicateEntry {
25    symbol_arity: SymbolArity,
26    predicate: Predicate,
27}
28
29/// The program's predicate table.
30///
31/// Maps `(symbol, arity)` pairs to predicates (clause sets or built-in functions).
32/// Also tracks which predicates are designated as body predicates for MIL learning.
33#[derive(Debug, PartialEq)]
34pub struct PredicateTable {
35    predicates: Vec<PredicateEntry>,
36    body_list: Vec<usize>,
37}
38
39//Return type for binary search of predicate keys
40#[derive(Debug, PartialEq, Eq)]
41enum FindReturn {
42    Index(usize),
43    InsertPos(usize),
44}
45
46impl PredicateTable {
47    pub fn new() -> Self {
48        PredicateTable {
49            predicates: vec![],
50            body_list: vec![],
51        }
52    }
53
54    //Performs a binary search of the ordered predicate table.
55    fn find_predicate(&self, symbol_arity: SymbolArity) -> FindReturn {
56        let mut lb: usize = 0;
57        let mut ub: usize = self.len();
58        let mut mid: usize;
59
60        while ub > lb {
61            mid = (lb + ub) / 2;
62            match symbol_arity.cmp(&self[mid].symbol_arity) {
63                Ordering::Less => ub = mid,
64                Ordering::Equal => return FindReturn::Index(mid),
65                Ordering::Greater => lb = mid + 1,
66            }
67        }
68        FindReturn::InsertPos(lb)
69    }
70
71    //Inserts a new predicate function to the table
72    pub fn insert_predicate_function(
73        &mut self,
74        symbol_arity: SymbolArity,
75        predicate_fn: PredicateFunction,
76    ) -> Result<(), &str> {
77        match self.find_predicate(symbol_arity) {
78            FindReturn::Index(idx) => match &mut self[idx].predicate {
79                Predicate::Function(old_predicate_fn) => {
80                    *old_predicate_fn = predicate_fn;
81                    Ok(())
82                }
83                _ => Err("Cannot insert predicate function to clause predicate"),
84            },
85            FindReturn::InsertPos(insert_idx) => {
86                self.insert(
87                    insert_idx,
88                    PredicateEntry {
89                        symbol_arity,
90                        predicate: Predicate::Function(predicate_fn),
91                    },
92                );
93                Ok(())
94            }
95        }
96    }
97
98    //Adds a clause to an existing enrty or creates a new entry with a single clause
99    pub fn add_clause_to_predicate(
100        &mut self,
101        clause: Clause,
102        symbol_arity: SymbolArity,
103    ) -> Result<(), &str> {
104        match self.find_predicate(symbol_arity) {
105            FindReturn::Index(idx) => match &mut self.get_mut(idx).unwrap().predicate {
106                Predicate::Function(_) => return Err("Cannot add clause to function predicate"),
107                Predicate::Clauses(clauses) => {
108                    *clauses = [&**clauses, &[clause]].concat().into_boxed_slice();
109                }
110            },
111            FindReturn::InsertPos(insert_idx) => {
112                self.insert(
113                    insert_idx,
114                    PredicateEntry {
115                        symbol_arity,
116                        predicate: Predicate::Clauses(Box::new([clause])),
117                    },
118                );
119            }
120        };
121        Ok(())
122    }
123
124    //Get predicate by SymbolArity key
125    pub fn get_predicate(&self, symbol_arity: SymbolArity) -> Option<Predicate> {
126        match self.find_predicate(symbol_arity) {
127            FindReturn::Index(i) => match &self[i].predicate {
128                Predicate::Function(predicate_fn) => Some(Predicate::Function(*predicate_fn)),
129                Predicate::Clauses(clauses) => Some(Predicate::Clauses(clauses.clone())),
130            },
131            FindReturn::InsertPos(_) => None,
132        }
133    }
134
135    pub fn get_variable_clauses(&self, arity: usize) -> Option<&Box<[Clause]>> {
136        match self.find_predicate((0, arity)) {
137            FindReturn::Index(i) => match &self[i].predicate {
138                Predicate::Clauses(clauses) => Some(clauses),
139                _ => None,
140            },
141            _ => None,
142        }
143    }
144
145    //Remove predicate by SymbolArity key, if clause predicate return the range to remove from clause table
146    pub fn _remove_predicate(&mut self, symbol_arity: SymbolArity) {
147        if let FindReturn::Index(predicate_idx) = self.find_predicate(symbol_arity) {
148            if let Predicate::Clauses(_clauses) = self.remove(predicate_idx).predicate {
149                self.body_list.retain(|i| *i != predicate_idx);
150            }
151            for i in &mut self.body_list {
152                if *i > predicate_idx {
153                    println!("{i}");
154                    *i -= 1;
155                }
156            }
157        }
158    }
159
160    //Remove or add entry index from the body predicate list
161    pub fn set_body(&mut self, symbol_arity: SymbolArity, value: bool) -> Result<(), &str> {
162        match self.find_predicate(symbol_arity) {
163            FindReturn::Index(idx) => {
164                let predicate = &mut self[idx];
165                if matches!(predicate.predicate, Predicate::Function(_)) {
166                    Err("Can't set predicate function to body")
167                } else {
168                    if value == false {
169                        self.body_list.retain(|&idx2| idx != idx2);
170                    } else {
171                        self.body_list.push(idx);
172                    }
173                    Ok(())
174                }
175            }
176            _ => Ok(()), //Err("Can't set non existing predicate to body"),
177        }
178    }
179
180    //Get all clause index ranges from entry indexes in the body_list
181    pub fn get_body_clauses(&self, arity: usize) -> Vec<Clause> {
182        let mut body_clauses = vec![];
183
184        for &idx in &self.body_list {
185            if self[idx].symbol_arity.1 != arity {
186                continue;
187            }
188            if let Predicate::Clauses(pred_clauses) = &self[idx].predicate {
189                body_clauses.extend_from_slice(pred_clauses);
190            }
191        }
192
193        body_clauses
194    }
195}
196
197impl Deref for PredicateTable {
198    type Target = Vec<PredicateEntry>;
199
200    fn deref(&self) -> &Self::Target {
201        &self.predicates
202    }
203}
204
205impl DerefMut for PredicateTable {
206    fn deref_mut(&mut self) -> &mut Self::Target {
207        &mut self.predicates
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use crate::{
214        heap::{query_heap::QueryHeap, symbol_db::SymbolDB},
215        predicate_modules::PredReturn,
216        program::{hypothesis::Hypothesis, predicate_table::FindReturn},
217        Config,
218    };
219    use std::sync::Arc;
220
221    use super::{super::clause::Clause, Predicate, PredicateEntry, PredicateTable};
222
223    fn pred_fn_placeholder(
224        _heap: &mut QueryHeap,
225        _hypothesis: &mut Hypothesis,
226        _goal: usize,
227        _predicate_table: Arc<PredicateTable>,
228        _config: Config,
229    ) -> PredReturn {
230        PredReturn::True
231    }
232
233    fn setup() -> PredicateTable {
234        let p = SymbolDB::set_const("p".into());
235        let q = SymbolDB::set_const("q".into());
236        let pred_func = SymbolDB::set_const("func".into());
237
238        if q < p || q > pred_func {
239            panic!("q comes before p in predicate table tests");
240        }
241
242        PredicateTable {
243            predicates: vec![
244                PredicateEntry {
245                    symbol_arity: (0, 2),
246                    predicate: Predicate::Clauses(Box::new([
247                        Clause::new(vec![0, 3], Some(vec![0, 1]), None),
248                        Clause::new(vec![7, 11], Some(vec![0]), None),
249                    ])),
250                },
251                PredicateEntry {
252                    symbol_arity: (p, 2),
253                    predicate: Predicate::Clauses(Box::new([
254                        Clause::new(vec![15, 19], None, None),
255                        Clause::new(vec![23, 27], None, None),
256                    ])),
257                },
258                PredicateEntry {
259                    symbol_arity: (q, 2),
260                    predicate: Predicate::Clauses(Box::new([
261                        Clause::new(vec![31, 35], None, None),
262                        Clause::new(vec![39, 43], None, None),
263                    ])),
264                },
265                PredicateEntry {
266                    symbol_arity: (pred_func, 2),
267                    predicate: Predicate::Function(pred_fn_placeholder),
268                },
269            ],
270            body_list: vec![1],
271        }
272    }
273
274    #[test]
275    fn find_predicate() {
276        let pred_table = setup();
277
278        let symbol = SymbolDB::set_const("find_predicate_test_symbol".into());
279        let p = SymbolDB::set_const("p".into());
280
281        assert_eq!(pred_table.find_predicate((0, 1)), FindReturn::InsertPos(0));
282        assert_eq!(
283            pred_table.find_predicate((symbol, 2)),
284            FindReturn::InsertPos(4)
285        );
286        assert_eq!(pred_table.find_predicate((p, 1)), FindReturn::InsertPos(1));
287        assert_eq!(pred_table.find_predicate((p, 2)), FindReturn::Index(1));
288
289        let pred_table = PredicateTable {
290            predicates: vec![],
291            body_list: vec![],
292        };
293
294        assert_eq!(pred_table.find_predicate((50, 2)), FindReturn::InsertPos(0));
295    }
296
297    #[test]
298    fn get_predicate() {
299        let pred_table = setup();
300        let p = SymbolDB::set_const("p".into());
301
302        assert_eq!(pred_table.get_predicate((p, 3)), None);
303        assert_eq!(
304            pred_table.get_predicate((p, 2)),
305            Some(Predicate::Clauses(Box::new([
306                Clause::new(vec![15, 19], None, None),
307                Clause::new(vec![23, 27], None, None),
308            ])))
309        );
310    }
311
312    #[test]
313    fn insert_predicate_function() {
314        let mut pred_table = setup();
315        let pred_func = SymbolDB::set_const("func".into());
316        let p = SymbolDB::set_const("p".into());
317
318        assert_eq!(
319            pred_table.insert_predicate_function((p, 2), pred_fn_placeholder),
320            Err("Cannot insert predicate function to clause predicate")
321        );
322
323        pred_table
324            .insert_predicate_function((pred_func, 3), pred_fn_placeholder)
325            .unwrap();
326        assert_eq!(
327            pred_table.get_predicate((pred_func, 3)),
328            Some(Predicate::Function(pred_fn_placeholder))
329        );
330    }
331
332    #[test]
333    fn add_clause_to_predicate() {
334        let mut pred_table = setup();
335        let p = SymbolDB::set_const("p".into());
336        let r = SymbolDB::set_const("r".into());
337        let pred_func = SymbolDB::set_const("func".into());
338
339        pred_table
340            .add_clause_to_predicate(Clause::new(vec![], Some(vec![]), None), (p, 2))
341            .unwrap();
342        pred_table
343            .add_clause_to_predicate(Clause::new(vec![], Some(vec![]), None), (r, 2))
344            .unwrap();
345        assert_eq!(
346            pred_table
347                .add_clause_to_predicate(Clause::new(vec![], Some(vec![]), None), (pred_func, 2)),
348            Err("Cannot add clause to function predicate")
349        );
350
351        assert_eq!(
352            pred_table.get_predicate((p, 2)),
353            Some(Predicate::Clauses(Box::new([
354                Clause::new(vec![15, 19], None, None),
355                Clause::new(vec![23, 27], None, None),
356                Clause::new(vec![], Some(vec![]), None)
357            ])))
358        );
359        assert_eq!(
360            pred_table.get_predicate((r, 2)),
361            Some(Predicate::Clauses(Box::new([Clause::new(
362                vec![],
363                Some(vec![]),
364                None
365            )])))
366        );
367    }
368
369    #[test]
370    fn remove_predicate() {
371        let mut pred_table = setup();
372        let p = SymbolDB::set_const("p".into());
373        let q = SymbolDB::set_const("q".into());
374        let pred_func = SymbolDB::set_const("func".into());
375
376        pred_table._remove_predicate((p, 2));
377
378        assert_eq!(
379            pred_table,
380            PredicateTable {
381                predicates: vec![
382                    PredicateEntry {
383                        symbol_arity: (0, 2),
384                        predicate: Predicate::Clauses(Box::new([
385                            Clause::new(vec![0, 3], Some(vec![0, 1]), None),
386                            Clause::new(vec![7, 11], Some(vec![0]), None),
387                        ])),
388                    },
389                    PredicateEntry {
390                        symbol_arity: (q, 2),
391                        predicate: Predicate::Clauses(Box::new([
392                            Clause::new(vec![31, 35], None, None),
393                            Clause::new(vec![39, 43], None, None),
394                        ])),
395                    },
396                    PredicateEntry {
397                        symbol_arity: (pred_func, 2),
398                        predicate: Predicate::Function(pred_fn_placeholder),
399                    },
400                ],
401                body_list: vec![],
402            }
403        );
404
405        let mut pred_table = setup();
406        pred_table._remove_predicate((q, 2));
407
408        assert_eq!(
409            pred_table,
410            PredicateTable {
411                predicates: vec![
412                    PredicateEntry {
413                        symbol_arity: (0, 2),
414                        predicate: Predicate::Clauses(Box::new([
415                            Clause::new(vec![0, 3], Some(vec![0, 1]), None),
416                            Clause::new(vec![7, 11], Some(vec![0]), None),
417                        ])),
418                    },
419                    PredicateEntry {
420                        symbol_arity: (p, 2),
421                        predicate: Predicate::Clauses(Box::new([
422                            Clause::new(vec![15, 19], None, None),
423                            Clause::new(vec![23, 27], None, None),
424                        ])),
425                    },
426                    PredicateEntry {
427                        symbol_arity: (pred_func, 2),
428                        predicate: Predicate::Function(pred_fn_placeholder),
429                    },
430                ],
431                body_list: vec![1],
432            }
433        );
434
435        let mut pred_table = setup();
436        pred_table._remove_predicate((0, 2));
437
438        assert_eq!(
439            pred_table,
440            PredicateTable {
441                predicates: vec![
442                    PredicateEntry {
443                        symbol_arity: (p, 2),
444                        predicate: Predicate::Clauses(Box::new([
445                            Clause::new(vec![15, 19], None, None),
446                            Clause::new(vec![23, 27], None, None),
447                        ])),
448                    },
449                    PredicateEntry {
450                        symbol_arity: (q, 2),
451                        predicate: Predicate::Clauses(Box::new([
452                            Clause::new(vec![31, 35], None, None),
453                            Clause::new(vec![39, 43], None, None),
454                        ])),
455                    },
456                    PredicateEntry {
457                        symbol_arity: (pred_func, 2),
458                        predicate: Predicate::Function(pred_fn_placeholder),
459                    },
460                ],
461                body_list: vec![0],
462            }
463        );
464    }
465
466    #[test]
467    fn set_body() {
468        let mut pred_table = setup();
469        let p = SymbolDB::set_const("p".into());
470        let q = SymbolDB::set_const("q".into());
471        let pred_func = SymbolDB::set_const("func".into());
472
473        pred_table.set_body((p, 2), false).unwrap();
474        pred_table.set_body((q, 2), true).unwrap();
475
476        assert_eq!(pred_table.body_list, [2]);
477    }
478
479    #[test]
480    fn get_body_clauses() {
481        let mut pred_table = setup();
482        let q = SymbolDB::set_const("q".into());
483
484        assert_eq!(pred_table.get_body_clauses(1), []);
485        assert_eq!(
486            pred_table.get_body_clauses(2),
487            [
488                Clause::new(vec![15, 19], None, None),
489                Clause::new(vec![23, 27], None, None),
490            ]
491        );
492
493        pred_table.set_body((q, 2), true).unwrap();
494
495        assert_eq!(
496            pred_table.get_body_clauses(2),
497            [
498                Clause::new(vec![15, 19], None, None),
499                Clause::new(vec![23, 27], None, None),
500                Clause::new(vec![31, 35], None, None),
501                Clause::new(vec![39, 43], None, None),
502            ]
503        );
504    }
505}