1use std::collections::HashMap;
16
17use bumpalo::Bump;
18
19#[derive(Debug)]
30pub struct Unit<'a> {
31    pub decls: &'a [&'a Decl<'a>],
32    pub clauses: &'a [&'a Clause<'a>],
33}
34
35#[derive(Debug, Clone, Copy, PartialEq)]
37pub struct Decl<'a> {
38    pub atom: &'a Atom<'a>,
39    pub descr: &'a [&'a Atom<'a>],
40    pub bounds: Option<&'a [&'a BoundDecl<'a>]>,
41    pub constraints: Option<&'a Constraints<'a>>,
42}
43
44#[derive(Debug, PartialEq)]
45pub struct BoundDecl<'a> {
46    pub base_terms: &'a [&'a BaseTerm<'a>],
47}
48
49#[derive(Debug, Clone, PartialEq)]
51pub struct Constraints<'a> {
52    pub consequences: &'a [&'a Atom<'a>],
54    pub alternatives: &'a [&'a [&'a Atom<'a>]],
56}
57
58#[derive(Debug)]
59pub struct Clause<'a> {
60    pub head: &'a Atom<'a>,
61    pub premises: &'a [&'a Term<'a>],
62    pub transform: &'a [&'a TransformStmt<'a>],
63}
64
65#[derive(Debug)]
66pub struct TransformStmt<'a> {
67    pub var: Option<&'a str>,
68    pub app: &'a BaseTerm<'a>,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum Term<'a> {
74    Atom(&'a Atom<'a>),
75    NegAtom(&'a Atom<'a>),
76    Eq(&'a BaseTerm<'a>, &'a BaseTerm<'a>),
77    Ineq(&'a BaseTerm<'a>, &'a BaseTerm<'a>),
78}
79
80impl<'a> std::fmt::Display for Term<'a> {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        match self {
83            Term::Atom(atom) => write!(f, "{atom}"),
84            Term::NegAtom(atom) => write!(f, "!{atom}"),
85            Term::Eq(left, right) => write!(f, "{left} = {right}"),
86            Term::Ineq(left, right) => write!(f, "{left} != {right}"),
87        }
88    }
89}
90
91impl<'a> Term<'a> {
92    pub fn apply_subst<'b>(
93        &'a self,
94        bump: &'b Bump,
95        subst: &HashMap<&'a str, &'a BaseTerm<'a>>,
96    ) -> &'b Term<'b> {
97        &*bump.alloc(match self {
98            Term::Atom(atom) => Term::Atom(atom.apply_subst(bump, subst)),
99            Term::NegAtom(atom) => Term::NegAtom(atom.apply_subst(bump, subst)),
100            Term::Eq(left, right) => Term::Eq(
101                left.apply_subst(bump, subst),
102                right.apply_subst(bump, subst),
103            ),
104            Term::Ineq(left, right) => Term::Ineq(
105                left.apply_subst(bump, subst),
106                right.apply_subst(bump, subst),
107            ),
108        })
109    }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum BaseTerm<'a> {
114    Const(Const<'a>),
115    Variable(&'a str),
116    ApplyFn(FunctionSym<'a>, &'a [&'a BaseTerm<'a>]),
117}
118
119impl<'a> BaseTerm<'a> {
120    pub fn apply_subst<'b>(
121        &'a self,
122        bump: &'b Bump,
123        subst: &HashMap<&'a str, &'a BaseTerm<'a>>,
124    ) -> &'b BaseTerm<'b> {
125        match self {
126            BaseTerm::Const(_) => copy_base_term(bump, self),
127            BaseTerm::Variable(v) => subst
128                .get(v)
129                .map_or(copy_base_term(bump, self), |b| copy_base_term(bump, b)),
130            BaseTerm::ApplyFn(fun, args) => {
131                let args: Vec<&'b BaseTerm<'b>> = args
132                    .iter()
133                    .map(|arg| arg.apply_subst(bump, subst))
134                    .collect();
135                copy_base_term(bump, &BaseTerm::ApplyFn(*fun, &args))
136            }
137        }
138    }
139}
140
141impl<'a> std::fmt::Display for BaseTerm<'a> {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        match self {
144            BaseTerm::Const(c) => write!(f, "{c}"),
145            BaseTerm::Variable(v) => write!(f, "{v}"),
146            BaseTerm::ApplyFn(FunctionSym { name: n, .. }, args) => write!(
147                f,
148                "{n}({})",
149                args.iter()
150                    .map(|x| x.to_string())
151                    .collect::<Vec<_>>()
152                    .join(",")
153            ),
154        }
155    }
156}
157#[derive(Debug, Clone, Copy, PartialEq)]
158pub enum Const<'a> {
159    Name(&'a str),
160    Bool(bool),
161    Number(i64),
162    Float(f64),
163    String(&'a str),
164    Bytes(&'a [u8]),
165    List(&'a [&'a Const<'a>]),
166    Map {
167        keys: &'a [&'a Const<'a>],
168        values: &'a [&'a Const<'a>],
169    },
170    Struct {
171        fields: &'a [&'a str],
172        values: &'a [&'a Const<'a>],
173    },
174}
175
176impl<'a> Eq for Const<'a> {}
177
178impl<'a> std::fmt::Display for Const<'a> {
179    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180        match *self {
181            Const::Name(v) => write!(f, "{v}"),
182            Const::Bool(v) => write!(f, "{v}"),
183            Const::Number(v) => write!(f, "{v}"),
184            Const::Float(v) => write!(f, "{v}"),
185            Const::String(v) => write!(f, "{v}"),
186            Const::Bytes(v) => write!(f, "{:?}", v),
187            Const::List(v) => {
188                write!(
189                    f,
190                    "[{}]",
191                    v.iter()
192                        .map(|x| x.to_string())
193                        .collect::<Vec<_>>()
194                        .join(", ")
195                )
196            }
197            Const::Map { keys: _, values: _ } => write!(f, "{{...}}"),
198            Const::Struct {
199                fields: _,
200                values: _,
201            } => write!(f, "{{...}}"),
202        }
203    }
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
207pub struct PredicateSym<'a> {
208    pub name: &'a str,
209    pub arity: Option<u8>,
210}
211
212#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
213pub struct FunctionSym<'a> {
214    pub name: &'a str,
215    pub arity: Option<u8>,
216}
217
218#[derive(Debug, Clone, Copy, PartialEq, Eq)]
219pub struct Atom<'a> {
220    pub sym: PredicateSym<'a>,
221
222    pub args: &'a [&'a BaseTerm<'a>],
223}
224
225impl<'a> Atom<'a> {
226    pub fn matches(&'a self, query_args: &[&BaseTerm]) -> bool {
230        for (fact_arg, query_arg) in self.args.iter().zip(query_args.iter()) {
231            if let BaseTerm::Const(_) = query_arg {
232                if fact_arg != query_arg {
233                    return false;
234                }
235            }
236        }
237        true
238    }
239
240    pub fn apply_subst<'b>(
241        &'a self,
242        bump: &'b Bump,
243        subst: &HashMap<&'a str, &'a BaseTerm<'a>>,
244    ) -> &'b Atom<'b> {
245        let args: Vec<&'b BaseTerm<'b>> = self
246            .args
247            .iter()
248            .map(|arg| arg.apply_subst(bump, subst))
249            .collect();
250        let args = &*bump.alloc_slice_copy(&args);
251        bump.alloc(Atom {
252            sym: copy_predicate_sym(bump, self.sym),
253            args,
254        })
255    }
256}
257
258impl<'a> std::fmt::Display for Atom<'a> {
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        write!(f, "{}(", self.sym.name)?;
261        for arg in self.args {
262            write!(f, "{arg}")?;
263        }
264        write!(f, ")")
265    }
266}
267
268pub fn copy_predicate_sym<'dest>(bump: &'dest Bump, p: PredicateSym) -> PredicateSym<'dest> {
269    PredicateSym {
270        name: bump.alloc_str(p.name),
271        arity: p.arity,
272    }
273}
274
275pub fn copy_atom<'dest, 'src>(bump: &'dest Bump, atom: &'src Atom<'src>) -> &'dest Atom<'dest> {
277    let args: Vec<_> = atom
278        .args
279        .iter()
280        .map(|arg| copy_base_term(bump, arg))
281        .collect();
282    let args = &*bump.alloc_slice_copy(&args);
283    bump.alloc(Atom {
284        sym: copy_predicate_sym(bump, atom.sym),
285        args,
286    })
287}
288
289pub fn copy_base_term<'dest, 'src>(
291    bump: &'dest Bump,
292    b: &'src BaseTerm<'src>,
293) -> &'dest BaseTerm<'dest> {
294    match b {
295        BaseTerm::Const(c) =>
296        {
298            bump.alloc(BaseTerm::Const(*copy_const(bump, c)))
299        }
300        BaseTerm::Variable(s) => bump.alloc(BaseTerm::Variable(bump.alloc_str(s))),
301        BaseTerm::ApplyFn(fun, args) => {
302            let fun = FunctionSym {
303                name: bump.alloc_str(fun.name),
304                arity: fun.arity,
305            };
306            let args: Vec<_> = args.iter().map(|a| copy_base_term(bump, a)).collect();
307            let args = bump.alloc_slice_copy(&args);
308            bump.alloc(BaseTerm::ApplyFn(fun, args))
309        }
310    }
311}
312
313pub fn copy_const<'dest, 'src>(bump: &'dest Bump, c: &'src Const<'src>) -> &'dest Const<'dest> {
315    match c {
316        Const::Name(name) => {
317            let name = &*bump.alloc_str(name);
318            bump.alloc(Const::Name(name))
319        }
320        Const::Bool(b) => bump.alloc(Const::Bool(*b)),
321        Const::Number(n) => bump.alloc(Const::Number(*n)),
322        Const::Float(f) => bump.alloc(Const::Float(*f)),
323        Const::String(s) => {
324            let s = &*bump.alloc_str(s);
325            bump.alloc(Const::String(s))
326        }
327        Const::Bytes(b) => {
328            let b = &*bump.alloc_slice_copy(b);
329            bump.alloc(Const::Bytes(b))
330        }
331        Const::List(cs) => {
332            let cs: Vec<_> = cs.iter().map(|c| copy_const(bump, c)).collect();
333            let cs = &*bump.alloc_slice_copy(&cs);
334            bump.alloc(Const::List(cs))
335        }
336        Const::Map { keys, values } => {
337            let keys: Vec<_> = keys.iter().map(|c| copy_const(bump, c)).collect();
338            let keys = &*bump.alloc_slice_copy(&keys);
339
340            let values: Vec<_> = values.iter().map(|c| copy_const(bump, c)).collect();
341            let values = &*bump.alloc_slice_copy(&values);
342
343            bump.alloc(Const::Map { keys, values })
344        }
345        Const::Struct { fields, values } => {
346            let fields: Vec<_> = fields.iter().map(|s| &*bump.alloc_str(s)).collect();
347            let fields = &*bump.alloc_slice_copy(&fields);
348
349            let values: Vec<_> = values.iter().map(|c| copy_const(bump, c)).collect();
350            let values = &*bump.alloc_slice_copy(&values);
351
352            bump.alloc(Const::Struct { fields, values })
353        }
354    }
355}
356
357pub fn copy_transform<'dest, 'src>(
358    bump: &'dest Bump,
359    stmt: &'src TransformStmt<'src>,
360) -> &'dest TransformStmt<'dest> {
361    let TransformStmt { var, app } = stmt;
362    let var = var.map(|s| &*bump.alloc_str(s));
363    let app = copy_base_term(bump, app);
364    bump.alloc(TransformStmt { var, app })
365}
366
367pub fn copy_clause<'dest, 'src>(
368    bump: &'dest Bump,
369    clause: &'src Clause<'src>,
370) -> &'dest Clause<'dest> {
371    let Clause {
372        head,
373        premises,
374        transform,
375    } = clause;
376    let premises: Vec<_> = premises.iter().map(|x| copy_term(bump, x)).collect();
377    let transform: Vec<_> = transform.iter().map(|x| copy_transform(bump, x)).collect();
378    bump.alloc(Clause {
379        head: copy_atom(bump, head),
380        premises: &*bump.alloc_slice_copy(&premises),
381        transform: &*bump.alloc_slice_copy(&transform),
382    })
383}
384
385fn copy_term<'dest, 'src>(bump: &'dest Bump, term: &'src Term<'src>) -> &'dest Term<'dest> {
386    match term {
387        Term::Atom(atom) => {
388            let atom = copy_atom(bump, atom);
389            bump.alloc(Term::Atom(atom))
390        }
391        Term::NegAtom(atom) => {
392            let atom = copy_atom(bump, atom);
393            bump.alloc(Term::NegAtom(atom))
394        }
395        Term::Eq(left, right) => {
396            let left = copy_base_term(bump, left);
397            let right = copy_base_term(bump, right);
398            bump.alloc(Term::Eq(left, right))
399        }
400        Term::Ineq(left, right) => {
401            let left = copy_base_term(bump, left);
402            let right = copy_base_term(bump, right);
403            bump.alloc(Term::Ineq(left, right))
404        }
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use bumpalo::Bump;
412    use googletest::prelude::*;
413
414    #[test]
415    fn copying_atom_works() {
416        let bump = Bump::new();
417        let foo = &*bump.alloc(BaseTerm::Const(Const::Name("/foo")));
418        let bar = bump.alloc(PredicateSym {
419            name: "bar",
420            arity: Some(1),
421        });
422        let bar_args = bump.alloc_slice_copy(&[foo]);
423        let head = bump.alloc(Atom {
424            sym: *bar,
425            args: &*bar_args,
426        });
427        assert_that!("bar(/foo)", eq(head.to_string()));
428    }
429
430    #[test]
431    fn atom_display_works() {
432        let bar = BaseTerm::Const(Const::Name("/bar"));
433        assert_that!(bar, displays_as(eq("/bar")));
434
435        let atom = Atom {
436            sym: PredicateSym {
437                name: "foo",
438                arity: Some(1),
439            },
440            args: &[&bar],
441        };
442        assert_that!(atom, displays_as(eq("foo(/bar)")));
443
444        let tests = vec![
445            (Term::Atom(&atom), "foo(/bar)"),
446            (Term::NegAtom(&atom), "!foo(/bar)"),
447            (Term::Eq(&bar, &bar), "/bar = /bar"),
448            (Term::Ineq(&bar, &bar), "/bar != /bar"),
449        ];
450        for (term, s) in tests {
451            assert_that!(term, displays_as(eq(s)));
452        }
453    }
454}