1use itertools::Itertools;
2
3use super::ident::{Ident, IdentCtx};
4use super::lit::{LitType, LitVal};
5
6use std::collections::HashMap;
7use std::convert::Infallible;
8use std::fmt;
9
10#[derive(Debug, Clone, PartialEq)]
11pub enum Term<V, L, C> {
12 Var(V),
13 Lit(L),
14 Cons(C, Vec<Term<V, L, C>>),
15}
16
17impl<V: fmt::Display, L: fmt::Display, C: fmt::Display> fmt::Display for Term<V, L, C> {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 match self {
20 Term::Var(var) => fmt::Display::fmt(&var, f),
21 Term::Lit(lit) => fmt::Display::fmt(&lit, f),
22 Term::Cons(cons, flds) => {
23 if flds.is_empty() && !format!("{}", cons).is_empty() {
24 fmt::Display::fmt(&cons, f)
25 } else {
26 let flds = flds.iter().format(", ");
27 write!(f, "{cons}({flds})")
28 }
29 }
30 }
31 }
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum OptCons<T> {
36 Some(T), None, }
39
40impl<T: fmt::Display> fmt::Display for OptCons<T> {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 OptCons::Some(cons) => fmt::Display::fmt(cons, f),
44 OptCons::None => Ok(()), }
46 }
47}
48
49pub type TermId = Term<Ident, LitVal, OptCons<Ident>>;
50pub type TermCtx = Term<IdentCtx, LitVal, OptCons<Ident>>;
51pub type AtomId = Term<Ident, LitVal, Infallible>;
52pub type AtomCtx = Term<IdentCtx, LitVal, Infallible>;
53pub type TypeId = Term<Ident, LitType, OptCons<Ident>>;
54
55impl<V, L, C> Term<V, L, C> {
56 pub fn is_var(&self) -> bool {
57 matches!(self, Term::Var(_))
58 }
59
60 pub fn is_lit(&self) -> bool {
61 matches!(self, Term::Lit(_))
62 }
63
64 pub fn is_cons(&self) -> bool {
65 matches!(self, Term::Cons(_, _))
66 }
67}
68
69impl<L: Copy, C: Copy> Term<Ident, L, C> {
70 pub fn tag_ctx(&self, ctx: usize) -> Term<IdentCtx, L, C> {
71 match self {
72 Term::Var(var) => Term::Var(var.tag_ctx(ctx)),
73 Term::Lit(lit) => Term::Lit(*lit),
74 Term::Cons(cons, flds) => {
75 let flds = flds.iter().map(|fld| fld.tag_ctx(ctx)).collect();
76 Term::Cons(*cons, flds)
77 }
78 }
79 }
80}
81
82impl<V: Copy, L: Copy, C: Copy> Term<V, L, C> {
83 pub fn to_atom(&self) -> Option<Term<V, L, Infallible>> {
84 match self {
85 Term::Var(var) => Some(Term::Var(*var)),
86 Term::Lit(lit) => Some(Term::Lit(*lit)),
87 Term::Cons(_cons, _flds) => None,
88 }
89 }
90}
91
92impl<V: Copy, L: Copy> Term<V, L, Infallible> {
93 pub fn to_term<C>(&self) -> Term<V, L, C> {
94 match self {
95 Term::Var(var) => Term::Var(*var),
96 Term::Lit(lit) => Term::Lit(*lit),
97 Term::Cons(_cons, _flds) => unreachable!(),
98 }
99 }
100}
101
102impl<V: Copy + Eq, L, C> Term<V, L, C> {
103 pub fn occurs(&self, x: &V) -> bool {
104 match self {
105 Term::Var(y) => x == y,
106 Term::Lit(_) => false,
107 Term::Cons(_cons, flds) => flds.iter().any(|fld| fld.occurs(x)),
108 }
109 }
110
111 pub fn free_vars(&self) -> Vec<V> {
112 let mut vec = Vec::new();
113 self.free_vars_help(&mut vec);
114 vec
115 }
116
117 fn free_vars_help(&self, vec: &mut Vec<V>) {
118 match self {
119 Term::Var(var) => {
120 if !vec.contains(var) {
121 vec.push(*var);
122 }
123 }
124 Term::Lit(_lit) => {}
125 Term::Cons(_cons, flds) => {
126 flds.iter().for_each(|fld| fld.free_vars_help(vec));
127 }
128 }
129 }
130}
131
132impl<V: Copy + Eq + std::hash::Hash, L: Copy, C: Copy> Term<V, L, C> {
133 pub fn substitute(&self, map: &HashMap<V, Term<V, L, C>>) -> Term<V, L, C> {
134 match self {
135 Term::Var(var) => {
136 if let Some(term) = map.get(var) {
137 term.clone()
138 } else {
139 Term::Var(*var)
140 }
141 }
142 Term::Lit(lit) => Term::Lit(*lit),
143 Term::Cons(cons, flds) => {
144 let flds = flds.iter().map(|fld| fld.substitute(map)).collect();
145 Term::Cons(*cons, flds)
146 }
147 }
148 }
149}