Skip to main content

mangle_ast/
lib.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod pretty;
16pub use pretty::PrettyPrint;
17
18use bumpalo::Bump;
19use fxhash::FxHashMap;
20use std::cell::RefCell;
21use std::mem;
22use std::num::NonZeroUsize;
23use std::sync::{Arc, Mutex};
24
25const INTERNER_DEFAULT_CAPACITY: NonZeroUsize = NonZeroUsize::new(4096).unwrap();
26
27/// A simple way to intern strings and refer to them using u32 indices.
28/// This is following the blog post here:
29/// <https://matklad.github.io/2020/03/22/fast-simple-rust-interner.html>
30pub struct Interner {
31    map: FxHashMap<&'static str, u32>,
32    vec: Vec<&'static str>,
33    buf: String,
34    full: Vec<String>,
35}
36
37impl Interner {
38    fn new_global_interner() -> Arc<Mutex<Interner>> {
39        Arc::new(Mutex::new(Interner::with_capacity(
40            INTERNER_DEFAULT_CAPACITY.into(),
41        )))
42    }
43
44    pub fn with_capacity(cap: usize) -> Interner {
45        let cap = cap.next_power_of_two();
46        let mut interner = Interner {
47            map: FxHashMap::default(),
48            vec: Vec::new(),
49            buf: String::with_capacity(cap),
50            full: Vec::new(),
51        };
52        interner.intern("_");
53        interner
54    }
55
56    /// Interns the argument and returns a unique index.
57    pub fn intern(&mut self, name: &str) -> u32 {
58        if let Some(&id) = self.map.get(name) {
59            return id;
60        }
61        let name = unsafe { self.alloc(name) };
62        let id = self.map.len() as u32;
63        self.map.insert(name, id);
64        self.vec.push(name);
65        debug_assert!(self.lookup(id).expect("expected to find name") == name);
66        debug_assert!(self.intern(name) == id);
67        id
68    }
69
70    pub fn lookup_name_index(&self, name: &str) -> Option<u32> {
71        self.map.get(name).copied()
72    }
73
74    fn lookup(&self, id: u32) -> Option<&'static str> {
75        self.vec.get(id as usize).copied()
76    }
77
78    /// Copies a str into our internal buffer and returns a reference.
79    ///
80    /// # Safety
81    ///
82    /// The reference is only valid for the lifetime of this object.
83    unsafe fn alloc(&mut self, name: &str) -> &'static str {
84        let cap = self.buf.capacity();
85        if cap < self.buf.len() + name.len() {
86            let new_cap = (cap.max(name.len()) + 1).next_power_of_two();
87            let new_buf = String::with_capacity(new_cap);
88            let old_buf = mem::replace(&mut self.buf, new_buf);
89            self.full.push(old_buf);
90        }
91        let interned = {
92            let start = self.buf.len();
93            self.buf.push_str(name);
94            &self.buf[start..]
95        };
96        // SAFETY: This is reference to our internal buffer, which will
97        // be valid as long as this object is valid.
98        unsafe { &*(interned as *const str) }
99    }
100}
101
102pub struct Arena {
103    pub(crate) bump: Bump,
104    pub(crate) interner: Arc<Mutex<Interner>>,
105    pub(crate) predicate_syms: RefCell<Vec<PredicateSym>>,
106    pub(crate) function_syms: RefCell<Vec<FunctionSym>>,
107}
108
109impl<'arena> Arena {
110    pub fn new(interner: Arc<Mutex<Interner>>) -> Self {
111        Self {
112            bump: Bump::new(),
113            interner,
114            predicate_syms: RefCell::new(Vec::new()),
115            function_syms: RefCell::new(Vec::new()),
116        }
117    }
118
119    pub fn new_with_global_interner() -> Self {
120        Self::new(Interner::new_global_interner())
121    }
122
123    pub fn intern(&'arena self, name: &str) -> u32 {
124        self.interner.lock().unwrap().intern(name)
125    }
126
127    // Returns index for name, if it exists.
128    pub fn lookup_opt(&'arena self, name: &str) -> Option<u32> {
129        self.interner.lock().unwrap().map.get(name).copied()
130    }
131
132    pub fn name(&'arena self, name: &str) -> Const<'arena> {
133        Const::Name(self.intern(name))
134    }
135
136    pub fn variable(&'arena self, name: &str) -> &'arena BaseTerm<'arena> {
137        self.alloc(BaseTerm::Variable(self.variable_sym(name)))
138    }
139
140    pub fn const_(&'arena self, c: Const<'arena>) -> &'arena BaseTerm<'arena> {
141        self.alloc(BaseTerm::Const(c))
142    }
143
144    pub fn atom(
145        &'arena self,
146        p: PredicateIndex,
147        args: &[&'arena BaseTerm<'arena>],
148    ) -> &'arena Atom<'arena> {
149        self.alloc(Atom {
150            sym: p,
151            args: self.alloc_slice_copy(args),
152        })
153    }
154
155    pub fn apply_fn(
156        &'arena self,
157        fun: FunctionIndex,
158        args: &[&'arena BaseTerm<'arena>],
159    ) -> &'arena BaseTerm<'arena> {
160        //let fun = &self.function_syms.borrow()[fun.0];
161        //let name = self.lookup_name(fun.name);
162        let args = self.alloc_slice_copy(args);
163        self.alloc(BaseTerm::ApplyFn(fun, args))
164    }
165
166    pub fn alloc<T>(&self, x: T) -> &mut T {
167        self.bump.alloc(x)
168    }
169
170    pub fn alloc_slice_copy<T: Copy>(&self, x: &[T]) -> &[T] {
171        self.bump.alloc_slice_copy(x)
172    }
173
174    pub fn alloc_str(&'arena self, s: &str) -> &'arena str {
175        self.bump.alloc_str(s)
176    }
177
178    pub fn new_query(&'arena self, p: PredicateIndex) -> Atom<'arena> {
179        let arity = self.predicate_syms.borrow()[p.0].arity;
180        let args: Vec<_> = match arity {
181            Some(arity) => (0..arity).map(|_i| &ANY_VAR_TERM).collect(),
182            None => Vec::new(),
183        };
184
185        let args = self.alloc_slice_copy(&args);
186        Atom { sym: p, args }
187    }
188
189    /// Given a name index, returns the name.
190    pub fn lookup_name(&self, name_index: u32) -> Option<&'static str> {
191        self.interner.lock().unwrap().lookup(name_index)
192    }
193
194    /// Given a name, returns the index of the name if it exists in the interner.
195    pub fn lookup_name_index(&self, name: &str) -> Option<u32> {
196        self.interner.lock().unwrap().lookup_name_index(name)
197    }
198
199    /// Given predicate index, returns name of predicate symbol.
200    pub fn predicate_name(&self, predicate_index: PredicateIndex) -> Option<&'static str> {
201        let syms = self.predicate_syms.borrow();
202        let i = predicate_index.0;
203        if i >= syms.len() {
204            return None;
205        }
206        let n = syms[i].name;
207        self.interner.lock().unwrap().lookup(n)
208    }
209
210    /// Given predicate index, returns arity of predicate symbol.
211    pub fn predicate_arity(&self, predicate_index: PredicateIndex) -> Option<u8> {
212        self.predicate_syms
213            .borrow()
214            .get(predicate_index.0)
215            .and_then(|s| s.arity)
216    }
217
218    /// Given function index, returns name of function symbol.
219    pub fn function_name(&self, function_index: FunctionIndex) -> Option<&'static str> {
220        let syms = self.function_syms.borrow();
221        let i = function_index.0;
222        if i >= syms.len() {
223            return None;
224        }
225        let n = syms[i].name;
226        self.interner.lock().unwrap().lookup(n)
227    }
228
229    /// Returns index for this predicate symbol.
230    pub fn lookup_predicate_sym(&'arena self, predicate_name: u32) -> Option<PredicateIndex> {
231        for (index, p) in self.predicate_syms.borrow().iter().enumerate() {
232            if p.name == predicate_name {
233                return Some(PredicateIndex(index));
234            }
235        }
236        None
237    }
238
239    /// Constructs a new variable symbol.
240    pub fn variable_sym(&'arena self, name: &str) -> VariableIndex {
241        let n = self.interner.lock().unwrap().intern(name);
242        VariableIndex(n)
243    }
244
245    // Looks up a function sym, copies if it doesn't exist.
246    pub fn function_sym(&'arena self, name: &str, arity: Option<u8>) -> FunctionIndex {
247        let n = self.interner.lock().unwrap().intern(name);
248        let f = FunctionSym { name: n, arity };
249        for (index, f) in self.function_syms.borrow().iter().enumerate() {
250            if f.name == n {
251                return FunctionIndex(index);
252            }
253        }
254
255        self.function_syms.borrow_mut().push(f);
256        FunctionIndex(self.function_syms.borrow().len() - 1)
257    }
258
259    // Looks up or creates a predicate_sym, copies if it doesn't exist.
260    pub fn predicate_sym(&'arena self, name: &str, arity: Option<u8>) -> PredicateIndex {
261        let n = self.interner.lock().unwrap().intern(name);
262        let p = PredicateSym { name: n, arity };
263        for (index, p) in self.predicate_syms.borrow().iter().enumerate() {
264            if p.name == n {
265                return PredicateIndex(index);
266            }
267        }
268
269        self.predicate_syms.borrow_mut().push(p);
270        PredicateIndex(self.predicate_syms.borrow().len() - 1)
271    }
272
273    pub fn copy_function_sym<'src>(
274        &'arena self,
275        src: &'src Arena,
276        f: FunctionIndex,
277    ) -> FunctionIndex {
278        let function_sym = &src.function_syms.borrow()[f.0];
279        let name = src
280            .lookup_name(function_sym.name)
281            .expect("expected to find name");
282        self.function_sym(name, function_sym.arity)
283    }
284
285    pub fn copy_predicate_sym<'src>(
286        &'arena self,
287        src: &'src Arena,
288        p: PredicateIndex, //predicate_sym: &'src PredicateSym,
289    ) -> PredicateIndex {
290        let predicate_sym = &src.predicate_syms.borrow()[p.0];
291        let name = src
292            .lookup_name(predicate_sym.name)
293            .expect("expected to find name");
294        self.predicate_sym(name, predicate_sym.arity)
295    }
296
297    // Copies BaseTerm from another [`Arena`].
298    pub fn copy_atom<'src>(
299        &'arena self,
300        src: &'src Arena,
301        atom: &'src Atom<'src>,
302    ) -> &'arena Atom<'arena> {
303        let args: Vec<_> = atom
304            .args
305            .iter()
306            .map(|arg| self.copy_base_term(src, arg))
307            .collect();
308        let args = self.alloc_slice_copy(&args);
309        // TODO: have to look up predicate syms from !source!
310        self.alloc(Atom {
311            sym: self.copy_predicate_sym(src, atom.sym),
312            args,
313        })
314    }
315
316    // Copies BaseTerm to another Arena
317    pub fn copy_base_term<'src>(
318        &'arena self,
319        src: &'src Arena,
320        b: &'src BaseTerm<'src>,
321    ) -> &'arena BaseTerm<'arena> {
322        match b {
323            BaseTerm::Const(c) =>
324            // Should it be a reference?
325            {
326                self.alloc(BaseTerm::Const(*self.copy_const(src, c)))
327            }
328            BaseTerm::Variable(v) => {
329                let name = src
330                    .interner
331                    .lock()
332                    .unwrap()
333                    .lookup(v.0)
334                    .expect("expected to find name")
335                    .to_string();
336                let v = self.variable_sym(&name);
337                self.alloc(BaseTerm::Variable(v))
338            }
339            BaseTerm::ApplyFn(fun, args) => {
340                let fun = self.copy_function_sym(src, *fun);
341                //let fun = FunctionSym { name: self.alloc_str(fun.name), arity: fun.arity };
342                let args: Vec<_> = args.iter().map(|a| self.copy_base_term(src, a)).collect();
343                let args = self.alloc_slice_copy(&args);
344                self.alloc(BaseTerm::ApplyFn(fun, args))
345            }
346        }
347    }
348
349    // Copies Const to another Arena
350    pub fn copy_const<'src>(
351        &'arena self,
352        src: &'src Arena,
353        c: &'src Const<'src>,
354    ) -> &'arena Const<'arena> {
355        match c {
356            Const::Name(name) => {
357                let name = src
358                    .interner
359                    .lock()
360                    .unwrap()
361                    .lookup(*name)
362                    .expect("expected to find name");
363                let name = self.interner.lock().unwrap().intern(name);
364                self.alloc(Const::Name(name))
365            }
366            Const::Bool(b) => self.alloc(Const::Bool(*b)),
367            Const::Number(n) => self.alloc(Const::Number(*n)),
368            Const::Float(f) => self.alloc(Const::Float(*f)),
369            Const::Time(t) => self.alloc(Const::Time(*t)),
370            Const::Duration(d) => self.alloc(Const::Duration(*d)),
371            Const::String(s) => {
372                let s = self.alloc_str(s);
373                self.alloc(Const::String(s))
374            }
375            Const::Bytes(b) => {
376                let b = self.alloc_slice_copy(b);
377                self.alloc(Const::Bytes(b))
378            }
379            Const::List(cs) => {
380                let cs: Vec<_> = cs.iter().map(|c| self.copy_const(src, c)).collect();
381                let cs = self.alloc_slice_copy(&cs);
382                self.alloc(Const::List(cs))
383            }
384            Const::Map { keys, values } => {
385                let keys: Vec<_> = keys.iter().map(|c| self.copy_const(src, c)).collect();
386                let keys = self.alloc_slice_copy(&keys);
387
388                let values: Vec<_> = values.iter().map(|c| self.copy_const(src, c)).collect();
389                let values = self.alloc_slice_copy(&values);
390
391                self.alloc(Const::Map { keys, values })
392            }
393            Const::Struct { fields, values } => {
394                let fields: Vec<_> = fields.iter().map(|s| self.alloc_str(s)).collect();
395                let fields = self.alloc_slice_copy(&fields);
396
397                let values: Vec<_> = values.iter().map(|c| self.copy_const(src, c)).collect();
398                let values = self.alloc_slice_copy(&values);
399
400                self.alloc(Const::Struct { fields, values })
401            }
402        }
403    }
404
405    pub fn copy_transform<'src>(
406        &'arena self,
407        src: &'src Arena,
408        stmt: &'src TransformStmt<'src>,
409    ) -> &'arena TransformStmt<'arena> {
410        let TransformStmt { var, app } = stmt;
411        let var = var.map(|s| self.alloc_str(s));
412        let app = self.copy_base_term(src, app);
413        self.alloc(TransformStmt { var, app })
414    }
415
416    pub fn copy_clause<'src>(
417        &'arena self,
418        src: &'src Arena,
419        src_clause: &'src Clause<'src>,
420    ) -> &'arena Clause<'arena> {
421        let Clause {
422            head,
423            head_time,
424            premises,
425            transform,
426        } = src_clause;
427        let premises: Vec<_> = premises.iter().map(|x| self.copy_term(src, x)).collect();
428        let transform: Vec<_> = transform
429            .iter()
430            .map(|x| self.copy_transform(src, x))
431            .collect();
432        self.alloc(Clause {
433            head: self.copy_atom(src, head),
434            head_time: *head_time,
435            premises: self.alloc_slice_copy(&premises),
436            transform: self.alloc_slice_copy(&transform),
437        })
438    }
439
440    fn copy_term<'src>(
441        &'arena self,
442        src: &'src Arena,
443        term: &'src Term<'src>,
444    ) -> &'arena Term<'arena> {
445        match term {
446            Term::Atom(atom) => {
447                let atom = self.copy_atom(src, atom);
448                self.alloc(Term::Atom(atom))
449            }
450            Term::NegAtom(atom) => {
451                let atom = self.copy_atom(src, atom);
452                self.alloc(Term::NegAtom(atom))
453            }
454            Term::Eq(left, right) => {
455                let left = self.copy_base_term(src, left);
456                let right = self.copy_base_term(src, right);
457                self.alloc(Term::Eq(left, right))
458            }
459            Term::Ineq(left, right) => {
460                let left = self.copy_base_term(src, left);
461                let right = self.copy_base_term(src, right);
462                self.alloc(Term::Ineq(left, right))
463            }
464            Term::TemporalAtom(atom, interval) => {
465                let atom = self.copy_atom(src, atom);
466                self.alloc(Term::TemporalAtom(atom, *interval))
467            }
468        }
469    }
470}
471
472// Immutable representation of syntax.
473// We use references instead of a smart pointer,
474// relying on an arena that holds everything.
475// This way we can use pattern matching.
476
477// Unit is a source unit.
478// It consists of declarations and clauses.
479// After parsing, all source units for a Mangle package
480// are merged into one, so unit can also be seen
481// as translation unit.
482#[derive(Debug)]
483pub struct Unit<'a> {
484    pub decls: &'a [&'a Decl<'a>],
485    pub clauses: &'a [&'a Clause<'a>],
486}
487
488// ---------------------------------------------------------------------------
489// Temporal annotations
490// ---------------------------------------------------------------------------
491
492/// A bound in a temporal interval.
493#[derive(Debug, Clone, Copy, PartialEq, Eq)]
494pub enum TemporalBound {
495    /// A concrete timestamp (nanoseconds since Unix epoch).
496    Timestamp(i64),
497    /// A variable to be bound during evaluation.
498    Variable(VariableIndex),
499    /// Negative infinity (written as `_` in start position).
500    NegInf,
501    /// Positive infinity (written as `_` in end position).
502    PosInf,
503}
504
505/// A temporal interval `@[start, end]`, inclusive on both endpoints.
506/// A point interval `@[T]` is represented as `start == end`.
507#[derive(Debug, Clone, Copy, PartialEq, Eq)]
508pub struct Interval {
509    pub start: TemporalBound,
510    pub end: TemporalBound,
511}
512
513// ---------------------------------------------------------------------------
514
515// Predicate, package and use declarations.
516#[derive(Debug, Clone, Copy, PartialEq)]
517pub struct Decl<'a> {
518    pub atom: &'a Atom<'a>,
519    pub descr: &'a [&'a Atom<'a>],
520    pub bounds: Option<&'a [&'a BoundDecl<'a>]>,
521    pub constraints: Option<&'a Constraints<'a>>,
522    pub is_temporal: bool,
523}
524
525#[derive(Debug, PartialEq)]
526pub struct BoundDecl<'a> {
527    pub base_terms: &'a [&'a BaseTerm<'a>],
528}
529
530//
531#[derive(Debug, Clone, PartialEq)]
532pub struct Constraints<'a> {
533    // All of these must hold.
534    pub consequences: &'a [&'a Atom<'a>],
535    // In addition to consequences, at least one of these must hold.
536    pub alternatives: &'a [&'a [&'a Atom<'a>]],
537}
538
539#[derive(Debug)]
540pub struct Clause<'a> {
541    pub head: &'a Atom<'a>,
542    /// Optional temporal annotation on the head: `head(X)@[S, E] :- ...`
543    pub head_time: Option<Interval>,
544    pub premises: &'a [&'a Term<'a>],
545    pub transform: &'a [&'a TransformStmt<'a>],
546}
547
548#[derive(Debug)]
549pub struct TransformStmt<'a> {
550    pub var: Option<&'a str>,
551    pub app: &'a BaseTerm<'a>,
552}
553
554// Term that may appear on righthand-side of a clause.
555#[derive(Debug, Clone, Copy, PartialEq, Eq)]
556pub enum Term<'a> {
557    Atom(&'a Atom<'a>),
558    NegAtom(&'a Atom<'a>),
559    Eq(&'a BaseTerm<'a>, &'a BaseTerm<'a>),
560    Ineq(&'a BaseTerm<'a>, &'a BaseTerm<'a>),
561    /// Atom with temporal annotation: `p(X)@[S, E]`
562    TemporalAtom(&'a Atom<'a>, Interval),
563}
564
565impl std::fmt::Display for Term<'_> {
566    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
567        match self {
568            Term::Atom(atom) => write!(f, "{atom}"),
569            Term::NegAtom(atom) => write!(f, "!{atom}"),
570            Term::Eq(left, right) => write!(f, "{left} = {right}"),
571            Term::Ineq(left, right) => write!(f, "{left} != {right}"),
572            Term::TemporalAtom(atom, interval) => write!(f, "{atom}@[{}, {}]", interval.start, interval.end),
573        }
574    }
575}
576
577impl std::fmt::Display for TemporalBound {
578    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
579        match self {
580            TemporalBound::Timestamp(nanos) => write!(f, "t#{nanos}"),
581            TemporalBound::Variable(v) => write!(f, "{v}"),
582            TemporalBound::NegInf | TemporalBound::PosInf => write!(f, "_"),
583        }
584    }
585}
586
587impl<'a> Term<'a> {
588    pub fn apply_subst(
589        &'a self,
590        arena: &'a Arena,
591        subst: &FxHashMap<u32, &'a BaseTerm<'a>>,
592    ) -> &'a Term<'a> {
593        &*arena.alloc(match self {
594            Term::Atom(atom) => Term::Atom(atom.apply_subst(arena, subst)),
595            Term::NegAtom(atom) => Term::NegAtom(atom.apply_subst(arena, subst)),
596            Term::Eq(left, right) => Term::Eq(
597                left.apply_subst(arena, subst),
598                right.apply_subst(arena, subst),
599            ),
600            Term::Ineq(left, right) => Term::Ineq(
601                left.apply_subst(arena, subst),
602                right.apply_subst(arena, subst),
603            ),
604            Term::TemporalAtom(atom, interval) => {
605                Term::TemporalAtom(atom.apply_subst(arena, subst), *interval)
606            }
607        })
608    }
609}
610
611#[derive(Debug, Clone, Copy, PartialEq, Eq)]
612pub enum BaseTerm<'a> {
613    Const(Const<'a>),
614    Variable(VariableIndex),
615    ApplyFn(FunctionIndex, &'a [&'a BaseTerm<'a>]),
616}
617
618impl<'arena> BaseTerm<'arena> {
619    pub fn apply_subst(
620        &'arena self,
621        arena: &'arena Arena,
622        subst: &FxHashMap<u32, &'arena BaseTerm<'arena>>,
623    ) -> &'arena BaseTerm<'arena> {
624        match self {
625            BaseTerm::Const(_) => self,
626            BaseTerm::Variable(v) => subst.get(&v.0).unwrap_or(&self),
627            BaseTerm::ApplyFn(fun, args) => {
628                let args: Vec<&'arena BaseTerm<'arena>> = args
629                    .iter()
630                    .map(|arg| arg.apply_subst(arena, subst))
631                    .collect();
632                let args = arena.alloc_slice_copy(&args);
633                arena.alloc(BaseTerm::ApplyFn(*fun, args))
634            }
635        }
636    }
637}
638
639impl std::fmt::Display for BaseTerm<'_> {
640    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
641        match self {
642            BaseTerm::Const(c) => write!(f, "{c}"),
643            BaseTerm::Variable(v) => write!(f, "{v}"),
644            BaseTerm::ApplyFn(fun, args) => {
645                write!(
646                    f,
647                    "{fun}({})",
648                    args.iter()
649                        .map(|x| x.to_string())
650                        .collect::<Vec<_>>()
651                        .join(",")
652                )
653            }
654        }
655    }
656}
657#[derive(Debug, Clone, Copy, PartialEq)]
658pub enum Const<'a> {
659    Name(u32),
660    Bool(bool),
661    Number(i64),
662    Float(f64),
663    String(&'a str),
664    Bytes(&'a [u8]),
665    /// Time as nanoseconds since Unix epoch.
666    Time(i64),
667    /// Duration as nanoseconds.
668    Duration(i64),
669    List(&'a [&'a Const<'a>]),
670    Map {
671        keys: &'a [&'a Const<'a>],
672        values: &'a [&'a Const<'a>],
673    },
674    Struct {
675        fields: &'a [&'a str],
676        values: &'a [&'a Const<'a>],
677    },
678}
679
680impl Eq for Const<'_> {}
681
682impl std::fmt::Display for Const<'_> {
683    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
684        match *self {
685            Const::Name(v) => write!(f, "n${v}"),
686            Const::Bool(v) => write!(f, "{v}"),
687            Const::Number(v) => write!(f, "{v}"),
688            Const::Float(v) => write!(f, "{v}"),
689            Const::String(v) => write!(f, "{v}"),
690            Const::Bytes(v) => write!(f, "{v:?}"),
691            Const::Time(v) => write!(f, "t#{v}"),
692            Const::Duration(v) => write!(f, "d#{v}"),
693            Const::List(v) => {
694                write!(
695                    f,
696                    "[{}]",
697                    v.iter()
698                        .map(|x| x.to_string())
699                        .collect::<Vec<_>>()
700                        .join(", ")
701                )
702            }
703            Const::Map { keys, values } => {
704                if keys.is_empty() {
705                    write!(f, "fn:map()")
706                } else {
707                    write!(f, "[")?;
708                    for (i, (k, v)) in keys.iter().zip(values.iter()).enumerate() {
709                        if i > 0 {
710                            write!(f, ", ")?;
711                        }
712                        write!(f, "{k}: {v}")?;
713                    }
714                    write!(f, "]")
715                }
716            }
717            Const::Struct { fields, values } => {
718                write!(f, "{{")?;
719                for (i, (field, val)) in fields.iter().zip(values.iter()).enumerate() {
720                    if i > 0 {
721                        write!(f, ", ")?;
722                    }
723                    write!(f, "{field}: {val}")?;
724                }
725                write!(f, "}}")
726            }
727        }
728    }
729}
730
731#[derive(Debug, Clone, PartialEq, Eq, Hash)]
732pub struct PredicateSym {
733    pub name: u32,
734    pub arity: Option<u8>,
735}
736
737#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
738pub struct PredicateIndex(usize);
739
740impl std::fmt::Display for PredicateIndex {
741    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
742        write!(f, "p${}", self.0)
743    }
744}
745
746#[derive(Debug, Clone, PartialEq, Eq, Hash)]
747pub struct FunctionSym {
748    pub name: u32,
749    pub arity: Option<u8>,
750}
751
752#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
753pub struct FunctionIndex(usize);
754
755impl std::fmt::Display for FunctionIndex {
756    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
757        write!(f, "f${}", self.0)
758    }
759}
760
761#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
762pub struct VariableIndex(pub u32);
763
764impl std::fmt::Display for VariableIndex {
765    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
766        if self.0 == 0 {
767            write!(f, "_")
768        } else {
769            write!(f, "v${}", self.0)
770        }
771    }
772}
773
774#[derive(Debug, Clone, Copy, PartialEq, Eq)]
775pub struct Atom<'a> {
776    pub sym: PredicateIndex,
777    pub args: &'a [&'a BaseTerm<'a>],
778}
779
780impl<'a> Atom<'a> {
781    // A fact matches a query if there is a substitution s.t. subst(query) = fact.
782    // We assume that self is a fact and query_args are the arguments of a query,
783    // i.e. only variables and constants.
784    pub fn matches(&'a self, query_args: &[&BaseTerm]) -> bool {
785        for (fact_arg, query_arg) in self.args.iter().zip(query_args.iter()) {
786            if let BaseTerm::Const(_) = query_arg
787                && fact_arg != query_arg
788            {
789                return false;
790            }
791        }
792        true
793    }
794
795    pub fn apply_subst(
796        &'a self,
797        arena: &'a Arena,
798        subst: &FxHashMap<u32, &'a BaseTerm<'a>>,
799    ) -> &'a Atom<'a> {
800        let args: Vec<&'a BaseTerm<'a>> = self
801            .args
802            .iter()
803            .map(|arg| arg.apply_subst(arena, subst))
804            .collect();
805        arena.atom(self.sym, &args)
806    }
807}
808
809impl std::fmt::Display for Atom<'_> {
810    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
811        write!(f, "{}(", self.sym)?;
812        for (i, arg) in self.args.iter().enumerate() {
813            if i > 0 {
814                write!(f, ", ")?;
815            }
816            write!(f, "{arg}")?;
817        }
818        write!(f, ")")
819    }
820}
821
822static ANY_VAR_TERM: BaseTerm = BaseTerm::Variable(VariableIndex(0));
823
824#[cfg(test)]
825mod tests {
826    use super::*;
827    use googletest::prelude::*;
828
829    #[test]
830    fn copying_atom_works() {
831        let arena = Arena::new_with_global_interner();
832        let foo = arena.const_(arena.name("/foo"));
833        let bar = arena.predicate_sym("bar", Some(1));
834        let head = arena.atom(bar, &[foo]);
835        assert_that!(head.to_string(), eq("p$0(n$1)"));
836    }
837
838    #[test]
839    fn atom_display_works() {
840        let arena = Arena::new_with_global_interner();
841        let bar = arena.const_(arena.name("/bar"));
842        let sym = arena.predicate_sym("foo", Some(1));
843        let atom = Atom { sym, args: &[bar] };
844        assert_that!(atom, displays_as(eq("p$0(n$1)")));
845
846        let tests = vec![
847            (Term::Atom(&atom), "p$0(n$1)"),
848            (Term::NegAtom(&atom), "!p$0(n$1)"),
849            (Term::Eq(bar, bar), "n$1 = n$1"),
850            (Term::Ineq(bar, bar), "n$1 != n$1"),
851        ];
852        for (term, s) in tests {
853            assert_that!(term, displays_as(eq(s)));
854        }
855    }
856
857    #[test]
858    fn new_query_works() {
859        let arena = Arena::new_with_global_interner();
860
861        let pred = arena.predicate_sym("foo", Some(1));
862        let query = arena.new_query(pred);
863        assert_that!(query, displays_as(eq("p$0(_)")));
864
865        let pred = arena.predicate_sym("bar", Some(2));
866        let query = arena.new_query(pred);
867        assert_that!(query, displays_as(eq("p$1(_, _)")));
868
869        let pred = arena.predicate_sym("frob", None);
870        let query = arena.new_query(pred);
871        assert_that!(query, displays_as(eq("p$2()")));
872    }
873
874    #[test]
875    fn subst_works() {
876        let arena = Arena::new_with_global_interner();
877        let atom = arena.atom(arena.predicate_sym("foo", Some(1)), &[arena.variable("x")]);
878
879        let mut subst = FxHashMap::default();
880        subst.insert(arena.variable_sym("x").0, arena.const_(arena.name("/bar")));
881
882        let subst_atom = atom.apply_subst(&arena, &subst);
883        assert_that!(arena.name("/bar"), displays_as(eq("n$3")));
884        assert_that!(subst_atom, displays_as(eq("p$0(n$3)")));
885    }
886
887    #[test]
888    fn do_intern_beyond_initial_capacity() {
889        let arena = Arena::new_with_global_interner();
890
891        let p = arena.predicate_sym("/foo", Some(1));
892        let mut name = "".to_string();
893        for _ in 0..INTERNER_DEFAULT_CAPACITY.into() {
894            name += "a";
895        }
896        arena.interner.lock().unwrap().intern(&name);
897        assert_that!(arena.predicate_name(p), eq(Some("/foo")));
898    }
899
900    #[test]
901    fn pretty_print_works() {
902        let arena = Arena::new_with_global_interner();
903        let foo = arena.const_(arena.name("/foo"));
904        let bar_pred = arena.predicate_sym("bar", Some(1));
905        let x_var = arena.variable("X");
906        let head = arena.atom(bar_pred, &[x_var]); // bar(X)
907
908        let premise = Term::Eq(x_var, foo);
909        let premise_ref = arena.alloc(premise);
910
911        let clause = Clause {
912            head,
913            head_time: None,
914            premises: arena.alloc_slice_copy(&[premise_ref]),
915            transform: &[],
916        };
917
918        assert_that!(clause.pretty(&arena).to_string(), eq("bar(X) :- X = /foo."));
919
920        let fun = arena.function_sym("f", Some(1));
921        let app = arena.apply_fn(fun, &[x_var]);
922        assert_that!(app.pretty(&arena).to_string(), eq("f(X)"));
923    }
924}