Skip to main content

miden_assembly_syntax/sema/
context.rs

1use alloc::{
2    boxed::Box,
3    collections::{BTreeMap, BTreeSet},
4    sync::Arc,
5    vec::Vec,
6};
7
8use miden_debug_types::{SourceFile, SourceManager, SourceSpan, Span, Spanned};
9use miden_utils_diagnostics::{Diagnostic, Severity};
10
11use super::{SemanticAnalysisError, SyntaxError};
12use crate::ast::{
13    constants::{ConstEvalError, eval::CachedConstantValue},
14    *,
15};
16
17/// This maintains the state for semantic analysis of a single [Module].
18pub struct AnalysisContext {
19    constants: BTreeMap<Ident, Constant>,
20    cached_constant_values: BTreeMap<Ident, ConstantValue>,
21    imported: BTreeSet<Ident>,
22    procedures: BTreeSet<ProcedureName>,
23    errors: Vec<SemanticAnalysisError>,
24    source_file: Arc<SourceFile>,
25    source_manager: Arc<dyn SourceManager>,
26    warnings_as_errors: bool,
27}
28
29impl constants::ConstEnvironment for AnalysisContext {
30    type Error = SemanticAnalysisError;
31
32    fn get_source_file_for(&self, span: SourceSpan) -> Option<Arc<SourceFile>> {
33        if span.source_id() == self.source_file.id() {
34            Some(self.source_file.clone())
35        } else {
36            None
37        }
38    }
39    #[inline]
40    fn get(&mut self, name: &Ident) -> Result<Option<CachedConstantValue<'_>>, Self::Error> {
41        if let Some(value) = self.cached_constant_values.get(name) {
42            Ok(Some(CachedConstantValue::Hit(value)))
43        } else if let Some(constant) = self.constants.get(name) {
44            Ok(Some(CachedConstantValue::Miss(&constant.value)))
45        } else if self.imported.contains(name) {
46            // We don't have the definition available yet
47            Ok(None)
48        } else {
49            Err(ConstEvalError::UndefinedSymbol {
50                symbol: name.clone(),
51                source_file: self.get_source_file_for(name.span()),
52            }
53            .into())
54        }
55    }
56    #[inline(always)]
57    fn get_by_path(
58        &mut self,
59        path: Span<&Path>,
60    ) -> Result<Option<CachedConstantValue<'_>>, Self::Error> {
61        if let Some(name) = path.as_ident() {
62            self.get(&name)
63        } else {
64            Ok(None)
65        }
66    }
67
68    #[inline]
69    fn on_eval_completed(&mut self, name: Span<&Path>, value: &ConstantExpr) {
70        let Some(name) = name.as_ident() else {
71            return;
72        };
73        if let Some(value) = value.as_value() {
74            self.cached_constant_values.insert(name, value);
75        } else {
76            self.cached_constant_values.remove(&name);
77        }
78    }
79}
80
81impl AnalysisContext {
82    pub fn new(source_file: Arc<SourceFile>, source_manager: Arc<dyn SourceManager>) -> Self {
83        Self {
84            constants: Default::default(),
85            cached_constant_values: Default::default(),
86            imported: Default::default(),
87            procedures: Default::default(),
88            errors: Default::default(),
89            source_file,
90            source_manager,
91            warnings_as_errors: false,
92        }
93    }
94
95    pub fn set_warnings_as_errors(&mut self, yes: bool) {
96        self.warnings_as_errors = yes;
97    }
98
99    #[inline(always)]
100    pub fn warnings_as_errors(&self) -> bool {
101        self.warnings_as_errors
102    }
103
104    #[inline(always)]
105    pub fn source_manager(&self) -> Arc<dyn SourceManager> {
106        self.source_manager.clone()
107    }
108
109    pub fn register_procedure_name(&mut self, name: ProcedureName) {
110        self.procedures.insert(name);
111    }
112
113    pub fn register_imported_name(&mut self, name: Ident) {
114        self.imported.insert(name);
115    }
116
117    /// Define a new constant `constant`
118    ///
119    /// Returns `Err` if a constant with the same name is already defined
120    pub fn define_constant(&mut self, module: &mut Module, constant: Constant) {
121        if let Err(err) = module.define_constant(constant.clone()) {
122            self.errors.push(err);
123        } else {
124            let name = constant.name.clone();
125            self.constants.insert(name, constant);
126        }
127    }
128
129    /// Register a constant for semantic analysis without defining it in the module.
130    ///
131    /// This is used for enum variants so we can fold discriminants without
132    /// attempting to define the same constant twice.
133    pub fn register_constant(&mut self, constant: Constant) {
134        let name = constant.name.clone();
135        self.cached_constant_values.remove(&name);
136        if let Some(prev) = self.constants.get(&name) {
137            self.errors.push(SemanticAnalysisError::SymbolConflict {
138                span: constant.span,
139                prev_span: prev.span,
140            });
141        } else {
142            self.constants.insert(name, constant);
143        }
144    }
145
146    /// Rewrite all constant declarations by performing const evaluation of their expressions.
147    ///
148    /// This also has the effect of validating that the constant expressions themselves are valid.
149    pub fn simplify_constants(&mut self) {
150        self.cached_constant_values.clear();
151        let constants = self.constants.keys().cloned().collect::<Vec<_>>();
152
153        for constant in constants.iter() {
154            let expr = ConstantExpr::Var(Span::new(
155                constant.span(),
156                PathBuf::from(constant.clone()).into(),
157            ));
158            match constants::eval::expr(&expr, self) {
159                Ok(value) => {
160                    if let Some(cached) = value.as_value() {
161                        self.cached_constant_values.insert(constant.clone(), cached);
162                    } else {
163                        self.cached_constant_values.remove(constant);
164                    }
165                    self.constants.get_mut(constant).unwrap().value = value;
166                },
167                Err(err) => {
168                    self.cached_constant_values.remove(constant);
169                    self.errors.push(err);
170                },
171            }
172        }
173    }
174
175    /// Get the constant value bound to `name`
176    ///
177    /// Returns `Err` if the symbol is undefined
178    pub fn get_constant(&self, name: &Ident) -> Result<&ConstantExpr, SemanticAnalysisError> {
179        if let Some(expr) = self.constants.get(name) {
180            Ok(&expr.value)
181        } else {
182            Err(SemanticAnalysisError::SymbolResolutionError(Box::new(
183                SymbolResolutionError::undefined(name.span(), &self.source_manager),
184            )))
185        }
186    }
187
188    pub fn error(&mut self, diagnostic: SemanticAnalysisError) {
189        self.errors.push(diagnostic);
190    }
191
192    pub fn has_errors(&self) -> bool {
193        if self.warnings_as_errors() {
194            return !self.errors.is_empty();
195        }
196        self.errors
197            .iter()
198            .any(|err| matches!(err.severity().unwrap_or(Severity::Error), Severity::Error))
199    }
200
201    pub fn has_failed(&mut self) -> Result<(), SyntaxError> {
202        if self.has_errors() {
203            Err(SyntaxError {
204                source_file: self.source_file.clone(),
205                errors: core::mem::take(&mut self.errors),
206            })
207        } else {
208            Ok(())
209        }
210    }
211
212    pub fn into_result(self) -> Result<(), SyntaxError> {
213        if self.has_errors() {
214            Err(SyntaxError {
215                source_file: self.source_file.clone(),
216                errors: self.errors,
217            })
218        } else {
219            self.emit_warnings();
220            Ok(())
221        }
222    }
223
224    #[cfg(feature = "std")]
225    fn emit_warnings(self) {
226        use crate::diagnostics::Report;
227
228        if !self.errors.is_empty() {
229            // Emit warnings to stderr
230            let warning = Report::from(super::errors::SyntaxWarning {
231                source_file: self.source_file,
232                errors: self.errors,
233            });
234            std::eprintln!("{warning}");
235        }
236    }
237
238    #[cfg(not(feature = "std"))]
239    fn emit_warnings(self) {}
240}
241
242#[cfg(test)]
243mod tests {
244    use alloc::{boxed::Box, string::String, sync::Arc};
245    use core::cell::Cell;
246
247    use super::AnalysisContext;
248    use crate::{
249        Path, PathBuf,
250        ast::{
251            Constant, ConstantExpr, ConstantOp, ConstantValue, Ident, Visibility,
252            constants::{self, eval::CachedConstantValue},
253        },
254        debuginfo::{
255            DefaultSourceManager, SourceContent, SourceLanguage, SourceManager, SourceSpan, Span,
256            Uri,
257        },
258        parser::IntValue,
259    };
260
261    struct CountingEnv<'a> {
262        inner: &'a mut AnalysisContext,
263        hits: Cell<usize>,
264        misses: Cell<usize>,
265    }
266
267    impl<'a> CountingEnv<'a> {
268        fn new(inner: &'a mut AnalysisContext) -> Self {
269            Self {
270                inner,
271                hits: Cell::new(0),
272                misses: Cell::new(0),
273            }
274        }
275
276        fn hits(&self) -> usize {
277            self.hits.get()
278        }
279
280        fn misses(&self) -> usize {
281            self.misses.get()
282        }
283    }
284
285    impl constants::ConstEnvironment for CountingEnv<'_> {
286        type Error = super::SemanticAnalysisError;
287
288        fn get_source_file_for(
289            &self,
290            span: SourceSpan,
291        ) -> Option<Arc<crate::debuginfo::SourceFile>> {
292            <AnalysisContext as constants::ConstEnvironment>::get_source_file_for(self.inner, span)
293        }
294
295        fn get(&mut self, name: &Ident) -> Result<Option<CachedConstantValue<'_>>, Self::Error> {
296            let value = <AnalysisContext as constants::ConstEnvironment>::get(self.inner, name)?;
297            if let Some(ref value) = value {
298                match value {
299                    CachedConstantValue::Hit(_) => self.hits.set(self.hits.get() + 1),
300                    CachedConstantValue::Miss(_) => self.misses.set(self.misses.get() + 1),
301                }
302            }
303            Ok(value)
304        }
305
306        fn get_by_path(
307            &mut self,
308            path: Span<&Path>,
309        ) -> Result<Option<CachedConstantValue<'_>>, Self::Error> {
310            if let Some(name) = path.as_ident() {
311                self.get(&name)
312            } else {
313                <AnalysisContext as constants::ConstEnvironment>::get_by_path(self.inner, path)
314            }
315        }
316
317        fn on_eval_completed(&mut self, name: Span<&Path>, value: &ConstantExpr) {
318            <AnalysisContext as constants::ConstEnvironment>::on_eval_completed(
319                self.inner, name, value,
320            );
321        }
322    }
323
324    fn make_name(i: usize) -> Ident {
325        format!("C{i:05}").parse().expect("generated constant name must be valid")
326    }
327
328    fn make_ref(name: Ident) -> ConstantExpr {
329        let path = Arc::<Path>::from(PathBuf::from(name));
330        ConstantExpr::Var(Span::new(SourceSpan::default(), path))
331    }
332
333    fn make_shared_subexpression_chain(context: &mut AnalysisContext, depth: usize) {
334        for i in 0..depth {
335            let name = make_name(i);
336            let next = make_name(i + 1);
337            context.register_constant(Constant::new(
338                SourceSpan::default(),
339                Visibility::Public,
340                name,
341                ConstantExpr::BinaryOp {
342                    span: SourceSpan::default(),
343                    op: ConstantOp::Add,
344                    lhs: Box::new(make_ref(next.clone())),
345                    rhs: Box::new(make_ref(next)),
346                },
347            ));
348        }
349
350        context.register_constant(Constant::new(
351            SourceSpan::default(),
352            Visibility::Public,
353            make_name(depth),
354            ConstantExpr::Int(Span::new(SourceSpan::default(), IntValue::from(1_u32))),
355        ));
356    }
357
358    #[test]
359    fn semantic_const_eval_memoizes_shared_subexpressions() {
360        let source_manager = Arc::new(DefaultSourceManager::default());
361        let uri =
362            Uri::from(String::from("mem://const-eval-shared-subexpressions").into_boxed_str());
363        let content = SourceContent::new(
364            SourceLanguage::Masm,
365            uri.clone(),
366            String::from("begin\n    nop\nend\n").into_boxed_str(),
367        );
368        let source_file = source_manager.load_from_raw_parts(uri, content);
369        let mut context = AnalysisContext::new(source_file, source_manager);
370
371        // Each Ci references C(i+1) twice, so without memoization the number of misses would
372        // grow exponentially with depth.
373        let depth = 24;
374        make_shared_subexpression_chain(&mut context, depth);
375
376        let root_name = make_name(0);
377        let mut env = CountingEnv::new(&mut context);
378        let root = make_ref(root_name);
379        let result = constants::eval::expr(&root, &mut env)
380            .expect("shared-subexpression constant graph should evaluate");
381
382        assert!(
383            matches!(result.as_value(), Some(ConstantValue::Int(_))),
384            "evaluation should produce a concrete integer constant value"
385        );
386        assert_eq!(env.misses(), depth + 1, "each constant in the chain should miss at most once");
387        assert_eq!(
388            env.hits(),
389            depth,
390            "the second reference to each dependency should be served from cache"
391        );
392    }
393}