Skip to main content

lust/typechecker/
type_env.rs

1use crate::{
2    ast::*,
3    builtins::{self, BuiltinFunction},
4    config::LustConfig,
5    error::{LustError, Result},
6};
7use alloc::{
8    boxed::Box,
9    format,
10    string::{String, ToString},
11    vec,
12    vec::Vec,
13};
14use core::fmt;
15use hashbrown::{HashMap, HashSet};
16pub struct TypeEnv {
17    scopes: Vec<HashMap<String, Type>>,
18    refinements: Vec<HashMap<String, Type>>,
19    generic_instances: HashMap<String, HashMap<String, Type>>,
20    functions: HashMap<String, FunctionSignature>,
21    structs: HashMap<String, StructDef>,
22    enums: HashMap<String, EnumDef>,
23    traits: HashMap<String, TraitDef>,
24    type_aliases: HashMap<String, (Vec<String>, Type)>,
25    impls: Vec<ImplBlock>,
26    builtin_types: HashSet<String>,
27    constants: HashMap<String, Type>,
28}
29
30#[derive(Debug, Clone)]
31pub struct FunctionSignature {
32    pub params: Vec<Type>,
33    pub return_type: Type,
34    pub is_method: bool,
35}
36
37impl fmt::Display for FunctionSignature {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        let params = self
40            .params
41            .iter()
42            .map(|param| param.to_string())
43            .collect::<Vec<_>>()
44            .join(", ");
45        write!(f, "function({}): {}", params, self.return_type)
46    }
47}
48
49impl TypeEnv {
50    pub fn new() -> Self {
51        Self::with_config(&LustConfig::default())
52    }
53
54    fn register_builtin_function_slice(&mut self, functions: &[BuiltinFunction], span: Span) {
55        for builtin in functions {
56            self.functions
57                .insert(builtin.name.to_string(), builtin.to_signature(span));
58        }
59    }
60
61    pub fn with_config(config: &LustConfig) -> Self {
62        let mut env = Self {
63            scopes: vec![HashMap::new()],
64            refinements: vec![HashMap::new()],
65            generic_instances: HashMap::new(),
66            functions: HashMap::new(),
67            structs: HashMap::new(),
68            enums: HashMap::new(),
69            traits: HashMap::new(),
70            type_aliases: HashMap::new(),
71            impls: Vec::new(),
72            builtin_types: HashSet::new(),
73            constants: HashMap::new(),
74        };
75        env.register_builtins(config);
76        env
77    }
78
79    fn register_builtins(&mut self, config: &LustConfig) {
80        let dummy_span = Span::new(0, 0, 0, 0);
81        self.register_builtin_type("Task");
82        self.register_builtin_type("TaskStatus");
83        self.register_builtin_type("TaskInfo");
84        self.register_builtin_type("Iterator");
85        self.register_builtin_type("LuaValue");
86        self.register_builtin_type("LuaTable");
87        self.register_builtin_type("LuaFunction");
88        self.register_builtin_type("LuaUserdata");
89        self.register_builtin_type("LuaThread");
90        let task_status_type = Type::new(TypeKind::Named("TaskStatus".to_string()), dummy_span);
91        let unknown_type = Type::new(TypeKind::Unknown, dummy_span);
92        let option_unknown_type =
93            Type::new(TypeKind::Option(Box::new(unknown_type.clone())), dummy_span);
94        let string_type = Type::new(TypeKind::String, dummy_span);
95        let option_string_type =
96            Type::new(TypeKind::Option(Box::new(string_type.clone())), dummy_span);
97        let task_info_struct = StructDef {
98            name: "TaskInfo".to_string(),
99            type_params: vec![],
100            trait_bounds: vec![],
101            fields: vec![
102                StructField {
103                    name: "state".to_string(),
104                    ty: task_status_type.clone(),
105                    visibility: Visibility::Public,
106                    ownership: FieldOwnership::Strong,
107                    weak_target: None,
108                },
109                StructField {
110                    name: "last_yield".to_string(),
111                    ty: option_unknown_type.clone(),
112                    visibility: Visibility::Public,
113                    ownership: FieldOwnership::Strong,
114                    weak_target: None,
115                },
116                StructField {
117                    name: "last_result".to_string(),
118                    ty: option_unknown_type.clone(),
119                    visibility: Visibility::Public,
120                    ownership: FieldOwnership::Strong,
121                    weak_target: None,
122                },
123                StructField {
124                    name: "error".to_string(),
125                    ty: option_string_type.clone(),
126                    visibility: Visibility::Public,
127                    ownership: FieldOwnership::Strong,
128                    weak_target: None,
129                },
130            ],
131            visibility: Visibility::Public,
132        };
133        self.structs
134            .insert("TaskInfo".to_string(), task_info_struct);
135        self.structs.insert(
136            "LuaTable".to_string(),
137            StructDef {
138                name: "LuaTable".to_string(),
139                type_params: vec![],
140                trait_bounds: vec![],
141                fields: vec![
142                    StructField {
143                        name: "table".to_string(),
144                        ty: Type::new(
145                            TypeKind::Map(
146                                Box::new(Type::new(
147                                    TypeKind::Named("LuaValue".to_string()),
148                                    dummy_span,
149                                )),
150                                Box::new(Type::new(
151                                    TypeKind::Named("LuaValue".to_string()),
152                                    dummy_span,
153                                )),
154                            ),
155                            dummy_span,
156                        ),
157                        visibility: Visibility::Public,
158                        ownership: FieldOwnership::Strong,
159                        weak_target: None,
160                    },
161                    StructField {
162                        name: "metamethods".to_string(),
163                        ty: Type::new(
164                            TypeKind::Map(
165                                Box::new(Type::new(TypeKind::String, dummy_span)),
166                                Box::new(Type::new(
167                                    TypeKind::Named("LuaValue".to_string()),
168                                    dummy_span,
169                                )),
170                            ),
171                            dummy_span,
172                        ),
173                        visibility: Visibility::Public,
174                        ownership: FieldOwnership::Strong,
175                        weak_target: None,
176                    },
177                ],
178                visibility: Visibility::Public,
179            },
180        );
181        for name in ["LuaFunction", "LuaThread"] {
182            self.structs.insert(
183                name.to_string(),
184                StructDef {
185                    name: name.to_string(),
186                    type_params: vec![],
187                    trait_bounds: vec![],
188                    fields: vec![StructField {
189                        name: "handle".to_string(),
190                        ty: Type::new(TypeKind::Int, dummy_span),
191                        visibility: Visibility::Public,
192                        ownership: FieldOwnership::Strong,
193                        weak_target: None,
194                    }],
195                    visibility: Visibility::Public,
196                },
197            );
198        }
199        self.structs.insert(
200            "LuaUserdata".to_string(),
201            StructDef {
202                name: "LuaUserdata".to_string(),
203                type_params: vec![],
204                trait_bounds: vec![],
205                fields: vec![
206                    StructField {
207                        name: "handle".to_string(),
208                        ty: Type::new(TypeKind::Int, dummy_span),
209                        visibility: Visibility::Public,
210                        ownership: FieldOwnership::Strong,
211                        weak_target: None,
212                    },
213                    StructField {
214                        name: "ptr".to_string(),
215                        ty: Type::new(TypeKind::Int, dummy_span),
216                        visibility: Visibility::Public,
217                        ownership: FieldOwnership::Strong,
218                        weak_target: None,
219                    },
220                    StructField {
221                        name: "state".to_string(),
222                        ty: Type::new(TypeKind::Int, dummy_span),
223                        visibility: Visibility::Public,
224                        ownership: FieldOwnership::Strong,
225                        weak_target: None,
226                    },
227                    StructField {
228                        name: "metamethods".to_string(),
229                        ty: Type::new(
230                            TypeKind::Map(
231                                Box::new(Type::new(TypeKind::String, dummy_span)),
232                                Box::new(Type::new(
233                                    TypeKind::Named("LuaValue".to_string()),
234                                    dummy_span,
235                                )),
236                            ),
237                            dummy_span,
238                        ),
239                        visibility: Visibility::Public,
240                        ownership: FieldOwnership::Strong,
241                        weak_target: None,
242                    },
243                ],
244                visibility: Visibility::Public,
245            },
246        );
247        self.register_builtin_function_slice(builtins::base_functions(), dummy_span);
248        self.register_builtin_function_slice(builtins::task_functions(), dummy_span);
249        self.register_builtin_function_slice(builtins::lua_functions(), dummy_span);
250        if let Some(global_scope) = self.scopes.first_mut() {
251            global_scope.insert("task".to_string(), Type::new(TypeKind::Unknown, dummy_span));
252            global_scope.insert("lua".to_string(), Type::new(TypeKind::Unknown, dummy_span));
253        }
254
255        let task_status_enum = EnumDef {
256            name: "TaskStatus".to_string(),
257            type_params: vec![],
258            trait_bounds: vec![],
259            variants: vec![
260                EnumVariant {
261                    name: "Ready".to_string(),
262                    fields: None,
263                },
264                EnumVariant {
265                    name: "Running".to_string(),
266                    fields: None,
267                },
268                EnumVariant {
269                    name: "Yielded".to_string(),
270                    fields: None,
271                },
272                EnumVariant {
273                    name: "Completed".to_string(),
274                    fields: None,
275                },
276                EnumVariant {
277                    name: "Failed".to_string(),
278                    fields: None,
279                },
280                EnumVariant {
281                    name: "Stopped".to_string(),
282                    fields: None,
283                },
284            ],
285            visibility: Visibility::Public,
286        };
287        self.enums
288            .insert("TaskStatus".to_string(), task_status_enum);
289        let lua_value_enum = EnumDef {
290            name: "LuaValue".to_string(),
291            type_params: vec![],
292            trait_bounds: vec![],
293            variants: vec![
294                EnumVariant {
295                    name: "Nil".to_string(),
296                    fields: None,
297                },
298                EnumVariant {
299                    name: "Bool".to_string(),
300                    fields: Some(vec![Type::new(TypeKind::Bool, dummy_span)]),
301                },
302                EnumVariant {
303                    name: "Int".to_string(),
304                    fields: Some(vec![Type::new(TypeKind::Int, dummy_span)]),
305                },
306                EnumVariant {
307                    name: "Float".to_string(),
308                    fields: Some(vec![Type::new(TypeKind::Float, dummy_span)]),
309                },
310                EnumVariant {
311                    name: "String".to_string(),
312                    fields: Some(vec![Type::new(TypeKind::String, dummy_span)]),
313                },
314                EnumVariant {
315                    name: "Table".to_string(),
316                    fields: Some(vec![Type::new(
317                        TypeKind::Named("LuaTable".to_string()),
318                        dummy_span,
319                    )]),
320                },
321                EnumVariant {
322                    name: "Function".to_string(),
323                    fields: Some(vec![Type::new(
324                        TypeKind::Named("LuaFunction".to_string()),
325                        dummy_span,
326                    )]),
327                },
328                EnumVariant {
329                    name: "Userdata".to_string(),
330                    fields: Some(vec![Type::new(
331                        TypeKind::Named("LuaUserdata".to_string()),
332                        dummy_span,
333                    )]),
334                },
335                EnumVariant {
336                    name: "Thread".to_string(),
337                    fields: Some(vec![Type::new(
338                        TypeKind::Named("LuaThread".to_string()),
339                        dummy_span,
340                    )]),
341                },
342                EnumVariant {
343                    name: "LightUserdata".to_string(),
344                    fields: Some(vec![Type::new(TypeKind::Unknown, dummy_span)]),
345                },
346            ],
347            visibility: Visibility::Public,
348        };
349        self.enums.insert("LuaValue".to_string(), lua_value_enum);
350        if config.is_module_enabled("io") {
351            if let Some(global_scope) = self.scopes.first_mut() {
352                global_scope.insert("io".to_string(), Type::new(TypeKind::Unknown, dummy_span));
353            }
354
355            self.register_builtin_function_slice(builtins::io_functions(), dummy_span);
356        }
357
358        if config.is_module_enabled("string") {
359            if let Some(global_scope) = self.scopes.first_mut() {
360                global_scope
361                    .insert("string".to_string(), Type::new(TypeKind::Unknown, dummy_span));
362            }
363
364            self.register_builtin_function_slice(builtins::string_functions(), dummy_span);
365        }
366
367        if config.is_module_enabled("os") {
368            if let Some(global_scope) = self.scopes.first_mut() {
369                global_scope.insert("os".to_string(), Type::new(TypeKind::Unknown, dummy_span));
370            }
371
372            self.register_builtin_function_slice(builtins::os_functions(), dummy_span);
373        }
374
375        let option_enum = EnumDef {
376            name: "Option".to_string(),
377            type_params: vec!["T".to_string()],
378            trait_bounds: vec![],
379            variants: vec![
380                EnumVariant {
381                    name: "Some".to_string(),
382                    fields: Some(vec![Type::new(
383                        TypeKind::Generic("T".to_string()),
384                        dummy_span,
385                    )]),
386                },
387                EnumVariant {
388                    name: "None".to_string(),
389                    fields: None,
390                },
391            ],
392            visibility: Visibility::Public,
393        };
394        self.enums.insert("Option".to_string(), option_enum);
395        let result_enum = EnumDef {
396            name: "Result".to_string(),
397            type_params: vec!["T".to_string(), "E".to_string()],
398            trait_bounds: vec![],
399            variants: vec![
400                EnumVariant {
401                    name: "Ok".to_string(),
402                    fields: Some(vec![Type::new(
403                        TypeKind::Generic("T".to_string()),
404                        dummy_span,
405                    )]),
406                },
407                EnumVariant {
408                    name: "Err".to_string(),
409                    fields: Some(vec![Type::new(
410                        TypeKind::Generic("E".to_string()),
411                        dummy_span,
412                    )]),
413                },
414            ],
415            visibility: Visibility::Public,
416        };
417        self.enums.insert("Result".to_string(), result_enum);
418        let to_string_trait = TraitDef {
419            name: "ToString".to_string(),
420            type_params: vec![],
421            methods: vec![TraitMethod {
422                name: "to_string".to_string(),
423                type_params: vec![],
424                params: vec![FunctionParam {
425                    name: "self".to_string(),
426                    ty: Type::new(TypeKind::Unknown, dummy_span),
427                    is_self: true,
428                }],
429                return_type: Some(Type::new(TypeKind::String, dummy_span)),
430                default_impl: None,
431            }],
432            visibility: Visibility::Public,
433        };
434        self.traits.insert("ToString".to_string(), to_string_trait);
435        let hash_key_trait = TraitDef {
436            name: "HashKey".to_string(),
437            type_params: vec![],
438            methods: vec![TraitMethod {
439                name: "to_hashkey".to_string(),
440                type_params: vec![],
441                params: vec![FunctionParam {
442                    name: "self".to_string(),
443                    ty: Type::new(TypeKind::Unknown, dummy_span),
444                    is_self: true,
445                }],
446                return_type: Some(Type::new(TypeKind::Unknown, dummy_span)),
447                default_impl: None,
448            }],
449            visibility: Visibility::Public,
450        };
451        self.traits.insert("HashKey".to_string(), hash_key_trait);
452        let int_to_string_impl = ImplBlock {
453            type_params: vec![],
454            trait_name: Some("ToString".to_string()),
455            target_type: Type::new(TypeKind::Int, dummy_span),
456            methods: vec![],
457            where_clause: vec![],
458        };
459        self.impls.push(int_to_string_impl);
460        let float_to_string_impl = ImplBlock {
461            type_params: vec![],
462            trait_name: Some("ToString".to_string()),
463            target_type: Type::new(TypeKind::Float, dummy_span),
464            methods: vec![],
465            where_clause: vec![],
466        };
467        self.impls.push(float_to_string_impl);
468        let bool_to_string_impl = ImplBlock {
469            type_params: vec![],
470            trait_name: Some("ToString".to_string()),
471            target_type: Type::new(TypeKind::Bool, dummy_span),
472            methods: vec![],
473            where_clause: vec![],
474        };
475        self.impls.push(bool_to_string_impl);
476        let string_to_string_impl = ImplBlock {
477            type_params: vec![],
478            trait_name: Some("ToString".to_string()),
479            target_type: Type::new(TypeKind::String, dummy_span),
480            methods: vec![],
481            where_clause: vec![],
482        };
483        self.impls.push(string_to_string_impl);
484        let luavalue_to_string_impl = ImplBlock {
485            type_params: vec![],
486            trait_name: Some("ToString".to_string()),
487            target_type: Type::new(TypeKind::Named("LuaValue".to_string()), dummy_span),
488            methods: vec![],
489            where_clause: vec![],
490        };
491        self.impls.push(luavalue_to_string_impl);
492    }
493
494    fn register_builtin_type(&mut self, name: &str) {
495        self.builtin_types.insert(name.to_string());
496    }
497
498    pub fn is_builtin_type(&self, name: &str) -> bool {
499        self.builtin_types.contains(name)
500    }
501
502    pub fn push_scope(&mut self) {
503        self.scopes.push(HashMap::new());
504        self.refinements.push(HashMap::new());
505    }
506
507    pub fn pop_scope(&mut self) {
508        self.scopes.pop();
509        self.refinements.pop();
510    }
511
512    pub fn declare_variable(&mut self, name: String, ty: Type) -> Result<()> {
513        let scope = self
514            .scopes
515            .last_mut()
516            .expect("Type environment has no scope");
517        scope.insert(name, ty);
518        Ok(())
519    }
520
521    pub fn lookup_variable(&self, name: &str) -> Option<Type> {
522        for refinement_scope in self.refinements.iter().rev() {
523            if let Some(ty) = refinement_scope.get(name) {
524                return Some(ty.clone());
525            }
526        }
527
528        for scope in self.scopes.iter().rev() {
529            if let Some(ty) = scope.get(name) {
530                return Some(ty.clone());
531            }
532        }
533
534        None
535    }
536
537    pub fn refine_variable_type(&mut self, name: String, refined_type: Type) {
538        if let Some(refinement_scope) = self.refinements.last_mut() {
539            refinement_scope.insert(name, refined_type);
540        }
541    }
542
543    pub fn register_constant(&mut self, name: String, ty: Type) -> Result<()> {
544        if let Some(existing) = self.constants.get(&name) {
545            if existing != &ty {
546                return Err(LustError::TypeError {
547                    message: format!(
548                        "Constant '{}' is already defined with a different type",
549                        name
550                    ),
551                });
552            }
553
554            return Ok(());
555        }
556
557        self.constants.insert(name, ty);
558        Ok(())
559    }
560
561    pub fn lookup_constant(&self, name: &str) -> Option<Type> {
562        self.constants.get(name).cloned()
563    }
564
565    pub fn record_generic_instance(
566        &mut self,
567        var_name: String,
568        type_param: String,
569        concrete_type: Type,
570    ) {
571        self.generic_instances
572            .entry(var_name)
573            .or_insert_with(HashMap::new)
574            .insert(type_param, concrete_type);
575    }
576
577    pub fn lookup_generic_param(&self, var_name: &str, type_param: &str) -> Option<Type> {
578        self.generic_instances
579            .get(var_name)?
580            .get(type_param)
581            .cloned()
582    }
583
584    pub fn register_function(&mut self, name: String, sig: FunctionSignature) -> Result<()> {
585        if self.functions.contains_key(&name) {
586            return Err(LustError::TypeError {
587                message: format!("Function '{}' is already defined", name),
588            });
589        }
590
591        self.functions.insert(name, sig);
592        Ok(())
593    }
594
595    pub fn register_or_update_function(
596        &mut self,
597        name: String,
598        sig: FunctionSignature,
599    ) -> Result<()> {
600        if let Some(existing) = self.functions.get_mut(&name) {
601            if existing.params != sig.params || existing.return_type != sig.return_type {
602                return Err(LustError::TypeError {
603                    message: format!(
604                        "Function '{}' is already defined with a different signature",
605                        name
606                    ),
607                });
608            }
609
610            if sig.is_method && !existing.is_method {
611                existing.is_method = true;
612            }
613            return Ok(());
614        }
615
616        self.functions.insert(name, sig);
617        Ok(())
618    }
619
620    pub fn lookup_function(&self, name: &str) -> Option<&FunctionSignature> {
621        self.functions.get(name)
622    }
623
624    pub fn function_signatures(&self) -> HashMap<String, FunctionSignature> {
625        self.functions.clone()
626    }
627
628    pub fn struct_definitions(&self) -> HashMap<String, StructDef> {
629        self.structs.clone()
630    }
631
632    pub fn enum_definitions(&self) -> HashMap<String, EnumDef> {
633        self.enums.clone()
634    }
635
636    pub fn register_struct(&mut self, s: &StructDef) -> Result<()> {
637        if self.structs.contains_key(&s.name) {
638            return Err(LustError::TypeError {
639                message: format!("Struct '{}' is already defined", s.name),
640            });
641        }
642
643        self.structs.insert(s.name.clone(), s.clone());
644        Ok(())
645    }
646
647    pub fn lookup_struct(&self, name: &str) -> Option<&StructDef> {
648        self.structs.get(name)
649    }
650
651    pub fn register_enum(&mut self, e: &EnumDef) -> Result<()> {
652        if self.enums.contains_key(&e.name) {
653            return Err(LustError::TypeError {
654                message: format!("Enum '{}' is already defined", e.name),
655            });
656        }
657
658        self.enums.insert(e.name.clone(), e.clone());
659        Ok(())
660    }
661
662    pub fn lookup_enum(&self, name: &str) -> Option<&EnumDef> {
663        self.enums.get(name)
664    }
665
666    pub fn register_trait(&mut self, t: &TraitDef) -> Result<()> {
667        if self.traits.contains_key(&t.name) {
668            return Err(LustError::TypeError {
669                message: format!("Trait '{}' is already defined", t.name),
670            });
671        }
672
673        self.traits.insert(t.name.clone(), t.clone());
674        Ok(())
675    }
676
677    pub fn lookup_trait(&self, name: &str) -> Option<&TraitDef> {
678        self.traits.get(name)
679    }
680
681    pub fn register_type_alias(
682        &mut self,
683        name: String,
684        type_params: Vec<String>,
685        target: Type,
686    ) -> Result<()> {
687        if self.type_aliases.contains_key(&name) {
688            return Err(LustError::TypeError {
689                message: format!("Type alias '{}' is already defined", name),
690            });
691        }
692
693        self.type_aliases.insert(name, (type_params, target));
694        Ok(())
695    }
696
697    pub fn register_impl(&mut self, impl_block: &ImplBlock) {
698        self.impls.push(impl_block.clone());
699    }
700
701    pub fn lookup_method(&self, type_name: &str, method_name: &str) -> Option<&FunctionDef> {
702        for impl_block in &self.impls {
703            if let TypeKind::Named(name) = &impl_block.target_type.kind {
704                if name == type_name {
705                    for method in &impl_block.methods {
706                        if method.name.ends_with(&format!(":{}", method_name))
707                            || method.name == method_name
708                        {
709                            return Some(method);
710                        }
711                    }
712                }
713            }
714        }
715
716        None
717    }
718
719    pub fn lookup_struct_field(&self, struct_name: &str, field_name: &str) -> Option<Type> {
720        let struct_def = self.lookup_struct(struct_name)?;
721        for field in &struct_def.fields {
722            if field.name == field_name {
723                return Some(field.ty.clone());
724            }
725        }
726
727        None
728    }
729
730    pub fn type_implements_trait(&self, ty: &Type, trait_name: &str) -> bool {
731        for impl_block in &self.impls {
732            if let Some(impl_trait_name) = &impl_block.trait_name {
733                if impl_trait_name == trait_name {
734                    if self.types_match(&impl_block.target_type, ty) {
735                        return true;
736                    }
737                }
738            }
739        }
740
741        false
742    }
743
744    fn types_match(&self, type1: &Type, type2: &Type) -> bool {
745        match (&type1.kind, &type2.kind) {
746            (TypeKind::Int, TypeKind::Int) => true,
747            (TypeKind::Float, TypeKind::Float) => true,
748            (TypeKind::Bool, TypeKind::Bool) => true,
749            (TypeKind::String, TypeKind::String) => true,
750            (TypeKind::Named(n1), TypeKind::Named(n2)) => n1 == n2,
751            (TypeKind::Array(t1), TypeKind::Array(t2)) => self.types_match(t1, t2),
752            (TypeKind::Map(k1, v1), TypeKind::Map(k2, v2)) => {
753                self.types_match(k1, k2) && self.types_match(v1, v2)
754            }
755
756            (TypeKind::Option(t1), TypeKind::Option(t2)) => self.types_match(t1, t2),
757            (TypeKind::Result(ok1, err1), TypeKind::Result(ok2, err2)) => {
758                self.types_match(ok1, ok2) && self.types_match(err1, err2)
759            }
760
761            _ => false,
762        }
763    }
764}