kodept_inference/
type.rs

1use std::collections::HashSet;
2use std::fmt::{Debug, Display, Formatter};
3use std::ops::BitAnd;
4
5use derive_more::{Constructor, Display as DeriveDisplay, From};
6use itertools::{concat, Itertools};
7use nonempty_collections::NEVec;
8
9use crate::InferState;
10use crate::substitution::Substitutions;
11use crate::traits::{FreeTypeVars, Substitutable};
12
13#[allow(dead_code)]
14fn expand_to_string(id: usize, alphabet: &'static str) -> String {
15    if id == 0 {
16        return alphabet
17            .chars()
18            .next()
19            .expect("Alphabet should contain at least one letter")
20            .to_string();
21    }
22
23    let alphabet: Vec<_> = alphabet.chars().collect();
24    let mut current = id;
25    let mut result = String::new();
26    while current > 0 {
27        result.push(alphabet[current % alphabet.len()]);
28        current /= alphabet.len();
29    }
30    result
31}
32
33#[derive(Clone, PartialEq, DeriveDisplay, Eq, Hash)]
34pub enum PrimitiveType {
35    Integral,
36    Floating,
37    Boolean,
38}
39
40#[derive(Copy, Clone, PartialEq, Hash, Eq, From)]
41pub struct TVar(pub(crate) usize);
42
43#[derive(Clone, PartialEq, Eq, Hash, Constructor)]
44
45pub struct Tuple(pub(crate) Vec<MonomorphicType>);
46
47#[derive(Clone, PartialEq, From, Eq, Hash)]
48pub enum MonomorphicType {
49    Primitive(PrimitiveType),
50    Var(TVar),
51    #[from(ignore)]
52    Fn(Box<MonomorphicType>, Box<MonomorphicType>),
53    Tuple(Tuple),
54    Pointer(Box<MonomorphicType>),
55    Constant(String),
56}
57
58#[derive(Clone, PartialEq)]
59pub struct PolymorphicType {
60    pub(crate) bindings: Vec<TVar>,
61    pub(crate) binding_type: MonomorphicType,
62}
63
64pub fn fun1<M: Into<MonomorphicType>, N: Into<MonomorphicType>>(
65    input: N,
66    output: M,
67) -> MonomorphicType {
68    MonomorphicType::Fn(Box::new(input.into()), Box::new(output.into()))
69}
70
71pub fn fun<M: Into<MonomorphicType>>(input: NEVec<MonomorphicType>, output: M) -> MonomorphicType {
72    match (input.head, input.tail.as_slice()) {
73        (x, []) => fun1(x, output),
74        (x, [xs @ .., last]) => fun1(
75            x,
76            xs.iter().fold(fun1(last.clone(), output), |acc, next| {
77                fun1(next.clone(), acc)
78            }),
79        ),
80    }
81}
82
83pub fn var<V: Into<TVar>>(id: V) -> MonomorphicType {
84    MonomorphicType::Var(id.into())
85}
86
87pub fn unit_type() -> MonomorphicType {
88    MonomorphicType::Tuple(Tuple(vec![]))
89}
90
91impl MonomorphicType {
92    fn rename(self, old: usize, new: usize) -> Self {
93        match self {
94            MonomorphicType::Var(TVar(id)) if id == old => TVar(new).into(),
95            MonomorphicType::Primitive(_)
96            | MonomorphicType::Var(_)
97            | MonomorphicType::Constant(_) => self,
98            MonomorphicType::Fn(input, output) => MonomorphicType::Fn(
99                Box::new(input.rename(old, new)),
100                Box::new(output.rename(old, new)),
101            ),
102            MonomorphicType::Tuple(Tuple(vec)) => MonomorphicType::Tuple(Tuple(
103                vec.into_iter().map(|it| it.rename(old, new)).collect(),
104            )),
105            MonomorphicType::Pointer(t) => MonomorphicType::Pointer(Box::new(t.rename(old, new))),
106        }
107    }
108
109    fn extract_vars(&self) -> Vec<usize> {
110        match self {
111            MonomorphicType::Primitive(_) | MonomorphicType::Constant(_) => vec![],
112            MonomorphicType::Var(TVar(x)) => vec![*x],
113            MonomorphicType::Fn(input, output) => {
114                concat([input.extract_vars(), output.extract_vars()])
115            }
116            MonomorphicType::Tuple(Tuple(vec)) => {
117                vec.iter().flat_map(MonomorphicType::extract_vars).collect()
118            }
119            MonomorphicType::Pointer(t) => t.extract_vars(),
120        }
121    }
122
123    pub fn generalize(&self, free: &HashSet<TVar>) -> PolymorphicType {
124        let diff: Vec<_> = self.free_types().difference(free).copied().collect();
125        PolymorphicType {
126            bindings: diff,
127            binding_type: self.clone(),
128        }
129    }
130
131    pub fn normalize(self) -> PolymorphicType {
132        self.generalize(&HashSet::new()).normalize()
133    }
134}
135
136impl PolymorphicType {
137    pub(crate) fn normalize(self) -> Self {
138        let mut free = self.binding_type.extract_vars();
139        free.sort_unstable();
140        free.dedup();
141        let len = free.len();
142        let binding_type = free
143            .iter()
144            .zip(0usize..)
145            .fold(self.binding_type, |acc, (&old, new)| acc.rename(old, new));
146        let bindings = free
147            .into_iter()
148            .zip(0usize..len)
149            .map(|it| TVar(it.1))
150            .collect();
151        Self {
152            bindings,
153            binding_type,
154        }
155    }
156
157    pub(crate) fn instantiate(&self, env: &mut InferState) -> MonomorphicType {
158        let fresh = self.bindings.iter().map(|it| (*it, env.new_var()));
159        let s0 = Substitutions::from_iter(fresh);
160        self.binding_type.substitute(&s0)
161    }
162}
163
164impl<S: Into<MonomorphicType>> From<S> for PolymorphicType {
165    fn from(value: S) -> Self {
166        Self {
167            bindings: vec![],
168            binding_type: value.into(),
169        }
170    }
171}
172
173impl BitAnd<Substitutions> for PolymorphicType {
174    type Output = PolymorphicType;
175
176    fn bitand(self, rhs: Substitutions) -> Self::Output {
177        self.substitute(&rhs)
178    }
179}
180
181impl BitAnd<&Substitutions> for PolymorphicType {
182    type Output = PolymorphicType;
183
184    fn bitand(self, rhs: &Substitutions) -> Self::Output {
185        self.substitute(rhs)
186    }
187}
188
189impl BitAnd<Substitutions> for &PolymorphicType {
190    type Output = PolymorphicType;
191
192    fn bitand(self, rhs: Substitutions) -> Self::Output {
193        self.substitute(&rhs)
194    }
195}
196
197impl BitAnd<&Substitutions> for &PolymorphicType {
198    type Output = PolymorphicType;
199
200    fn bitand(self, rhs: &Substitutions) -> Self::Output {
201        self.substitute(rhs)
202    }
203}
204
205impl BitAnd<Substitutions> for MonomorphicType {
206    type Output = MonomorphicType;
207
208    fn bitand(self, rhs: Substitutions) -> Self::Output {
209        self.substitute(&rhs)
210    }
211}
212
213impl BitAnd<&Substitutions> for MonomorphicType {
214    type Output = MonomorphicType;
215
216    fn bitand(self, rhs: &Substitutions) -> Self::Output {
217        self.substitute(rhs)
218    }
219}
220
221impl BitAnd<Substitutions> for &MonomorphicType {
222    type Output = MonomorphicType;
223
224    fn bitand(self, rhs: Substitutions) -> Self::Output {
225        self.substitute(&rhs)
226    }
227}
228
229impl BitAnd<&Substitutions> for &MonomorphicType {
230    type Output = MonomorphicType;
231
232    fn bitand(self, rhs: &Substitutions) -> Self::Output {
233        self.substitute(rhs)
234    }
235}
236
237impl Tuple {
238    #[must_use]
239    pub const fn unit() -> Tuple {
240        Tuple(vec![])
241    }
242
243    pub fn push(&mut self, value: MonomorphicType) {
244        self.0.push(value);
245    }
246}
247
248impl Display for TVar {
249    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
250        write!(f, "τ{}", self.0)
251    }
252}
253
254impl Display for MonomorphicType {
255    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256        match self {
257            MonomorphicType::Primitive(p) => write!(f, "{p}"),
258            MonomorphicType::Var(v) => write!(f, "{v}"),
259            MonomorphicType::Fn(input, output) => match input.as_ref() {
260                MonomorphicType::Fn(_, _) => write!(f, "({input}) -> {output}"),
261                _ => write!(f, "{input} -> {output}"),
262            },
263            MonomorphicType::Tuple(Tuple(vec)) => write!(f, "({})", vec.iter().join(", ")),
264            MonomorphicType::Pointer(t) => write!(f, "*{t}"),
265            MonomorphicType::Constant(id) => write!(f, "{id}"),
266        }
267    }
268}
269
270impl Display for PolymorphicType {
271    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
272        if self.bindings.is_empty() {
273            return write!(f, "{}", self.binding_type);
274        }
275        write!(
276            f,
277            "∀{} => {}",
278            self.bindings.iter().join(", "),
279            self.binding_type
280        )
281    }
282}
283
284impl Debug for TVar {
285    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
286        write!(f, "{self}")
287    }
288}
289
290impl Debug for MonomorphicType {
291    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
292        write!(f, "{self}")
293    }
294}
295
296impl Debug for PolymorphicType {
297    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
298        write!(f, "{self}")
299    }
300}