1use itertools::Itertools;
2
3use super::ident::{Ident, IdentCtx};
4use super::lit::{LitType, LitVal};
5
6use std::collections::HashMap;
7use std::convert::Infallible;
8use std::fmt;
9
10#[derive(Debug, Clone, PartialEq)]
11pub enum Term<V, L, C> {
12 Var(V),
13 Lit(L),
14 Cons(C, Vec<Term<V, L, C>>),
15}
16
17impl<V: fmt::Display, L: fmt::Display, C: fmt::Display> fmt::Display for Term<V, L, C> {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 match self {
20 Term::Var(var) => fmt::Display::fmt(&var, f),
21 Term::Lit(lit) => fmt::Display::fmt(&lit, f),
22 Term::Cons(cons, flds) => {
23 if flds.is_empty() && !format!("{cons}").is_empty() {
24 fmt::Display::fmt(&cons, f)
25 } else {
26 let flds = flds.iter().format(", ");
27 write!(f, "{cons}({flds})")
28 }
29 }
30 }
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum OptCons<T> {
36 Some(T), None, }
39
40impl<T: fmt::Display> fmt::Display for OptCons<T> {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 OptCons::Some(cons) => fmt::Display::fmt(cons, f),
44 OptCons::None => Ok(()), }
46 }
47}
48
49pub type TermVal<V = Ident> = Term<V, LitVal, OptCons<Ident>>;
50pub type AtomVal<V = Ident> = Term<V, LitVal, Infallible>;
51pub type TermType<V = Ident> = Term<V, LitType, OptCons<Ident>>;
52
53impl<V, L, C> Term<V, L, C> {
54 pub fn is_var(&self) -> bool {
55 matches!(self, Term::Var(_))
56 }
57
58 pub fn is_lit(&self) -> bool {
59 matches!(self, Term::Lit(_))
60 }
61
62 pub fn is_cons(&self) -> bool {
63 matches!(self, Term::Cons(_, _))
64 }
65
66 pub fn height(&self) -> usize {
67 match self {
68 Term::Var(_) | Term::Lit(_) => 1,
69 Term::Cons(_cons, flds) => {
70 let max_fld = flds.iter().map(Term::height).max().unwrap_or(0);
71 max_fld + 1
72 }
73 }
74 }
75
76 pub fn size(&self) -> usize {
77 match self {
78 Term::Var(_) | Term::Lit(_) => 1,
79 Term::Cons(_cons, flds) => {
80 let sum_fld: usize = flds.iter().map(Term::size).sum();
81 sum_fld + 1
82 }
83 }
84 }
85}
86
87impl<L: Copy, C: Copy> Term<Ident, L, C> {
88 pub fn tag_ctx(&self, ctx: usize) -> Term<IdentCtx, L, C> {
89 match self {
90 Term::Var(var) => Term::Var(var.tag_ctx(ctx)),
91 Term::Lit(lit) => Term::Lit(*lit),
92 Term::Cons(cons, flds) => {
93 let flds = flds.iter().map(|fld| fld.tag_ctx(ctx)).collect();
94 Term::Cons(*cons, flds)
95 }
96 }
97 }
98}
99
100impl<V: Copy, L: Copy, C: Copy> Term<V, L, C> {
101 pub fn to_atom(&self) -> Option<Term<V, L, Infallible>> {
102 match self {
103 Term::Var(var) => Some(Term::Var(*var)),
104 Term::Lit(lit) => Some(Term::Lit(*lit)),
105 Term::Cons(_cons, _flds) => None,
106 }
107 }
108}
109
110impl<V: Copy, L: Copy> Term<V, L, Infallible> {
111 pub fn to_term<C>(&self) -> Term<V, L, C> {
112 match self {
113 Term::Var(var) => Term::Var(*var),
114 Term::Lit(lit) => Term::Lit(*lit),
115 Term::Cons(_cons, _flds) => unreachable!(),
116 }
117 }
118}
119
120impl<V: Copy + Eq, L, C> Term<V, L, C> {
121 pub fn occurs(&self, x: &V) -> bool {
122 match self {
123 Term::Var(y) => x == y,
124 Term::Lit(_) => false,
125 Term::Cons(_cons, flds) => flds.iter().any(|fld| fld.occurs(x)),
126 }
127 }
128
129 pub fn free_vars(&self) -> Vec<V> {
130 let mut vec = Vec::new();
131 self.free_vars_help(&mut vec);
132 vec
133 }
134
135 fn free_vars_help(&self, vec: &mut Vec<V>) {
136 match self {
137 Term::Var(var) => {
138 if !vec.contains(var) {
139 vec.push(*var);
140 }
141 }
142 Term::Lit(_lit) => {}
143 Term::Cons(_cons, flds) => {
144 for fld in flds {
145 fld.free_vars_help(vec);
146 }
147 }
148 }
149 }
150}
151
152impl<V: Copy + Eq + std::hash::Hash, L: Copy, C: Copy> Term<V, L, C> {
153 pub fn substitute(&self, map: &HashMap<V, Term<V, L, C>>) -> Term<V, L, C> {
154 match self {
155 Term::Var(var) => {
156 if let Some(term) = map.get(var) {
157 term.clone()
158 } else {
159 Term::Var(*var)
160 }
161 }
162 Term::Lit(lit) => Term::Lit(*lit),
163 Term::Cons(cons, flds) => {
164 let flds = flds.iter().map(|fld| fld.substitute(map)).collect();
165 Term::Cons(*cons, flds)
166 }
167 }
168 }
169}