use alloc::{collections::BTreeSet, sync::Arc, vec::Vec};
use core::ops::ControlFlow;
use miden_assembly_syntax::{
ast::constants::eval::CachedConstantValue, diagnostics::RelatedError, library::ItemInfo,
sema::ConstEvalVisitor,
};
use miden_core::Felt;
use crate::{
ModuleIndex, SourceFile, SourceSpan, Span, Spanned,
ast::{
self, Alias, AliasTarget, InvocationTarget, Invoke, InvokeKind, Procedure,
SymbolResolution,
constants::ConstEnvironment,
visit::{self, VisitMut},
},
linker::{
LinkerError, Resolver, ResolverCache, SymbolItem, SymbolResolutionContext, SymbolResolver,
},
};
pub struct ModuleRewriter<'a, 'b: 'a> {
resolver: &'a SymbolResolver<'b>,
cache: &'a mut ResolverCache,
module_id: ModuleIndex,
invoked: BTreeSet<Invoke>,
}
macro_rules! wrap_const_control_flow {
($visitor:ident) => {
match $visitor.into_result() {
Ok(()) => return ControlFlow::Continue(()),
Err(errs) => {
let errors = errs.into_iter().map(RelatedError::wrap).collect::<Vec<_>>();
return ControlFlow::Break(LinkerError::Related {
errors: errors.into_boxed_slice(),
});
},
}
};
}
impl<'a, 'b: 'a> ModuleRewriter<'a, 'b> {
pub fn new(
module: ModuleIndex,
resolver: &'a SymbolResolver<'b>,
cache: &'a mut ResolverCache,
) -> Self {
Self {
resolver,
cache,
module_id: module,
invoked: Default::default(),
}
}
fn invalid_constant_ref(&self, span: SourceSpan) -> LinkerError {
LinkerError::InvalidConstantRef {
span,
source_file: self.get_source_file_for(span),
}
}
fn get_constant_by_gid(
&mut self,
gid: ast::GlobalItemIndex,
span: SourceSpan,
) -> Result<Option<CachedConstantValue<'_>>, LinkerError> {
if self.cache.constants.contains_key(&gid) {
let cached = self
.cache
.constants
.get(&gid)
.expect("constant value present in cache must be retrievable");
return Ok(Some(CachedConstantValue::Hit(cached)));
}
match self.resolver.linker()[gid].item() {
SymbolItem::Compiled(ItemInfo::Constant(info)) => {
return Ok(Some(CachedConstantValue::Hit(&info.value)));
},
SymbolItem::Constant(_) => {
let mut resolver = Resolver {
resolver: self.resolver,
cache: self.cache,
current_module: gid.module,
};
resolver.materialize_constant_by_gid(gid, span)?;
let cached = self
.cache
.constants
.get(&gid)
.expect("constant value inserted into cache must be retrievable");
return Ok(Some(CachedConstantValue::Hit(cached)));
},
SymbolItem::Compiled(_) | SymbolItem::Procedure(_) | SymbolItem::Type(_) => (),
SymbolItem::Alias { .. } => unreachable!(),
}
Err(self.invalid_constant_ref(span))
}
fn rewrite_target(
&mut self,
kind: InvokeKind,
target: &mut InvocationTarget,
) -> ControlFlow<LinkerError> {
log::debug!(target: "linker", " * rewriting {kind} target {target}");
let context = SymbolResolutionContext {
span: target.span(),
module: self.module_id,
kind: Some(kind),
};
match self.resolver.resolve_invoke_target(&context, target) {
Err(err) => {
log::error!(target: "linker", " | failed to resolve target {target}");
return ControlFlow::Break(err);
},
Ok(SymbolResolution::MastRoot(_)) => {
log::warn!(target: "linker", " | resolved phantom target {target}");
},
Ok(SymbolResolution::Exact { path, .. }) => {
log::debug!(target: "linker", " | target resolved to {path}");
match &mut *target {
InvocationTarget::MastRoot(_) => (),
InvocationTarget::Path(old_path) => {
*old_path = path.with_span(old_path.span());
},
target @ InvocationTarget::Symbol(_) => {
*target = InvocationTarget::Path(path.with_span(target.span()));
},
}
self.invoked.insert(Invoke { kind, target: target.clone() });
},
Ok(SymbolResolution::Module { id, path }) => {
log::debug!(target: "linker", " | target resolved to module {id}: '{path}'");
},
Ok(SymbolResolution::Local(item)) => {
log::debug!(target: "linker", " | target is already resolved locally to {item}");
},
Ok(SymbolResolution::External(path)) => {
log::debug!(target: "linker", " | target is externally defined at {path}");
match target {
InvocationTarget::MastRoot(_) => unreachable!(),
InvocationTarget::Path(old_path) => {
*old_path = path.with_span(old_path.span());
},
target @ InvocationTarget::Symbol(_) => {
*target = InvocationTarget::Path(path.with_span(target.span()));
},
}
},
}
ControlFlow::Continue(())
}
}
impl<'a, 'b: 'a> VisitMut<LinkerError> for ModuleRewriter<'a, 'b> {
fn visit_mut_procedure(&mut self, procedure: &mut Procedure) -> ControlFlow<LinkerError> {
log::debug!(target: "linker", " | visiting {}", procedure.name());
self.invoked.clear();
self.invoked.extend(procedure.invoked().cloned());
visit::visit_mut_procedure(self, procedure)?;
procedure.extend_invoked(core::mem::take(&mut self.invoked));
ControlFlow::Continue(())
}
fn visit_mut_syscall(&mut self, target: &mut InvocationTarget) -> ControlFlow<LinkerError> {
self.rewrite_target(InvokeKind::SysCall, target)
}
fn visit_mut_call(&mut self, target: &mut InvocationTarget) -> ControlFlow<LinkerError> {
self.rewrite_target(InvokeKind::Call, target)
}
fn visit_mut_invoke_target(
&mut self,
target: &mut InvocationTarget,
) -> ControlFlow<LinkerError> {
self.rewrite_target(InvokeKind::Exec, target)
}
fn visit_mut_alias(&mut self, alias: &mut Alias) -> ControlFlow<LinkerError> {
match alias.target() {
AliasTarget::MastRoot(_) => return ControlFlow::Continue(()),
AliasTarget::Path(path) if path.is_absolute() => return ControlFlow::Continue(()),
AliasTarget::Path(_) => (),
}
log::debug!(target: "linker", " * rewriting alias target {}", alias.target());
let span = alias.target().span();
let context = SymbolResolutionContext { span, module: self.module_id, kind: None };
match self.resolver.resolve_alias_target(&context, &*alias) {
Err(err) => {
log::error!(target: "linker", " | failed to resolve target {}", alias.target());
return ControlFlow::Break(err);
},
Ok(SymbolResolution::Module { id, path }) => {
log::debug!(target: "linker", " | target resolved to module '{path}' (id {id})");
*alias.target_mut() = AliasTarget::Path(path.with_span(span));
},
Ok(SymbolResolution::Exact { gid, path }) => {
log::debug!(target: "linker", " | target resolved to item '{path}' (id {gid})");
*alias.target_mut() = AliasTarget::Path(path.with_span(span));
},
Ok(SymbolResolution::MastRoot(digest)) => {
log::warn!(target: "linker", " | target resolved to mast root {digest}");
},
Ok(SymbolResolution::Local(_) | SymbolResolution::External(_)) => unreachable!(),
}
ControlFlow::Continue(())
}
fn visit_mut_alias_target(&mut self, _target: &mut AliasTarget) -> ControlFlow<LinkerError> {
unreachable!("expected all alias targets to be reached via an alias")
}
fn visit_mut_immediate_u8(&mut self, imm: &mut ast::Immediate<u8>) -> ControlFlow<LinkerError> {
let mut visitor = ConstEvalVisitor::new(self);
let _ = visitor.visit_mut_immediate_u8(imm);
wrap_const_control_flow!(visitor)
}
fn visit_mut_immediate_u16(
&mut self,
imm: &mut ast::Immediate<u16>,
) -> ControlFlow<LinkerError> {
let mut visitor = ConstEvalVisitor::new(self);
let _ = visitor.visit_mut_immediate_u16(imm);
wrap_const_control_flow!(visitor)
}
fn visit_mut_immediate_u32(
&mut self,
imm: &mut ast::Immediate<u32>,
) -> ControlFlow<LinkerError> {
let mut visitor = ConstEvalVisitor::new(self);
let _ = visitor.visit_mut_immediate_u32(imm);
wrap_const_control_flow!(visitor)
}
fn visit_mut_immediate_error_message(
&mut self,
imm: &mut ast::Immediate<Arc<str>>,
) -> ControlFlow<LinkerError> {
let mut visitor = ConstEvalVisitor::new(self);
let _ = visitor.visit_mut_immediate_error_message(imm);
wrap_const_control_flow!(visitor)
}
fn visit_mut_immediate_felt(
&mut self,
imm: &mut ast::Immediate<Felt>,
) -> ControlFlow<LinkerError> {
let mut visitor = ConstEvalVisitor::new(self);
let _ = visitor.visit_mut_immediate_felt(imm);
wrap_const_control_flow!(visitor)
}
fn visit_mut_immediate_push_value(
&mut self,
imm: &mut ast::Immediate<miden_assembly_syntax::parser::PushValue>,
) -> ControlFlow<LinkerError> {
let mut visitor = ConstEvalVisitor::new(self);
let _ = visitor.visit_mut_immediate_push_value(imm);
wrap_const_control_flow!(visitor)
}
fn visit_mut_immediate_word_value(
&mut self,
imm: &mut ast::Immediate<miden_assembly_syntax::parser::WordValue>,
) -> ControlFlow<LinkerError> {
let mut visitor = ConstEvalVisitor::new(self);
let _ = visitor.visit_mut_immediate_word_value(imm);
wrap_const_control_flow!(visitor)
}
}
impl<'a, 'b: 'a> ConstEnvironment for ModuleRewriter<'a, 'b> {
type Error = LinkerError;
fn get_source_file_for(&self, span: SourceSpan) -> Option<Arc<SourceFile>> {
self.resolver.source_manager().get(span.source_id()).ok()
}
fn get(&mut self, name: &ast::Ident) -> Result<Option<CachedConstantValue<'_>>, Self::Error> {
let name = Span::new(name.span(), name.as_str());
let context = SymbolResolutionContext {
span: name.span(),
module: self.module_id,
kind: None,
};
let gid = match self.resolver.resolve_local(&context, &name)? {
SymbolResolution::Exact { gid, .. } => gid,
SymbolResolution::Local(item) => self.module_id + item.into_inner(),
SymbolResolution::External(path) => {
return self.get_by_path(path.as_deref());
},
SymbolResolution::Module { .. } | SymbolResolution::MastRoot(_) => {
return Err(self.invalid_constant_ref(name.span()));
},
};
self.get_constant_by_gid(gid, name.span())
}
fn get_by_path(
&mut self,
path: Span<&ast::Path>,
) -> Result<Option<CachedConstantValue<'_>>, Self::Error> {
let context = SymbolResolutionContext {
span: path.span(),
module: self.module_id,
kind: None,
};
let gid = match self.resolver.resolve_path(&context, path)? {
SymbolResolution::Exact { gid, .. } => gid,
SymbolResolution::Local(item) => self.module_id + item.into_inner(),
SymbolResolution::MastRoot(_) | SymbolResolution::Module { .. } => {
return Err(self.invalid_constant_ref(path.span()));
},
SymbolResolution::External(_) => unreachable!(),
};
self.get_constant_by_gid(gid, path.span())
}
}