1use std::collections::HashSet;
2use std::fmt::{Debug, Display, Formatter};
3use std::ops::BitAnd;
4
5use derive_more::{Constructor, Display as DeriveDisplay, From};
6use itertools::{concat, Itertools};
7use nonempty_collections::NEVec;
8
9use crate::InferState;
10use crate::substitution::Substitutions;
11use crate::traits::{FreeTypeVars, Substitutable};
12
13#[allow(dead_code)]
14fn expand_to_string(id: usize, alphabet: &'static str) -> String {
15 if id == 0 {
16 return alphabet
17 .chars()
18 .next()
19 .expect("Alphabet should contain at least one letter")
20 .to_string();
21 }
22
23 let alphabet: Vec<_> = alphabet.chars().collect();
24 let mut current = id;
25 let mut result = String::new();
26 while current > 0 {
27 result.push(alphabet[current % alphabet.len()]);
28 current /= alphabet.len();
29 }
30 result
31}
32
33#[derive(Clone, PartialEq, DeriveDisplay, Eq, Hash)]
34pub enum PrimitiveType {
35 Integral,
36 Floating,
37 Boolean,
38}
39
40#[derive(Copy, Clone, PartialEq, Hash, Eq, From)]
41pub struct TVar(pub(crate) usize);
42
43#[derive(Clone, PartialEq, Eq, Hash, Constructor)]
44
45pub struct Tuple(pub(crate) Vec<MonomorphicType>);
46
47#[derive(Clone, PartialEq, From, Eq, Hash)]
48pub enum MonomorphicType {
49 Primitive(PrimitiveType),
50 Var(TVar),
51 #[from(ignore)]
52 Fn(Box<MonomorphicType>, Box<MonomorphicType>),
53 Tuple(Tuple),
54 Pointer(Box<MonomorphicType>),
55 Constant(String),
56}
57
58#[derive(Clone, PartialEq)]
59pub struct PolymorphicType {
60 pub(crate) bindings: Vec<TVar>,
61 pub(crate) binding_type: MonomorphicType,
62}
63
64pub fn fun1<M: Into<MonomorphicType>, N: Into<MonomorphicType>>(
65 input: N,
66 output: M,
67) -> MonomorphicType {
68 MonomorphicType::Fn(Box::new(input.into()), Box::new(output.into()))
69}
70
71pub fn fun<M: Into<MonomorphicType>>(input: NEVec<MonomorphicType>, output: M) -> MonomorphicType {
72 match (input.head, input.tail.as_slice()) {
73 (x, []) => fun1(x, output),
74 (x, [xs @ .., last]) => fun1(
75 x,
76 xs.iter().fold(fun1(last.clone(), output), |acc, next| {
77 fun1(next.clone(), acc)
78 }),
79 ),
80 }
81}
82
83pub fn var<V: Into<TVar>>(id: V) -> MonomorphicType {
84 MonomorphicType::Var(id.into())
85}
86
87pub fn unit_type() -> MonomorphicType {
88 MonomorphicType::Tuple(Tuple(vec![]))
89}
90
91impl MonomorphicType {
92 fn rename(self, old: usize, new: usize) -> Self {
93 match self {
94 MonomorphicType::Var(TVar(id)) if id == old => TVar(new).into(),
95 MonomorphicType::Primitive(_)
96 | MonomorphicType::Var(_)
97 | MonomorphicType::Constant(_) => self,
98 MonomorphicType::Fn(input, output) => MonomorphicType::Fn(
99 Box::new(input.rename(old, new)),
100 Box::new(output.rename(old, new)),
101 ),
102 MonomorphicType::Tuple(Tuple(vec)) => MonomorphicType::Tuple(Tuple(
103 vec.into_iter().map(|it| it.rename(old, new)).collect(),
104 )),
105 MonomorphicType::Pointer(t) => MonomorphicType::Pointer(Box::new(t.rename(old, new))),
106 }
107 }
108
109 fn extract_vars(&self) -> Vec<usize> {
110 match self {
111 MonomorphicType::Primitive(_) | MonomorphicType::Constant(_) => vec![],
112 MonomorphicType::Var(TVar(x)) => vec![*x],
113 MonomorphicType::Fn(input, output) => {
114 concat([input.extract_vars(), output.extract_vars()])
115 }
116 MonomorphicType::Tuple(Tuple(vec)) => {
117 vec.iter().flat_map(MonomorphicType::extract_vars).collect()
118 }
119 MonomorphicType::Pointer(t) => t.extract_vars(),
120 }
121 }
122
123 pub fn generalize(&self, free: &HashSet<TVar>) -> PolymorphicType {
124 let diff: Vec<_> = self.free_types().difference(free).copied().collect();
125 PolymorphicType {
126 bindings: diff,
127 binding_type: self.clone(),
128 }
129 }
130
131 pub fn normalize(self) -> PolymorphicType {
132 self.generalize(&HashSet::new()).normalize()
133 }
134}
135
136impl PolymorphicType {
137 pub(crate) fn normalize(self) -> Self {
138 let mut free = self.binding_type.extract_vars();
139 free.sort_unstable();
140 free.dedup();
141 let len = free.len();
142 let binding_type = free
143 .iter()
144 .zip(0usize..)
145 .fold(self.binding_type, |acc, (&old, new)| acc.rename(old, new));
146 let bindings = free
147 .into_iter()
148 .zip(0usize..len)
149 .map(|it| TVar(it.1))
150 .collect();
151 Self {
152 bindings,
153 binding_type,
154 }
155 }
156
157 pub(crate) fn instantiate(&self, env: &mut InferState) -> MonomorphicType {
158 let fresh = self.bindings.iter().map(|it| (*it, env.new_var()));
159 let s0 = Substitutions::from_iter(fresh);
160 self.binding_type.substitute(&s0)
161 }
162}
163
164impl<S: Into<MonomorphicType>> From<S> for PolymorphicType {
165 fn from(value: S) -> Self {
166 Self {
167 bindings: vec![],
168 binding_type: value.into(),
169 }
170 }
171}
172
173impl BitAnd<Substitutions> for PolymorphicType {
174 type Output = PolymorphicType;
175
176 fn bitand(self, rhs: Substitutions) -> Self::Output {
177 self.substitute(&rhs)
178 }
179}
180
181impl BitAnd<&Substitutions> for PolymorphicType {
182 type Output = PolymorphicType;
183
184 fn bitand(self, rhs: &Substitutions) -> Self::Output {
185 self.substitute(rhs)
186 }
187}
188
189impl BitAnd<Substitutions> for &PolymorphicType {
190 type Output = PolymorphicType;
191
192 fn bitand(self, rhs: Substitutions) -> Self::Output {
193 self.substitute(&rhs)
194 }
195}
196
197impl BitAnd<&Substitutions> for &PolymorphicType {
198 type Output = PolymorphicType;
199
200 fn bitand(self, rhs: &Substitutions) -> Self::Output {
201 self.substitute(rhs)
202 }
203}
204
205impl BitAnd<Substitutions> for MonomorphicType {
206 type Output = MonomorphicType;
207
208 fn bitand(self, rhs: Substitutions) -> Self::Output {
209 self.substitute(&rhs)
210 }
211}
212
213impl BitAnd<&Substitutions> for MonomorphicType {
214 type Output = MonomorphicType;
215
216 fn bitand(self, rhs: &Substitutions) -> Self::Output {
217 self.substitute(rhs)
218 }
219}
220
221impl BitAnd<Substitutions> for &MonomorphicType {
222 type Output = MonomorphicType;
223
224 fn bitand(self, rhs: Substitutions) -> Self::Output {
225 self.substitute(&rhs)
226 }
227}
228
229impl BitAnd<&Substitutions> for &MonomorphicType {
230 type Output = MonomorphicType;
231
232 fn bitand(self, rhs: &Substitutions) -> Self::Output {
233 self.substitute(rhs)
234 }
235}
236
237impl Tuple {
238 #[must_use]
239 pub const fn unit() -> Tuple {
240 Tuple(vec![])
241 }
242
243 pub fn push(&mut self, value: MonomorphicType) {
244 self.0.push(value);
245 }
246}
247
248impl Display for TVar {
249 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
250 write!(f, "τ{}", self.0)
251 }
252}
253
254impl Display for MonomorphicType {
255 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
256 match self {
257 MonomorphicType::Primitive(p) => write!(f, "{p}"),
258 MonomorphicType::Var(v) => write!(f, "{v}"),
259 MonomorphicType::Fn(input, output) => match input.as_ref() {
260 MonomorphicType::Fn(_, _) => write!(f, "({input}) -> {output}"),
261 _ => write!(f, "{input} -> {output}"),
262 },
263 MonomorphicType::Tuple(Tuple(vec)) => write!(f, "({})", vec.iter().join(", ")),
264 MonomorphicType::Pointer(t) => write!(f, "*{t}"),
265 MonomorphicType::Constant(id) => write!(f, "{id}"),
266 }
267 }
268}
269
270impl Display for PolymorphicType {
271 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
272 if self.bindings.is_empty() {
273 return write!(f, "{}", self.binding_type);
274 }
275 write!(
276 f,
277 "∀{} => {}",
278 self.bindings.iter().join(", "),
279 self.binding_type
280 )
281 }
282}
283
284impl Debug for TVar {
285 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
286 write!(f, "{self}")
287 }
288}
289
290impl Debug for MonomorphicType {
291 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
292 write!(f, "{self}")
293 }
294}
295
296impl Debug for PolymorphicType {
297 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
298 write!(f, "{self}")
299 }
300}