miden_assembly_syntax/sema/
context.rs1use alloc::{
2 boxed::Box,
3 collections::{BTreeMap, BTreeSet},
4 sync::Arc,
5 vec::Vec,
6};
7
8use miden_debug_types::{SourceFile, SourceManager, SourceSpan, Span, Spanned};
9use miden_utils_diagnostics::{Diagnostic, Severity};
10
11use super::{SemanticAnalysisError, SyntaxError};
12use crate::ast::{
13 constants::{ConstEvalError, eval::CachedConstantValue},
14 *,
15};
16
17pub struct AnalysisContext {
19 constants: BTreeMap<Ident, Constant>,
20 cached_constant_values: BTreeMap<Ident, ConstantValue>,
21 imported: BTreeSet<Ident>,
22 procedures: BTreeSet<ProcedureName>,
23 errors: Vec<SemanticAnalysisError>,
24 source_file: Arc<SourceFile>,
25 source_manager: Arc<dyn SourceManager>,
26 warnings_as_errors: bool,
27}
28
29impl constants::ConstEnvironment for AnalysisContext {
30 type Error = SemanticAnalysisError;
31
32 fn get_source_file_for(&self, span: SourceSpan) -> Option<Arc<SourceFile>> {
33 if span.source_id() == self.source_file.id() {
34 Some(self.source_file.clone())
35 } else {
36 None
37 }
38 }
39 #[inline]
40 fn get(&mut self, name: &Ident) -> Result<Option<CachedConstantValue<'_>>, Self::Error> {
41 if let Some(value) = self.cached_constant_values.get(name) {
42 Ok(Some(CachedConstantValue::Hit(value)))
43 } else if let Some(constant) = self.constants.get(name) {
44 Ok(Some(CachedConstantValue::Miss(&constant.value)))
45 } else if self.imported.contains(name) {
46 Ok(None)
48 } else {
49 Err(ConstEvalError::UndefinedSymbol {
50 symbol: name.clone(),
51 source_file: self.get_source_file_for(name.span()),
52 }
53 .into())
54 }
55 }
56 #[inline(always)]
57 fn get_by_path(
58 &mut self,
59 path: Span<&Path>,
60 ) -> Result<Option<CachedConstantValue<'_>>, Self::Error> {
61 if let Some(name) = path.as_ident() {
62 self.get(&name)
63 } else {
64 Ok(None)
65 }
66 }
67
68 #[inline]
69 fn on_eval_completed(&mut self, name: Span<&Path>, value: &ConstantExpr) {
70 let Some(name) = name.as_ident() else {
71 return;
72 };
73 if let Some(value) = value.as_value() {
74 self.cached_constant_values.insert(name, value);
75 } else {
76 self.cached_constant_values.remove(&name);
77 }
78 }
79}
80
81impl AnalysisContext {
82 pub fn new(source_file: Arc<SourceFile>, source_manager: Arc<dyn SourceManager>) -> Self {
83 Self {
84 constants: Default::default(),
85 cached_constant_values: Default::default(),
86 imported: Default::default(),
87 procedures: Default::default(),
88 errors: Default::default(),
89 source_file,
90 source_manager,
91 warnings_as_errors: false,
92 }
93 }
94
95 pub fn set_warnings_as_errors(&mut self, yes: bool) {
96 self.warnings_as_errors = yes;
97 }
98
99 #[inline(always)]
100 pub fn warnings_as_errors(&self) -> bool {
101 self.warnings_as_errors
102 }
103
104 #[inline(always)]
105 pub fn source_manager(&self) -> Arc<dyn SourceManager> {
106 self.source_manager.clone()
107 }
108
109 pub fn register_procedure_name(&mut self, name: ProcedureName) {
110 self.procedures.insert(name);
111 }
112
113 pub fn register_imported_name(&mut self, name: Ident) {
114 self.imported.insert(name);
115 }
116
117 pub fn define_constant(&mut self, module: &mut Module, constant: Constant) {
121 if let Err(err) = module.define_constant(constant.clone()) {
122 self.errors.push(err);
123 } else {
124 let name = constant.name.clone();
125 self.constants.insert(name, constant);
126 }
127 }
128
129 pub fn register_constant(&mut self, constant: Constant) {
134 let name = constant.name.clone();
135 self.cached_constant_values.remove(&name);
136 if let Some(prev) = self.constants.get(&name) {
137 self.errors.push(SemanticAnalysisError::SymbolConflict {
138 span: constant.span,
139 prev_span: prev.span,
140 });
141 } else {
142 self.constants.insert(name, constant);
143 }
144 }
145
146 pub fn simplify_constants(&mut self) {
150 self.cached_constant_values.clear();
151 let constants = self.constants.keys().cloned().collect::<Vec<_>>();
152
153 for constant in constants.iter() {
154 let expr = ConstantExpr::Var(Span::new(
155 constant.span(),
156 PathBuf::from(constant.clone()).into(),
157 ));
158 match constants::eval::expr(&expr, self) {
159 Ok(value) => {
160 if let Some(cached) = value.as_value() {
161 self.cached_constant_values.insert(constant.clone(), cached);
162 } else {
163 self.cached_constant_values.remove(constant);
164 }
165 self.constants.get_mut(constant).unwrap().value = value;
166 },
167 Err(err) => {
168 self.cached_constant_values.remove(constant);
169 self.errors.push(err);
170 },
171 }
172 }
173 }
174
175 pub fn get_constant(&self, name: &Ident) -> Result<&ConstantExpr, SemanticAnalysisError> {
179 if let Some(expr) = self.constants.get(name) {
180 Ok(&expr.value)
181 } else {
182 Err(SemanticAnalysisError::SymbolResolutionError(Box::new(
183 SymbolResolutionError::undefined(name.span(), &self.source_manager),
184 )))
185 }
186 }
187
188 pub fn error(&mut self, diagnostic: SemanticAnalysisError) {
189 self.errors.push(diagnostic);
190 }
191
192 pub fn has_errors(&self) -> bool {
193 if self.warnings_as_errors() {
194 return !self.errors.is_empty();
195 }
196 self.errors
197 .iter()
198 .any(|err| matches!(err.severity().unwrap_or(Severity::Error), Severity::Error))
199 }
200
201 pub fn has_failed(&mut self) -> Result<(), SyntaxError> {
202 if self.has_errors() {
203 Err(SyntaxError {
204 source_file: self.source_file.clone(),
205 errors: core::mem::take(&mut self.errors),
206 })
207 } else {
208 Ok(())
209 }
210 }
211
212 pub fn into_result(self) -> Result<(), SyntaxError> {
213 if self.has_errors() {
214 Err(SyntaxError {
215 source_file: self.source_file.clone(),
216 errors: self.errors,
217 })
218 } else {
219 self.emit_warnings();
220 Ok(())
221 }
222 }
223
224 #[cfg(feature = "std")]
225 fn emit_warnings(self) {
226 use crate::diagnostics::Report;
227
228 if !self.errors.is_empty() {
229 let warning = Report::from(super::errors::SyntaxWarning {
231 source_file: self.source_file,
232 errors: self.errors,
233 });
234 std::eprintln!("{warning}");
235 }
236 }
237
238 #[cfg(not(feature = "std"))]
239 fn emit_warnings(self) {}
240}
241
242#[cfg(test)]
243mod tests {
244 use alloc::{boxed::Box, string::String, sync::Arc};
245 use core::cell::Cell;
246
247 use super::AnalysisContext;
248 use crate::{
249 Path, PathBuf,
250 ast::{
251 Constant, ConstantExpr, ConstantOp, ConstantValue, Ident, Visibility,
252 constants::{self, eval::CachedConstantValue},
253 },
254 debuginfo::{
255 DefaultSourceManager, SourceContent, SourceLanguage, SourceManager, SourceSpan, Span,
256 Uri,
257 },
258 parser::IntValue,
259 };
260
261 struct CountingEnv<'a> {
262 inner: &'a mut AnalysisContext,
263 hits: Cell<usize>,
264 misses: Cell<usize>,
265 }
266
267 impl<'a> CountingEnv<'a> {
268 fn new(inner: &'a mut AnalysisContext) -> Self {
269 Self {
270 inner,
271 hits: Cell::new(0),
272 misses: Cell::new(0),
273 }
274 }
275
276 fn hits(&self) -> usize {
277 self.hits.get()
278 }
279
280 fn misses(&self) -> usize {
281 self.misses.get()
282 }
283 }
284
285 impl constants::ConstEnvironment for CountingEnv<'_> {
286 type Error = super::SemanticAnalysisError;
287
288 fn get_source_file_for(
289 &self,
290 span: SourceSpan,
291 ) -> Option<Arc<crate::debuginfo::SourceFile>> {
292 <AnalysisContext as constants::ConstEnvironment>::get_source_file_for(self.inner, span)
293 }
294
295 fn get(&mut self, name: &Ident) -> Result<Option<CachedConstantValue<'_>>, Self::Error> {
296 let value = <AnalysisContext as constants::ConstEnvironment>::get(self.inner, name)?;
297 if let Some(ref value) = value {
298 match value {
299 CachedConstantValue::Hit(_) => self.hits.set(self.hits.get() + 1),
300 CachedConstantValue::Miss(_) => self.misses.set(self.misses.get() + 1),
301 }
302 }
303 Ok(value)
304 }
305
306 fn get_by_path(
307 &mut self,
308 path: Span<&Path>,
309 ) -> Result<Option<CachedConstantValue<'_>>, Self::Error> {
310 if let Some(name) = path.as_ident() {
311 self.get(&name)
312 } else {
313 <AnalysisContext as constants::ConstEnvironment>::get_by_path(self.inner, path)
314 }
315 }
316
317 fn on_eval_completed(&mut self, name: Span<&Path>, value: &ConstantExpr) {
318 <AnalysisContext as constants::ConstEnvironment>::on_eval_completed(
319 self.inner, name, value,
320 );
321 }
322 }
323
324 fn make_name(i: usize) -> Ident {
325 format!("C{i:05}").parse().expect("generated constant name must be valid")
326 }
327
328 fn make_ref(name: Ident) -> ConstantExpr {
329 let path = Arc::<Path>::from(PathBuf::from(name));
330 ConstantExpr::Var(Span::new(SourceSpan::default(), path))
331 }
332
333 fn make_shared_subexpression_chain(context: &mut AnalysisContext, depth: usize) {
334 for i in 0..depth {
335 let name = make_name(i);
336 let next = make_name(i + 1);
337 context.register_constant(Constant::new(
338 SourceSpan::default(),
339 Visibility::Public,
340 name,
341 ConstantExpr::BinaryOp {
342 span: SourceSpan::default(),
343 op: ConstantOp::Add,
344 lhs: Box::new(make_ref(next.clone())),
345 rhs: Box::new(make_ref(next)),
346 },
347 ));
348 }
349
350 context.register_constant(Constant::new(
351 SourceSpan::default(),
352 Visibility::Public,
353 make_name(depth),
354 ConstantExpr::Int(Span::new(SourceSpan::default(), IntValue::from(1_u32))),
355 ));
356 }
357
358 #[test]
359 fn semantic_const_eval_memoizes_shared_subexpressions() {
360 let source_manager = Arc::new(DefaultSourceManager::default());
361 let uri =
362 Uri::from(String::from("mem://const-eval-shared-subexpressions").into_boxed_str());
363 let content = SourceContent::new(
364 SourceLanguage::Masm,
365 uri.clone(),
366 String::from("begin\n nop\nend\n").into_boxed_str(),
367 );
368 let source_file = source_manager.load_from_raw_parts(uri, content);
369 let mut context = AnalysisContext::new(source_file, source_manager);
370
371 let depth = 24;
374 make_shared_subexpression_chain(&mut context, depth);
375
376 let root_name = make_name(0);
377 let mut env = CountingEnv::new(&mut context);
378 let root = make_ref(root_name);
379 let result = constants::eval::expr(&root, &mut env)
380 .expect("shared-subexpression constant graph should evaluate");
381
382 assert!(
383 matches!(result.as_value(), Some(ConstantValue::Int(_))),
384 "evaluation should produce a concrete integer constant value"
385 );
386 assert_eq!(env.misses(), depth + 1, "each constant in the chain should miss at most once");
387 assert_eq!(
388 env.hits(),
389 depth,
390 "the second reference to each dependency should be served from cache"
391 );
392 }
393}