1use crate::{
2 prove::{canonical as canon, prover::Integer},
3 Atom,
4};
5use std::{
6 fmt::{self, Debug, Display, Write},
7 ops,
8 vec::IntoIter,
9};
10
11#[derive(Clone, Debug, Hash)]
12pub struct ClauseDataset<T>(pub Vec<Clause<T>>);
13
14impl<T> IntoIterator for ClauseDataset<T> {
15 type Item = Clause<T>;
16 type IntoIter = IntoIter<Self::Item>;
17
18 fn into_iter(self) -> Self::IntoIter {
19 self.0.into_iter()
20 }
21}
22
23impl<T> ops::Deref for ClauseDataset<T> {
24 type Target = Vec<Clause<T>>;
25
26 fn deref(&self) -> &Self::Target {
27 &self.0
28 }
29}
30
31#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
32pub struct Clause<T> {
33 pub head: Term<T>,
34 pub body: Option<Expr<T>>,
35}
36
37impl<T> Clause<T> {
38 pub fn fact(head: Term<T>) -> Self {
39 Self { head, body: None }
40 }
41
42 pub fn rule(head: Term<T>, body: Expr<T>) -> Self {
43 Self {
44 head,
45 body: Some(body),
46 }
47 }
48
49 pub fn map<U, F: FnMut(T) -> U>(self, f: &mut F) -> Clause<U> {
50 Clause {
51 head: self.head.map(f),
52 body: self.body.map(|expr| expr.map(f)),
53 }
54 }
55
56 pub fn replace_term<F>(&mut self, f: &mut F)
57 where
58 F: FnMut(&Term<T>) -> Option<Term<T>>,
59 {
60 self.head.replace_all(f);
61 if let Some(body) = &mut self.body {
62 body.replace_term(f);
63 }
64 }
65}
66
67impl Clause<Integer> {
68 pub fn needs_tabling(&self) -> bool {
76 return if let Some(body) = &self.body {
77 let mut head = self.head.clone();
78 let mut body = body.clone();
79 canon::canonicalize_term(&mut head);
80 canon::canonicalize_expr_on_term(&mut body);
81 helper(&body.distribute_not(), &head)
82 } else {
83 false
84 };
85
86 fn helper(expr: &Expr<Integer>, head: &Term<Integer>) -> bool {
89 match expr {
90 Expr::Term(term) => term == head,
91 Expr::Not(arg) => helper(arg, head),
92 Expr::And(args) => {
93 if let Some((last, first)) = args.split_last() {
94 first.iter().any(|arg| helper(arg, head)) || helper(last, head)
95 } else {
96 false
97 }
98 }
99 Expr::Or(args) => args.iter().any(|arg| helper(arg, head)),
100 }
101 }
102 }
103}
104
105impl<T: Display> Display for Clause<T> {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 self.head.fmt(f)?;
108 if let Some(body) = &self.body {
109 f.write_str(" :- ")?;
110 body.fmt(f)?;
111 }
112 f.write_char('.')
113 }
114}
115
116#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
117pub struct Term<T> {
118 pub functor: T,
119 pub args: Vec<Term<T>>,
120}
121
122impl<T> Term<T> {
123 pub fn atom(functor: T) -> Self {
124 Term {
125 functor,
126 args: vec![],
127 }
128 }
129
130 pub fn compound<I: IntoIterator<Item = Term<T>>>(functor: T, args: I) -> Self {
131 Term {
132 functor,
133 args: args.into_iter().collect(),
134 }
135 }
136
137 pub fn map<U, F: FnMut(T) -> U>(self, f: &mut F) -> Term<U> {
138 Term {
139 functor: f(self.functor),
140 args: self.args.into_iter().map(|arg| arg.map(f)).collect(),
141 }
142 }
143
144 pub fn replace_all<F>(&mut self, f: &mut F) -> bool
145 where
146 F: FnMut(&Term<T>) -> Option<Term<T>>,
147 {
148 if let Some(new) = f(self) {
149 *self = new;
150 true
151 } else {
152 let mut replaced = false;
153 for arg in &mut self.args {
154 replaced |= arg.replace_all(f);
155 }
156 replaced
157 }
158 }
159}
160
161impl<T: Clone> Term<T> {
162 pub fn predicate(&self) -> Predicate<T> {
163 Predicate {
164 functor: self.functor.clone(),
165 arity: self.args.len() as u32,
166 }
167 }
168}
169
170impl<T: Atom> Term<T> {
171 pub fn is_variable(&self) -> bool {
172 let is_variable = self.functor.is_variable();
173
174 #[cfg(debug_assertions)]
175 if is_variable {
176 assert!(self.args.is_empty());
177 }
178
179 is_variable
180 }
181
182 pub fn contains_variable(&self) -> bool {
183 if self.is_variable() {
184 return true;
185 }
186
187 self.args.iter().any(|arg| arg.contains_variable())
188 }
189
190 pub fn replace_variables<F: FnMut(&mut T)>(&mut self, mut f: F) {
191 fn helper<T, F>(term: &mut Term<T>, f: &mut F)
192 where
193 T: Atom,
194 F: FnMut(&mut T),
195 {
196 if term.is_variable() {
197 f(&mut term.functor);
198 } else {
199 for arg in &mut term.args {
200 helper(arg, f);
201 }
202 }
203 }
204 helper(self, &mut f)
205 }
206}
207
208impl<T: Display> Display for Term<T> {
209 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210 fmt::Display::fmt(&self.functor, f)?;
211 if !self.args.is_empty() {
212 f.write_char('(')?;
213 for (i, arg) in self.args.iter().enumerate() {
214 arg.fmt(f)?;
215 if i + 1 < self.args.len() {
216 f.write_str(", ")?;
217 }
218 }
219 f.write_char(')')?;
220 }
221 Ok(())
222 }
223}
224
225#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
226pub enum Expr<T> {
227 Term(Term<T>),
228 Not(Box<Expr<T>>),
229 And(Vec<Expr<T>>),
230 Or(Vec<Expr<T>>),
231}
232
233impl<T> Expr<T> {
234 pub fn term(term: Term<T>) -> Self {
235 Self::Term(term)
236 }
237
238 pub fn term_atom(functor: T) -> Self {
239 Self::Term(Term::atom(functor))
240 }
241
242 pub fn term_compound<I: IntoIterator<Item = Term<T>>>(functor: T, args: I) -> Self {
243 Self::Term(Term::compound(functor, args))
244 }
245
246 pub fn expr_not(expr: Expr<T>) -> Self {
247 Self::Not(Box::new(expr))
248 }
249
250 pub fn expr_and<I: IntoIterator<Item = Expr<T>>>(args: I) -> Self {
251 Self::And(args.into_iter().collect())
252 }
253
254 pub fn expr_or<I: IntoIterator<Item = Expr<T>>>(args: I) -> Self {
255 Self::Or(args.into_iter().collect())
256 }
257
258 pub fn map<U, F: FnMut(T) -> U>(self, f: &mut F) -> Expr<U> {
259 match self {
260 Self::Term(term) => Expr::Term(term.map(f)),
261 Self::Not(arg) => Expr::Not(Box::new(arg.map(f))),
262 Self::And(args) => Expr::And(args.into_iter().map(|arg| arg.map(f)).collect()),
263 Self::Or(args) => Expr::Or(args.into_iter().map(|arg| arg.map(f)).collect()),
264 }
265 }
266
267 pub fn replace_term<F>(&mut self, f: &mut F)
268 where
269 F: FnMut(&Term<T>) -> Option<Term<T>>,
270 {
271 match self {
272 Self::Term(term) => {
273 term.replace_all(f);
274 }
275 Self::Not(inner) => inner.replace_term(f),
276 Self::And(args) | Self::Or(args) => {
277 for arg in args {
278 arg.replace_term(f);
279 }
280 }
281 }
282 }
283}
284
285impl<T: PartialEq> Expr<T> {
286 pub fn contains_term(&self, term: &Term<T>) -> bool {
287 match self {
288 Self::Term(t) => t == term,
289 Self::Not(arg) => arg.contains_term(term),
290 Self::And(args) | Self::Or(args) => args.iter().any(|arg| arg.contains_term(term)),
291 }
292 }
293
294 pub fn distribute_not(self) -> Self {
296 match self {
297 Self::Term(term) => Self::Term(term),
298 Self::Not(expr) => match *expr {
299 Self::Term(term) => Self::Not(Box::new(Self::Term(term))),
300 Self::Not(inner) => inner.distribute_not(),
301 Self::And(args) => Self::Or(
302 args.into_iter()
303 .map(|arg| Self::Not(Box::new(arg)).distribute_not())
304 .collect(),
305 ),
306 Self::Or(args) => Self::And(
307 args.into_iter()
308 .map(|arg| Self::Not(Box::new(arg)).distribute_not())
309 .collect(),
310 ),
311 },
312 Self::And(args) => Self::And(args.into_iter().map(Self::distribute_not).collect()),
313 Self::Or(args) => Self::Or(args.into_iter().map(Self::distribute_not).collect()),
314 }
315 }
316}
317
318impl<T: Display> Display for Expr<T> {
319 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
320 match self {
321 Self::Term(term) => term.fmt(f)?,
322 Self::Not(arg) => {
323 f.write_str("\\+ ")?;
324 if matches!(**arg, Self::And(_) | Self::Or(_)) {
325 f.write_char('(')?;
326 arg.fmt(f)?;
327 f.write_char(')')?;
328 } else {
329 arg.fmt(f)?;
330 }
331 }
332 Self::And(args) => {
333 for (i, arg) in args.iter().enumerate() {
334 if matches!(arg, Self::Or(_)) {
335 f.write_char('(')?;
336 arg.fmt(f)?;
337 f.write_char(')')?;
338 } else {
339 arg.fmt(f)?;
340 }
341 if i + 1 < args.len() {
342 f.write_str(", ")?;
343 }
344 }
345 }
346 Self::Or(args) => {
347 for (i, arg) in args.iter().enumerate() {
348 arg.fmt(f)?;
349 if i + 1 < args.len() {
350 f.write_str("; ")?;
351 }
352 }
353 }
354 }
355 Ok(())
356 }
357}
358
359#[derive(Debug, Clone, PartialEq, Eq, Hash)]
360pub struct Predicate<T> {
361 pub functor: T,
362 pub arity: u32,
363}
364
365#[cfg(test)]
366mod tests {
367 use super::{Expr, Term};
368
369 #[test]
370 fn distribute_not_applies_de_morgan() {
371 let expr = Expr::expr_not(Expr::expr_and([
372 Expr::term_atom("a"),
373 Expr::expr_or([Expr::term_atom("b"), Expr::term_atom("c")]),
374 ]));
375
376 let expected = Expr::expr_or([
377 Expr::expr_not(Expr::term_atom("a")),
378 Expr::expr_and([
379 Expr::expr_not(Expr::term_atom("b")),
380 Expr::expr_not(Expr::term_atom("c")),
381 ]),
382 ]);
383
384 assert_eq!(expr.distribute_not(), expected);
385 }
386
387 #[test]
388 fn distribute_not_removes_double_negation() {
389 let expr = Expr::expr_not(Expr::expr_not(Expr::term(Term::compound(
390 "f",
391 [Term::atom("x")],
392 ))));
393
394 assert_eq!(
395 expr.distribute_not(),
396 Expr::term(Term::compound("f", [Term::atom("x")]))
397 );
398 }
399}