1use crate::{Type, TypeError};
4use std::collections::HashMap;
5
6pub struct InferenceContext {
8 bindings: HashMap<String, Type>,
10
11 substitutions: HashMap<u32, Type>,
13
14 next_type_var: u32,
16}
17
18impl InferenceContext {
19 pub fn new() -> Self {
21 Self {
22 bindings: HashMap::new(),
23 substitutions: HashMap::new(),
24 next_type_var: 0,
25 }
26 }
27
28 pub fn fresh_type_var(&mut self) -> Type {
30 let id = self.next_type_var;
31 self.next_type_var += 1;
32 Type::TypeVar(id)
33 }
34
35 pub fn bind_variable(&mut self, name: String, typ: Type) {
37 self.bindings.insert(name, typ);
38 }
39
40 pub fn lookup_variable(&self, name: &str) -> Option<&Type> {
42 self.bindings.get(name)
43 }
44
45 pub fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
47 let t1 = self.apply_substitutions(t1);
48 let t2 = self.apply_substitutions(t2);
49
50 match (&t1, &t2) {
51 (Type::Int, Type::Int)
53 | (Type::Float, Type::Float)
54 | (Type::Str, Type::Str)
55 | (Type::Bool, Type::Bool)
56 | (Type::Bytes, Type::Bytes)
57 | (Type::Unit, Type::Unit) => Ok(()),
58
59 (Type::Unknown, _) | (_, Type::Unknown) => Ok(()),
61
62 (Type::TypeVar(id), t) | (t, Type::TypeVar(id)) => {
64 if let Type::TypeVar(id2) = t {
65 if id == id2 {
66 return Ok(());
67 }
68 }
69 self.substitutions.insert(*id, t.clone());
71 Ok(())
72 }
73
74 (Type::List(t1), Type::List(t2)) => self.unify(t1, t2),
76 (Type::Optional(t1), Type::Optional(t2)) => self.unify(t1, t2),
77 (Type::Promise(t1), Type::Promise(t2)) => self.unify(t1, t2),
78
79 (Type::Dict(k1, v1), Type::Dict(k2, v2)) => {
80 self.unify(k1, k2)?;
81 self.unify(v1, v2)
82 }
83
84 (Type::Result(ok1, err1), Type::Result(ok2, err2)) => {
85 self.unify(ok1, ok2)?;
86 self.unify(err1, err2)
87 }
88
89 (
91 Type::Function {
92 params: p1,
93 return_type: r1,
94 },
95 Type::Function {
96 params: p2,
97 return_type: r2,
98 },
99 ) => {
100 if p1.len() != p2.len() {
101 return Err(TypeError::ArgumentCountMismatch {
102 expected: p1.len(),
103 found: p2.len(),
104 });
105 }
106
107 for ((_, t1), (_, t2)) in p1.iter().zip(p2.iter()) {
108 self.unify(t1, t2)?;
109 }
110
111 self.unify(r1, r2)
112 }
113
114 _ => Err(TypeError::TypeMismatch {
116 expected: t1,
117 found: t2,
118 }),
119 }
120 }
121
122 pub fn apply_substitutions(&self, typ: &Type) -> Type {
124 match typ {
125 Type::TypeVar(id) => {
126 if let Some(substitution) = self.substitutions.get(id) {
127 self.apply_substitutions(substitution)
128 } else {
129 typ.clone()
130 }
131 }
132 Type::List(t) => Type::List(Box::new(self.apply_substitutions(t))),
133 Type::Dict(k, v) => Type::Dict(
134 Box::new(self.apply_substitutions(k)),
135 Box::new(self.apply_substitutions(v)),
136 ),
137 Type::Optional(t) => Type::Optional(Box::new(self.apply_substitutions(t))),
138 Type::Promise(t) => Type::Promise(Box::new(self.apply_substitutions(t))),
139 Type::Result(ok, err) => Type::Result(
140 Box::new(self.apply_substitutions(ok)),
141 Box::new(self.apply_substitutions(err)),
142 ),
143 Type::Function {
144 params,
145 return_type,
146 } => Type::Function {
147 params: params
148 .iter()
149 .map(|(name, t)| (name.clone(), self.apply_substitutions(t)))
150 .collect(),
151 return_type: Box::new(self.apply_substitutions(return_type)),
152 },
153 _ => typ.clone(),
154 }
155 }
156}
157
158impl Default for InferenceContext {
159 fn default() -> Self {
160 Self::new()
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 #[test]
169 fn test_unify_basic_types() {
170 let mut ctx = InferenceContext::new();
171
172 assert!(ctx.unify(&Type::Int, &Type::Int).is_ok());
174 assert!(ctx.unify(&Type::Str, &Type::Str).is_ok());
175
176 assert!(ctx.unify(&Type::Int, &Type::Str).is_err());
178 }
179
180 #[test]
181 fn test_unify_with_unknown() {
182 let mut ctx = InferenceContext::new();
183
184 assert!(ctx.unify(&Type::Unknown, &Type::Int).is_ok());
186 assert!(ctx.unify(&Type::Str, &Type::Unknown).is_ok());
187 }
188
189 #[test]
190 fn test_type_variable_substitution() {
191 let mut ctx = InferenceContext::new();
192 let tvar = ctx.fresh_type_var();
193
194 assert!(ctx.unify(&tvar, &Type::Int).is_ok());
196
197 let resolved = ctx.apply_substitutions(&tvar);
199 assert_eq!(resolved, Type::Int);
200 }
201}