1use std::{
2 cmp::Ordering,
3 ops::{Deref, DerefMut},
4};
5
6use crate::predicate_modules::PredicateFunction;
7
8use super::clause::Clause;
9
10pub(crate) type SymbolArity = (usize, usize);
12
13#[allow(unpredictable_function_pointer_comparisons)]
15#[derive(PartialEq, Eq, Debug, Clone)]
16pub enum Predicate {
17 Function(PredicateFunction),
19 Clauses(Box<[Clause]>),
21}
22
23#[derive(PartialEq, Eq, Debug)]
25pub struct PredicateEntry {
26 symbol_arity: SymbolArity,
27 predicate: Predicate,
28}
29
30#[derive(Debug, PartialEq)]
35pub struct PredicateTable {
36 predicates: Vec<PredicateEntry>,
37 body_list: Vec<usize>,
38}
39
40#[derive(Debug, PartialEq, Eq)]
42enum FindReturn {
43 Index(usize),
44 InsertPos(usize),
45}
46
47impl PredicateTable {
48 pub fn new() -> Self {
49 PredicateTable {
50 predicates: vec![],
51 body_list: vec![],
52 }
53 }
54
55 fn find_predicate(&self, symbol_arity: SymbolArity) -> FindReturn {
57 let mut lb: usize = 0;
58 let mut ub: usize = self.len();
59 let mut mid: usize;
60
61 while ub > lb {
62 mid = (lb + ub) / 2;
63 match symbol_arity.cmp(&self[mid].symbol_arity) {
64 Ordering::Less => ub = mid,
65 Ordering::Equal => return FindReturn::Index(mid),
66 Ordering::Greater => lb = mid + 1,
67 }
68 }
69 FindReturn::InsertPos(lb)
70 }
71
72 pub fn insert_predicate_function(
74 &mut self,
75 symbol_arity: SymbolArity,
76 predicate_fn: PredicateFunction,
77 ) -> Result<(), &str> {
78 match self.find_predicate(symbol_arity) {
79 FindReturn::Index(idx) => match &mut self[idx].predicate {
80 Predicate::Function(old_predicate_fn) => {
81 *old_predicate_fn = predicate_fn;
82 Ok(())
83 }
84 _ => Err("Cannot insert predicate function to clause predicate"),
85 },
86 FindReturn::InsertPos(insert_idx) => {
87 self.insert(
88 insert_idx,
89 PredicateEntry {
90 symbol_arity,
91 predicate: Predicate::Function(predicate_fn),
92 },
93 );
94 Ok(())
95 }
96 }
97 }
98
99 pub fn add_clause_to_predicate(
101 &mut self,
102 clause: Clause,
103 symbol_arity: SymbolArity,
104 ) -> Result<(), &str> {
105 match self.find_predicate(symbol_arity) {
106 FindReturn::Index(idx) => match &mut self.get_mut(idx).unwrap().predicate {
107 Predicate::Function(_) => return Err("Cannot add clause to function predicate"),
108 Predicate::Clauses(clauses) => {
109 *clauses = [&**clauses, &[clause]].concat().into_boxed_slice();
110 }
111 },
112 FindReturn::InsertPos(insert_idx) => {
113 self.insert(
114 insert_idx,
115 PredicateEntry {
116 symbol_arity,
117 predicate: Predicate::Clauses(Box::new([clause])),
118 },
119 );
120 }
121 };
122 Ok(())
123 }
124
125 pub fn get_predicate(&self, symbol_arity: SymbolArity) -> Option<&Predicate> {
127 match self.find_predicate(symbol_arity) {
128 FindReturn::Index(i) => Some(&self[i].predicate),
129 FindReturn::InsertPos(_) => None,
130 }
131 }
132
133 pub fn get_variable_clauses(&self, arity: usize) -> Option<&Box<[Clause]>> {
134 match self.find_predicate((0, arity)) {
135 FindReturn::Index(i) => match &self[i].predicate {
136 Predicate::Clauses(clauses) => Some(clauses),
137 _ => None,
138 },
139 _ => None,
140 }
141 }
142
143 pub fn _remove_predicate(&mut self, symbol_arity: SymbolArity) {
145 if let FindReturn::Index(predicate_idx) = self.find_predicate(symbol_arity) {
146 if let Predicate::Clauses(_clauses) = self.remove(predicate_idx).predicate {
147 self.body_list.retain(|i| *i != predicate_idx);
148 }
149 for i in &mut self.body_list {
150 if *i > predicate_idx {
151 println!("{i}");
152 *i -= 1;
153 }
154 }
155 }
156 }
157
158 pub fn set_body(&mut self, symbol_arity: SymbolArity, value: bool) -> Result<(), &str> {
160 match self.find_predicate(symbol_arity) {
161 FindReturn::Index(idx) => {
162 let predicate = &mut self[idx];
163 if matches!(predicate.predicate, Predicate::Function(_)) {
164 Err("Can't set predicate function to body")
165 } else {
166 if value == false {
167 self.body_list.retain(|&idx2| idx != idx2);
168 } else {
169 self.body_list.push(idx);
170 }
171 Ok(())
172 }
173 }
174 _ => Ok(()), }
176 }
177
178 pub fn get_body_clauses(&self, arity: usize) -> impl Iterator<Item = &Clause> {
180 self.body_list
181 .iter()
182 .filter_map(move |&idx| {
183 if self[idx].symbol_arity.1 != arity {
184 return None;
185 }
186 if let Predicate::Clauses(pred_clauses) = &self[idx].predicate {
187 Some(pred_clauses.iter())
188 } else {
189 None
190 }
191 })
192 .flatten()
193 }
194}
195
196impl Deref for PredicateTable {
197 type Target = Vec<PredicateEntry>;
198
199 fn deref(&self) -> &Self::Target {
200 &self.predicates
201 }
202}
203
204impl DerefMut for PredicateTable {
205 fn deref_mut(&mut self) -> &mut Self::Target {
206 &mut self.predicates
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::{super::clause::Clause, Predicate, PredicateEntry, PredicateTable};
213 use crate::{
214 heap::{query_heap::QueryHeap, symbol_db::SymbolDB},
215 predicate_modules::PredReturn,
216 program::{hypothesis::Hypothesis, predicate_table::FindReturn},
217 Config,
218 };
219
220 fn pred_fn_placeholder(
221 _heap: &mut QueryHeap,
222 _hypothesis: &mut Hypothesis,
223 _goal: usize,
224 _predicate_table: &PredicateTable,
225 _config: Config,
226 ) -> PredReturn {
227 PredReturn::True
228 }
229
230 fn setup() -> (PredicateTable, usize, usize, usize) {
233 let p = SymbolDB::set_const("p".into());
234 let q = SymbolDB::set_const("q".into());
235 let pred_func = SymbolDB::set_const("func".into());
236
237 let p_entry = PredicateEntry {
238 symbol_arity: (p, 2),
239 predicate: Predicate::Clauses(Box::new([
240 Clause::new(vec![15, 19], None, None),
241 Clause::new(vec![23, 27], None, None),
242 ])),
243 };
244 let q_entry = PredicateEntry {
245 symbol_arity: (q, 2),
246 predicate: Predicate::Clauses(Box::new([
247 Clause::new(vec![31, 35], None, None),
248 Clause::new(vec![39, 43], None, None),
249 ])),
250 };
251 let func_entry = PredicateEntry {
252 symbol_arity: (pred_func, 2),
253 predicate: Predicate::Function(pred_fn_placeholder),
254 };
255 let zero_entry = PredicateEntry {
256 symbol_arity: (0, 2),
257 predicate: Predicate::Clauses(Box::new([
258 Clause::new(vec![0, 3], Some(vec![0, 1]), None),
259 Clause::new(vec![7, 11], Some(vec![0]), None),
260 ])),
261 };
262
263 let mut predicates = vec![zero_entry, p_entry, q_entry, func_entry];
264 predicates.sort_by_key(|e| e.symbol_arity);
265
266 let p_idx = predicates
268 .iter()
269 .position(|e| e.symbol_arity == (p, 2))
270 .unwrap();
271
272 (
273 PredicateTable {
274 predicates,
275 body_list: vec![p_idx],
276 },
277 p,
278 q,
279 pred_func,
280 )
281 }
282
283 #[test]
284 fn find_predicate() {
285 let (pred_table, p, _q, _pred_func) = setup();
286
287 let symbol = SymbolDB::set_const("find_predicate_test_symbol".into());
288 let p_idx = pred_table
289 .iter()
290 .position(|e| e.symbol_arity == (p, 2))
291 .unwrap();
292
293 assert_eq!(pred_table.find_predicate((0, 1)), FindReturn::InsertPos(0));
294 assert_eq!(pred_table.find_predicate((p, 2)), FindReturn::Index(p_idx));
295
296 assert_eq!(
298 pred_table.find_predicate((symbol, 2)),
299 if symbol > pred_table.last().unwrap().symbol_arity.0 {
300 FindReturn::InsertPos(pred_table.len())
301 } else {
302 pred_table.find_predicate((symbol, 2))
303 }
304 );
305
306 assert_eq!(
308 pred_table.find_predicate((p, 1)),
309 FindReturn::InsertPos(p_idx)
310 );
311
312 let pred_table = PredicateTable {
313 predicates: vec![],
314 body_list: vec![],
315 };
316
317 assert_eq!(pred_table.find_predicate((50, 2)), FindReturn::InsertPos(0));
318 }
319
320 #[test]
321 fn get_predicate() {
322 let (pred_table, p, _q, _pred_func) = setup();
323
324 assert_eq!(pred_table.get_predicate((p, 3)), None);
325 assert_eq!(
326 pred_table.get_predicate((p, 2)),
327 Some(&Predicate::Clauses(Box::new([
328 Clause::new(vec![15, 19], None, None),
329 Clause::new(vec![23, 27], None, None),
330 ])))
331 );
332 }
333
334 #[test]
335 fn insert_predicate_function() {
336 let (mut pred_table, p, _q, pred_func) = setup();
337
338 assert_eq!(
339 pred_table.insert_predicate_function((p, 2), pred_fn_placeholder),
340 Err("Cannot insert predicate function to clause predicate")
341 );
342
343 pred_table
344 .insert_predicate_function((pred_func, 3), pred_fn_placeholder)
345 .unwrap();
346 assert_eq!(
347 pred_table.get_predicate((pred_func, 3)),
348 Some(&Predicate::Function(pred_fn_placeholder))
349 );
350 }
351
352 #[test]
353 fn add_clause_to_predicate() {
354 let (mut pred_table, p, _q, pred_func) = setup();
355 let r = SymbolDB::set_const("r".into());
356
357 pred_table
358 .add_clause_to_predicate(Clause::new(vec![], Some(vec![]), None), (p, 2))
359 .unwrap();
360 pred_table
361 .add_clause_to_predicate(Clause::new(vec![], Some(vec![]), None), (r, 2))
362 .unwrap();
363 assert_eq!(
364 pred_table
365 .add_clause_to_predicate(Clause::new(vec![], Some(vec![]), None), (pred_func, 2)),
366 Err("Cannot add clause to function predicate")
367 );
368
369 assert_eq!(
370 pred_table.get_predicate((p, 2)),
371 Some(&Predicate::Clauses(Box::new([
372 Clause::new(vec![15, 19], None, None),
373 Clause::new(vec![23, 27], None, None),
374 Clause::new(vec![], Some(vec![]), None)
375 ])))
376 );
377 assert_eq!(
378 pred_table.get_predicate((r, 2)),
379 Some(&Predicate::Clauses(Box::new([Clause::new(
380 vec![],
381 Some(vec![]),
382 None
383 )])))
384 );
385 }
386
387 #[test]
388 fn remove_predicate() {
389 let (mut pred_table, p, _q, _pred_func) = setup();
391 let len_before = pred_table.len();
392 pred_table._remove_predicate((p, 2));
393 assert_eq!(pred_table.len(), len_before - 1);
394 assert_eq!(pred_table.get_predicate((p, 2)), None);
395 assert!(
397 pred_table.body_list.is_empty()
398 || pred_table
399 .body_list
400 .iter()
401 .all(|&idx| pred_table[idx].symbol_arity != (p, 2))
402 );
403
404 let (mut pred_table, _p, q, _pred_func) = setup();
406 let len_before = pred_table.len();
407 pred_table._remove_predicate((q, 2));
408 assert_eq!(pred_table.len(), len_before - 1);
409 assert_eq!(pred_table.get_predicate((q, 2)), None);
410
411 let (mut pred_table, p, q, _pred_func) = setup();
413 let len_before = pred_table.len();
414 pred_table._remove_predicate((0, 2));
415 assert_eq!(pred_table.len(), len_before - 1);
416 assert_eq!(pred_table.get_predicate((0, 2)), None);
417 assert!(pred_table.get_predicate((p, 2)).is_some());
419 assert!(pred_table.get_predicate((q, 2)).is_some());
420 }
421
422 #[test]
423 fn set_body() {
424 let (mut pred_table, p, q, _pred_func) = setup();
425
426 pred_table.set_body((p, 2), false).unwrap();
428 pred_table.set_body((q, 2), true).unwrap();
430
431 let q_idx = pred_table
432 .iter()
433 .position(|e| e.symbol_arity == (q, 2))
434 .unwrap();
435 assert_eq!(pred_table.body_list, [q_idx]);
436 }
437
438 #[test]
439 fn get_body_clauses() {
440 let (mut pred_table, _p, q, _pred_func) = setup();
441
442 let empty: Vec<&Clause> = pred_table.get_body_clauses(1).collect();
444 assert!(empty.is_empty());
445
446 let body2: Vec<&Clause> = pred_table.get_body_clauses(2).collect();
448 assert_eq!(
449 body2,
450 vec![
451 &Clause::new(vec![15, 19], None, None),
452 &Clause::new(vec![23, 27], None, None),
453 ]
454 );
455
456 pred_table.set_body((q, 2), true).unwrap();
458
459 let body2_ext: Vec<&Clause> = pred_table.get_body_clauses(2).collect();
460 assert_eq!(body2_ext.len(), 4);
461 assert!(body2_ext.contains(&&Clause::new(vec![15, 19], None, None)));
463 assert!(body2_ext.contains(&&Clause::new(vec![23, 27], None, None)));
464 assert!(body2_ext.contains(&&Clause::new(vec![31, 35], None, None)));
465 assert!(body2_ext.contains(&&Clause::new(vec![39, 43], None, None)));
466 }
467}