1use itertools::Itertools;
2
3use super::ident::{Ident, IdentCtx};
4use super::lit::{LitType, LitVal};
5
6use std::collections::HashMap;
7use std::convert::Infallible;
8use std::fmt;
9
10#[derive(Debug, Clone, PartialEq)]
11pub enum Term<V, L, C> {
12 Var(V),
13 Lit(L),
14 Cons(C, Vec<Term<V, L, C>>),
15}
16
17impl<V: fmt::Display, L: fmt::Display, C: fmt::Display> fmt::Display for Term<V, L, C> {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 match self {
20 Term::Var(var) => fmt::Display::fmt(&var, f),
21 Term::Lit(lit) => fmt::Display::fmt(&lit, f),
22 Term::Cons(cons, flds) => {
23 if flds.is_empty() && !format!("{}", cons).is_empty() {
24 fmt::Display::fmt(&cons, f)
25 } else {
26 let flds = flds.iter().format(", ");
27 write!(f, "{cons}({flds})")
28 }
29 }
30 }
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum OptCons<T> {
36 Some(T), None, }
39
40impl<T: fmt::Display> fmt::Display for OptCons<T> {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 OptCons::Some(cons) => fmt::Display::fmt(cons, f),
44 OptCons::None => Ok(()), }
46 }
47}
48
49pub type TermVal<V = Ident> = Term<V, LitVal, OptCons<Ident>>;
50pub type AtomVal<V = Ident> = Term<V, LitVal, Infallible>;
51pub type TermType<V = Ident> = Term<V, LitType, OptCons<Ident>>;
52
53impl<V, L, C> Term<V, L, C> {
54 pub fn is_var(&self) -> bool {
55 matches!(self, Term::Var(_))
56 }
57
58 pub fn is_lit(&self) -> bool {
59 matches!(self, Term::Lit(_))
60 }
61
62 pub fn is_cons(&self) -> bool {
63 matches!(self, Term::Cons(_, _))
64 }
65
66 pub fn height(&self) -> usize {
67 match self {
68 Term::Var(_) => 1,
69 Term::Lit(_) => 1,
70 Term::Cons(_cons, flds) => {
71 let max_fld = flds.iter().map(|fld| fld.height()).max().unwrap_or(0);
72 max_fld + 1
73 }
74 }
75 }
76
77 pub fn size(&self) -> usize {
78 match self {
79 Term::Var(_) => 1,
80 Term::Lit(_) => 1,
81 Term::Cons(_cons, flds) => {
82 let sum_fld: usize = flds.iter().map(|fld| fld.size()).sum();
83 sum_fld + 1
84 }
85 }
86 }
87}
88
89impl<L: Copy, C: Copy> Term<Ident, L, C> {
90 pub fn tag_ctx(&self, ctx: usize) -> Term<IdentCtx, L, C> {
91 match self {
92 Term::Var(var) => Term::Var(var.tag_ctx(ctx)),
93 Term::Lit(lit) => Term::Lit(*lit),
94 Term::Cons(cons, flds) => {
95 let flds = flds.iter().map(|fld| fld.tag_ctx(ctx)).collect();
96 Term::Cons(*cons, flds)
97 }
98 }
99 }
100}
101
102impl<V: Copy, L: Copy, C: Copy> Term<V, L, C> {
103 pub fn to_atom(&self) -> Option<Term<V, L, Infallible>> {
104 match self {
105 Term::Var(var) => Some(Term::Var(*var)),
106 Term::Lit(lit) => Some(Term::Lit(*lit)),
107 Term::Cons(_cons, _flds) => None,
108 }
109 }
110}
111
112impl<V: Copy, L: Copy> Term<V, L, Infallible> {
113 pub fn to_term<C>(&self) -> Term<V, L, C> {
114 match self {
115 Term::Var(var) => Term::Var(*var),
116 Term::Lit(lit) => Term::Lit(*lit),
117 Term::Cons(_cons, _flds) => unreachable!(),
118 }
119 }
120}
121
122impl<V: Copy + Eq, L, C> Term<V, L, C> {
123 pub fn occurs(&self, x: &V) -> bool {
124 match self {
125 Term::Var(y) => x == y,
126 Term::Lit(_) => false,
127 Term::Cons(_cons, flds) => flds.iter().any(|fld| fld.occurs(x)),
128 }
129 }
130
131 pub fn free_vars(&self) -> Vec<V> {
132 let mut vec = Vec::new();
133 self.free_vars_help(&mut vec);
134 vec
135 }
136
137 fn free_vars_help(&self, vec: &mut Vec<V>) {
138 match self {
139 Term::Var(var) => {
140 if !vec.contains(var) {
141 vec.push(*var);
142 }
143 }
144 Term::Lit(_lit) => {}
145 Term::Cons(_cons, flds) => {
146 flds.iter().for_each(|fld| fld.free_vars_help(vec));
147 }
148 }
149 }
150}
151
152impl<V: Copy + Eq + std::hash::Hash, L: Copy, C: Copy> Term<V, L, C> {
153 pub fn substitute(&self, map: &HashMap<V, Term<V, L, C>>) -> Term<V, L, C> {
154 match self {
155 Term::Var(var) => {
156 if let Some(term) = map.get(var) {
157 term.clone()
158 } else {
159 Term::Var(*var)
160 }
161 }
162 Term::Lit(lit) => Term::Lit(*lit),
163 Term::Cons(cons, flds) => {
164 let flds = flds.iter().map(|fld| fld.substitute(map)).collect();
165 Term::Cons(*cons, flds)
166 }
167 }
168 }
169}