Skip to main content

mangle_ast/
lib.rs

1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod pretty;
16pub use pretty::PrettyPrint;
17
18use bumpalo::Bump;
19use fxhash::FxHashMap;
20use std::cell::RefCell;
21use std::mem;
22use std::num::NonZeroUsize;
23use std::sync::{Arc, Mutex};
24
25const INTERNER_DEFAULT_CAPACITY: NonZeroUsize = NonZeroUsize::new(4096).unwrap();
26
27/// A simple way to intern strings and refer to them using u32 indices.
28/// This is following the blog post here:
29/// <https://matklad.github.io/2020/03/22/fast-simple-rust-interner.html>
30pub struct Interner {
31    map: FxHashMap<&'static str, u32>,
32    vec: Vec<&'static str>,
33    buf: String,
34    full: Vec<String>,
35}
36
37impl Interner {
38    fn new_global_interner() -> Arc<Mutex<Interner>> {
39        Arc::new(Mutex::new(Interner::with_capacity(
40            INTERNER_DEFAULT_CAPACITY.into(),
41        )))
42    }
43
44    pub fn with_capacity(cap: usize) -> Interner {
45        let cap = cap.next_power_of_two();
46        let mut interner = Interner {
47            map: FxHashMap::default(),
48            vec: Vec::new(),
49            buf: String::with_capacity(cap),
50            full: Vec::new(),
51        };
52        interner.intern("_");
53        interner
54    }
55
56    /// Interns the argument and returns a unique index.
57    pub fn intern(&mut self, name: &str) -> u32 {
58        if let Some(&id) = self.map.get(name) {
59            return id;
60        }
61        let name = unsafe { self.alloc(name) };
62        let id = self.map.len() as u32;
63        self.map.insert(name, id);
64        self.vec.push(name);
65        debug_assert!(self.lookup(id).expect("expected to find name") == name);
66        debug_assert!(self.intern(name) == id);
67        id
68    }
69
70    pub fn lookup_name_index(&self, name: &str) -> Option<u32> {
71        self.map.get(name).copied()
72    }
73
74    fn lookup(&self, id: u32) -> Option<&'static str> {
75        self.vec.get(id as usize).copied()
76    }
77
78    /// Copies a str into our internal buffer and returns a reference.
79    ///
80    /// # Safety
81    ///
82    /// The reference is only valid for the lifetime of this object.
83    unsafe fn alloc(&mut self, name: &str) -> &'static str {
84        let cap = self.buf.capacity();
85        if cap < self.buf.len() + name.len() {
86            let new_cap = (cap.max(name.len()) + 1).next_power_of_two();
87            let new_buf = String::with_capacity(new_cap);
88            let old_buf = mem::replace(&mut self.buf, new_buf);
89            self.full.push(old_buf);
90        }
91        let interned = {
92            let start = self.buf.len();
93            self.buf.push_str(name);
94            &self.buf[start..]
95        };
96        // SAFETY: This is reference to our internal buffer, which will
97        // be valid as long as this object is valid.
98        unsafe { &*(interned as *const str) }
99    }
100}
101
102pub struct Arena {
103    pub(crate) bump: Bump,
104    pub(crate) interner: Arc<Mutex<Interner>>,
105    pub(crate) predicate_syms: RefCell<Vec<PredicateSym>>,
106    pub(crate) function_syms: RefCell<Vec<FunctionSym>>,
107}
108
109impl<'arena> Arena {
110    pub fn new(interner: Arc<Mutex<Interner>>) -> Self {
111        Self {
112            bump: Bump::new(),
113            interner,
114            predicate_syms: RefCell::new(Vec::new()),
115            function_syms: RefCell::new(Vec::new()),
116        }
117    }
118
119    pub fn new_with_global_interner() -> Self {
120        Self::new(Interner::new_global_interner())
121    }
122
123    pub fn intern(&'arena self, name: &str) -> u32 {
124        self.interner.lock().unwrap().intern(name)
125    }
126
127    // Returns index for name, if it exists.
128    pub fn lookup_opt(&'arena self, name: &str) -> Option<u32> {
129        self.interner.lock().unwrap().map.get(name).copied()
130    }
131
132    pub fn name(&'arena self, name: &str) -> Const<'arena> {
133        Const::Name(self.intern(name))
134    }
135
136    pub fn variable(&'arena self, name: &str) -> &'arena BaseTerm<'arena> {
137        self.alloc(BaseTerm::Variable(self.variable_sym(name)))
138    }
139
140    pub fn const_(&'arena self, c: Const<'arena>) -> &'arena BaseTerm<'arena> {
141        self.alloc(BaseTerm::Const(c))
142    }
143
144    pub fn atom(
145        &'arena self,
146        p: PredicateIndex,
147        args: &[&'arena BaseTerm<'arena>],
148    ) -> &'arena Atom<'arena> {
149        self.alloc(Atom {
150            sym: p,
151            args: self.alloc_slice_copy(args),
152        })
153    }
154
155    pub fn apply_fn(
156        &'arena self,
157        fun: FunctionIndex,
158        args: &[&'arena BaseTerm<'arena>],
159    ) -> &'arena BaseTerm<'arena> {
160        //let fun = &self.function_syms.borrow()[fun.0];
161        //let name = self.lookup_name(fun.name);
162        let args = self.alloc_slice_copy(args);
163        self.alloc(BaseTerm::ApplyFn(fun, args))
164    }
165
166    pub fn alloc<T>(&self, x: T) -> &mut T {
167        self.bump.alloc(x)
168    }
169
170    pub fn alloc_slice_copy<T: Copy>(&self, x: &[T]) -> &[T] {
171        self.bump.alloc_slice_copy(x)
172    }
173
174    pub fn alloc_str(&'arena self, s: &str) -> &'arena str {
175        self.bump.alloc_str(s)
176    }
177
178    pub fn new_query(&'arena self, p: PredicateIndex) -> Atom<'arena> {
179        let arity = self.predicate_syms.borrow()[p.0].arity;
180        let args: Vec<_> = match arity {
181            Some(arity) => (0..arity).map(|_i| &ANY_VAR_TERM).collect(),
182            None => Vec::new(),
183        };
184
185        let args = self.alloc_slice_copy(&args);
186        Atom { sym: p, args }
187    }
188
189    /// Given a name index, returns the name.
190    pub fn lookup_name(&self, name_index: u32) -> Option<&'static str> {
191        self.interner.lock().unwrap().lookup(name_index)
192    }
193
194    /// Given a name, returns the index of the name if it exists in the interner.
195    pub fn lookup_name_index(&self, name: &str) -> Option<u32> {
196        self.interner.lock().unwrap().lookup_name_index(name)
197    }
198
199    /// Given predicate index, returns name of predicate symbol.
200    pub fn predicate_name(&self, predicate_index: PredicateIndex) -> Option<&'static str> {
201        let syms = self.predicate_syms.borrow();
202        let i = predicate_index.0;
203        if i >= syms.len() {
204            return None;
205        }
206        let n = syms[i].name;
207        self.interner.lock().unwrap().lookup(n)
208    }
209
210    /// Given predicate index, returns arity of predicate symbol.
211    pub fn predicate_arity(&self, predicate_index: PredicateIndex) -> Option<u8> {
212        self.predicate_syms
213            .borrow()
214            .get(predicate_index.0)
215            .and_then(|s| s.arity)
216    }
217
218    /// Given function index, returns name of function symbol.
219    pub fn function_name(&self, function_index: FunctionIndex) -> Option<&'static str> {
220        let syms = self.function_syms.borrow();
221        let i = function_index.0;
222        if i >= syms.len() {
223            return None;
224        }
225        let n = syms[i].name;
226        self.interner.lock().unwrap().lookup(n)
227    }
228
229    /// Returns index for this predicate symbol.
230    pub fn lookup_predicate_sym(&'arena self, predicate_name: u32) -> Option<PredicateIndex> {
231        for (index, p) in self.predicate_syms.borrow().iter().enumerate() {
232            if p.name == predicate_name {
233                return Some(PredicateIndex(index));
234            }
235        }
236        None
237    }
238
239    /// Constructs a new variable symbol.
240    pub fn variable_sym(&'arena self, name: &str) -> VariableIndex {
241        let n = self.interner.lock().unwrap().intern(name);
242        VariableIndex(n)
243    }
244
245    // Looks up a function sym, copies if it doesn't exist.
246    pub fn function_sym(&'arena self, name: &str, arity: Option<u8>) -> FunctionIndex {
247        let n = self.interner.lock().unwrap().intern(name);
248        let f = FunctionSym { name: n, arity };
249        for (index, f) in self.function_syms.borrow().iter().enumerate() {
250            if f.name == n {
251                return FunctionIndex(index);
252            }
253        }
254
255        self.function_syms.borrow_mut().push(f);
256        FunctionIndex(self.function_syms.borrow().len() - 1)
257    }
258
259    // Looks up or creates a predicate_sym, copies if it doesn't exist.
260    pub fn predicate_sym(&'arena self, name: &str, arity: Option<u8>) -> PredicateIndex {
261        let n = self.interner.lock().unwrap().intern(name);
262        let p = PredicateSym { name: n, arity };
263        for (index, p) in self.predicate_syms.borrow().iter().enumerate() {
264            if p.name == n {
265                return PredicateIndex(index);
266            }
267        }
268
269        self.predicate_syms.borrow_mut().push(p);
270        PredicateIndex(self.predicate_syms.borrow().len() - 1)
271    }
272
273    pub fn copy_function_sym<'src>(
274        &'arena self,
275        src: &'src Arena,
276        f: FunctionIndex,
277    ) -> FunctionIndex {
278        let function_sym = &src.function_syms.borrow()[f.0];
279        let name = src
280            .lookup_name(function_sym.name)
281            .expect("expected to find name");
282        self.function_sym(name, function_sym.arity)
283    }
284
285    pub fn copy_predicate_sym<'src>(
286        &'arena self,
287        src: &'src Arena,
288        p: PredicateIndex, //predicate_sym: &'src PredicateSym,
289    ) -> PredicateIndex {
290        let predicate_sym = &src.predicate_syms.borrow()[p.0];
291        let name = src
292            .lookup_name(predicate_sym.name)
293            .expect("expected to find name");
294        self.predicate_sym(name, predicate_sym.arity)
295    }
296
297    // Copies BaseTerm from another [`Arena`].
298    pub fn copy_atom<'src>(
299        &'arena self,
300        src: &'src Arena,
301        atom: &'src Atom<'src>,
302    ) -> &'arena Atom<'arena> {
303        let args: Vec<_> = atom
304            .args
305            .iter()
306            .map(|arg| self.copy_base_term(src, arg))
307            .collect();
308        let args = self.alloc_slice_copy(&args);
309        // TODO: have to look up predicate syms from !source!
310        self.alloc(Atom {
311            sym: self.copy_predicate_sym(src, atom.sym),
312            args,
313        })
314    }
315
316    // Copies BaseTerm to another Arena
317    pub fn copy_base_term<'src>(
318        &'arena self,
319        src: &'src Arena,
320        b: &'src BaseTerm<'src>,
321    ) -> &'arena BaseTerm<'arena> {
322        match b {
323            BaseTerm::Const(c) =>
324            // Should it be a reference?
325            {
326                self.alloc(BaseTerm::Const(*self.copy_const(src, c)))
327            }
328            BaseTerm::Variable(v) => {
329                let name = src
330                    .interner
331                    .lock()
332                    .unwrap()
333                    .lookup(v.0)
334                    .expect("expected to find name")
335                    .to_string();
336                let v = self.variable_sym(&name);
337                self.alloc(BaseTerm::Variable(v))
338            }
339            BaseTerm::ApplyFn(fun, args) => {
340                let fun = self.copy_function_sym(src, *fun);
341                //let fun = FunctionSym { name: self.alloc_str(fun.name), arity: fun.arity };
342                let args: Vec<_> = args.iter().map(|a| self.copy_base_term(src, a)).collect();
343                let args = self.alloc_slice_copy(&args);
344                self.alloc(BaseTerm::ApplyFn(fun, args))
345            }
346        }
347    }
348
349    // Copies Const to another Arena
350    pub fn copy_const<'src>(
351        &'arena self,
352        src: &'src Arena,
353        c: &'src Const<'src>,
354    ) -> &'arena Const<'arena> {
355        match c {
356            Const::Name(name) => {
357                let name = src
358                    .interner
359                    .lock()
360                    .unwrap()
361                    .lookup(*name)
362                    .expect("expected to find name");
363                let name = self.interner.lock().unwrap().intern(name);
364                self.alloc(Const::Name(name))
365            }
366            Const::Bool(b) => self.alloc(Const::Bool(*b)),
367            Const::Number(n) => self.alloc(Const::Number(*n)),
368            Const::Float(f) => self.alloc(Const::Float(*f)),
369            Const::String(s) => {
370                let s = self.alloc_str(s);
371                self.alloc(Const::String(s))
372            }
373            Const::Bytes(b) => {
374                let b = self.alloc_slice_copy(b);
375                self.alloc(Const::Bytes(b))
376            }
377            Const::List(cs) => {
378                let cs: Vec<_> = cs.iter().map(|c| self.copy_const(src, c)).collect();
379                let cs = self.alloc_slice_copy(&cs);
380                self.alloc(Const::List(cs))
381            }
382            Const::Map { keys, values } => {
383                let keys: Vec<_> = keys.iter().map(|c| self.copy_const(src, c)).collect();
384                let keys = self.alloc_slice_copy(&keys);
385
386                let values: Vec<_> = values.iter().map(|c| self.copy_const(src, c)).collect();
387                let values = self.alloc_slice_copy(&values);
388
389                self.alloc(Const::Map { keys, values })
390            }
391            Const::Struct { fields, values } => {
392                let fields: Vec<_> = fields.iter().map(|s| self.alloc_str(s)).collect();
393                let fields = self.alloc_slice_copy(&fields);
394
395                let values: Vec<_> = values.iter().map(|c| self.copy_const(src, c)).collect();
396                let values = self.alloc_slice_copy(&values);
397
398                self.alloc(Const::Struct { fields, values })
399            }
400        }
401    }
402
403    pub fn copy_transform<'src>(
404        &'arena self,
405        src: &'src Arena,
406        stmt: &'src TransformStmt<'src>,
407    ) -> &'arena TransformStmt<'arena> {
408        let TransformStmt { var, app } = stmt;
409        let var = var.map(|s| self.alloc_str(s));
410        let app = self.copy_base_term(src, app);
411        self.alloc(TransformStmt { var, app })
412    }
413
414    pub fn copy_clause<'src>(
415        &'arena self,
416        src: &'src Arena,
417        src_clause: &'src Clause<'src>,
418    ) -> &'arena Clause<'arena> {
419        let Clause {
420            head,
421            premises,
422            transform,
423        } = src_clause;
424        let premises: Vec<_> = premises.iter().map(|x| self.copy_term(src, x)).collect();
425        let transform: Vec<_> = transform
426            .iter()
427            .map(|x| self.copy_transform(src, x))
428            .collect();
429        self.alloc(Clause {
430            head: self.copy_atom(src, head),
431            premises: self.alloc_slice_copy(&premises),
432            transform: self.alloc_slice_copy(&transform),
433        })
434    }
435
436    fn copy_term<'src>(
437        &'arena self,
438        src: &'src Arena,
439        term: &'src Term<'src>,
440    ) -> &'arena Term<'arena> {
441        match term {
442            Term::Atom(atom) => {
443                let atom = self.copy_atom(src, atom);
444                self.alloc(Term::Atom(atom))
445            }
446            Term::NegAtom(atom) => {
447                let atom = self.copy_atom(src, atom);
448                self.alloc(Term::NegAtom(atom))
449            }
450            Term::Eq(left, right) => {
451                let left = self.copy_base_term(src, left);
452                let right = self.copy_base_term(src, right);
453                self.alloc(Term::Eq(left, right))
454            }
455            Term::Ineq(left, right) => {
456                let left = self.copy_base_term(src, left);
457                let right = self.copy_base_term(src, right);
458                self.alloc(Term::Ineq(left, right))
459            }
460        }
461    }
462}
463
464// Immutable representation of syntax.
465// We use references instead of a smart pointer,
466// relying on an arena that holds everything.
467// This way we can use pattern matching.
468
469// Unit is a source unit.
470// It consists of declarations and clauses.
471// After parsing, all source units for a Mangle package
472// are merged into one, so unit can also be seen
473// as translation unit.
474#[derive(Debug)]
475pub struct Unit<'a> {
476    pub decls: &'a [&'a Decl<'a>],
477    pub clauses: &'a [&'a Clause<'a>],
478}
479
480// Predicate, package and use declarations.
481#[derive(Debug, Clone, Copy, PartialEq)]
482pub struct Decl<'a> {
483    pub atom: &'a Atom<'a>,
484    pub descr: &'a [&'a Atom<'a>],
485    pub bounds: Option<&'a [&'a BoundDecl<'a>]>,
486    pub constraints: Option<&'a Constraints<'a>>,
487}
488
489#[derive(Debug, PartialEq)]
490pub struct BoundDecl<'a> {
491    pub base_terms: &'a [&'a BaseTerm<'a>],
492}
493
494//
495#[derive(Debug, Clone, PartialEq)]
496pub struct Constraints<'a> {
497    // All of these must hold.
498    pub consequences: &'a [&'a Atom<'a>],
499    // In addition to consequences, at least one of these must hold.
500    pub alternatives: &'a [&'a [&'a Atom<'a>]],
501}
502
503#[derive(Debug)]
504pub struct Clause<'a> {
505    pub head: &'a Atom<'a>,
506    pub premises: &'a [&'a Term<'a>],
507    pub transform: &'a [&'a TransformStmt<'a>],
508}
509
510#[derive(Debug)]
511pub struct TransformStmt<'a> {
512    pub var: Option<&'a str>,
513    pub app: &'a BaseTerm<'a>,
514}
515
516// Term that may appear on righthand-side of a clause.
517#[derive(Debug, Clone, Copy, PartialEq, Eq)]
518pub enum Term<'a> {
519    Atom(&'a Atom<'a>),
520    NegAtom(&'a Atom<'a>),
521    Eq(&'a BaseTerm<'a>, &'a BaseTerm<'a>),
522    Ineq(&'a BaseTerm<'a>, &'a BaseTerm<'a>),
523}
524
525impl std::fmt::Display for Term<'_> {
526    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527        match self {
528            Term::Atom(atom) => write!(f, "{atom}"),
529            Term::NegAtom(atom) => write!(f, "!{atom}"),
530            Term::Eq(left, right) => write!(f, "{left} = {right}"),
531            Term::Ineq(left, right) => write!(f, "{left} != {right}"),
532        }
533    }
534}
535
536impl<'a> Term<'a> {
537    pub fn apply_subst(
538        &'a self,
539        arena: &'a Arena,
540        subst: &FxHashMap<u32, &'a BaseTerm<'a>>,
541    ) -> &'a Term<'a> {
542        &*arena.alloc(match self {
543            Term::Atom(atom) => Term::Atom(atom.apply_subst(arena, subst)),
544            Term::NegAtom(atom) => Term::NegAtom(atom.apply_subst(arena, subst)),
545            Term::Eq(left, right) => Term::Eq(
546                left.apply_subst(arena, subst),
547                right.apply_subst(arena, subst),
548            ),
549            Term::Ineq(left, right) => Term::Ineq(
550                left.apply_subst(arena, subst),
551                right.apply_subst(arena, subst),
552            ),
553        })
554    }
555}
556
557#[derive(Debug, Clone, Copy, PartialEq, Eq)]
558pub enum BaseTerm<'a> {
559    Const(Const<'a>),
560    Variable(VariableIndex),
561    ApplyFn(FunctionIndex, &'a [&'a BaseTerm<'a>]),
562}
563
564impl<'arena> BaseTerm<'arena> {
565    pub fn apply_subst(
566        &'arena self,
567        arena: &'arena Arena,
568        subst: &FxHashMap<u32, &'arena BaseTerm<'arena>>,
569    ) -> &'arena BaseTerm<'arena> {
570        match self {
571            BaseTerm::Const(_) => self,
572            BaseTerm::Variable(v) => subst.get(&v.0).unwrap_or(&self),
573            BaseTerm::ApplyFn(fun, args) => {
574                let args: Vec<&'arena BaseTerm<'arena>> = args
575                    .iter()
576                    .map(|arg| arg.apply_subst(arena, subst))
577                    .collect();
578                let args = arena.alloc_slice_copy(&args);
579                arena.alloc(BaseTerm::ApplyFn(*fun, args))
580            }
581        }
582    }
583}
584
585impl std::fmt::Display for BaseTerm<'_> {
586    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
587        match self {
588            BaseTerm::Const(c) => write!(f, "{c}"),
589            BaseTerm::Variable(v) => write!(f, "{v}"),
590            BaseTerm::ApplyFn(fun, args) => {
591                write!(
592                    f,
593                    "{fun}({})",
594                    args.iter()
595                        .map(|x| x.to_string())
596                        .collect::<Vec<_>>()
597                        .join(",")
598                )
599            }
600        }
601    }
602}
603#[derive(Debug, Clone, Copy, PartialEq)]
604pub enum Const<'a> {
605    Name(u32),
606    Bool(bool),
607    Number(i64),
608    Float(f64),
609    String(&'a str),
610    Bytes(&'a [u8]),
611    List(&'a [&'a Const<'a>]),
612    Map {
613        keys: &'a [&'a Const<'a>],
614        values: &'a [&'a Const<'a>],
615    },
616    Struct {
617        fields: &'a [&'a str],
618        values: &'a [&'a Const<'a>],
619    },
620}
621
622impl Eq for Const<'_> {}
623
624impl std::fmt::Display for Const<'_> {
625    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
626        match *self {
627            Const::Name(v) => write!(f, "n${v}"),
628            Const::Bool(v) => write!(f, "{v}"),
629            Const::Number(v) => write!(f, "{v}"),
630            Const::Float(v) => write!(f, "{v}"),
631            Const::String(v) => write!(f, "{v}"),
632            Const::Bytes(v) => write!(f, "{v:?}"),
633            Const::List(v) => {
634                write!(
635                    f,
636                    "[{}]",
637                    v.iter()
638                        .map(|x| x.to_string())
639                        .collect::<Vec<_>>()
640                        .join(", ")
641                )
642            }
643            Const::Map { keys, values } => {
644                if keys.is_empty() {
645                    write!(f, "fn:map()")
646                } else {
647                    write!(f, "[")?;
648                    for (i, (k, v)) in keys.iter().zip(values.iter()).enumerate() {
649                        if i > 0 {
650                            write!(f, ", ")?;
651                        }
652                        write!(f, "{k}: {v}")?;
653                    }
654                    write!(f, "]")
655                }
656            }
657            Const::Struct { fields, values } => {
658                write!(f, "{{")?;
659                for (i, (field, val)) in fields.iter().zip(values.iter()).enumerate() {
660                    if i > 0 {
661                        write!(f, ", ")?;
662                    }
663                    write!(f, "{field}: {val}")?;
664                }
665                write!(f, "}}")
666            }
667        }
668    }
669}
670
671#[derive(Debug, Clone, PartialEq, Eq, Hash)]
672pub struct PredicateSym {
673    pub name: u32,
674    pub arity: Option<u8>,
675}
676
677#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
678pub struct PredicateIndex(usize);
679
680impl std::fmt::Display for PredicateIndex {
681    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
682        write!(f, "p${}", self.0)
683    }
684}
685
686#[derive(Debug, Clone, PartialEq, Eq, Hash)]
687pub struct FunctionSym {
688    pub name: u32,
689    pub arity: Option<u8>,
690}
691
692#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
693pub struct FunctionIndex(usize);
694
695impl std::fmt::Display for FunctionIndex {
696    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
697        write!(f, "f${}", self.0)
698    }
699}
700
701#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
702pub struct VariableIndex(pub u32);
703
704impl std::fmt::Display for VariableIndex {
705    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
706        if self.0 == 0 {
707            write!(f, "_")
708        } else {
709            write!(f, "v${}", self.0)
710        }
711    }
712}
713
714#[derive(Debug, Clone, Copy, PartialEq, Eq)]
715pub struct Atom<'a> {
716    pub sym: PredicateIndex,
717    pub args: &'a [&'a BaseTerm<'a>],
718}
719
720impl<'a> Atom<'a> {
721    // A fact matches a query if there is a substitution s.t. subst(query) = fact.
722    // We assume that self is a fact and query_args are the arguments of a query,
723    // i.e. only variables and constants.
724    pub fn matches(&'a self, query_args: &[&BaseTerm]) -> bool {
725        for (fact_arg, query_arg) in self.args.iter().zip(query_args.iter()) {
726            if let BaseTerm::Const(_) = query_arg
727                && fact_arg != query_arg
728            {
729                return false;
730            }
731        }
732        true
733    }
734
735    pub fn apply_subst(
736        &'a self,
737        arena: &'a Arena,
738        subst: &FxHashMap<u32, &'a BaseTerm<'a>>,
739    ) -> &'a Atom<'a> {
740        let args: Vec<&'a BaseTerm<'a>> = self
741            .args
742            .iter()
743            .map(|arg| arg.apply_subst(arena, subst))
744            .collect();
745        arena.atom(self.sym, &args)
746    }
747}
748
749impl std::fmt::Display for Atom<'_> {
750    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
751        write!(f, "{}(", self.sym)?;
752        for (i, arg) in self.args.iter().enumerate() {
753            if i > 0 {
754                write!(f, ", ")?;
755            }
756            write!(f, "{arg}")?;
757        }
758        write!(f, ")")
759    }
760}
761
762static ANY_VAR_TERM: BaseTerm = BaseTerm::Variable(VariableIndex(0));
763
764#[cfg(test)]
765mod tests {
766    use super::*;
767    use googletest::prelude::*;
768
769    #[test]
770    fn copying_atom_works() {
771        let arena = Arena::new_with_global_interner();
772        let foo = arena.const_(arena.name("/foo"));
773        let bar = arena.predicate_sym("bar", Some(1));
774        let head = arena.atom(bar, &[foo]);
775        assert_that!(head.to_string(), eq("p$0(n$1)"));
776    }
777
778    #[test]
779    fn atom_display_works() {
780        let arena = Arena::new_with_global_interner();
781        let bar = arena.const_(arena.name("/bar"));
782        let sym = arena.predicate_sym("foo", Some(1));
783        let atom = Atom { sym, args: &[bar] };
784        assert_that!(atom, displays_as(eq("p$0(n$1)")));
785
786        let tests = vec![
787            (Term::Atom(&atom), "p$0(n$1)"),
788            (Term::NegAtom(&atom), "!p$0(n$1)"),
789            (Term::Eq(bar, bar), "n$1 = n$1"),
790            (Term::Ineq(bar, bar), "n$1 != n$1"),
791        ];
792        for (term, s) in tests {
793            assert_that!(term, displays_as(eq(s)));
794        }
795    }
796
797    #[test]
798    fn new_query_works() {
799        let arena = Arena::new_with_global_interner();
800
801        let pred = arena.predicate_sym("foo", Some(1));
802        let query = arena.new_query(pred);
803        assert_that!(query, displays_as(eq("p$0(_)")));
804
805        let pred = arena.predicate_sym("bar", Some(2));
806        let query = arena.new_query(pred);
807        assert_that!(query, displays_as(eq("p$1(_, _)")));
808
809        let pred = arena.predicate_sym("frob", None);
810        let query = arena.new_query(pred);
811        assert_that!(query, displays_as(eq("p$2()")));
812    }
813
814    #[test]
815    fn subst_works() {
816        let arena = Arena::new_with_global_interner();
817        let atom = arena.atom(arena.predicate_sym("foo", Some(1)), &[arena.variable("x")]);
818
819        let mut subst = FxHashMap::default();
820        subst.insert(arena.variable_sym("x").0, arena.const_(arena.name("/bar")));
821
822        let subst_atom = atom.apply_subst(&arena, &subst);
823        assert_that!(arena.name("/bar"), displays_as(eq("n$3")));
824        assert_that!(subst_atom, displays_as(eq("p$0(n$3)")));
825    }
826
827    #[test]
828    fn do_intern_beyond_initial_capacity() {
829        let arena = Arena::new_with_global_interner();
830
831        let p = arena.predicate_sym("/foo", Some(1));
832        let mut name = "".to_string();
833        for _ in 0..INTERNER_DEFAULT_CAPACITY.into() {
834            name += "a";
835        }
836        arena.interner.lock().unwrap().intern(&name);
837        assert_that!(arena.predicate_name(p), eq(Some("/foo")));
838    }
839
840    #[test]
841    fn pretty_print_works() {
842        let arena = Arena::new_with_global_interner();
843        let foo = arena.const_(arena.name("/foo"));
844        let bar_pred = arena.predicate_sym("bar", Some(1));
845        let x_var = arena.variable("X");
846        let head = arena.atom(bar_pred, &[x_var]); // bar(X)
847
848        let premise = Term::Eq(x_var, foo);
849        let premise_ref = arena.alloc(premise);
850
851        let clause = Clause {
852            head,
853            premises: arena.alloc_slice_copy(&[premise_ref]),
854            transform: &[],
855        };
856
857        assert_that!(clause.pretty(&arena).to_string(), eq("bar(X) :- X = /foo."));
858
859        let fun = arena.function_sym("f", Some(1));
860        let app = arena.apply_fn(fun, &[x_var]);
861        assert_that!(app.pretty(&arena).to_string(), eq("f(X)"));
862    }
863}