Skip to main content

prune_lang/utils/
term.rs

1use itertools::Itertools;
2
3use super::ident::{Ident, IdentCtx};
4use super::lit::{LitType, LitVal};
5
6use std::collections::HashMap;
7use std::convert::Infallible;
8use std::fmt;
9
10#[derive(Debug, Clone, PartialEq)]
11pub enum Term<V, L, C> {
12    Var(V),
13    Lit(L),
14    Cons(C, Vec<Term<V, L, C>>),
15}
16
17impl<V: fmt::Display, L: fmt::Display, C: fmt::Display> fmt::Display for Term<V, L, C> {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        match self {
20            Term::Var(var) => fmt::Display::fmt(&var, f),
21            Term::Lit(lit) => fmt::Display::fmt(&lit, f),
22            Term::Cons(cons, flds) => {
23                if flds.is_empty() && !format!("{}", cons).is_empty() {
24                    fmt::Display::fmt(&cons, f)
25                } else {
26                    let flds = flds.iter().format(", ");
27                    write!(f, "{cons}({flds})")
28                }
29            }
30        }
31    }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum OptCons<T> {
36    Some(T), // constructors
37    None,    // placeholder for tuples (without constructor)
38}
39
40impl<T: fmt::Display> fmt::Display for OptCons<T> {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            OptCons::Some(cons) => fmt::Display::fmt(cons, f),
44            OptCons::None => Ok(()), // tuples' placeholder won't be printed.
45        }
46    }
47}
48
49pub type TermVal<V = Ident> = Term<V, LitVal, OptCons<Ident>>;
50pub type AtomVal<V = Ident> = Term<V, LitVal, Infallible>;
51pub type TermType<V = Ident> = Term<V, LitType, OptCons<Ident>>;
52
53impl<V, L, C> Term<V, L, C> {
54    pub fn is_var(&self) -> bool {
55        matches!(self, Term::Var(_))
56    }
57
58    pub fn is_lit(&self) -> bool {
59        matches!(self, Term::Lit(_))
60    }
61
62    pub fn is_cons(&self) -> bool {
63        matches!(self, Term::Cons(_, _))
64    }
65
66    pub fn height(&self) -> usize {
67        match self {
68            Term::Var(_) => 1,
69            Term::Lit(_) => 1,
70            Term::Cons(_cons, flds) => {
71                let max_fld = flds.iter().map(|fld| fld.height()).max().unwrap_or(0);
72                max_fld + 1
73            }
74        }
75    }
76
77    pub fn size(&self) -> usize {
78        match self {
79            Term::Var(_) => 1,
80            Term::Lit(_) => 1,
81            Term::Cons(_cons, flds) => {
82                let sum_fld: usize = flds.iter().map(|fld| fld.size()).sum();
83                sum_fld + 1
84            }
85        }
86    }
87}
88
89impl<L: Copy, C: Copy> Term<Ident, L, C> {
90    pub fn tag_ctx(&self, ctx: usize) -> Term<IdentCtx, L, C> {
91        match self {
92            Term::Var(var) => Term::Var(var.tag_ctx(ctx)),
93            Term::Lit(lit) => Term::Lit(*lit),
94            Term::Cons(cons, flds) => {
95                let flds = flds.iter().map(|fld| fld.tag_ctx(ctx)).collect();
96                Term::Cons(*cons, flds)
97            }
98        }
99    }
100}
101
102impl<V: Copy, L: Copy, C: Copy> Term<V, L, C> {
103    pub fn to_atom(&self) -> Option<Term<V, L, Infallible>> {
104        match self {
105            Term::Var(var) => Some(Term::Var(*var)),
106            Term::Lit(lit) => Some(Term::Lit(*lit)),
107            Term::Cons(_cons, _flds) => None,
108        }
109    }
110}
111
112impl<V: Copy, L: Copy> Term<V, L, Infallible> {
113    pub fn to_term<C>(&self) -> Term<V, L, C> {
114        match self {
115            Term::Var(var) => Term::Var(*var),
116            Term::Lit(lit) => Term::Lit(*lit),
117            Term::Cons(_cons, _flds) => unreachable!(),
118        }
119    }
120}
121
122impl<V: Copy + Eq, L, C> Term<V, L, C> {
123    pub fn occurs(&self, x: &V) -> bool {
124        match self {
125            Term::Var(y) => x == y,
126            Term::Lit(_) => false,
127            Term::Cons(_cons, flds) => flds.iter().any(|fld| fld.occurs(x)),
128        }
129    }
130
131    pub fn free_vars(&self) -> Vec<V> {
132        let mut vec = Vec::new();
133        self.free_vars_help(&mut vec);
134        vec
135    }
136
137    fn free_vars_help(&self, vec: &mut Vec<V>) {
138        match self {
139            Term::Var(var) => {
140                if !vec.contains(var) {
141                    vec.push(*var);
142                }
143            }
144            Term::Lit(_lit) => {}
145            Term::Cons(_cons, flds) => {
146                flds.iter().for_each(|fld| fld.free_vars_help(vec));
147            }
148        }
149    }
150}
151
152impl<V: Copy + Eq + std::hash::Hash, L: Copy, C: Copy> Term<V, L, C> {
153    pub fn substitute(&self, map: &HashMap<V, Term<V, L, C>>) -> Term<V, L, C> {
154        match self {
155            Term::Var(var) => {
156                if let Some(term) = map.get(var) {
157                    term.clone()
158                } else {
159                    Term::Var(*var)
160                }
161            }
162            Term::Lit(lit) => Term::Lit(*lit),
163            Term::Cons(cons, flds) => {
164                let flds = flds.iter().map(|fld| fld.substitute(map)).collect();
165                Term::Cons(*cons, flds)
166            }
167        }
168    }
169}