miden-assembly 0.24.1

Miden VM assembly language
Documentation
use alloc::{collections::BTreeSet, sync::Arc, vec::Vec};
use core::ops::ControlFlow;

use miden_assembly_syntax::{
    ast::constants::eval::CachedConstantValue, diagnostics::RelatedError, module::ItemInfo,
    sema::ConstEvalVisitor,
};
use miden_core::Felt;

use crate::{
    ModuleIndex, SourceFile, SourceSpan, Span, Spanned,
    ast::{
        self, InvocationTarget, Invoke, InvokeKind, Procedure, SymbolResolution,
        constants::ConstEnvironment,
        visit::{self, VisitMut},
    },
    linker::{
        LinkerError, Resolver, ResolverCache, SymbolItem, SymbolResolutionContext, SymbolResolver,
    },
};

// MODULE REWRITE CHECK
// ================================================================================================

/// A [ModuleRewriter] handles applying all of the module-wide rewrites to a [Module] that is being
/// added to the module graph of the linker. These rewrites include:
///
/// * Resolving, at least partially, all of the invocation targets in procedures of the module, and
///   rewriting those targets as concretely as possible OR as phantom calls representing procedures
///   referenced by MAST root for which we have no definition.
pub struct ModuleRewriter<'a, 'b: 'a> {
    resolver: &'a SymbolResolver<'b>,
    cache: &'a mut ResolverCache,
    module_id: ModuleIndex,
    invoked: BTreeSet<Invoke>,
}

macro_rules! wrap_const_control_flow {
    ($visitor:ident) => {
        match $visitor.into_result() {
            Ok(()) => return ControlFlow::Continue(()),
            Err(errs) => {
                let errors = errs.into_iter().map(RelatedError::wrap).collect::<Vec<_>>();
                return ControlFlow::Break(LinkerError::Related {
                    errors: errors.into_boxed_slice(),
                });
            },
        }
    };
}

impl<'a, 'b: 'a> ModuleRewriter<'a, 'b> {
    /// Create a new instance of this pass with the given [SymbolResolver]
    pub fn new(
        module: ModuleIndex,
        resolver: &'a SymbolResolver<'b>,
        cache: &'a mut ResolverCache,
    ) -> Self {
        Self {
            resolver,
            cache,
            module_id: module,
            invoked: Default::default(),
        }
    }

    fn invalid_constant_ref(&self, span: SourceSpan) -> LinkerError {
        LinkerError::InvalidConstantRef {
            span,
            source_file: self.get_source_file_for(span),
        }
    }

    fn get_constant_by_gid(
        &mut self,
        gid: ast::GlobalItemIndex,
        span: SourceSpan,
    ) -> Result<Option<CachedConstantValue<'_>>, LinkerError> {
        if self.cache.constants.contains_key(&gid) {
            let cached = self
                .cache
                .constants
                .get(&gid)
                .expect("constant value present in cache must be retrievable");
            return Ok(Some(CachedConstantValue::Hit(cached)));
        }

        match self.resolver.linker()[gid].item() {
            SymbolItem::Compiled(ItemInfo::Constant(info)) => {
                return Ok(Some(CachedConstantValue::Hit(&info.value)));
            },
            SymbolItem::Constant(_) => {
                let mut resolver = Resolver {
                    resolver: self.resolver,
                    cache: self.cache,
                    current_module: gid.module,
                };
                resolver.materialize_constant_by_gid(gid, span)?;
                let cached = self
                    .cache
                    .constants
                    .get(&gid)
                    .expect("constant value inserted into cache must be retrievable");
                return Ok(Some(CachedConstantValue::Hit(cached)));
            },
            SymbolItem::Compiled(_) | SymbolItem::Procedure(_) | SymbolItem::Type(_) => (),
        }

        Err(self.invalid_constant_ref(span))
    }

    fn rewrite_target(
        &mut self,
        kind: InvokeKind,
        target: &mut InvocationTarget,
    ) -> ControlFlow<LinkerError> {
        log::debug!(target: "linker", "    * rewriting {kind} target {target}");
        let context = SymbolResolutionContext {
            span: target.span(),
            module: self.module_id,
            kind: Some(kind),
        };
        match self.resolver.resolve_invoke_target(&context, target) {
            Err(err) => {
                log::error!(target: "linker", "    | failed to resolve target {target}");
                return ControlFlow::Break(err);
            },
            Ok(SymbolResolution::MastRoot(_)) => {
                log::warn!(target: "linker", "    | resolved phantom target {target}");
            },
            Ok(SymbolResolution::Exact { path, .. }) => {
                log::debug!(target: "linker", "    | target resolved to {path}");
                match &mut *target {
                    InvocationTarget::MastRoot(_) => (),
                    InvocationTarget::Path(old_path) => {
                        *old_path = path.with_span(old_path.span());
                    },
                    target @ InvocationTarget::Symbol(_) => {
                        *target = InvocationTarget::Path(path.with_span(target.span()));
                    },
                }
                self.invoked.insert(Invoke { kind, target: target.clone() });
            },
            Ok(SymbolResolution::Module { id, path }) => {
                log::debug!(target: "linker", "    | target resolved to module {id}: '{path}'");
            },
            Ok(SymbolResolution::Local(_) | SymbolResolution::External(_)) => {
                unreachable!("link-time namespace resolution should produce exact ids")
            },
        }

        ControlFlow::Continue(())
    }
}

