glyph_types/
inference.rs

1//! Type inference engine for Glyph
2
3use crate::{Type, TypeError};
4use std::collections::HashMap;
5
6/// Type inference context
7pub struct InferenceContext {
8    /// Variable bindings
9    bindings: HashMap<String, Type>,
10
11    /// Type variable substitutions
12    substitutions: HashMap<u32, Type>,
13
14    /// Next type variable ID
15    next_type_var: u32,
16}
17
18impl InferenceContext {
19    /// Create a new inference context
20    pub fn new() -> Self {
21        Self {
22            bindings: HashMap::new(),
23            substitutions: HashMap::new(),
24            next_type_var: 0,
25        }
26    }
27
28    /// Create a fresh type variable
29    pub fn fresh_type_var(&mut self) -> Type {
30        let id = self.next_type_var;
31        self.next_type_var += 1;
32        Type::TypeVar(id)
33    }
34
35    /// Add a variable binding
36    pub fn bind_variable(&mut self, name: String, typ: Type) {
37        self.bindings.insert(name, typ);
38    }
39
40    /// Look up a variable's type
41    pub fn lookup_variable(&self, name: &str) -> Option<&Type> {
42        self.bindings.get(name)
43    }
44
45    /// Unify two types
46    pub fn unify(&mut self, t1: &Type, t2: &Type) -> Result<(), TypeError> {
47        let t1 = self.apply_substitutions(t1);
48        let t2 = self.apply_substitutions(t2);
49
50        match (&t1, &t2) {
51            // Same types unify
52            (Type::Int, Type::Int)
53            | (Type::Float, Type::Float)
54            | (Type::Str, Type::Str)
55            | (Type::Bool, Type::Bool)
56            | (Type::Bytes, Type::Bytes)
57            | (Type::Unit, Type::Unit) => Ok(()),
58
59            // Unknown unifies with anything
60            (Type::Unknown, _) | (_, Type::Unknown) => Ok(()),
61
62            // Type variable unification
63            (Type::TypeVar(id), t) | (t, Type::TypeVar(id)) => {
64                if let Type::TypeVar(id2) = t {
65                    if id == id2 {
66                        return Ok(());
67                    }
68                }
69                // Occurs check would go here for full correctness
70                self.substitutions.insert(*id, t.clone());
71                Ok(())
72            }
73
74            // Container types
75            (Type::List(t1), Type::List(t2)) => self.unify(t1, t2),
76            (Type::Optional(t1), Type::Optional(t2)) => self.unify(t1, t2),
77            (Type::Promise(t1), Type::Promise(t2)) => self.unify(t1, t2),
78
79            (Type::Dict(k1, v1), Type::Dict(k2, v2)) => {
80                self.unify(k1, k2)?;
81                self.unify(v1, v2)
82            }
83
84            (Type::Result(ok1, err1), Type::Result(ok2, err2)) => {
85                self.unify(ok1, ok2)?;
86                self.unify(err1, err2)
87            }
88
89            // Function types
90            (
91                Type::Function {
92                    params: p1,
93                    return_type: r1,
94                },
95                Type::Function {
96                    params: p2,
97                    return_type: r2,
98                },
99            ) => {
100                if p1.len() != p2.len() {
101                    return Err(TypeError::ArgumentCountMismatch {
102                        expected: p1.len(),
103                        found: p2.len(),
104                    });
105                }
106
107                for ((_, t1), (_, t2)) in p1.iter().zip(p2.iter()) {
108                    self.unify(t1, t2)?;
109                }
110
111                self.unify(r1, r2)
112            }
113
114            // Type mismatch
115            _ => Err(TypeError::TypeMismatch {
116                expected: t1,
117                found: t2,
118            }),
119        }
120    }
121
122    /// Apply substitutions to a type
123    pub fn apply_substitutions(&self, typ: &Type) -> Type {
124        match typ {
125            Type::TypeVar(id) => {
126                if let Some(substitution) = self.substitutions.get(id) {
127                    self.apply_substitutions(substitution)
128                } else {
129                    typ.clone()
130                }
131            }
132            Type::List(t) => Type::List(Box::new(self.apply_substitutions(t))),
133            Type::Dict(k, v) => Type::Dict(
134                Box::new(self.apply_substitutions(k)),
135                Box::new(self.apply_substitutions(v)),
136            ),
137            Type::Optional(t) => Type::Optional(Box::new(self.apply_substitutions(t))),
138            Type::Promise(t) => Type::Promise(Box::new(self.apply_substitutions(t))),
139            Type::Result(ok, err) => Type::Result(
140                Box::new(self.apply_substitutions(ok)),
141                Box::new(self.apply_substitutions(err)),
142            ),
143            Type::Function {
144                params,
145                return_type,
146            } => Type::Function {
147                params: params
148                    .iter()
149                    .map(|(name, t)| (name.clone(), self.apply_substitutions(t)))
150                    .collect(),
151                return_type: Box::new(self.apply_substitutions(return_type)),
152            },
153            _ => typ.clone(),
154        }
155    }
156}
157
158impl Default for InferenceContext {
159    fn default() -> Self {
160        Self::new()
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_unify_basic_types() {
170        let mut ctx = InferenceContext::new();
171
172        // Same types should unify
173        assert!(ctx.unify(&Type::Int, &Type::Int).is_ok());
174        assert!(ctx.unify(&Type::Str, &Type::Str).is_ok());
175
176        // Different types should not unify
177        assert!(ctx.unify(&Type::Int, &Type::Str).is_err());
178    }
179
180    #[test]
181    fn test_unify_with_unknown() {
182        let mut ctx = InferenceContext::new();
183
184        // Unknown should unify with anything
185        assert!(ctx.unify(&Type::Unknown, &Type::Int).is_ok());
186        assert!(ctx.unify(&Type::Str, &Type::Unknown).is_ok());
187    }
188
189    #[test]
190    fn test_type_variable_substitution() {
191        let mut ctx = InferenceContext::new();
192        let tvar = ctx.fresh_type_var();
193
194        // Unify type variable with concrete type
195        assert!(ctx.unify(&tvar, &Type::Int).is_ok());
196
197        // Apply substitutions should resolve the type variable
198        let resolved = ctx.apply_substitutions(&tvar);
199        assert_eq!(resolved, Type::Int);
200    }
201}