1use std::collections::HashMap;
16
17use bumpalo::Bump;
18
19#[derive(Debug)]
30pub struct Unit<'a> {
31 pub decls: &'a [&'a Decl<'a>],
32 pub clauses: &'a [&'a Clause<'a>],
33}
34
35#[derive(Debug, Clone, Copy, PartialEq)]
37pub struct Decl<'a> {
38 pub atom: &'a Atom<'a>,
39 pub descr: &'a [&'a Atom<'a>],
40 pub bounds: Option<&'a [&'a BoundDecl<'a>]>,
41 pub constraints: Option<&'a Constraints<'a>>,
42}
43
44#[derive(Debug, PartialEq)]
45pub struct BoundDecl<'a> {
46 pub base_terms: &'a [&'a BaseTerm<'a>],
47}
48
49#[derive(Debug, Clone, PartialEq)]
51pub struct Constraints<'a> {
52 pub consequences: &'a [&'a Atom<'a>],
54 pub alternatives: &'a [&'a [&'a Atom<'a>]],
56}
57
58#[derive(Debug)]
59pub struct Clause<'a> {
60 pub head: &'a Atom<'a>,
61 pub premises: &'a [&'a Term<'a>],
62 pub transform: &'a [&'a TransformStmt<'a>],
63}
64
65#[derive(Debug)]
66pub struct TransformStmt<'a> {
67 pub var: Option<&'a str>,
68 pub app: &'a BaseTerm<'a>,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub enum Term<'a> {
74 Atom(&'a Atom<'a>),
75 NegAtom(&'a Atom<'a>),
76 Eq(&'a BaseTerm<'a>, &'a BaseTerm<'a>),
77 Ineq(&'a BaseTerm<'a>, &'a BaseTerm<'a>),
78}
79
80impl<'a> std::fmt::Display for Term<'a> {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 match self {
83 Term::Atom(atom) => write!(f, "{atom}"),
84 Term::NegAtom(atom) => write!(f, "!{atom}"),
85 Term::Eq(left, right) => write!(f, "{left} = {right}"),
86 Term::Ineq(left, right) => write!(f, "{left} != {right}"),
87 }
88 }
89}
90
91impl<'a> Term<'a> {
92 pub fn apply_subst<'b>(
93 &'a self,
94 bump: &'b Bump,
95 subst: &HashMap<&'a str, &'a BaseTerm<'a>>,
96 ) -> &'b Term<'b> {
97 &*bump.alloc(match self {
98 Term::Atom(atom) => Term::Atom(atom.apply_subst(bump, subst)),
99 Term::NegAtom(atom) => Term::NegAtom(atom.apply_subst(bump, subst)),
100 Term::Eq(left, right) => Term::Eq(
101 left.apply_subst(bump, subst),
102 right.apply_subst(bump, subst),
103 ),
104 Term::Ineq(left, right) => Term::Ineq(
105 left.apply_subst(bump, subst),
106 right.apply_subst(bump, subst),
107 ),
108 })
109 }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum BaseTerm<'a> {
114 Const(Const<'a>),
115 Variable(&'a str),
116 ApplyFn(FunctionSym<'a>, &'a [&'a BaseTerm<'a>]),
117}
118
119impl<'a> BaseTerm<'a> {
120 pub fn apply_subst<'b>(
121 &'a self,
122 bump: &'b Bump,
123 subst: &HashMap<&'a str, &'a BaseTerm<'a>>,
124 ) -> &'b BaseTerm<'b> {
125 match self {
126 BaseTerm::Const(_) => copy_base_term(bump, self),
127 BaseTerm::Variable(v) => subst
128 .get(v)
129 .map_or(copy_base_term(bump, self), |b| copy_base_term(bump, b)),
130 BaseTerm::ApplyFn(fun, args) => {
131 let args: Vec<&'b BaseTerm<'b>> = args
132 .iter()
133 .map(|arg| arg.apply_subst(bump, subst))
134 .collect();
135 copy_base_term(bump, &BaseTerm::ApplyFn(*fun, &args))
136 }
137 }
138 }
139}
140
141impl<'a> std::fmt::Display for BaseTerm<'a> {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 match self {
144 BaseTerm::Const(c) => write!(f, "{c}"),
145 BaseTerm::Variable(v) => write!(f, "{v}"),
146 BaseTerm::ApplyFn(FunctionSym { name: n, .. }, args) => write!(
147 f,
148 "{n}({})",
149 args.iter()
150 .map(|x| x.to_string())
151 .collect::<Vec<_>>()
152 .join(",")
153 ),
154 }
155 }
156}
157#[derive(Debug, Clone, Copy, PartialEq)]
158pub enum Const<'a> {
159 Name(&'a str),
160 Bool(bool),
161 Number(i64),
162 Float(f64),
163 String(&'a str),
164 Bytes(&'a [u8]),
165 List(&'a [&'a Const<'a>]),
166 Map {
167 keys: &'a [&'a Const<'a>],
168 values: &'a [&'a Const<'a>],
169 },
170 Struct {
171 fields: &'a [&'a str],
172 values: &'a [&'a Const<'a>],
173 },
174}
175
176impl<'a> Eq for Const<'a> {}
177
178impl<'a> std::fmt::Display for Const<'a> {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 match *self {
181 Const::Name(v) => write!(f, "{v}"),
182 Const::Bool(v) => write!(f, "{v}"),
183 Const::Number(v) => write!(f, "{v}"),
184 Const::Float(v) => write!(f, "{v}"),
185 Const::String(v) => write!(f, "{v}"),
186 Const::Bytes(v) => write!(f, "{:?}", v),
187 Const::List(v) => {
188 write!(
189 f,
190 "[{}]",
191 v.iter()
192 .map(|x| x.to_string())
193 .collect::<Vec<_>>()
194 .join(", ")
195 )
196 }
197 Const::Map { keys: _, values: _ } => write!(f, "{{...}}"),
198 Const::Struct {
199 fields: _,
200 values: _,
201 } => write!(f, "{{...}}"),
202 }
203 }
204}
205
206#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
207pub struct PredicateSym<'a> {
208 pub name: &'a str,
209 pub arity: Option<u8>,
210}
211
212#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
213pub struct FunctionSym<'a> {
214 pub name: &'a str,
215 pub arity: Option<u8>,
216}
217
218#[derive(Debug, Clone, Copy, PartialEq, Eq)]
219pub struct Atom<'a> {
220 pub sym: PredicateSym<'a>,
221
222 pub args: &'a [&'a BaseTerm<'a>],
223}
224
225impl<'a> Atom<'a> {
226 pub fn matches(&'a self, query_args: &[&BaseTerm]) -> bool {
230 for (fact_arg, query_arg) in self.args.iter().zip(query_args.iter()) {
231 if let BaseTerm::Const(_) = query_arg {
232 if fact_arg != query_arg {
233 return false;
234 }
235 }
236 }
237 true
238 }
239
240 pub fn apply_subst<'b>(
241 &'a self,
242 bump: &'b Bump,
243 subst: &HashMap<&'a str, &'a BaseTerm<'a>>,
244 ) -> &'b Atom<'b> {
245 let args: Vec<&'b BaseTerm<'b>> = self
246 .args
247 .iter()
248 .map(|arg| arg.apply_subst(bump, subst))
249 .collect();
250 let args = &*bump.alloc_slice_copy(&args);
251 bump.alloc(Atom {
252 sym: copy_predicate_sym(bump, self.sym),
253 args,
254 })
255 }
256}
257
258impl<'a> std::fmt::Display for Atom<'a> {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 write!(f, "{}(", self.sym.name)?;
261 for arg in self.args {
262 write!(f, "{arg}")?;
263 }
264 write!(f, ")")
265 }
266}
267
268pub fn copy_predicate_sym<'dest>(bump: &'dest Bump, p: PredicateSym) -> PredicateSym<'dest> {
269 PredicateSym {
270 name: bump.alloc_str(p.name),
271 arity: p.arity,
272 }
273}
274
275pub fn copy_atom<'dest, 'src>(bump: &'dest Bump, atom: &'src Atom<'src>) -> &'dest Atom<'dest> {
277 let args: Vec<_> = atom
278 .args
279 .iter()
280 .map(|arg| copy_base_term(bump, arg))
281 .collect();
282 let args = &*bump.alloc_slice_copy(&args);
283 bump.alloc(Atom {
284 sym: copy_predicate_sym(bump, atom.sym),
285 args,
286 })
287}
288
289pub fn copy_base_term<'dest, 'src>(
291 bump: &'dest Bump,
292 b: &'src BaseTerm<'src>,
293) -> &'dest BaseTerm<'dest> {
294 match b {
295 BaseTerm::Const(c) =>
296 {
298 bump.alloc(BaseTerm::Const(*copy_const(bump, c)))
299 }
300 BaseTerm::Variable(s) => bump.alloc(BaseTerm::Variable(bump.alloc_str(s))),
301 BaseTerm::ApplyFn(fun, args) => {
302 let fun = FunctionSym {
303 name: bump.alloc_str(fun.name),
304 arity: fun.arity,
305 };
306 let args: Vec<_> = args.iter().map(|a| copy_base_term(bump, a)).collect();
307 let args = bump.alloc_slice_copy(&args);
308 bump.alloc(BaseTerm::ApplyFn(fun, args))
309 }
310 }
311}
312
313pub fn copy_const<'dest, 'src>(bump: &'dest Bump, c: &'src Const<'src>) -> &'dest Const<'dest> {
315 match c {
316 Const::Name(name) => {
317 let name = &*bump.alloc_str(name);
318 bump.alloc(Const::Name(name))
319 }
320 Const::Bool(b) => bump.alloc(Const::Bool(*b)),
321 Const::Number(n) => bump.alloc(Const::Number(*n)),
322 Const::Float(f) => bump.alloc(Const::Float(*f)),
323 Const::String(s) => {
324 let s = &*bump.alloc_str(s);
325 bump.alloc(Const::String(s))
326 }
327 Const::Bytes(b) => {
328 let b = &*bump.alloc_slice_copy(b);
329 bump.alloc(Const::Bytes(b))
330 }
331 Const::List(cs) => {
332 let cs: Vec<_> = cs.iter().map(|c| copy_const(bump, c)).collect();
333 let cs = &*bump.alloc_slice_copy(&cs);
334 bump.alloc(Const::List(cs))
335 }
336 Const::Map { keys, values } => {
337 let keys: Vec<_> = keys.iter().map(|c| copy_const(bump, c)).collect();
338 let keys = &*bump.alloc_slice_copy(&keys);
339
340 let values: Vec<_> = values.iter().map(|c| copy_const(bump, c)).collect();
341 let values = &*bump.alloc_slice_copy(&values);
342
343 bump.alloc(Const::Map { keys, values })
344 }
345 Const::Struct { fields, values } => {
346 let fields: Vec<_> = fields.iter().map(|s| &*bump.alloc_str(s)).collect();
347 let fields = &*bump.alloc_slice_copy(&fields);
348
349 let values: Vec<_> = values.iter().map(|c| copy_const(bump, c)).collect();
350 let values = &*bump.alloc_slice_copy(&values);
351
352 bump.alloc(Const::Struct { fields, values })
353 }
354 }
355}
356
357pub fn copy_transform<'dest, 'src>(
358 bump: &'dest Bump,
359 stmt: &'src TransformStmt<'src>,
360) -> &'dest TransformStmt<'dest> {
361 let TransformStmt { var, app } = stmt;
362 let var = var.map(|s| &*bump.alloc_str(s));
363 let app = copy_base_term(bump, app);
364 bump.alloc(TransformStmt { var, app })
365}
366
367pub fn copy_clause<'dest, 'src>(
368 bump: &'dest Bump,
369 clause: &'src Clause<'src>,
370) -> &'dest Clause<'dest> {
371 let Clause {
372 head,
373 premises,
374 transform,
375 } = clause;
376 let premises: Vec<_> = premises.iter().map(|x| copy_term(bump, x)).collect();
377 let transform: Vec<_> = transform.iter().map(|x| copy_transform(bump, x)).collect();
378 bump.alloc(Clause {
379 head: copy_atom(bump, head),
380 premises: &*bump.alloc_slice_copy(&premises),
381 transform: &*bump.alloc_slice_copy(&transform),
382 })
383}
384
385fn copy_term<'dest, 'src>(bump: &'dest Bump, term: &'src Term<'src>) -> &'dest Term<'dest> {
386 match term {
387 Term::Atom(atom) => {
388 let atom = copy_atom(bump, atom);
389 bump.alloc(Term::Atom(atom))
390 }
391 Term::NegAtom(atom) => {
392 let atom = copy_atom(bump, atom);
393 bump.alloc(Term::NegAtom(atom))
394 }
395 Term::Eq(left, right) => {
396 let left = copy_base_term(bump, left);
397 let right = copy_base_term(bump, right);
398 bump.alloc(Term::Eq(left, right))
399 }
400 Term::Ineq(left, right) => {
401 let left = copy_base_term(bump, left);
402 let right = copy_base_term(bump, right);
403 bump.alloc(Term::Ineq(left, right))
404 }
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411 use bumpalo::Bump;
412 use googletest::prelude::*;
413
414 #[test]
415 fn copying_atom_works() {
416 let bump = Bump::new();
417 let foo = &*bump.alloc(BaseTerm::Const(Const::Name("/foo")));
418 let bar = bump.alloc(PredicateSym {
419 name: "bar",
420 arity: Some(1),
421 });
422 let bar_args = bump.alloc_slice_copy(&[foo]);
423 let head = bump.alloc(Atom {
424 sym: *bar,
425 args: &*bar_args,
426 });
427 assert_that!("bar(/foo)", eq(head.to_string()));
428 }
429
430 #[test]
431 fn atom_display_works() {
432 let bar = BaseTerm::Const(Const::Name("/bar"));
433 assert_that!(bar, displays_as(eq("/bar")));
434
435 let atom = Atom {
436 sym: PredicateSym {
437 name: "foo",
438 arity: Some(1),
439 },
440 args: &[&bar],
441 };
442 assert_that!(atom, displays_as(eq("foo(/bar)")));
443
444 let tests = vec![
445 (Term::Atom(&atom), "foo(/bar)"),
446 (Term::NegAtom(&atom), "!foo(/bar)"),
447 (Term::Eq(&bar, &bar), "/bar = /bar"),
448 (Term::Ineq(&bar, &bar), "/bar != /bar"),
449 ];
450 for (term, s) in tests {
451 assert_that!(term, displays_as(eq(s)));
452 }
453 }
454}