Skip to main content

logic_eval/parse/
repr.rs

1use crate::{
2    prove::{canonical as canon, prover::Integer},
3    Atom,
4};
5use std::{
6    fmt::{self, Debug, Display, Write},
7    ops,
8    vec::IntoIter,
9};
10
11#[derive(Clone, Debug, Hash)]
12pub struct ClauseDataset<T>(pub Vec<Clause<T>>);
13
14impl<T> IntoIterator for ClauseDataset<T> {
15    type Item = Clause<T>;
16    type IntoIter = IntoIter<Self::Item>;
17
18    fn into_iter(self) -> Self::IntoIter {
19        self.0.into_iter()
20    }
21}
22
23impl<T> ops::Deref for ClauseDataset<T> {
24    type Target = Vec<Clause<T>>;
25
26    fn deref(&self) -> &Self::Target {
27        &self.0
28    }
29}
30
31#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
32pub struct Clause<T> {
33    pub head: Term<T>,
34    pub body: Option<Expr<T>>,
35}
36
37impl<T> Clause<T> {
38    pub fn fact(head: Term<T>) -> Self {
39        Self { head, body: None }
40    }
41
42    pub fn rule(head: Term<T>, body: Expr<T>) -> Self {
43        Self {
44            head,
45            body: Some(body),
46        }
47    }
48
49    pub fn map<U, F: FnMut(T) -> U>(self, f: &mut F) -> Clause<U> {
50        Clause {
51            head: self.head.map(f),
52            body: self.body.map(|expr| expr.map(f)),
53        }
54    }
55
56    pub fn replace_term<F>(&mut self, f: &mut F)
57    where
58        F: FnMut(&Term<T>) -> Option<Term<T>>,
59    {
60        self.head.replace_all(f);
61        if let Some(body) = &mut self.body {
62            body.replace_term(f);
63        }
64    }
65}
66
67impl Clause<Integer> {
68    /// Returns true if the clause needs SLG resolution (tabling).
69    ///
70    /// If a clause has left or mid recursion, it must be handled by tabling.
71    ///
72    /// # Examples
73    /// foo(X, Y) :- foo(A, B) ...     // left recursion
74    /// foo(X, Y) :- ... foo(A, B) ... // mid recursion
75    pub fn needs_tabling(&self) -> bool {
76        return if let Some(body) = &self.body {
77            let mut head = self.head.clone();
78            let mut body = body.clone();
79            canon::canonicalize_term(&mut head);
80            canon::canonicalize_expr_on_term(&mut body);
81            helper(&body.distribute_not(), &head)
82        } else {
83            false
84        };
85
86        // === Internal helper functions ===
87
88        fn helper(expr: &Expr<Integer>, head: &Term<Integer>) -> bool {
89            match expr {
90                Expr::Term(term) => term == head,
91                Expr::Not(arg) => helper(arg, head),
92                Expr::And(args) => {
93                    if let Some((last, first)) = args.split_last() {
94                        first.iter().any(|arg| helper(arg, head)) || helper(last, head)
95                    } else {
96                        false
97                    }
98                }
99                Expr::Or(args) => args.iter().any(|arg| helper(arg, head)),
100            }
101        }
102    }
103}
104
105impl<T: Display> Display for Clause<T> {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        self.head.fmt(f)?;
108        if let Some(body) = &self.body {
109            f.write_str(" :- ")?;
110            body.fmt(f)?;
111        }
112        f.write_char('.')
113    }
114}
115
116#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
117pub struct Term<T> {
118    pub functor: T,
119    pub args: Vec<Term<T>>,
120}
121
122impl<T> Term<T> {
123    pub fn atom(functor: T) -> Self {
124        Term {
125            functor,
126            args: vec![],
127        }
128    }
129
130    pub fn compound<I: IntoIterator<Item = Term<T>>>(functor: T, args: I) -> Self {
131        Term {
132            functor,
133            args: args.into_iter().collect(),
134        }
135    }
136
137    pub fn map<U, F: FnMut(T) -> U>(self, f: &mut F) -> Term<U> {
138        Term {
139            functor: f(self.functor),
140            args: self.args.into_iter().map(|arg| arg.map(f)).collect(),
141        }
142    }
143
144    pub fn replace_all<F>(&mut self, f: &mut F) -> bool
145    where
146        F: FnMut(&Term<T>) -> Option<Term<T>>,
147    {
148        if let Some(new) = f(self) {
149            *self = new;
150            true
151        } else {
152            let mut replaced = false;
153            for arg in &mut self.args {
154                replaced |= arg.replace_all(f);
155            }
156            replaced
157        }
158    }
159}
160
161impl<T: Clone> Term<T> {
162    pub fn predicate(&self) -> Predicate<T> {
163        Predicate {
164            functor: self.functor.clone(),
165            arity: self.args.len() as u32,
166        }
167    }
168}
169
170impl<T: Atom> Term<T> {
171    pub fn is_variable(&self) -> bool {
172        let is_variable = self.functor.is_variable();
173
174        #[cfg(debug_assertions)]
175        if is_variable {
176            assert!(self.args.is_empty());
177        }
178
179        is_variable
180    }
181
182    pub fn contains_variable(&self) -> bool {
183        if self.is_variable() {
184            return true;
185        }
186
187        self.args.iter().any(|arg| arg.contains_variable())
188    }
189
190    pub fn replace_variables<F: FnMut(&mut T)>(&mut self, mut f: F) {
191        fn helper<T, F>(term: &mut Term<T>, f: &mut F)
192        where
193            T: Atom,
194            F: FnMut(&mut T),
195        {
196            if term.is_variable() {
197                f(&mut term.functor);
198            } else {
199                for arg in &mut term.args {
200                    helper(arg, f);
201                }
202            }
203        }
204        helper(self, &mut f)
205    }
206}
207
208impl<T: Display> Display for Term<T> {
209    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        fmt::Display::fmt(&self.functor, f)?;
211        if !self.args.is_empty() {
212            f.write_char('(')?;
213            for (i, arg) in self.args.iter().enumerate() {
214                arg.fmt(f)?;
215                if i + 1 < self.args.len() {
216                    f.write_str(", ")?;
217                }
218            }
219            f.write_char(')')?;
220        }
221        Ok(())
222    }
223}
224
225#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
226pub enum Expr<T> {
227    Term(Term<T>),
228    Not(Box<Expr<T>>),
229    And(Vec<Expr<T>>),
230    Or(Vec<Expr<T>>),
231}
232
233impl<T> Expr<T> {
234    pub fn term(term: Term<T>) -> Self {
235        Self::Term(term)
236    }
237
238    pub fn term_atom(functor: T) -> Self {
239        Self::Term(Term::atom(functor))
240    }
241
242    pub fn term_compound<I: IntoIterator<Item = Term<T>>>(functor: T, args: I) -> Self {
243        Self::Term(Term::compound(functor, args))
244    }
245
246    pub fn expr_not(expr: Expr<T>) -> Self {
247        Self::Not(Box::new(expr))
248    }
249
250    pub fn expr_and<I: IntoIterator<Item = Expr<T>>>(args: I) -> Self {
251        Self::And(args.into_iter().collect())
252    }
253
254    pub fn expr_or<I: IntoIterator<Item = Expr<T>>>(args: I) -> Self {
255        Self::Or(args.into_iter().collect())
256    }
257
258    pub fn map<U, F: FnMut(T) -> U>(self, f: &mut F) -> Expr<U> {
259        match self {
260            Self::Term(term) => Expr::Term(term.map(f)),
261            Self::Not(arg) => Expr::Not(Box::new(arg.map(f))),
262            Self::And(args) => Expr::And(args.into_iter().map(|arg| arg.map(f)).collect()),
263            Self::Or(args) => Expr::Or(args.into_iter().map(|arg| arg.map(f)).collect()),
264        }
265    }
266
267    pub fn replace_term<F>(&mut self, f: &mut F)
268    where
269        F: FnMut(&Term<T>) -> Option<Term<T>>,
270    {
271        match self {
272            Self::Term(term) => {
273                term.replace_all(f);
274            }
275            Self::Not(inner) => inner.replace_term(f),
276            Self::And(args) | Self::Or(args) => {
277                for arg in args {
278                    arg.replace_term(f);
279                }
280            }
281        }
282    }
283}
284
285impl<T: PartialEq> Expr<T> {
286    pub fn contains_term(&self, term: &Term<T>) -> bool {
287        match self {
288            Self::Term(t) => t == term,
289            Self::Not(arg) => arg.contains_term(term),
290            Self::And(args) | Self::Or(args) => args.iter().any(|arg| arg.contains_term(term)),
291        }
292    }
293
294    /// e.g. ¬(A ∧ (B ∨ C)) -> ¬A ∨ (¬B ∧ ¬C)
295    pub fn distribute_not(self) -> Self {
296        match self {
297            Self::Term(term) => Self::Term(term),
298            Self::Not(expr) => match *expr {
299                Self::Term(term) => Self::Not(Box::new(Self::Term(term))),
300                Self::Not(inner) => inner.distribute_not(),
301                Self::And(args) => Self::Or(
302                    args.into_iter()
303                        .map(|arg| Self::Not(Box::new(arg)).distribute_not())
304                        .collect(),
305                ),
306                Self::Or(args) => Self::And(
307                    args.into_iter()
308                        .map(|arg| Self::Not(Box::new(arg)).distribute_not())
309                        .collect(),
310                ),
311            },
312            Self::And(args) => Self::And(args.into_iter().map(Self::distribute_not).collect()),
313            Self::Or(args) => Self::Or(args.into_iter().map(Self::distribute_not).collect()),
314        }
315    }
316}
317
318impl<T: Display> Display for Expr<T> {
319    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
320        match self {
321            Self::Term(term) => term.fmt(f)?,
322            Self::Not(arg) => {
323                f.write_str("\\+ ")?;
324                if matches!(**arg, Self::And(_) | Self::Or(_)) {
325                    f.write_char('(')?;
326                    arg.fmt(f)?;
327                    f.write_char(')')?;
328                } else {
329                    arg.fmt(f)?;
330                }
331            }
332            Self::And(args) => {
333                for (i, arg) in args.iter().enumerate() {
334                    if matches!(arg, Self::Or(_)) {
335                        f.write_char('(')?;
336                        arg.fmt(f)?;
337                        f.write_char(')')?;
338                    } else {
339                        arg.fmt(f)?;
340                    }
341                    if i + 1 < args.len() {
342                        f.write_str(", ")?;
343                    }
344                }
345            }
346            Self::Or(args) => {
347                for (i, arg) in args.iter().enumerate() {
348                    arg.fmt(f)?;
349                    if i + 1 < args.len() {
350                        f.write_str("; ")?;
351                    }
352                }
353            }
354        }
355        Ok(())
356    }
357}
358
359#[derive(Debug, Clone, PartialEq, Eq, Hash)]
360pub struct Predicate<T> {
361    pub functor: T,
362    pub arity: u32,
363}
364
365#[cfg(test)]
366mod tests {
367    use super::{Expr, Term};
368
369    #[test]
370    fn distribute_not_applies_de_morgan() {
371        let expr = Expr::expr_not(Expr::expr_and([
372            Expr::term_atom("a"),
373            Expr::expr_or([Expr::term_atom("b"), Expr::term_atom("c")]),
374        ]));
375
376        let expected = Expr::expr_or([
377            Expr::expr_not(Expr::term_atom("a")),
378            Expr::expr_and([
379                Expr::expr_not(Expr::term_atom("b")),
380                Expr::expr_not(Expr::term_atom("c")),
381            ]),
382        ]);
383
384        assert_eq!(expr.distribute_not(), expected);
385    }
386
387    #[test]
388    fn distribute_not_removes_double_negation() {
389        let expr = Expr::expr_not(Expr::expr_not(Expr::term(Term::compound(
390            "f",
391            [Term::atom("x")],
392        ))));
393
394        assert_eq!(
395            expr.distribute_not(),
396            Expr::term(Term::compound("f", [Term::atom("x")]))
397        );
398    }
399}