Skip to main content

lisette_semantics/
store.rs

1use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
2use std::cell::Cell;
3
4use syntax::ast::{EnumVariant, Expression, StructFieldDefinition};
5use syntax::program::{Definition, File, Interface, MethodSignatures, Module, ModuleId};
6use syntax::types::{SubstitutionMap, Type, substitute};
7
8pub const ENTRY_MODULE_ID: &str = "_entry_";
9pub const ENTRY_FILE_ID: u32 = 0;
10
11pub struct Store {
12    pub modules: HashMap<String, Module>,
13    pub module_ids: Vec<ModuleId>,
14    /// file ID -> module ID
15    pub files: HashMap<u32, String>,
16    visited_modules: HashSet<String>,
17    /// File ID counter. Starts at 2 because 0 is reserved for entry, 1 for prelude.
18    next_file_id: Cell<u32>,
19}
20
21impl Default for Store {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl Store {
28    pub fn new() -> Self {
29        let prelude_module = Module::new("prelude");
30        let nominal_module = Module::nominal();
31
32        let modules = vec![
33            (prelude_module.id.clone(), prelude_module),
34            (nominal_module.id.clone(), nominal_module),
35        ]
36        .into_iter()
37        .collect();
38
39        let module_ids = vec!["prelude".to_string()];
40
41        Self {
42            files: Default::default(),
43            modules,
44            module_ids,
45            visited_modules: Default::default(),
46            next_file_id: Cell::new(2), // 0 = entry, 1 = prelude
47        }
48    }
49
50    pub fn new_file_id(&self) -> u32 {
51        let id = self.next_file_id.get();
52        self.next_file_id.set(id + 1);
53        id
54    }
55
56    pub fn register_file(&mut self, file_id: u32, module_id: &str) {
57        self.files.insert(file_id, module_id.to_string());
58    }
59
60    pub fn entry_module_id(&self) -> &'static str {
61        ENTRY_MODULE_ID
62    }
63
64    /// Initializes the entry module with reserved file ID 0.
65    pub fn init_entry_module(&mut self) {
66        self.add_module(ENTRY_MODULE_ID);
67        self.register_file(ENTRY_FILE_ID, ENTRY_MODULE_ID);
68    }
69
70    pub fn store_entry_file(&mut self, filename: &str, source: &str, ast: Vec<Expression>) {
71        self.store_file(
72            ENTRY_MODULE_ID,
73            File {
74                id: ENTRY_FILE_ID,
75                module_id: ENTRY_MODULE_ID.to_string(),
76                name: filename.to_string(),
77                source: source.to_string(),
78                items: ast,
79            },
80        );
81    }
82
83    pub fn store_module(&mut self, module_id: &str, files: Vec<File>) {
84        self.mark_visited(module_id);
85        self.add_module(module_id);
86
87        for file in files {
88            self.store_file(module_id, file);
89        }
90    }
91
92    /// Stores a file in the module and registers the file_id -> module_id mapping.
93    /// .d.lis files go to `typedefs`, .lis files go to `files`.
94    pub fn store_file(&mut self, module_id: &str, file: File) {
95        self.files.insert(file.id, module_id.to_string());
96
97        let module = self
98            .get_module_mut(module_id)
99            .expect("module must exist to store file");
100
101        if file.is_d_lis() {
102            module.typedefs.insert(file.id, file);
103        } else {
104            module.files.insert(file.id, file);
105        }
106    }
107
108    pub fn get_file(&self, file_id: u32) -> Option<&File> {
109        let module_id = self.files.get(&file_id)?;
110        let module = self.get_module(module_id)?;
111        module
112            .get_file(file_id)
113            .or_else(|| module.get_typedef_by_id(file_id))
114    }
115
116    pub fn get_file_mut(&mut self, file_id: u32) -> Option<&mut File> {
117        let module_id = self.files.get(&file_id)?.clone();
118        let module = self.modules.get_mut(&module_id)?;
119        module
120            .files
121            .get_mut(&file_id)
122            .or_else(|| module.typedefs.get_mut(&file_id))
123    }
124
125    pub fn get_module(&self, module_id: &str) -> Option<&Module> {
126        self.modules.get(module_id)
127    }
128
129    pub fn has(&self, module_id: &str) -> bool {
130        self.modules.contains_key(module_id)
131    }
132
133    pub fn add_module(&mut self, module_id: &str) {
134        if self.modules.contains_key(module_id) {
135            return;
136        }
137
138        self.modules
139            .insert(module_id.to_string(), Module::new(module_id));
140        self.module_ids.push(module_id.to_string());
141    }
142
143    pub fn get_module_mut(&mut self, module_id: &str) -> Option<&mut Module> {
144        self.modules.get_mut(module_id)
145    }
146
147    pub fn is_visited(&self, module_id: &str) -> bool {
148        self.visited_modules.contains(module_id)
149    }
150
151    pub fn mark_visited(&mut self, module_id: &str) {
152        self.visited_modules.insert(module_id.to_string());
153    }
154
155    pub fn get_definition(&self, qualified_name: &str) -> Option<&Definition> {
156        let (module_name, _) = qualified_name.split_once('.')?;
157
158        self.get_module(module_name)?
159            .definitions
160            .get(qualified_name)
161    }
162
163    pub fn get_enum_variants(&self, qualified_name: &str) -> Option<&[EnumVariant]> {
164        match self.get_definition(qualified_name)? {
165            Definition::Enum { variants, .. } => Some(variants),
166            _ => None,
167        }
168    }
169
170    pub fn get_struct_fields(&self, qualified_name: &str) -> Option<&[StructFieldDefinition]> {
171        match self.get_definition(qualified_name)? {
172            Definition::Struct { fields, .. } => Some(fields),
173            _ => None,
174        }
175    }
176
177    pub fn get_type(&self, qualified_name: &str) -> Option<&Type> {
178        self.get_definition(qualified_name)
179            .map(|definition| definition.ty())
180    }
181
182    pub fn get_interface(&self, qualified_name: &str) -> Option<&Interface> {
183        match self.get_definition(qualified_name)? {
184            Definition::Interface { definition, .. } => Some(definition),
185            _ => None,
186        }
187    }
188
189    pub fn get_own_methods(&self, qualified_name: &str) -> Option<&MethodSignatures> {
190        match self.get_definition(qualified_name)? {
191            Definition::Struct { methods, .. } => Some(methods),
192            Definition::TypeAlias { methods, .. } => Some(methods),
193            Definition::Enum { methods, .. } => Some(methods),
194            Definition::ValueEnum { methods, .. } => Some(methods),
195            _ => None,
196        }
197    }
198
199    pub fn get_all_methods(
200        &self,
201        ty: &Type,
202        trait_bounds: &HashMap<String, Vec<Type>>,
203    ) -> MethodSignatures {
204        let Type::Constructor { id, .. } = ty.strip_refs().resolve() else {
205            return MethodSignatures::default();
206        };
207        let qualified_name = id;
208
209        if let Some(interface) = self.get_interface(&qualified_name) {
210            let mut all_interface_methods = MethodSignatures::default();
211
212            let type_args = ty.get_type_params().unwrap_or_default();
213            let map: SubstitutionMap = interface
214                .generics
215                .iter()
216                .map(|g| g.name.clone())
217                .zip(type_args.iter().cloned())
218                .collect();
219
220            for (name, method_ty) in &interface.methods {
221                let substituted = substitute(method_ty, &map);
222                all_interface_methods.insert(name.clone(), substituted.with_receiver_placeholder());
223            }
224
225            for parent in &interface.parents {
226                for (name, method_ty) in self.get_all_methods(parent, trait_bounds) {
227                    all_interface_methods.insert(name, method_ty);
228                }
229            }
230
231            return all_interface_methods;
232        }
233
234        if let Some(bound_types) = trait_bounds.get(qualified_name.as_str()) {
235            return bound_types
236                .iter()
237                .flat_map(|interface_ty| self.get_all_methods(interface_ty, trait_bounds))
238                .collect();
239        }
240
241        let mut methods = self
242            .get_own_methods(&qualified_name)
243            .cloned()
244            .unwrap_or_default();
245
246        // Type aliases inherit methods from the underlying type.
247        if let Some(Definition::TypeAlias { ty: alias_ty, .. }) =
248            self.get_definition(&qualified_name)
249        {
250            let underlying = match &alias_ty {
251                Type::Forall { body, .. } => body.as_ref(),
252                other => other,
253            };
254            if let Type::Constructor { id: alias_id, .. } = underlying
255                && alias_id.as_str() != qualified_name.as_str()
256            {
257                let alias_ty = alias_ty.clone();
258                for (name, method_ty) in self.get_all_methods(&alias_ty, trait_bounds) {
259                    methods.entry(name).or_insert(method_ty);
260                }
261            }
262        }
263
264        methods
265    }
266
267    pub fn get_methods_from_bounds(
268        &self,
269        qualified_name: &str,
270        trait_bounds: &HashMap<String, Vec<Type>>,
271    ) -> MethodSignatures {
272        if let Some(bound_types) = trait_bounds.get(qualified_name) {
273            return bound_types
274                .iter()
275                .flat_map(|interface_ty| self.get_all_methods(interface_ty, trait_bounds))
276                .collect();
277        }
278        MethodSignatures::default()
279    }
280}