Skip to main content

lisette_semantics/
store.rs

1use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
2use std::cell::Cell;
3
4use syntax::ast::{EnumVariant, Expression, StructFieldDefinition};
5use syntax::program::{Definition, File, Interface, MethodSignatures, Module, ModuleId};
6use syntax::types::{SubstitutionMap, Type, substitute};
7
8pub const ENTRY_MODULE_ID: &str = "_entry_";
9pub const ENTRY_FILE_ID: u32 = 0;
10
11pub struct Store {
12    pub modules: HashMap<String, Module>,
13    pub module_ids: Vec<ModuleId>,
14    /// file ID -> module ID
15    pub files: HashMap<u32, String>,
16    /// Go module ID -> Go package name, from the typedef `// Package:` directive.
17    /// Present only when the package name differs from the final path segment.
18    pub go_package_names: HashMap<String, String>,
19    visited_modules: HashSet<String>,
20    /// File ID counter. Starts at 2 because 0 is reserved for entry, 1 for prelude.
21    next_file_id: Cell<u32>,
22}
23
24impl Default for Store {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl Store {
31    pub fn new() -> Self {
32        let prelude_module = Module::new("prelude");
33        let nominal_module = Module::nominal();
34
35        let modules = vec![
36            (prelude_module.id.clone(), prelude_module),
37            (nominal_module.id.clone(), nominal_module),
38        ]
39        .into_iter()
40        .collect();
41
42        let module_ids = vec!["prelude".to_string()];
43
44        Self {
45            files: Default::default(),
46            modules,
47            module_ids,
48            go_package_names: Default::default(),
49            visited_modules: Default::default(),
50            next_file_id: Cell::new(2), // 0 = entry, 1 = prelude
51        }
52    }
53
54    pub fn new_file_id(&self) -> u32 {
55        let id = self.next_file_id.get();
56        self.next_file_id.set(id + 1);
57        id
58    }
59
60    pub fn register_file(&mut self, file_id: u32, module_id: &str) {
61        self.files.insert(file_id, module_id.to_string());
62    }
63
64    pub fn entry_module_id(&self) -> &'static str {
65        ENTRY_MODULE_ID
66    }
67
68    /// Initializes the entry module with reserved file ID 0.
69    pub fn init_entry_module(&mut self) {
70        self.add_module(ENTRY_MODULE_ID);
71        self.register_file(ENTRY_FILE_ID, ENTRY_MODULE_ID);
72    }
73
74    pub fn store_entry_file(&mut self, filename: &str, source: &str, ast: Vec<Expression>) {
75        self.store_file(
76            ENTRY_MODULE_ID,
77            File {
78                id: ENTRY_FILE_ID,
79                module_id: ENTRY_MODULE_ID.to_string(),
80                name: filename.to_string(),
81                source: source.to_string(),
82                items: ast,
83            },
84        );
85    }
86
87    pub fn store_module(&mut self, module_id: &str, files: Vec<File>) {
88        self.mark_visited(module_id);
89        self.add_module(module_id);
90
91        for file in files {
92            self.store_file(module_id, file);
93        }
94    }
95
96    /// Stores a file in the module and registers the file_id -> module_id mapping.
97    /// .d.lis files go to `typedefs`, .lis files go to `files`.
98    pub fn store_file(&mut self, module_id: &str, file: File) {
99        self.files.insert(file.id, module_id.to_string());
100
101        let module = self
102            .get_module_mut(module_id)
103            .expect("module must exist to store file");
104
105        if file.is_d_lis() {
106            module.typedefs.insert(file.id, file);
107        } else {
108            module.files.insert(file.id, file);
109        }
110    }
111
112    pub fn get_file(&self, file_id: u32) -> Option<&File> {
113        let module_id = self.files.get(&file_id)?;
114        let module = self.get_module(module_id)?;
115        module
116            .get_file(file_id)
117            .or_else(|| module.get_typedef_by_id(file_id))
118    }
119
120    pub fn get_file_mut(&mut self, file_id: u32) -> Option<&mut File> {
121        let module_id = self.files.get(&file_id)?.clone();
122        let module = self.modules.get_mut(&module_id)?;
123        module
124            .files
125            .get_mut(&file_id)
126            .or_else(|| module.typedefs.get_mut(&file_id))
127    }
128
129    pub fn get_module(&self, module_id: &str) -> Option<&Module> {
130        self.modules.get(module_id)
131    }
132
133    pub fn has(&self, module_id: &str) -> bool {
134        self.modules.contains_key(module_id)
135    }
136
137    pub fn add_module(&mut self, module_id: &str) {
138        if self.modules.contains_key(module_id) {
139            return;
140        }
141
142        self.modules
143            .insert(module_id.to_string(), Module::new(module_id));
144        self.module_ids.push(module_id.to_string());
145    }
146
147    pub fn get_module_mut(&mut self, module_id: &str) -> Option<&mut Module> {
148        self.modules.get_mut(module_id)
149    }
150
151    pub fn is_visited(&self, module_id: &str) -> bool {
152        self.visited_modules.contains(module_id)
153    }
154
155    pub fn mark_visited(&mut self, module_id: &str) {
156        self.visited_modules.insert(module_id.to_string());
157    }
158
159    pub fn get_definition(&self, qualified_name: &str) -> Option<&Definition> {
160        let module_name = self.module_for_qualified_name(qualified_name)?;
161
162        self.get_module(module_name)?
163            .definitions
164            .get(qualified_name)
165    }
166
167    pub fn module_for_qualified_name<'a>(&'a self, qualified_name: &'a str) -> Option<&'a str> {
168        if !qualified_name.starts_with("go:") || !qualified_name.contains('/') {
169            let (module_name, _) = qualified_name.split_once('.')?;
170            return Some(module_name);
171        }
172
173        let mut best: Option<&str> = None;
174        for module_id in self.modules.keys() {
175            if qualified_name.starts_with(module_id.as_str())
176                && qualified_name.as_bytes().get(module_id.len()) == Some(&b'.')
177                && best
178                    .as_ref()
179                    .is_none_or(|prev| module_id.len() > prev.len())
180            {
181                best = Some(module_id.as_str());
182            }
183        }
184        best
185    }
186
187    pub fn get_enum_variants(&self, qualified_name: &str) -> Option<&[EnumVariant]> {
188        match self.get_definition(qualified_name)? {
189            Definition::Enum { variants, .. } => Some(variants),
190            _ => None,
191        }
192    }
193
194    pub fn get_struct_fields(&self, qualified_name: &str) -> Option<&[StructFieldDefinition]> {
195        match self.get_definition(qualified_name)? {
196            Definition::Struct { fields, .. } => Some(fields),
197            _ => None,
198        }
199    }
200
201    pub fn get_type(&self, qualified_name: &str) -> Option<&Type> {
202        self.get_definition(qualified_name)
203            .map(|definition| definition.ty())
204    }
205
206    pub fn get_interface(&self, qualified_name: &str) -> Option<&Interface> {
207        match self.get_definition(qualified_name)? {
208            Definition::Interface { definition, .. } => Some(definition),
209            _ => None,
210        }
211    }
212
213    pub fn get_own_methods(&self, qualified_name: &str) -> Option<&MethodSignatures> {
214        match self.get_definition(qualified_name)? {
215            Definition::Struct { methods, .. } => Some(methods),
216            Definition::TypeAlias { methods, .. } => Some(methods),
217            Definition::Enum { methods, .. } => Some(methods),
218            Definition::ValueEnum { methods, .. } => Some(methods),
219            _ => None,
220        }
221    }
222
223    pub fn get_all_methods(
224        &self,
225        ty: &Type,
226        trait_bounds: &HashMap<String, Vec<Type>>,
227    ) -> MethodSignatures {
228        let Type::Constructor { id, .. } = ty.strip_refs().resolve() else {
229            return MethodSignatures::default();
230        };
231        let qualified_name = id;
232
233        if let Some(interface) = self.get_interface(&qualified_name) {
234            let mut all_interface_methods = MethodSignatures::default();
235
236            let type_args = ty.get_type_params().unwrap_or_default();
237            let map: SubstitutionMap = interface
238                .generics
239                .iter()
240                .map(|g| g.name.clone())
241                .zip(type_args.iter().cloned())
242                .collect();
243
244            for (name, method_ty) in &interface.methods {
245                let substituted = substitute(method_ty, &map);
246                all_interface_methods.insert(name.clone(), substituted.with_receiver_placeholder());
247            }
248
249            for parent in &interface.parents {
250                for (name, method_ty) in self.get_all_methods(parent, trait_bounds) {
251                    all_interface_methods.insert(name, method_ty);
252                }
253            }
254
255            return all_interface_methods;
256        }
257
258        if let Some(bound_types) = trait_bounds.get(qualified_name.as_str()) {
259            return bound_types
260                .iter()
261                .flat_map(|interface_ty| self.get_all_methods(interface_ty, trait_bounds))
262                .collect();
263        }
264
265        let mut methods = self
266            .get_own_methods(&qualified_name)
267            .cloned()
268            .unwrap_or_default();
269
270        // Type aliases inherit methods from the underlying type.
271        if let Some(Definition::TypeAlias { ty: alias_ty, .. }) =
272            self.get_definition(&qualified_name)
273        {
274            let underlying = match &alias_ty {
275                Type::Forall { body, .. } => body.as_ref(),
276                other => other,
277            };
278            if let Type::Constructor { id: alias_id, .. } = underlying
279                && alias_id.as_str() != qualified_name.as_str()
280            {
281                let alias_ty = alias_ty.clone();
282                for (name, method_ty) in self.get_all_methods(&alias_ty, trait_bounds) {
283                    methods.entry(name).or_insert(method_ty);
284                }
285            }
286        }
287
288        methods
289    }
290
291    pub fn get_methods_from_bounds(
292        &self,
293        qualified_name: &str,
294        trait_bounds: &HashMap<String, Vec<Type>>,
295    ) -> MethodSignatures {
296        if let Some(bound_types) = trait_bounds.get(qualified_name) {
297            return bound_types
298                .iter()
299                .flat_map(|interface_ty| self.get_all_methods(interface_ty, trait_bounds))
300                .collect();
301        }
302        MethodSignatures::default()
303    }
304}