lust/typechecker/
type_env.rs

1use crate::{
2    ast::*,
3    builtins::{self, BuiltinFunction},
4    config::LustConfig,
5    error::{LustError, Result},
6};
7use alloc::{
8    boxed::Box,
9    format,
10    string::{String, ToString},
11    vec,
12    vec::Vec,
13};
14use hashbrown::{HashMap, HashSet};
15pub struct TypeEnv {
16    scopes: Vec<HashMap<String, Type>>,
17    refinements: Vec<HashMap<String, Type>>,
18    generic_instances: HashMap<String, HashMap<String, Type>>,
19    functions: HashMap<String, FunctionSignature>,
20    structs: HashMap<String, StructDef>,
21    enums: HashMap<String, EnumDef>,
22    traits: HashMap<String, TraitDef>,
23    type_aliases: HashMap<String, (Vec<String>, Type)>,
24    impls: Vec<ImplBlock>,
25    builtin_types: HashSet<String>,
26}
27
28#[derive(Debug, Clone)]
29pub struct FunctionSignature {
30    pub params: Vec<Type>,
31    pub return_type: Type,
32    pub is_method: bool,
33}
34
35impl TypeEnv {
36    pub fn new() -> Self {
37        Self::with_config(&LustConfig::default())
38    }
39
40    fn register_builtin_function_slice(&mut self, functions: &[BuiltinFunction], span: Span) {
41        for builtin in functions {
42            self.functions
43                .insert(builtin.name.to_string(), builtin.to_signature(span));
44        }
45    }
46
47    pub fn with_config(config: &LustConfig) -> Self {
48        let mut env = Self {
49            scopes: vec![HashMap::new()],
50            refinements: vec![HashMap::new()],
51            generic_instances: HashMap::new(),
52            functions: HashMap::new(),
53            structs: HashMap::new(),
54            enums: HashMap::new(),
55            traits: HashMap::new(),
56            type_aliases: HashMap::new(),
57            impls: Vec::new(),
58            builtin_types: HashSet::new(),
59        };
60        env.register_builtins(config);
61        env
62    }
63
64    fn register_builtins(&mut self, config: &LustConfig) {
65        let dummy_span = Span::new(0, 0, 0, 0);
66        self.register_builtin_type("Task");
67        self.register_builtin_type("TaskStatus");
68        self.register_builtin_type("TaskInfo");
69        let task_status_type = Type::new(TypeKind::Named("TaskStatus".to_string()), dummy_span);
70        let unknown_type = Type::new(TypeKind::Unknown, dummy_span);
71        let option_unknown_type =
72            Type::new(TypeKind::Option(Box::new(unknown_type.clone())), dummy_span);
73        let string_type = Type::new(TypeKind::String, dummy_span);
74        let option_string_type =
75            Type::new(TypeKind::Option(Box::new(string_type.clone())), dummy_span);
76        let task_info_struct = StructDef {
77            name: "TaskInfo".to_string(),
78            type_params: vec![],
79            trait_bounds: vec![],
80            fields: vec![
81                StructField {
82                    name: "state".to_string(),
83                    ty: task_status_type.clone(),
84                    visibility: Visibility::Public,
85                    ownership: FieldOwnership::Strong,
86                    weak_target: None,
87                },
88                StructField {
89                    name: "last_yield".to_string(),
90                    ty: option_unknown_type.clone(),
91                    visibility: Visibility::Public,
92                    ownership: FieldOwnership::Strong,
93                    weak_target: None,
94                },
95                StructField {
96                    name: "last_result".to_string(),
97                    ty: option_unknown_type.clone(),
98                    visibility: Visibility::Public,
99                    ownership: FieldOwnership::Strong,
100                    weak_target: None,
101                },
102                StructField {
103                    name: "error".to_string(),
104                    ty: option_string_type.clone(),
105                    visibility: Visibility::Public,
106                    ownership: FieldOwnership::Strong,
107                    weak_target: None,
108                },
109            ],
110            visibility: Visibility::Public,
111        };
112        self.structs
113            .insert("TaskInfo".to_string(), task_info_struct);
114        self.register_builtin_function_slice(builtins::base_functions(), dummy_span);
115        self.register_builtin_function_slice(builtins::task_functions(), dummy_span);
116        if let Some(global_scope) = self.scopes.first_mut() {
117            global_scope.insert("task".to_string(), Type::new(TypeKind::Unknown, dummy_span));
118        }
119
120        let task_status_enum = EnumDef {
121            name: "TaskStatus".to_string(),
122            type_params: vec![],
123            trait_bounds: vec![],
124            variants: vec![
125                EnumVariant {
126                    name: "Ready".to_string(),
127                    fields: None,
128                },
129                EnumVariant {
130                    name: "Running".to_string(),
131                    fields: None,
132                },
133                EnumVariant {
134                    name: "Yielded".to_string(),
135                    fields: None,
136                },
137                EnumVariant {
138                    name: "Completed".to_string(),
139                    fields: None,
140                },
141                EnumVariant {
142                    name: "Failed".to_string(),
143                    fields: None,
144                },
145                EnumVariant {
146                    name: "Stopped".to_string(),
147                    fields: None,
148                },
149            ],
150            visibility: Visibility::Public,
151        };
152        self.enums
153            .insert("TaskStatus".to_string(), task_status_enum);
154        if config.is_module_enabled("io") {
155            if let Some(global_scope) = self.scopes.first_mut() {
156                global_scope.insert("io".to_string(), Type::new(TypeKind::Unknown, dummy_span));
157            }
158
159            self.register_builtin_function_slice(builtins::io_functions(), dummy_span);
160        }
161
162        if config.is_module_enabled("os") {
163            if let Some(global_scope) = self.scopes.first_mut() {
164                global_scope.insert("os".to_string(), Type::new(TypeKind::Unknown, dummy_span));
165            }
166
167            self.register_builtin_function_slice(builtins::os_functions(), dummy_span);
168        }
169
170        let option_enum = EnumDef {
171            name: "Option".to_string(),
172            type_params: vec!["T".to_string()],
173            trait_bounds: vec![],
174            variants: vec![
175                EnumVariant {
176                    name: "Some".to_string(),
177                    fields: Some(vec![Type::new(
178                        TypeKind::Generic("T".to_string()),
179                        dummy_span,
180                    )]),
181                },
182                EnumVariant {
183                    name: "None".to_string(),
184                    fields: None,
185                },
186            ],
187            visibility: Visibility::Public,
188        };
189        self.enums.insert("Option".to_string(), option_enum);
190        let result_enum = EnumDef {
191            name: "Result".to_string(),
192            type_params: vec!["T".to_string(), "E".to_string()],
193            trait_bounds: vec![],
194            variants: vec![
195                EnumVariant {
196                    name: "Ok".to_string(),
197                    fields: Some(vec![Type::new(
198                        TypeKind::Generic("T".to_string()),
199                        dummy_span,
200                    )]),
201                },
202                EnumVariant {
203                    name: "Err".to_string(),
204                    fields: Some(vec![Type::new(
205                        TypeKind::Generic("E".to_string()),
206                        dummy_span,
207                    )]),
208                },
209            ],
210            visibility: Visibility::Public,
211        };
212        self.enums.insert("Result".to_string(), result_enum);
213        let to_string_trait = TraitDef {
214            name: "ToString".to_string(),
215            type_params: vec![],
216            methods: vec![TraitMethod {
217                name: "to_string".to_string(),
218                type_params: vec![],
219                params: vec![FunctionParam {
220                    name: "self".to_string(),
221                    ty: Type::new(TypeKind::Unknown, dummy_span),
222                    is_self: true,
223                }],
224                return_type: Some(Type::new(TypeKind::String, dummy_span)),
225                default_impl: None,
226            }],
227            visibility: Visibility::Public,
228        };
229        self.traits.insert("ToString".to_string(), to_string_trait);
230        let hashable_trait = TraitDef {
231            name: "Hashable".to_string(),
232            type_params: vec![],
233            methods: vec![TraitMethod {
234                name: "hash".to_string(),
235                type_params: vec![],
236                params: vec![FunctionParam {
237                    name: "self".to_string(),
238                    ty: Type::new(TypeKind::Unknown, dummy_span),
239                    is_self: true,
240                }],
241                return_type: Some(Type::new(TypeKind::Int, dummy_span)),
242                default_impl: None,
243            }],
244            visibility: Visibility::Public,
245        };
246        self.traits.insert("Hashable".to_string(), hashable_trait);
247        let int_hashable_impl = ImplBlock {
248            type_params: vec![],
249            trait_name: Some("Hashable".to_string()),
250            target_type: Type::new(TypeKind::Int, dummy_span),
251            methods: vec![],
252            where_clause: vec![],
253        };
254        self.impls.push(int_hashable_impl);
255        let float_hashable_impl = ImplBlock {
256            type_params: vec![],
257            trait_name: Some("Hashable".to_string()),
258            target_type: Type::new(TypeKind::Float, dummy_span),
259            methods: vec![],
260            where_clause: vec![],
261        };
262        self.impls.push(float_hashable_impl);
263        let bool_hashable_impl = ImplBlock {
264            type_params: vec![],
265            trait_name: Some("Hashable".to_string()),
266            target_type: Type::new(TypeKind::Bool, dummy_span),
267            methods: vec![],
268            where_clause: vec![],
269        };
270        self.impls.push(bool_hashable_impl);
271        let string_hashable_impl = ImplBlock {
272            type_params: vec![],
273            trait_name: Some("Hashable".to_string()),
274            target_type: Type::new(TypeKind::String, dummy_span),
275            methods: vec![],
276            where_clause: vec![],
277        };
278        self.impls.push(string_hashable_impl);
279        let int_to_string_impl = ImplBlock {
280            type_params: vec![],
281            trait_name: Some("ToString".to_string()),
282            target_type: Type::new(TypeKind::Int, dummy_span),
283            methods: vec![],
284            where_clause: vec![],
285        };
286        self.impls.push(int_to_string_impl);
287        let float_to_string_impl = ImplBlock {
288            type_params: vec![],
289            trait_name: Some("ToString".to_string()),
290            target_type: Type::new(TypeKind::Float, dummy_span),
291            methods: vec![],
292            where_clause: vec![],
293        };
294        self.impls.push(float_to_string_impl);
295        let bool_to_string_impl = ImplBlock {
296            type_params: vec![],
297            trait_name: Some("ToString".to_string()),
298            target_type: Type::new(TypeKind::Bool, dummy_span),
299            methods: vec![],
300            where_clause: vec![],
301        };
302        self.impls.push(bool_to_string_impl);
303        let string_to_string_impl = ImplBlock {
304            type_params: vec![],
305            trait_name: Some("ToString".to_string()),
306            target_type: Type::new(TypeKind::String, dummy_span),
307            methods: vec![],
308            where_clause: vec![],
309        };
310        self.impls.push(string_to_string_impl);
311    }
312
313    fn register_builtin_type(&mut self, name: &str) {
314        self.builtin_types.insert(name.to_string());
315    }
316
317    pub fn is_builtin_type(&self, name: &str) -> bool {
318        self.builtin_types.contains(name)
319    }
320
321    pub fn push_scope(&mut self) {
322        self.scopes.push(HashMap::new());
323        self.refinements.push(HashMap::new());
324    }
325
326    pub fn pop_scope(&mut self) {
327        self.scopes.pop();
328        self.refinements.pop();
329    }
330
331    pub fn declare_variable(&mut self, name: String, ty: Type) -> Result<()> {
332        let scope = self
333            .scopes
334            .last_mut()
335            .expect("Type environment has no scope");
336        if scope.contains_key(&name) {
337            return Err(LustError::TypeError {
338                message: format!("Variable '{}' is already declared in this scope", name),
339            });
340        }
341
342        scope.insert(name, ty);
343        Ok(())
344    }
345
346    pub fn lookup_variable(&self, name: &str) -> Option<Type> {
347        for refinement_scope in self.refinements.iter().rev() {
348            if let Some(ty) = refinement_scope.get(name) {
349                return Some(ty.clone());
350            }
351        }
352
353        for scope in self.scopes.iter().rev() {
354            if let Some(ty) = scope.get(name) {
355                return Some(ty.clone());
356            }
357        }
358
359        None
360    }
361
362    pub fn refine_variable_type(&mut self, name: String, refined_type: Type) {
363        if let Some(refinement_scope) = self.refinements.last_mut() {
364            refinement_scope.insert(name, refined_type);
365        }
366    }
367
368    pub fn record_generic_instance(
369        &mut self,
370        var_name: String,
371        type_param: String,
372        concrete_type: Type,
373    ) {
374        self.generic_instances
375            .entry(var_name)
376            .or_insert_with(HashMap::new)
377            .insert(type_param, concrete_type);
378    }
379
380    pub fn lookup_generic_param(&self, var_name: &str, type_param: &str) -> Option<Type> {
381        self.generic_instances
382            .get(var_name)?
383            .get(type_param)
384            .cloned()
385    }
386
387    pub fn register_function(&mut self, name: String, sig: FunctionSignature) -> Result<()> {
388        if self.functions.contains_key(&name) {
389            return Err(LustError::TypeError {
390                message: format!("Function '{}' is already defined", name),
391            });
392        }
393
394        self.functions.insert(name, sig);
395        Ok(())
396    }
397
398    pub fn lookup_function(&self, name: &str) -> Option<&FunctionSignature> {
399        self.functions.get(name)
400    }
401
402    pub fn function_signatures(&self) -> HashMap<String, FunctionSignature> {
403        self.functions.clone()
404    }
405
406    pub fn struct_definitions(&self) -> HashMap<String, StructDef> {
407        self.structs.clone()
408    }
409
410    pub fn enum_definitions(&self) -> HashMap<String, EnumDef> {
411        self.enums.clone()
412    }
413
414    pub fn register_struct(&mut self, s: &StructDef) -> Result<()> {
415        if self.structs.contains_key(&s.name) {
416            return Err(LustError::TypeError {
417                message: format!("Struct '{}' is already defined", s.name),
418            });
419        }
420
421        self.structs.insert(s.name.clone(), s.clone());
422        Ok(())
423    }
424
425    pub fn lookup_struct(&self, name: &str) -> Option<&StructDef> {
426        self.structs.get(name)
427    }
428
429    pub fn register_enum(&mut self, e: &EnumDef) -> Result<()> {
430        if self.enums.contains_key(&e.name) {
431            return Err(LustError::TypeError {
432                message: format!("Enum '{}' is already defined", e.name),
433            });
434        }
435
436        self.enums.insert(e.name.clone(), e.clone());
437        Ok(())
438    }
439
440    pub fn lookup_enum(&self, name: &str) -> Option<&EnumDef> {
441        self.enums.get(name)
442    }
443
444    pub fn register_trait(&mut self, t: &TraitDef) -> Result<()> {
445        if self.traits.contains_key(&t.name) {
446            return Err(LustError::TypeError {
447                message: format!("Trait '{}' is already defined", t.name),
448            });
449        }
450
451        self.traits.insert(t.name.clone(), t.clone());
452        Ok(())
453    }
454
455    pub fn lookup_trait(&self, name: &str) -> Option<&TraitDef> {
456        self.traits.get(name)
457    }
458
459    pub fn register_type_alias(
460        &mut self,
461        name: String,
462        type_params: Vec<String>,
463        target: Type,
464    ) -> Result<()> {
465        if self.type_aliases.contains_key(&name) {
466            return Err(LustError::TypeError {
467                message: format!("Type alias '{}' is already defined", name),
468            });
469        }
470
471        self.type_aliases.insert(name, (type_params, target));
472        Ok(())
473    }
474
475    pub fn register_impl(&mut self, impl_block: &ImplBlock) {
476        self.impls.push(impl_block.clone());
477    }
478
479    pub fn lookup_method(&self, type_name: &str, method_name: &str) -> Option<&FunctionDef> {
480        for impl_block in &self.impls {
481            if let TypeKind::Named(name) = &impl_block.target_type.kind {
482                if name == type_name {
483                    for method in &impl_block.methods {
484                        if method.name.ends_with(&format!(":{}", method_name))
485                            || method.name == method_name
486                        {
487                            return Some(method);
488                        }
489                    }
490                }
491            }
492        }
493
494        None
495    }
496
497    pub fn lookup_struct_field(&self, struct_name: &str, field_name: &str) -> Option<Type> {
498        let struct_def = self.lookup_struct(struct_name)?;
499        for field in &struct_def.fields {
500            if field.name == field_name {
501                return Some(field.ty.clone());
502            }
503        }
504
505        None
506    }
507
508    pub fn type_implements_trait(&self, ty: &Type, trait_name: &str) -> bool {
509        for impl_block in &self.impls {
510            if let Some(impl_trait_name) = &impl_block.trait_name {
511                if impl_trait_name == trait_name {
512                    if self.types_match(&impl_block.target_type, ty) {
513                        return true;
514                    }
515                }
516            }
517        }
518
519        false
520    }
521
522    fn types_match(&self, type1: &Type, type2: &Type) -> bool {
523        match (&type1.kind, &type2.kind) {
524            (TypeKind::Int, TypeKind::Int) => true,
525            (TypeKind::Float, TypeKind::Float) => true,
526            (TypeKind::Bool, TypeKind::Bool) => true,
527            (TypeKind::String, TypeKind::String) => true,
528            (TypeKind::Named(n1), TypeKind::Named(n2)) => n1 == n2,
529            (TypeKind::Array(t1), TypeKind::Array(t2)) => self.types_match(t1, t2),
530            (TypeKind::Map(k1, v1), TypeKind::Map(k2, v2)) => {
531                self.types_match(k1, k2) && self.types_match(v1, v2)
532            }
533
534            (TypeKind::Option(t1), TypeKind::Option(t2)) => self.types_match(t1, t2),
535            (TypeKind::Result(ok1, err1), TypeKind::Result(ok2, err2)) => {
536                self.types_match(ok1, ok2) && self.types_match(err1, err2)
537            }
538
539            _ => false,
540        }
541    }
542}