lust/typechecker/
type_env.rs

1use crate::{
2    ast::*,
3    builtins::{self, BuiltinFunction},
4    config::LustConfig,
5    error::{LustError, Result},
6};
7use std::collections::{HashMap, HashSet};
8pub struct TypeEnv {
9    scopes: Vec<HashMap<String, Type>>,
10    refinements: Vec<HashMap<String, Type>>,
11    generic_instances: HashMap<String, HashMap<String, Type>>,
12    functions: HashMap<String, FunctionSignature>,
13    structs: HashMap<String, StructDef>,
14    enums: HashMap<String, EnumDef>,
15    traits: HashMap<String, TraitDef>,
16    type_aliases: HashMap<String, (Vec<String>, Type)>,
17    impls: Vec<ImplBlock>,
18    builtin_types: HashSet<String>,
19}
20
21#[derive(Debug, Clone)]
22pub struct FunctionSignature {
23    pub params: Vec<Type>,
24    pub return_type: Type,
25    pub is_method: bool,
26}
27
28impl TypeEnv {
29    pub fn new() -> Self {
30        Self::with_config(&LustConfig::default())
31    }
32
33    fn register_builtin_function_slice(&mut self, functions: &[BuiltinFunction], span: Span) {
34        for builtin in functions {
35            self.functions
36                .insert(builtin.name.to_string(), builtin.to_signature(span));
37        }
38    }
39
40    pub fn with_config(config: &LustConfig) -> Self {
41        let mut env = Self {
42            scopes: vec![HashMap::new()],
43            refinements: vec![HashMap::new()],
44            generic_instances: HashMap::new(),
45            functions: HashMap::new(),
46            structs: HashMap::new(),
47            enums: HashMap::new(),
48            traits: HashMap::new(),
49            type_aliases: HashMap::new(),
50            impls: Vec::new(),
51            builtin_types: HashSet::new(),
52        };
53        env.register_builtins(config);
54        env
55    }
56
57    fn register_builtins(&mut self, config: &LustConfig) {
58        let dummy_span = Span::new(0, 0, 0, 0);
59        self.register_builtin_type("Task");
60        self.register_builtin_type("TaskStatus");
61        self.register_builtin_type("TaskInfo");
62        let task_status_type = Type::new(TypeKind::Named("TaskStatus".to_string()), dummy_span);
63        let unknown_type = Type::new(TypeKind::Unknown, dummy_span);
64        let option_unknown_type =
65            Type::new(TypeKind::Option(Box::new(unknown_type.clone())), dummy_span);
66        let string_type = Type::new(TypeKind::String, dummy_span);
67        let option_string_type =
68            Type::new(TypeKind::Option(Box::new(string_type.clone())), dummy_span);
69        let task_info_struct = StructDef {
70            name: "TaskInfo".to_string(),
71            type_params: vec![],
72            trait_bounds: vec![],
73            fields: vec![
74                StructField {
75                    name: "state".to_string(),
76                    ty: task_status_type.clone(),
77                    visibility: Visibility::Public,
78                    ownership: FieldOwnership::Strong,
79                    weak_target: None,
80                },
81                StructField {
82                    name: "last_yield".to_string(),
83                    ty: option_unknown_type.clone(),
84                    visibility: Visibility::Public,
85                    ownership: FieldOwnership::Strong,
86                    weak_target: None,
87                },
88                StructField {
89                    name: "last_result".to_string(),
90                    ty: option_unknown_type.clone(),
91                    visibility: Visibility::Public,
92                    ownership: FieldOwnership::Strong,
93                    weak_target: None,
94                },
95                StructField {
96                    name: "error".to_string(),
97                    ty: option_string_type.clone(),
98                    visibility: Visibility::Public,
99                    ownership: FieldOwnership::Strong,
100                    weak_target: None,
101                },
102            ],
103            visibility: Visibility::Public,
104        };
105        self.structs
106            .insert("TaskInfo".to_string(), task_info_struct);
107        self.register_builtin_function_slice(builtins::base_functions(), dummy_span);
108        self.register_builtin_function_slice(builtins::task_functions(), dummy_span);
109        if let Some(global_scope) = self.scopes.first_mut() {
110            global_scope.insert("task".to_string(), Type::new(TypeKind::Unknown, dummy_span));
111        }
112
113        let task_status_enum = EnumDef {
114            name: "TaskStatus".to_string(),
115            type_params: vec![],
116            trait_bounds: vec![],
117            variants: vec![
118                EnumVariant {
119                    name: "Ready".to_string(),
120                    fields: None,
121                },
122                EnumVariant {
123                    name: "Running".to_string(),
124                    fields: None,
125                },
126                EnumVariant {
127                    name: "Yielded".to_string(),
128                    fields: None,
129                },
130                EnumVariant {
131                    name: "Completed".to_string(),
132                    fields: None,
133                },
134                EnumVariant {
135                    name: "Failed".to_string(),
136                    fields: None,
137                },
138                EnumVariant {
139                    name: "Stopped".to_string(),
140                    fields: None,
141                },
142            ],
143            visibility: Visibility::Public,
144        };
145        self.enums
146            .insert("TaskStatus".to_string(), task_status_enum);
147        if config.is_module_enabled("io") {
148            if let Some(global_scope) = self.scopes.first_mut() {
149                global_scope.insert("io".to_string(), Type::new(TypeKind::Unknown, dummy_span));
150            }
151
152            self.register_builtin_function_slice(builtins::io_functions(), dummy_span);
153        }
154
155        if config.is_module_enabled("os") {
156            if let Some(global_scope) = self.scopes.first_mut() {
157                global_scope.insert("os".to_string(), Type::new(TypeKind::Unknown, dummy_span));
158            }
159
160            self.register_builtin_function_slice(builtins::os_functions(), dummy_span);
161        }
162
163        let option_enum = EnumDef {
164            name: "Option".to_string(),
165            type_params: vec!["T".to_string()],
166            trait_bounds: vec![],
167            variants: vec![
168                EnumVariant {
169                    name: "Some".to_string(),
170                    fields: Some(vec![Type::new(
171                        TypeKind::Generic("T".to_string()),
172                        dummy_span,
173                    )]),
174                },
175                EnumVariant {
176                    name: "None".to_string(),
177                    fields: None,
178                },
179            ],
180            visibility: Visibility::Public,
181        };
182        self.enums.insert("Option".to_string(), option_enum);
183        let result_enum = EnumDef {
184            name: "Result".to_string(),
185            type_params: vec!["T".to_string(), "E".to_string()],
186            trait_bounds: vec![],
187            variants: vec![
188                EnumVariant {
189                    name: "Ok".to_string(),
190                    fields: Some(vec![Type::new(
191                        TypeKind::Generic("T".to_string()),
192                        dummy_span,
193                    )]),
194                },
195                EnumVariant {
196                    name: "Err".to_string(),
197                    fields: Some(vec![Type::new(
198                        TypeKind::Generic("E".to_string()),
199                        dummy_span,
200                    )]),
201                },
202            ],
203            visibility: Visibility::Public,
204        };
205        self.enums.insert("Result".to_string(), result_enum);
206        let to_string_trait = TraitDef {
207            name: "ToString".to_string(),
208            type_params: vec![],
209            methods: vec![TraitMethod {
210                name: "to_string".to_string(),
211                type_params: vec![],
212                params: vec![FunctionParam {
213                    name: "self".to_string(),
214                    ty: Type::new(TypeKind::Unknown, dummy_span),
215                    is_self: true,
216                }],
217                return_type: Some(Type::new(TypeKind::String, dummy_span)),
218                default_impl: None,
219            }],
220            visibility: Visibility::Public,
221        };
222        self.traits.insert("ToString".to_string(), to_string_trait);
223        let hashable_trait = TraitDef {
224            name: "Hashable".to_string(),
225            type_params: vec![],
226            methods: vec![TraitMethod {
227                name: "hash".to_string(),
228                type_params: vec![],
229                params: vec![FunctionParam {
230                    name: "self".to_string(),
231                    ty: Type::new(TypeKind::Unknown, dummy_span),
232                    is_self: true,
233                }],
234                return_type: Some(Type::new(TypeKind::Int, dummy_span)),
235                default_impl: None,
236            }],
237            visibility: Visibility::Public,
238        };
239        self.traits.insert("Hashable".to_string(), hashable_trait);
240        let int_hashable_impl = ImplBlock {
241            type_params: vec![],
242            trait_name: Some("Hashable".to_string()),
243            target_type: Type::new(TypeKind::Int, dummy_span),
244            methods: vec![],
245            where_clause: vec![],
246        };
247        self.impls.push(int_hashable_impl);
248        let float_hashable_impl = ImplBlock {
249            type_params: vec![],
250            trait_name: Some("Hashable".to_string()),
251            target_type: Type::new(TypeKind::Float, dummy_span),
252            methods: vec![],
253            where_clause: vec![],
254        };
255        self.impls.push(float_hashable_impl);
256        let bool_hashable_impl = ImplBlock {
257            type_params: vec![],
258            trait_name: Some("Hashable".to_string()),
259            target_type: Type::new(TypeKind::Bool, dummy_span),
260            methods: vec![],
261            where_clause: vec![],
262        };
263        self.impls.push(bool_hashable_impl);
264        let string_hashable_impl = ImplBlock {
265            type_params: vec![],
266            trait_name: Some("Hashable".to_string()),
267            target_type: Type::new(TypeKind::String, dummy_span),
268            methods: vec![],
269            where_clause: vec![],
270        };
271        self.impls.push(string_hashable_impl);
272        let int_to_string_impl = ImplBlock {
273            type_params: vec![],
274            trait_name: Some("ToString".to_string()),
275            target_type: Type::new(TypeKind::Int, dummy_span),
276            methods: vec![],
277            where_clause: vec![],
278        };
279        self.impls.push(int_to_string_impl);
280        let float_to_string_impl = ImplBlock {
281            type_params: vec![],
282            trait_name: Some("ToString".to_string()),
283            target_type: Type::new(TypeKind::Float, dummy_span),
284            methods: vec![],
285            where_clause: vec![],
286        };
287        self.impls.push(float_to_string_impl);
288        let bool_to_string_impl = ImplBlock {
289            type_params: vec![],
290            trait_name: Some("ToString".to_string()),
291            target_type: Type::new(TypeKind::Bool, dummy_span),
292            methods: vec![],
293            where_clause: vec![],
294        };
295        self.impls.push(bool_to_string_impl);
296        let string_to_string_impl = ImplBlock {
297            type_params: vec![],
298            trait_name: Some("ToString".to_string()),
299            target_type: Type::new(TypeKind::String, dummy_span),
300            methods: vec![],
301            where_clause: vec![],
302        };
303        self.impls.push(string_to_string_impl);
304    }
305
306    fn register_builtin_type(&mut self, name: &str) {
307        self.builtin_types.insert(name.to_string());
308    }
309
310    pub fn is_builtin_type(&self, name: &str) -> bool {
311        self.builtin_types.contains(name)
312    }
313
314    pub fn push_scope(&mut self) {
315        self.scopes.push(HashMap::new());
316        self.refinements.push(HashMap::new());
317    }
318
319    pub fn pop_scope(&mut self) {
320        self.scopes.pop();
321        self.refinements.pop();
322    }
323
324    pub fn declare_variable(&mut self, name: String, ty: Type) -> Result<()> {
325        let scope = self
326            .scopes
327            .last_mut()
328            .expect("Type environment has no scope");
329        if scope.contains_key(&name) {
330            return Err(LustError::TypeError {
331                message: format!("Variable '{}' is already declared in this scope", name),
332            });
333        }
334
335        scope.insert(name, ty);
336        Ok(())
337    }
338
339    pub fn lookup_variable(&self, name: &str) -> Option<Type> {
340        for refinement_scope in self.refinements.iter().rev() {
341            if let Some(ty) = refinement_scope.get(name) {
342                return Some(ty.clone());
343            }
344        }
345
346        for scope in self.scopes.iter().rev() {
347            if let Some(ty) = scope.get(name) {
348                return Some(ty.clone());
349            }
350        }
351
352        None
353    }
354
355    pub fn refine_variable_type(&mut self, name: String, refined_type: Type) {
356        if let Some(refinement_scope) = self.refinements.last_mut() {
357            refinement_scope.insert(name, refined_type);
358        }
359    }
360
361    pub fn record_generic_instance(
362        &mut self,
363        var_name: String,
364        type_param: String,
365        concrete_type: Type,
366    ) {
367        self.generic_instances
368            .entry(var_name)
369            .or_insert_with(HashMap::new)
370            .insert(type_param, concrete_type);
371    }
372
373    pub fn lookup_generic_param(&self, var_name: &str, type_param: &str) -> Option<Type> {
374        self.generic_instances
375            .get(var_name)?
376            .get(type_param)
377            .cloned()
378    }
379
380    pub fn register_function(&mut self, name: String, sig: FunctionSignature) -> Result<()> {
381        if self.functions.contains_key(&name) {
382            return Err(LustError::TypeError {
383                message: format!("Function '{}' is already defined", name),
384            });
385        }
386
387        self.functions.insert(name, sig);
388        Ok(())
389    }
390
391    pub fn lookup_function(&self, name: &str) -> Option<&FunctionSignature> {
392        self.functions.get(name)
393    }
394
395    pub fn function_signatures(&self) -> HashMap<String, FunctionSignature> {
396        self.functions.clone()
397    }
398
399    pub fn struct_definitions(&self) -> HashMap<String, StructDef> {
400        self.structs.clone()
401    }
402
403    pub fn enum_definitions(&self) -> HashMap<String, EnumDef> {
404        self.enums.clone()
405    }
406
407    pub fn register_struct(&mut self, s: &StructDef) -> Result<()> {
408        if self.structs.contains_key(&s.name) {
409            return Err(LustError::TypeError {
410                message: format!("Struct '{}' is already defined", s.name),
411            });
412        }
413
414        self.structs.insert(s.name.clone(), s.clone());
415        Ok(())
416    }
417
418    pub fn lookup_struct(&self, name: &str) -> Option<&StructDef> {
419        self.structs.get(name)
420    }
421
422    pub fn register_enum(&mut self, e: &EnumDef) -> Result<()> {
423        if self.enums.contains_key(&e.name) {
424            return Err(LustError::TypeError {
425                message: format!("Enum '{}' is already defined", e.name),
426            });
427        }
428
429        self.enums.insert(e.name.clone(), e.clone());
430        Ok(())
431    }
432
433    pub fn lookup_enum(&self, name: &str) -> Option<&EnumDef> {
434        self.enums.get(name)
435    }
436
437    pub fn register_trait(&mut self, t: &TraitDef) -> Result<()> {
438        if self.traits.contains_key(&t.name) {
439            return Err(LustError::TypeError {
440                message: format!("Trait '{}' is already defined", t.name),
441            });
442        }
443
444        self.traits.insert(t.name.clone(), t.clone());
445        Ok(())
446    }
447
448    pub fn lookup_trait(&self, name: &str) -> Option<&TraitDef> {
449        self.traits.get(name)
450    }
451
452    pub fn register_type_alias(
453        &mut self,
454        name: String,
455        type_params: Vec<String>,
456        target: Type,
457    ) -> Result<()> {
458        if self.type_aliases.contains_key(&name) {
459            return Err(LustError::TypeError {
460                message: format!("Type alias '{}' is already defined", name),
461            });
462        }
463
464        self.type_aliases.insert(name, (type_params, target));
465        Ok(())
466    }
467
468    pub fn register_impl(&mut self, impl_block: &ImplBlock) {
469        self.impls.push(impl_block.clone());
470    }
471
472    pub fn lookup_method(&self, type_name: &str, method_name: &str) -> Option<&FunctionDef> {
473        for impl_block in &self.impls {
474            if let TypeKind::Named(name) = &impl_block.target_type.kind {
475                if name == type_name {
476                    for method in &impl_block.methods {
477                        if method.name.ends_with(&format!(":{}", method_name))
478                            || method.name == method_name
479                        {
480                            return Some(method);
481                        }
482                    }
483                }
484            }
485        }
486
487        None
488    }
489
490    pub fn lookup_struct_field(&self, struct_name: &str, field_name: &str) -> Option<Type> {
491        let struct_def = self.lookup_struct(struct_name)?;
492        for field in &struct_def.fields {
493            if field.name == field_name {
494                return Some(field.ty.clone());
495            }
496        }
497
498        None
499    }
500
501    pub fn type_implements_trait(&self, ty: &Type, trait_name: &str) -> bool {
502        for impl_block in &self.impls {
503            if let Some(impl_trait_name) = &impl_block.trait_name {
504                if impl_trait_name == trait_name {
505                    if self.types_match(&impl_block.target_type, ty) {
506                        return true;
507                    }
508                }
509            }
510        }
511
512        false
513    }
514
515    fn types_match(&self, type1: &Type, type2: &Type) -> bool {
516        match (&type1.kind, &type2.kind) {
517            (TypeKind::Int, TypeKind::Int) => true,
518            (TypeKind::Float, TypeKind::Float) => true,
519            (TypeKind::Bool, TypeKind::Bool) => true,
520            (TypeKind::String, TypeKind::String) => true,
521            (TypeKind::Named(n1), TypeKind::Named(n2)) => n1 == n2,
522            (TypeKind::Array(t1), TypeKind::Array(t2)) => self.types_match(t1, t2),
523            (TypeKind::Map(k1, v1), TypeKind::Map(k2, v2)) => {
524                self.types_match(k1, k2) && self.types_match(v1, v2)
525            }
526
527            (TypeKind::Option(t1), TypeKind::Option(t2)) => self.types_match(t1, t2),
528            (TypeKind::Result(ok1, err1), TypeKind::Result(ok2, err2)) => {
529                self.types_match(ok1, ok2) && self.types_match(err1, err2)
530            }
531
532            _ => false,
533        }
534    }
535}