impl<'a, 'b: 'a> VisitMut<LinkerError> for ModuleRewriter<'a, 'b> {
    fn visit_mut_procedure(&mut self, procedure: &mut Procedure) -> ControlFlow<LinkerError> {
        log::debug!(target: "linker", "  | visiting {}", procedure.name());
        self.invoked.clear();
        self.invoked.extend(procedure.invoked().cloned());
        visit::visit_mut_procedure(self, procedure)?;
        procedure.extend_invoked(core::mem::take(&mut self.invoked));
        ControlFlow::Continue(())
    }
    fn visit_mut_syscall(&mut self, target: &mut InvocationTarget) -> ControlFlow<LinkerError> {
        self.rewrite_target(InvokeKind::SysCall, target)
    }
    fn visit_mut_call(&mut self, target: &mut InvocationTarget) -> ControlFlow<LinkerError> {
        self.rewrite_target(InvokeKind::Call, target)
    }
    fn visit_mut_invoke_target(
        &mut self,
        target: &mut InvocationTarget,
    ) -> ControlFlow<LinkerError> {
        self.rewrite_target(InvokeKind::Exec, target)
    }
    fn visit_mut_immediate_u8(&mut self, imm: &mut ast::Immediate<u8>) -> ControlFlow<LinkerError> {
        let mut visitor = ConstEvalVisitor::new(self);
        let _ = visitor.visit_mut_immediate_u8(imm);
        wrap_const_control_flow!(visitor)
    }
    fn visit_mut_immediate_u16(
        &mut self,
        imm: &mut ast::Immediate<u16>,
    ) -> ControlFlow<LinkerError> {
        let mut visitor = ConstEvalVisitor::new(self);
        let _ = visitor.visit_mut_immediate_u16(imm);
        wrap_const_control_flow!(visitor)
    }
    fn visit_mut_immediate_u32(
        &mut self,
        imm: &mut ast::Immediate<u32>,
    ) -> ControlFlow<LinkerError> {
        let mut visitor = ConstEvalVisitor::new(self);
        let _ = visitor.visit_mut_immediate_u32(imm);
        wrap_const_control_flow!(visitor)
    }
    fn visit_mut_immediate_error_message(
        &mut self,
        imm: &mut ast::Immediate<Arc<str>>,
    ) -> ControlFlow<LinkerError> {
        let mut visitor = ConstEvalVisitor::new(self);
        let _ = visitor.visit_mut_immediate_error_message(imm);
        wrap_const_control_flow!(visitor)
    }
    fn visit_mut_immediate_felt(
        &mut self,
        imm: &mut ast::Immediate<Felt>,
    ) -> ControlFlow<LinkerError> {
        let mut visitor = ConstEvalVisitor::new(self);
        let _ = visitor.visit_mut_immediate_felt(imm);
        wrap_const_control_flow!(visitor)
    }
    fn visit_mut_immediate_push_value(
        &mut self,
        imm: &mut ast::Immediate<miden_assembly_syntax::parser::PushValue>,
    ) -> ControlFlow<LinkerError> {
        let mut visitor = ConstEvalVisitor::new(self);
        let _ = visitor.visit_mut_immediate_push_value(imm);
        wrap_const_control_flow!(visitor)
    }
    fn visit_mut_immediate_word_value(
        &mut self,
        imm: &mut ast::Immediate<miden_assembly_syntax::parser::WordValue>,
    ) -> ControlFlow<LinkerError> {
        let mut visitor = ConstEvalVisitor::new(self);
        let _ = visitor.visit_mut_immediate_word_value(imm);
        wrap_const_control_flow!(visitor)
    }
}

impl<'a, 'b: 'a> ConstEnvironment for ModuleRewriter<'a, 'b> {
    type Error = LinkerError;

    fn get_source_file_for(&self, span: SourceSpan) -> Option<Arc<SourceFile>> {
        self.resolver.source_manager().get(span.source_id()).ok()
    }

    fn get(&mut self, name: &ast::Ident) -> Result<Option<CachedConstantValue<'_>>, Self::Error> {
        let name = Span::new(name.span(), name.as_str());
        let context = SymbolResolutionContext {
            span: name.span(),
            module: self.module_id,
            kind: None,
        };
        let path = ast::Path::new(name.inner());
        let gid = self.resolver.resolve_constant_path(&context, Span::new(name.span(), path))?;

        self.get_constant_by_gid(gid, name.span())
    }

    fn get_by_path(
        &mut self,
        path: Span<&ast::Path>,
    ) -> Result<Option<CachedConstantValue<'_>>, Self::Error> {
        let context = SymbolResolutionContext {
            span: path.span(),
            module: self.module_id,
            kind: None,
        };
        let gid = self.resolver.resolve_constant_path(&context, path)?;

        self.get_constant_by_gid(gid, path.span())
    }
}