1use std::collections::HashMap;
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub struct TyVar(pub u32);
9
10impl fmt::Display for TyVar {
11 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12 write!(f, "τ{}", self.0)
13 }
14}
15
16#[derive(Debug, Clone, PartialEq)]
18pub enum MonoType {
19 Var(TyVar),
21 Int,
23 Float,
25 Bool,
27 String,
29 Char,
31 Unit,
33 Function(Box<MonoType>, Box<MonoType>),
35 List(Box<MonoType>),
37 Tuple(Vec<MonoType>),
39 Optional(Box<MonoType>),
41 Result(Box<MonoType>, Box<MonoType>),
43 Named(String),
45 Reference(Box<MonoType>),
47 DataFrame(Vec<(String, MonoType)>),
49 Series(Box<MonoType>),
51}
52
53impl fmt::Display for MonoType {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 match self {
56 MonoType::Var(v) => write!(f, "{v}"),
57 MonoType::Int => write!(f, "i32"),
58 MonoType::Float => write!(f, "f64"),
59 MonoType::Bool => write!(f, "bool"),
60 MonoType::String => write!(f, "String"),
61 MonoType::Char => write!(f, "char"),
62 MonoType::Unit => write!(f, "()"),
63 MonoType::Function(arg, ret) => write!(f, "({arg} -> {ret})"),
64 MonoType::List(elem) => write!(f, "[{elem}]"),
65 MonoType::Optional(inner) => write!(f, "{inner}?"),
66 MonoType::Result(ok, err) => write!(f, "Result<{ok}, {err}>"),
67 MonoType::Tuple(types) => {
68 write!(f, "(")?;
69 for (i, ty) in types.iter().enumerate() {
70 if i > 0 {
71 write!(f, ", ")?;
72 }
73 write!(f, "{ty}")?;
74 }
75 write!(f, ")")
76 }
77 MonoType::Named(name) => write!(f, "{name}"),
78 MonoType::Reference(inner) => write!(f, "&{inner}"),
79 MonoType::DataFrame(columns) => {
80 write!(f, "DataFrame[")?;
81 for (i, (name, ty)) in columns.iter().enumerate() {
82 if i > 0 {
83 write!(f, ", ")?;
84 }
85 write!(f, "{name}: {ty}")?;
86 }
87 write!(f, "]")
88 }
89 MonoType::Series(dtype) => write!(f, "Series<{dtype}>"),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct TypeScheme {
97 pub vars: Vec<TyVar>,
99 pub ty: MonoType,
101}
102
103impl TypeScheme {
104 #[must_use]
106 pub fn mono(ty: MonoType) -> Self {
107 TypeScheme {
108 vars: Vec::new(),
109 ty,
110 }
111 }
112
113 pub fn instantiate(&self, gen: &mut TyVarGenerator) -> MonoType {
115 if self.vars.is_empty() {
116 self.ty.clone()
117 } else {
118 let subst: HashMap<TyVar, MonoType> = self
119 .vars
120 .iter()
121 .map(|v| (v.clone(), MonoType::Var(gen.fresh())))
122 .collect();
123 self.ty.substitute(&subst)
124 }
125 }
126}
127
128impl fmt::Display for TypeScheme {
129 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130 if self.vars.is_empty() {
131 write!(f, "{}", self.ty)
132 } else {
133 write!(f, "∀")?;
134 for (i, var) in self.vars.iter().enumerate() {
135 if i > 0 {
136 write!(f, ",")?;
137 }
138 write!(f, "{var}")?;
139 }
140 write!(f, ". {}", self.ty)
141 }
142 }
143}
144
145pub struct TyVarGenerator {
147 next: u32,
148}
149
150impl TyVarGenerator {
151 #[must_use]
152 pub fn new() -> Self {
153 TyVarGenerator { next: 0 }
154 }
155
156 pub fn fresh(&mut self) -> TyVar {
157 let var = TyVar(self.next);
158 self.next += 1;
159 var
160 }
161}
162
163impl Default for TyVarGenerator {
164 fn default() -> Self {
165 Self::new()
166 }
167}
168
169pub type Substitution = HashMap<TyVar, MonoType>;
171
172impl MonoType {
173 #[must_use]
175 pub fn substitute(&self, subst: &Substitution) -> MonoType {
176 match self {
177 MonoType::Var(v) => subst.get(v).cloned().unwrap_or_else(|| self.clone()),
178 MonoType::Function(arg, ret) => MonoType::Function(
179 Box::new(arg.substitute(subst)),
180 Box::new(ret.substitute(subst)),
181 ),
182 MonoType::List(elem) => MonoType::List(Box::new(elem.substitute(subst))),
183 MonoType::Optional(inner) => MonoType::Optional(Box::new(inner.substitute(subst))),
184 MonoType::Result(ok, err) => MonoType::Result(
185 Box::new(ok.substitute(subst)),
186 Box::new(err.substitute(subst)),
187 ),
188 MonoType::DataFrame(columns) => MonoType::DataFrame(
189 columns
190 .iter()
191 .map(|(name, ty)| (name.clone(), ty.substitute(subst)))
192 .collect(),
193 ),
194 MonoType::Series(dtype) => MonoType::Series(Box::new(dtype.substitute(subst))),
195 MonoType::Reference(inner) => MonoType::Reference(Box::new(inner.substitute(subst))),
196 MonoType::Tuple(types) => {
197 MonoType::Tuple(types.iter().map(|ty| ty.substitute(subst)).collect())
198 }
199 _ => self.clone(),
200 }
201 }
202
203 #[must_use]
205 pub fn free_vars(&self) -> Vec<TyVar> {
206 use std::collections::HashSet;
207
208 fn collect_vars(ty: &MonoType, vars: &mut HashSet<TyVar>) {
209 match ty {
210 MonoType::Var(v) => {
211 vars.insert(v.clone());
212 }
213 MonoType::Function(arg, ret) => {
214 collect_vars(arg, vars);
215 collect_vars(ret, vars);
216 }
217 MonoType::List(elem) => collect_vars(elem, vars),
218 MonoType::Optional(inner)
219 | MonoType::Series(inner)
220 | MonoType::Reference(inner) => {
221 collect_vars(inner, vars);
222 }
223 MonoType::Result(ok, err) => {
224 collect_vars(ok, vars);
225 collect_vars(err, vars);
226 }
227 MonoType::DataFrame(columns) => {
228 for (_, ty) in columns {
229 collect_vars(ty, vars);
230 }
231 }
232 MonoType::Tuple(types) => {
233 for ty in types {
234 collect_vars(ty, vars);
235 }
236 }
237 _ => {}
238 }
239 }
240
241 let mut vars = HashSet::new();
242 collect_vars(self, &mut vars);
243 vars.into_iter().collect()
244 }
245}
246
247#[cfg(test)]
248#[allow(clippy::unwrap_used, clippy::panic, clippy::expect_used)]
249#[allow(clippy::unwrap_used, clippy::panic)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_type_display() {
255 assert_eq!(MonoType::Int.to_string(), "i32");
256 assert_eq!(MonoType::Bool.to_string(), "bool");
257 assert_eq!(
258 MonoType::Function(Box::new(MonoType::Int), Box::new(MonoType::Bool)).to_string(),
259 "(i32 -> bool)"
260 );
261 assert_eq!(MonoType::List(Box::new(MonoType::Int)).to_string(), "[i32]");
262 }
263
264 #[test]
265 fn test_type_scheme_instantiation() {
266 let mut gen = TyVarGenerator::new();
267 let var = gen.fresh();
268
269 let scheme = TypeScheme {
270 vars: vec![var.clone()],
271 ty: MonoType::Function(
272 Box::new(MonoType::Var(var.clone())),
273 Box::new(MonoType::Var(var)),
274 ),
275 };
276
277 let instantiated = scheme.instantiate(&mut gen);
278 match instantiated {
279 MonoType::Function(arg, ret) => {
280 assert!(matches!(*arg, MonoType::Var(_)));
281 assert!(matches!(*ret, MonoType::Var(_)));
282 }
283 _ => panic!("Expected function type"),
284 }
285 }
286
287 #[test]
288 fn test_substitution() {
289 let mut subst = HashMap::new();
290 let var = TyVar(0);
291 subst.insert(var.clone(), MonoType::Int);
292
293 let ty = MonoType::List(Box::new(MonoType::Var(var)));
294 let result = ty.substitute(&subst);
295
296 assert_eq!(result, MonoType::List(Box::new(MonoType::Int)));
297 }
298
299 #[test]
300 fn test_free_vars() {
301 let var1 = TyVar(0);
302 let var2 = TyVar(1);
303
304 let ty = MonoType::Function(
305 Box::new(MonoType::Var(var1.clone())),
306 Box::new(MonoType::List(Box::new(MonoType::Var(var2.clone())))),
307 );
308
309 let free = ty.free_vars();
310 assert_eq!(free.len(), 2);
311 assert!(free.contains(&var1));
312 assert!(free.contains(&var2));
313
314 let ty_dup = MonoType::Function(
316 Box::new(MonoType::Var(var1.clone())),
317 Box::new(MonoType::Var(var1.clone())),
318 );
319 let free_dup = ty_dup.free_vars();
320 assert_eq!(free_dup.len(), 1);
321 assert!(free_dup.contains(&var1));
322 }
323}