1use std::{
2 cmp::Ordering,
3 ops::{Deref, DerefMut},
4};
5
6use crate::predicate_modules::PredicateFunction;
7
8use super::clause::Clause;
9
10pub(crate) type SymbolArity = (usize, usize);
12
13#[derive(PartialEq, Eq, Debug, Clone)]
15pub enum Predicate {
16 Function(PredicateFunction),
18 Clauses(Box<[Clause]>),
20}
21
22#[derive(PartialEq, Eq, Debug)]
24pub struct PredicateEntry {
25 symbol_arity: SymbolArity,
26 predicate: Predicate,
27}
28
29#[derive(Debug, PartialEq)]
34pub struct PredicateTable {
35 predicates: Vec<PredicateEntry>,
36 body_list: Vec<usize>,
37}
38
39#[derive(Debug, PartialEq, Eq)]
41enum FindReturn {
42 Index(usize),
43 InsertPos(usize),
44}
45
46impl PredicateTable {
47 pub fn new() -> Self {
48 PredicateTable {
49 predicates: vec![],
50 body_list: vec![],
51 }
52 }
53
54 fn find_predicate(&self, symbol_arity: SymbolArity) -> FindReturn {
56 let mut lb: usize = 0;
57 let mut ub: usize = self.len();
58 let mut mid: usize;
59
60 while ub > lb {
61 mid = (lb + ub) / 2;
62 match symbol_arity.cmp(&self[mid].symbol_arity) {
63 Ordering::Less => ub = mid,
64 Ordering::Equal => return FindReturn::Index(mid),
65 Ordering::Greater => lb = mid + 1,
66 }
67 }
68 FindReturn::InsertPos(lb)
69 }
70
71 pub fn insert_predicate_function(
73 &mut self,
74 symbol_arity: SymbolArity,
75 predicate_fn: PredicateFunction,
76 ) -> Result<(), &str> {
77 match self.find_predicate(symbol_arity) {
78 FindReturn::Index(idx) => match &mut self[idx].predicate {
79 Predicate::Function(old_predicate_fn) => {
80 *old_predicate_fn = predicate_fn;
81 Ok(())
82 }
83 _ => Err("Cannot insert predicate function to clause predicate"),
84 },
85 FindReturn::InsertPos(insert_idx) => {
86 self.insert(
87 insert_idx,
88 PredicateEntry {
89 symbol_arity,
90 predicate: Predicate::Function(predicate_fn),
91 },
92 );
93 Ok(())
94 }
95 }
96 }
97
98 pub fn add_clause_to_predicate(
100 &mut self,
101 clause: Clause,
102 symbol_arity: SymbolArity,
103 ) -> Result<(), &str> {
104 match self.find_predicate(symbol_arity) {
105 FindReturn::Index(idx) => match &mut self.get_mut(idx).unwrap().predicate {
106 Predicate::Function(_) => return Err("Cannot add clause to function predicate"),
107 Predicate::Clauses(clauses) => {
108 *clauses = [&**clauses, &[clause]].concat().into_boxed_slice();
109 }
110 },
111 FindReturn::InsertPos(insert_idx) => {
112 self.insert(
113 insert_idx,
114 PredicateEntry {
115 symbol_arity,
116 predicate: Predicate::Clauses(Box::new([clause])),
117 },
118 );
119 }
120 };
121 Ok(())
122 }
123
124 pub fn get_predicate(&self, symbol_arity: SymbolArity) -> Option<Predicate> {
126 match self.find_predicate(symbol_arity) {
127 FindReturn::Index(i) => match &self[i].predicate {
128 Predicate::Function(predicate_fn) => Some(Predicate::Function(*predicate_fn)),
129 Predicate::Clauses(clauses) => Some(Predicate::Clauses(clauses.clone())),
130 },
131 FindReturn::InsertPos(_) => None,
132 }
133 }
134
135 pub fn get_variable_clauses(&self, arity: usize) -> Option<&Box<[Clause]>> {
136 match self.find_predicate((0, arity)) {
137 FindReturn::Index(i) => match &self[i].predicate {
138 Predicate::Clauses(clauses) => Some(clauses),
139 _ => None,
140 },
141 _ => None,
142 }
143 }
144
145 pub fn _remove_predicate(&mut self, symbol_arity: SymbolArity) {
147 if let FindReturn::Index(predicate_idx) = self.find_predicate(symbol_arity) {
148 if let Predicate::Clauses(_clauses) = self.remove(predicate_idx).predicate {
149 self.body_list.retain(|i| *i != predicate_idx);
150 }
151 for i in &mut self.body_list {
152 if *i > predicate_idx {
153 println!("{i}");
154 *i -= 1;
155 }
156 }
157 }
158 }
159
160 pub fn set_body(&mut self, symbol_arity: SymbolArity, value: bool) -> Result<(), &str> {
162 match self.find_predicate(symbol_arity) {
163 FindReturn::Index(idx) => {
164 let predicate = &mut self[idx];
165 if matches!(predicate.predicate, Predicate::Function(_)) {
166 Err("Can't set predicate function to body")
167 } else {
168 if value == false {
169 self.body_list.retain(|&idx2| idx != idx2);
170 } else {
171 self.body_list.push(idx);
172 }
173 Ok(())
174 }
175 }
176 _ => Ok(()), }
178 }
179
180 pub fn get_body_clauses(&self, arity: usize) -> Vec<Clause> {
182 let mut body_clauses = vec![];
183
184 for &idx in &self.body_list {
185 if self[idx].symbol_arity.1 != arity {
186 continue;
187 }
188 if let Predicate::Clauses(pred_clauses) = &self[idx].predicate {
189 body_clauses.extend_from_slice(pred_clauses);
190 }
191 }
192
193 body_clauses
194 }
195}
196
197impl Deref for PredicateTable {
198 type Target = Vec<PredicateEntry>;
199
200 fn deref(&self) -> &Self::Target {
201 &self.predicates
202 }
203}
204
205impl DerefMut for PredicateTable {
206 fn deref_mut(&mut self) -> &mut Self::Target {
207 &mut self.predicates
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use crate::{
214 heap::{query_heap::QueryHeap, symbol_db::SymbolDB},
215 predicate_modules::PredReturn,
216 program::{hypothesis::Hypothesis, predicate_table::FindReturn},
217 Config,
218 };
219 use std::sync::Arc;
220
221 use super::{super::clause::Clause, Predicate, PredicateEntry, PredicateTable};
222
223 fn pred_fn_placeholder(
224 _heap: &mut QueryHeap,
225 _hypothesis: &mut Hypothesis,
226 _goal: usize,
227 _predicate_table: Arc<PredicateTable>,
228 _config: Config,
229 ) -> PredReturn {
230 PredReturn::True
231 }
232
233 fn setup() -> PredicateTable {
234 let p = SymbolDB::set_const("p".into());
235 let q = SymbolDB::set_const("q".into());
236 let pred_func = SymbolDB::set_const("func".into());
237
238 if q < p || q > pred_func {
239 panic!("q comes before p in predicate table tests");
240 }
241
242 PredicateTable {
243 predicates: vec![
244 PredicateEntry {
245 symbol_arity: (0, 2),
246 predicate: Predicate::Clauses(Box::new([
247 Clause::new(vec![0, 3], Some(vec![0, 1]), None),
248 Clause::new(vec![7, 11], Some(vec![0]), None),
249 ])),
250 },
251 PredicateEntry {
252 symbol_arity: (p, 2),
253 predicate: Predicate::Clauses(Box::new([
254 Clause::new(vec![15, 19], None, None),
255 Clause::new(vec![23, 27], None, None),
256 ])),
257 },
258 PredicateEntry {
259 symbol_arity: (q, 2),
260 predicate: Predicate::Clauses(Box::new([
261 Clause::new(vec![31, 35], None, None),
262 Clause::new(vec![39, 43], None, None),
263 ])),
264 },
265 PredicateEntry {
266 symbol_arity: (pred_func, 2),
267 predicate: Predicate::Function(pred_fn_placeholder),
268 },
269 ],
270 body_list: vec![1],
271 }
272 }
273
274 #[test]
275 fn find_predicate() {
276 let pred_table = setup();
277
278 let symbol = SymbolDB::set_const("find_predicate_test_symbol".into());
279 let p = SymbolDB::set_const("p".into());
280
281 assert_eq!(pred_table.find_predicate((0, 1)), FindReturn::InsertPos(0));
282 assert_eq!(
283 pred_table.find_predicate((symbol, 2)),
284 FindReturn::InsertPos(4)
285 );
286 assert_eq!(pred_table.find_predicate((p, 1)), FindReturn::InsertPos(1));
287 assert_eq!(pred_table.find_predicate((p, 2)), FindReturn::Index(1));
288
289 let pred_table = PredicateTable {
290 predicates: vec![],
291 body_list: vec![],
292 };
293
294 assert_eq!(pred_table.find_predicate((50, 2)), FindReturn::InsertPos(0));
295 }
296
297 #[test]
298 fn get_predicate() {
299 let pred_table = setup();
300 let p = SymbolDB::set_const("p".into());
301
302 assert_eq!(pred_table.get_predicate((p, 3)), None);
303 assert_eq!(
304 pred_table.get_predicate((p, 2)),
305 Some(Predicate::Clauses(Box::new([
306 Clause::new(vec![15, 19], None, None),
307 Clause::new(vec![23, 27], None, None),
308 ])))
309 );
310 }
311
312 #[test]
313 fn insert_predicate_function() {
314 let mut pred_table = setup();
315 let pred_func = SymbolDB::set_const("func".into());
316 let p = SymbolDB::set_const("p".into());
317
318 assert_eq!(
319 pred_table.insert_predicate_function((p, 2), pred_fn_placeholder),
320 Err("Cannot insert predicate function to clause predicate")
321 );
322
323 pred_table
324 .insert_predicate_function((pred_func, 3), pred_fn_placeholder)
325 .unwrap();
326 assert_eq!(
327 pred_table.get_predicate((pred_func, 3)),
328 Some(Predicate::Function(pred_fn_placeholder))
329 );
330 }
331
332 #[test]
333 fn add_clause_to_predicate() {
334 let mut pred_table = setup();
335 let p = SymbolDB::set_const("p".into());
336 let r = SymbolDB::set_const("r".into());
337 let pred_func = SymbolDB::set_const("func".into());
338
339 pred_table
340 .add_clause_to_predicate(Clause::new(vec![], Some(vec![]), None), (p, 2))
341 .unwrap();
342 pred_table
343 .add_clause_to_predicate(Clause::new(vec![], Some(vec![]), None), (r, 2))
344 .unwrap();
345 assert_eq!(
346 pred_table
347 .add_clause_to_predicate(Clause::new(vec![], Some(vec![]), None), (pred_func, 2)),
348 Err("Cannot add clause to function predicate")
349 );
350
351 assert_eq!(
352 pred_table.get_predicate((p, 2)),
353 Some(Predicate::Clauses(Box::new([
354 Clause::new(vec![15, 19], None, None),
355 Clause::new(vec![23, 27], None, None),
356 Clause::new(vec![], Some(vec![]), None)
357 ])))
358 );
359 assert_eq!(
360 pred_table.get_predicate((r, 2)),
361 Some(Predicate::Clauses(Box::new([Clause::new(
362 vec![],
363 Some(vec![]),
364 None
365 )])))
366 );
367 }
368
369 #[test]
370 fn remove_predicate() {
371 let mut pred_table = setup();
372 let p = SymbolDB::set_const("p".into());
373 let q = SymbolDB::set_const("q".into());
374 let pred_func = SymbolDB::set_const("func".into());
375
376 pred_table._remove_predicate((p, 2));
377
378 assert_eq!(
379 pred_table,
380 PredicateTable {
381 predicates: vec![
382 PredicateEntry {
383 symbol_arity: (0, 2),
384 predicate: Predicate::Clauses(Box::new([
385 Clause::new(vec![0, 3], Some(vec![0, 1]), None),
386 Clause::new(vec![7, 11], Some(vec![0]), None),
387 ])),
388 },
389 PredicateEntry {
390 symbol_arity: (q, 2),
391 predicate: Predicate::Clauses(Box::new([
392 Clause::new(vec![31, 35], None, None),
393 Clause::new(vec![39, 43], None, None),
394 ])),
395 },
396 PredicateEntry {
397 symbol_arity: (pred_func, 2),
398 predicate: Predicate::Function(pred_fn_placeholder),
399 },
400 ],
401 body_list: vec![],
402 }
403 );
404
405 let mut pred_table = setup();
406 pred_table._remove_predicate((q, 2));
407
408 assert_eq!(
409 pred_table,
410 PredicateTable {
411 predicates: vec![
412 PredicateEntry {
413 symbol_arity: (0, 2),
414 predicate: Predicate::Clauses(Box::new([
415 Clause::new(vec![0, 3], Some(vec![0, 1]), None),
416 Clause::new(vec![7, 11], Some(vec![0]), None),
417 ])),
418 },
419 PredicateEntry {
420 symbol_arity: (p, 2),
421 predicate: Predicate::Clauses(Box::new([
422 Clause::new(vec![15, 19], None, None),
423 Clause::new(vec![23, 27], None, None),
424 ])),
425 },
426 PredicateEntry {
427 symbol_arity: (pred_func, 2),
428 predicate: Predicate::Function(pred_fn_placeholder),
429 },
430 ],
431 body_list: vec![1],
432 }
433 );
434
435 let mut pred_table = setup();
436 pred_table._remove_predicate((0, 2));
437
438 assert_eq!(
439 pred_table,
440 PredicateTable {
441 predicates: vec![
442 PredicateEntry {
443 symbol_arity: (p, 2),
444 predicate: Predicate::Clauses(Box::new([
445 Clause::new(vec![15, 19], None, None),
446 Clause::new(vec![23, 27], None, None),
447 ])),
448 },
449 PredicateEntry {
450 symbol_arity: (q, 2),
451 predicate: Predicate::Clauses(Box::new([
452 Clause::new(vec![31, 35], None, None),
453 Clause::new(vec![39, 43], None, None),
454 ])),
455 },
456 PredicateEntry {
457 symbol_arity: (pred_func, 2),
458 predicate: Predicate::Function(pred_fn_placeholder),
459 },
460 ],
461 body_list: vec![0],
462 }
463 );
464 }
465
466 #[test]
467 fn set_body() {
468 let mut pred_table = setup();
469 let p = SymbolDB::set_const("p".into());
470 let q = SymbolDB::set_const("q".into());
471 let pred_func = SymbolDB::set_const("func".into());
472
473 pred_table.set_body((p, 2), false).unwrap();
474 pred_table.set_body((q, 2), true).unwrap();
475
476 assert_eq!(pred_table.body_list, [2]);
477 }
478
479 #[test]
480 fn get_body_clauses() {
481 let mut pred_table = setup();
482 let q = SymbolDB::set_const("q".into());
483
484 assert_eq!(pred_table.get_body_clauses(1), []);
485 assert_eq!(
486 pred_table.get_body_clauses(2),
487 [
488 Clause::new(vec![15, 19], None, None),
489 Clause::new(vec![23, 27], None, None),
490 ]
491 );
492
493 pred_table.set_body((q, 2), true).unwrap();
494
495 assert_eq!(
496 pred_table.get_body_clauses(2),
497 [
498 Clause::new(vec![15, 19], None, None),
499 Clause::new(vec![23, 27], None, None),
500 Clause::new(vec![31, 35], None, None),
501 Clause::new(vec![39, 43], None, None),
502 ]
503 );
504 }
505}