lust/typechecker/
type_env.rs

1use crate::{
2    ast::*,
3    builtins::{self, BuiltinFunction},
4    config::LustConfig,
5    error::{LustError, Result},
6};
7use alloc::{
8    boxed::Box,
9    format,
10    string::{String, ToString},
11    vec,
12    vec::Vec,
13};
14use hashbrown::{HashMap, HashSet};
15pub struct TypeEnv {
16    scopes: Vec<HashMap<String, Type>>,
17    refinements: Vec<HashMap<String, Type>>,
18    generic_instances: HashMap<String, HashMap<String, Type>>,
19    functions: HashMap<String, FunctionSignature>,
20    structs: HashMap<String, StructDef>,
21    enums: HashMap<String, EnumDef>,
22    traits: HashMap<String, TraitDef>,
23    type_aliases: HashMap<String, (Vec<String>, Type)>,
24    impls: Vec<ImplBlock>,
25    builtin_types: HashSet<String>,
26}
27
28#[derive(Debug, Clone)]
29pub struct FunctionSignature {
30    pub params: Vec<Type>,
31    pub return_type: Type,
32    pub is_method: bool,
33}
34
35impl TypeEnv {
36    pub fn new() -> Self {
37        Self::with_config(&LustConfig::default())
38    }
39
40    fn register_builtin_function_slice(&mut self, functions: &[BuiltinFunction], span: Span) {
41        for builtin in functions {
42            self.functions
43                .insert(builtin.name.to_string(), builtin.to_signature(span));
44        }
45    }
46
47    pub fn with_config(config: &LustConfig) -> Self {
48        let mut env = Self {
49            scopes: vec![HashMap::new()],
50            refinements: vec![HashMap::new()],
51            generic_instances: HashMap::new(),
52            functions: HashMap::new(),
53            structs: HashMap::new(),
54            enums: HashMap::new(),
55            traits: HashMap::new(),
56            type_aliases: HashMap::new(),
57            impls: Vec::new(),
58            builtin_types: HashSet::new(),
59        };
60        env.register_builtins(config);
61        env
62    }
63
64    fn register_builtins(&mut self, config: &LustConfig) {
65        let dummy_span = Span::new(0, 0, 0, 0);
66        self.register_builtin_type("Task");
67        self.register_builtin_type("TaskStatus");
68        self.register_builtin_type("TaskInfo");
69        let task_status_type = Type::new(TypeKind::Named("TaskStatus".to_string()), dummy_span);
70        let unknown_type = Type::new(TypeKind::Unknown, dummy_span);
71        let option_unknown_type =
72            Type::new(TypeKind::Option(Box::new(unknown_type.clone())), dummy_span);
73        let string_type = Type::new(TypeKind::String, dummy_span);
74        let option_string_type =
75            Type::new(TypeKind::Option(Box::new(string_type.clone())), dummy_span);
76        let task_info_struct = StructDef {
77            name: "TaskInfo".to_string(),
78            type_params: vec![],
79            trait_bounds: vec![],
80            fields: vec![
81                StructField {
82                    name: "state".to_string(),
83                    ty: task_status_type.clone(),
84                    visibility: Visibility::Public,
85                    ownership: FieldOwnership::Strong,
86                    weak_target: None,
87                },
88                StructField {
89                    name: "last_yield".to_string(),
90                    ty: option_unknown_type.clone(),
91                    visibility: Visibility::Public,
92                    ownership: FieldOwnership::Strong,
93                    weak_target: None,
94                },
95                StructField {
96                    name: "last_result".to_string(),
97                    ty: option_unknown_type.clone(),
98                    visibility: Visibility::Public,
99                    ownership: FieldOwnership::Strong,
100                    weak_target: None,
101                },
102                StructField {
103                    name: "error".to_string(),
104                    ty: option_string_type.clone(),
105                    visibility: Visibility::Public,
106                    ownership: FieldOwnership::Strong,
107                    weak_target: None,
108                },
109            ],
110            visibility: Visibility::Public,
111        };
112        self.structs
113            .insert("TaskInfo".to_string(), task_info_struct);
114        self.register_builtin_function_slice(builtins::base_functions(), dummy_span);
115        self.register_builtin_function_slice(builtins::task_functions(), dummy_span);
116        if let Some(global_scope) = self.scopes.first_mut() {
117            global_scope.insert("task".to_string(), Type::new(TypeKind::Unknown, dummy_span));
118        }
119
120        let task_status_enum = EnumDef {
121            name: "TaskStatus".to_string(),
122            type_params: vec![],
123            trait_bounds: vec![],
124            variants: vec![
125                EnumVariant {
126                    name: "Ready".to_string(),
127                    fields: None,
128                },
129                EnumVariant {
130                    name: "Running".to_string(),
131                    fields: None,
132                },
133                EnumVariant {
134                    name: "Yielded".to_string(),
135                    fields: None,
136                },
137                EnumVariant {
138                    name: "Completed".to_string(),
139                    fields: None,
140                },
141                EnumVariant {
142                    name: "Failed".to_string(),
143                    fields: None,
144                },
145                EnumVariant {
146                    name: "Stopped".to_string(),
147                    fields: None,
148                },
149            ],
150            visibility: Visibility::Public,
151        };
152        self.enums
153            .insert("TaskStatus".to_string(), task_status_enum);
154        if config.is_module_enabled("io") {
155            if let Some(global_scope) = self.scopes.first_mut() {
156                global_scope.insert("io".to_string(), Type::new(TypeKind::Unknown, dummy_span));
157            }
158
159            self.register_builtin_function_slice(builtins::io_functions(), dummy_span);
160        }
161
162        if config.is_module_enabled("os") {
163            if let Some(global_scope) = self.scopes.first_mut() {
164                global_scope.insert("os".to_string(), Type::new(TypeKind::Unknown, dummy_span));
165            }
166
167            self.register_builtin_function_slice(builtins::os_functions(), dummy_span);
168        }
169
170        let option_enum = EnumDef {
171            name: "Option".to_string(),
172            type_params: vec!["T".to_string()],
173            trait_bounds: vec![],
174            variants: vec![
175                EnumVariant {
176                    name: "Some".to_string(),
177                    fields: Some(vec![Type::new(
178                        TypeKind::Generic("T".to_string()),
179                        dummy_span,
180                    )]),
181                },
182                EnumVariant {
183                    name: "None".to_string(),
184                    fields: None,
185                },
186            ],
187            visibility: Visibility::Public,
188        };
189        self.enums.insert("Option".to_string(), option_enum);
190        let result_enum = EnumDef {
191            name: "Result".to_string(),
192            type_params: vec!["T".to_string(), "E".to_string()],
193            trait_bounds: vec![],
194            variants: vec![
195                EnumVariant {
196                    name: "Ok".to_string(),
197                    fields: Some(vec![Type::new(
198                        TypeKind::Generic("T".to_string()),
199                        dummy_span,
200                    )]),
201                },
202                EnumVariant {
203                    name: "Err".to_string(),
204                    fields: Some(vec![Type::new(
205                        TypeKind::Generic("E".to_string()),
206                        dummy_span,
207                    )]),
208                },
209            ],
210            visibility: Visibility::Public,
211        };
212        self.enums.insert("Result".to_string(), result_enum);
213        let to_string_trait = TraitDef {
214            name: "ToString".to_string(),
215            type_params: vec![],
216            methods: vec![TraitMethod {
217                name: "to_string".to_string(),
218                type_params: vec![],
219                params: vec![FunctionParam {
220                    name: "self".to_string(),
221                    ty: Type::new(TypeKind::Unknown, dummy_span),
222                    is_self: true,
223                }],
224                return_type: Some(Type::new(TypeKind::String, dummy_span)),
225                default_impl: None,
226            }],
227            visibility: Visibility::Public,
228        };
229        self.traits.insert("ToString".to_string(), to_string_trait);
230        let hashable_trait = TraitDef {
231            name: "Hashable".to_string(),
232            type_params: vec![],
233            methods: vec![TraitMethod {
234                name: "hash".to_string(),
235                type_params: vec![],
236                params: vec![FunctionParam {
237                    name: "self".to_string(),
238                    ty: Type::new(TypeKind::Unknown, dummy_span),
239                    is_self: true,
240                }],
241                return_type: Some(Type::new(TypeKind::Int, dummy_span)),
242                default_impl: None,
243            }],
244            visibility: Visibility::Public,
245        };
246        self.traits.insert("Hashable".to_string(), hashable_trait);
247        let int_hashable_impl = ImplBlock {
248            type_params: vec![],
249            trait_name: Some("Hashable".to_string()),
250            target_type: Type::new(TypeKind::Int, dummy_span),
251            methods: vec![],
252            where_clause: vec![],
253        };
254        self.impls.push(int_hashable_impl);
255        let float_hashable_impl = ImplBlock {
256            type_params: vec![],
257            trait_name: Some("Hashable".to_string()),
258            target_type: Type::new(TypeKind::Float, dummy_span),
259            methods: vec![],
260            where_clause: vec![],
261        };
262        self.impls.push(float_hashable_impl);
263        let bool_hashable_impl = ImplBlock {
264            type_params: vec![],
265            trait_name: Some("Hashable".to_string()),
266            target_type: Type::new(TypeKind::Bool, dummy_span),
267            methods: vec![],
268            where_clause: vec![],
269        };
270        self.impls.push(bool_hashable_impl);
271        let string_hashable_impl = ImplBlock {
272            type_params: vec![],
273            trait_name: Some("Hashable".to_string()),
274            target_type: Type::new(TypeKind::String, dummy_span),
275            methods: vec![],
276            where_clause: vec![],
277        };
278        self.impls.push(string_hashable_impl);
279        let int_to_string_impl = ImplBlock {
280            type_params: vec![],
281            trait_name: Some("ToString".to_string()),
282            target_type: Type::new(TypeKind::Int, dummy_span),
283            methods: vec![],
284            where_clause: vec![],
285        };
286        self.impls.push(int_to_string_impl);
287        let float_to_string_impl = ImplBlock {
288            type_params: vec![],
289            trait_name: Some("ToString".to_string()),
290            target_type: Type::new(TypeKind::Float, dummy_span),
291            methods: vec![],
292            where_clause: vec![],
293        };
294        self.impls.push(float_to_string_impl);
295        let bool_to_string_impl = ImplBlock {
296            type_params: vec![],
297            trait_name: Some("ToString".to_string()),
298            target_type: Type::new(TypeKind::Bool, dummy_span),
299            methods: vec![],
300            where_clause: vec![],
301        };
302        self.impls.push(bool_to_string_impl);
303        let string_to_string_impl = ImplBlock {
304            type_params: vec![],
305            trait_name: Some("ToString".to_string()),
306            target_type: Type::new(TypeKind::String, dummy_span),
307            methods: vec![],
308            where_clause: vec![],
309        };
310        self.impls.push(string_to_string_impl);
311    }
312
313    fn register_builtin_type(&mut self, name: &str) {
314        self.builtin_types.insert(name.to_string());
315    }
316
317    pub fn is_builtin_type(&self, name: &str) -> bool {
318        self.builtin_types.contains(name)
319    }
320
321    pub fn push_scope(&mut self) {
322        self.scopes.push(HashMap::new());
323        self.refinements.push(HashMap::new());
324    }
325
326    pub fn pop_scope(&mut self) {
327        self.scopes.pop();
328        self.refinements.pop();
329    }
330
331    pub fn declare_variable(&mut self, name: String, ty: Type) -> Result<()> {
332        let scope = self
333            .scopes
334            .last_mut()
335            .expect("Type environment has no scope");
336        scope.insert(name, ty);
337        Ok(())
338    }
339
340    pub fn lookup_variable(&self, name: &str) -> Option<Type> {
341        for refinement_scope in self.refinements.iter().rev() {
342            if let Some(ty) = refinement_scope.get(name) {
343                return Some(ty.clone());
344            }
345        }
346
347        for scope in self.scopes.iter().rev() {
348            if let Some(ty) = scope.get(name) {
349                return Some(ty.clone());
350            }
351        }
352
353        None
354    }
355
356    pub fn refine_variable_type(&mut self, name: String, refined_type: Type) {
357        if let Some(refinement_scope) = self.refinements.last_mut() {
358            refinement_scope.insert(name, refined_type);
359        }
360    }
361
362    pub fn record_generic_instance(
363        &mut self,
364        var_name: String,
365        type_param: String,
366        concrete_type: Type,
367    ) {
368        self.generic_instances
369            .entry(var_name)
370            .or_insert_with(HashMap::new)
371            .insert(type_param, concrete_type);
372    }
373
374    pub fn lookup_generic_param(&self, var_name: &str, type_param: &str) -> Option<Type> {
375        self.generic_instances
376            .get(var_name)?
377            .get(type_param)
378            .cloned()
379    }
380
381    pub fn register_function(&mut self, name: String, sig: FunctionSignature) -> Result<()> {
382        if self.functions.contains_key(&name) {
383            return Err(LustError::TypeError {
384                message: format!("Function '{}' is already defined", name),
385            });
386        }
387
388        self.functions.insert(name, sig);
389        Ok(())
390    }
391
392    pub fn lookup_function(&self, name: &str) -> Option<&FunctionSignature> {
393        self.functions.get(name)
394    }
395
396    pub fn function_signatures(&self) -> HashMap<String, FunctionSignature> {
397        self.functions.clone()
398    }
399
400    pub fn struct_definitions(&self) -> HashMap<String, StructDef> {
401        self.structs.clone()
402    }
403
404    pub fn enum_definitions(&self) -> HashMap<String, EnumDef> {
405        self.enums.clone()
406    }
407
408    pub fn register_struct(&mut self, s: &StructDef) -> Result<()> {
409        if self.structs.contains_key(&s.name) {
410            return Err(LustError::TypeError {
411                message: format!("Struct '{}' is already defined", s.name),
412            });
413        }
414
415        self.structs.insert(s.name.clone(), s.clone());
416        Ok(())
417    }
418
419    pub fn lookup_struct(&self, name: &str) -> Option<&StructDef> {
420        self.structs.get(name)
421    }
422
423    pub fn register_enum(&mut self, e: &EnumDef) -> Result<()> {
424        if self.enums.contains_key(&e.name) {
425            return Err(LustError::TypeError {
426                message: format!("Enum '{}' is already defined", e.name),
427            });
428        }
429
430        self.enums.insert(e.name.clone(), e.clone());
431        Ok(())
432    }
433
434    pub fn lookup_enum(&self, name: &str) -> Option<&EnumDef> {
435        self.enums.get(name)
436    }
437
438    pub fn register_trait(&mut self, t: &TraitDef) -> Result<()> {
439        if self.traits.contains_key(&t.name) {
440            return Err(LustError::TypeError {
441                message: format!("Trait '{}' is already defined", t.name),
442            });
443        }
444
445        self.traits.insert(t.name.clone(), t.clone());
446        Ok(())
447    }
448
449    pub fn lookup_trait(&self, name: &str) -> Option<&TraitDef> {
450        self.traits.get(name)
451    }
452
453    pub fn register_type_alias(
454        &mut self,
455        name: String,
456        type_params: Vec<String>,
457        target: Type,
458    ) -> Result<()> {
459        if self.type_aliases.contains_key(&name) {
460            return Err(LustError::TypeError {
461                message: format!("Type alias '{}' is already defined", name),
462            });
463        }
464
465        self.type_aliases.insert(name, (type_params, target));
466        Ok(())
467    }
468
469    pub fn register_impl(&mut self, impl_block: &ImplBlock) {
470        self.impls.push(impl_block.clone());
471    }
472
473    pub fn lookup_method(&self, type_name: &str, method_name: &str) -> Option<&FunctionDef> {
474        for impl_block in &self.impls {
475            if let TypeKind::Named(name) = &impl_block.target_type.kind {
476                if name == type_name {
477                    for method in &impl_block.methods {
478                        if method.name.ends_with(&format!(":{}", method_name))
479                            || method.name == method_name
480                        {
481                            return Some(method);
482                        }
483                    }
484                }
485            }
486        }
487
488        None
489    }
490
491    pub fn lookup_struct_field(&self, struct_name: &str, field_name: &str) -> Option<Type> {
492        let struct_def = self.lookup_struct(struct_name)?;
493        for field in &struct_def.fields {
494            if field.name == field_name {
495                return Some(field.ty.clone());
496            }
497        }
498
499        None
500    }
501
502    pub fn type_implements_trait(&self, ty: &Type, trait_name: &str) -> bool {
503        for impl_block in &self.impls {
504            if let Some(impl_trait_name) = &impl_block.trait_name {
505                if impl_trait_name == trait_name {
506                    if self.types_match(&impl_block.target_type, ty) {
507                        return true;
508                    }
509                }
510            }
511        }
512
513        false
514    }
515
516    fn types_match(&self, type1: &Type, type2: &Type) -> bool {
517        match (&type1.kind, &type2.kind) {
518            (TypeKind::Int, TypeKind::Int) => true,
519            (TypeKind::Float, TypeKind::Float) => true,
520            (TypeKind::Bool, TypeKind::Bool) => true,
521            (TypeKind::String, TypeKind::String) => true,
522            (TypeKind::Named(n1), TypeKind::Named(n2)) => n1 == n2,
523            (TypeKind::Array(t1), TypeKind::Array(t2)) => self.types_match(t1, t2),
524            (TypeKind::Map(k1, v1), TypeKind::Map(k2, v2)) => {
525                self.types_match(k1, k2) && self.types_match(v1, v2)
526            }
527
528            (TypeKind::Option(t1), TypeKind::Option(t2)) => self.types_match(t1, t2),
529            (TypeKind::Result(ok1, err1), TypeKind::Result(ok2, err2)) => {
530                self.types_match(ok1, ok2) && self.types_match(err1, err2)
531            }
532
533            _ => false,
534        }
535    }
536}