use alloc::{
collections::{BTreeMap, BTreeSet},
sync::Arc,
};
use core::fmt;
use miden_assembly::{
Library, Path, PathBuf,
ast::InvocationTarget,
library::{LibraryExport, ProcedureExport},
};
use miden_core::{Word, program::Program};
use miden_mast_package::{MastArtifact, Package};
use midenc_hir::{constants::ConstantData, dialects::builtin, interner::Symbol};
use midenc_session::{
Emit, OutputMode, OutputType, Session, Writer,
diagnostics::{IntoDiagnostic, Report, SourceSpan, Span, WrapErr},
};
use crate::{TraceEvent, lower::NativePtr, masm};
pub struct MasmComponent {
pub id: builtin::ComponentId,
pub init: Option<masm::InvocationTarget>,
pub entrypoint: Option<masm::InvocationTarget>,
pub kernel: Option<masm::KernelLibrary>,
pub rodata: Vec<Rodata>,
pub heap_base: u32,
pub stack_pointer: Option<u32>,
pub modules: Vec<Arc<masm::Module>>,
}
impl Emit for MasmComponent {
fn name(&self) -> Option<Symbol> {
None
}
fn output_type(&self, _mode: OutputMode) -> OutputType {
OutputType::Masm
}
fn write_to<W: Writer>(
&self,
mut writer: W,
mode: OutputMode,
_session: &Session,
) -> anyhow::Result<()> {
if mode != OutputMode::Text {
anyhow::bail!("masm emission does not support binary mode");
}
writer.write_fmt(core::format_args!("{self}"))?;
Ok(())
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct Rodata {
pub component: builtin::ComponentId,
pub digest: Word,
pub start: NativePtr,
pub data: Arc<ConstantData>,
}
impl fmt::Debug for Rodata {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Rodata")
.field("digest", &format_args!("{}", &self.digest))
.field("start", &self.start)
.field_with("data", |f| {
f.debug_struct("ConstantData")
.field("len", &self.data.len())
.finish_non_exhaustive()
})
.finish()
}
}
impl Rodata {
pub fn size_in_bytes(&self) -> usize {
self.data.len()
}
pub fn size_in_felts(&self) -> usize {
self.data.len().next_multiple_of(4) / 4
}
pub fn size_in_words(&self) -> usize {
self.size_in_felts().next_multiple_of(4) / 4
}
pub fn to_elements(&self) -> Vec<miden_processor::Felt> {
Self::bytes_to_elements(self.data.as_slice())
}
pub fn bytes_to_elements(bytes: &[u8]) -> Vec<miden_processor::Felt> {
use miden_processor::Felt;
let mut felts = Vec::with_capacity(bytes.len() / 4);
let mut iter = bytes.iter().copied().array_chunks::<4>();
felts.extend(iter.by_ref().map(|chunk| Felt::new(u32::from_le_bytes(chunk) as u64)));
let remainder = iter.into_remainder();
if remainder.len() > 0 {
let mut chunk = [0u8; 4];
for (i, byte) in remainder.enumerate() {
chunk[i] = byte;
}
felts.push(Felt::new(u32::from_le_bytes(chunk) as u64));
}
let size_in_felts = bytes.len().next_multiple_of(4) / 4;
let size_in_words = size_in_felts.next_multiple_of(4) / 4;
let padding = (size_in_words * 4).abs_diff(felts.len());
felts.resize(felts.len() + padding, Felt::ZERO);
debug_assert_eq!(felts.len() % 4, 0, "expected to be a valid number of words");
felts
}
}
inventory::submit! {
midenc_session::CompileFlag::new("test_harness")
.long("test-harness")
.action(midenc_session::FlagAction::SetTrue)
.help("If present, causes the code generator to emit extra code for the VM test harness")
.help_heading("Testing")
}
impl fmt::Display for MasmComponent {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
use crate::intrinsics::INTRINSICS_MODULE_NAMES;
for module in self.modules.iter() {
let module_name = module.path().as_str();
let module_name_trimmed = module_name.trim_start_matches("::");
if INTRINSICS_MODULE_NAMES.contains(&module_name) {
continue;
}
if module.is_in_namespace(Path::new("std"))
|| module_name_trimmed.starts_with("miden::core")
|| module_name_trimmed.starts_with("miden::protocol")
{
continue;
} else {
writeln!(f, "# mod {}\n", &module_name)?;
writeln!(f, "{module}")?;
}
}
Ok(())
}
}
impl MasmComponent {
pub fn assemble(
&self,
link_libraries: &[Arc<Library>],
link_packages: &BTreeMap<Symbol, Arc<Package>>,
session: &Session,
) -> Result<MastArtifact, Report> {
if let Some(entrypoint) = self.entrypoint.as_ref() {
self.assemble_program(entrypoint, link_libraries, link_packages, session)
.map(MastArtifact::Executable)
} else {
self.assemble_library(link_libraries, link_packages, session)
.map(MastArtifact::Library)
}
}
fn assemble_program(
&self,
entrypoint: &InvocationTarget,
link_libraries: &[Arc<Library>],
_link_packages: &BTreeMap<Symbol, Arc<Package>>,
session: &Session,
) -> Result<Arc<Program>, Report> {
use miden_assembly::Assembler;
log::debug!(
target: "assembly",
"assembling executable with entrypoint '{entrypoint}'"
);
let mut assembler = Assembler::new(session.source_manager.clone());
let mut lib_modules = BTreeSet::<PathBuf>::default();
for library in link_libraries.iter().cloned() {
for module in library.module_infos() {
log::debug!(target: "assembly", "registering '{}' with assembler", module.path());
lib_modules.insert(module.path().to_path_buf());
}
assembler.link_dynamic_library(library)?;
}
log::debug!(target: "assembly", "start adding the following modules with assembler: {}",
self.modules.iter().map(|m| m.path().to_string()).collect::<Vec<_>>().join(", "));
let mut modules = Vec::with_capacity(self.modules.len());
for module in self.modules.iter().cloned() {
if lib_modules.contains(module.path()) {
log::warn!(
target: "assembly",
"module '{}' is already registered with the assembler as library's module, \
skipping",
module.path()
);
continue;
}
if module.path().as_str().trim_start_matches("::").starts_with("intrinsics") {
log::debug!(target: "assembly", "adding intrinsics '{}' to assembler", module.path());
assembler.compile_and_statically_link(module)?;
} else {
log::debug!(target: "assembly", "adding '{}' for assembler", module.path());
modules.push(module);
}
}
for module in modules.into_iter().rev() {
assembler.compile_and_statically_link(module)?;
}
let emit_test_harness = session.get_flag("test_harness");
let main =
self.generate_main(entrypoint, emit_test_harness, session.source_manager.clone())?;
log::debug!(target: "assembly", "generated executable module:\n{main}");
let program = assembler.assemble_program(main)?;
let advice_map: miden_core::advice::AdviceMap =
self.rodata.iter().map(|rodata| (rodata.digest, rodata.to_elements())).collect();
Ok(Arc::new(program.with_advice_map(advice_map)))
}
fn assemble_library(
&self,
link_libraries: &[Arc<Library>],
_link_packages: &BTreeMap<Symbol, Arc<Package>>,
session: &Session,
) -> Result<Arc<Library>, Report> {
use miden_assembly::Assembler;
log::debug!(
target: "assembly",
"assembling library of {} modules",
self.modules.len()
);
let mut assembler = Assembler::new(session.source_manager.clone());
let mut lib_modules = BTreeSet::<PathBuf>::default();
for library in link_libraries.iter().cloned() {
for module in library.module_infos() {
log::debug!(target: "assembly", "registering '{}' with assembler", module.path());
lib_modules.insert(module.path().to_path_buf());
}
assembler.link_dynamic_library(library)?;
}
log::debug!(target: "assembly", "start adding the following modules with assembler: {}",
self.modules.iter().map(|m| m.path().to_string()).collect::<Vec<_>>().join(", "));
let mut modules = Vec::with_capacity(self.modules.len());
for module in self.modules.iter().cloned() {
if lib_modules.contains(module.path()) {
log::warn!(
target: "assembly",
"module '{}' is already registered with the assembler as library's module, \
skipping",
module.path()
);
continue;
}
if module.path().as_str().trim_start_matches("::").starts_with("intrinsics") {
log::debug!(target: "assembly", "adding intrinsics '{}' to assembler", module.path());
assembler.compile_and_statically_link(module)?;
} else {
log::debug!(target: "assembly", "adding '{}' for assembler", module.path());
modules.push(module);
}
}
let lib = assembler.assemble_library(modules)?;
let advice_map: miden_core::advice::AdviceMap =
self.rodata.iter().map(|rodata| (rodata.digest, rodata.to_elements())).collect();
let converted_exports = recover_wasm_cm_interfaces(&lib);
let mut mast_forest = lib.mast_forest().clone();
drop(lib);
{
let mast = Arc::get_mut(&mut mast_forest).expect("expected unique reference");
mast.advice_map_mut().extend(advice_map);
}
Ok(Library::new(mast_forest, converted_exports).map(Arc::new)?)
}
fn generate_main(
&self,
entrypoint: &InvocationTarget,
emit_test_harness: bool,
source_manager: Arc<dyn midenc_session::SourceManager + Send + Sync>,
) -> Result<Arc<masm::Module>, Report> {
use masm::{Instruction as Inst, Op};
let mut exe = Box::new(masm::Module::new_executable());
let span = SourceSpan::default();
let body = {
let mut block = masm::Block::new(span, Vec::with_capacity(64));
if let Some(init) = self.init.as_ref() {
block.push(Op::Inst(Span::new(span, Inst::Exec(init.clone()))));
}
if emit_test_harness {
self.emit_test_harness(&mut block);
}
block.push(Op::Inst(Span::new(
span,
Inst::Trace(TraceEvent::FrameStart.as_u32().into()),
)));
block.push(Op::Inst(Span::new(span, Inst::Exec(entrypoint.clone()))));
block
.push(Op::Inst(Span::new(span, Inst::Trace(TraceEvent::FrameEnd.as_u32().into()))));
let truncate_stack = {
let name = masm::ProcedureName::new("truncate_stack").unwrap();
let module = masm::LibraryPath::new("::miden::core::sys").unwrap();
let qualified = masm::QualifiedProcedureName::new(module.as_path(), name);
InvocationTarget::Path(Span::new(span, qualified.into_inner()))
};
block.push(Op::Inst(Span::new(span, Inst::Exec(truncate_stack))));
block
};
let start = masm::Procedure::new(
span,
masm::Visibility::Public,
masm::ProcedureName::main(),
0,
body,
);
exe.define_procedure(start, source_manager)
.into_diagnostic()
.wrap_err("failed to define executable `main` procedure")?;
Ok(Arc::from(exe))
}
fn emit_test_harness(&self, block: &mut masm::Block) {
use masm::{Instruction as Inst, IntValue, Op, PushValue};
use miden_core::Felt;
let span = SourceSpan::default();
let pipe_words_to_memory = {
let name = masm::ProcedureName::new("pipe_words_to_memory").unwrap();
let module = masm::LibraryPath::new("::miden::core::mem").unwrap();
let qualified = masm::QualifiedProcedureName::new(module.as_path(), name);
InvocationTarget::Path(Span::new(span, qualified.into_inner()))
};
block.push(Op::Inst(Span::new(span, Inst::AdvPush(1.into()))));
block.push(Op::Inst(Span::new(span, Inst::Dup0)));
block.push(Op::Inst(Span::new(span, Inst::Push(PushValue::Int(IntValue::U8(0)).into()))));
block.push(Op::Inst(Span::new(span, Inst::Gt)));
let mut loop_body = Vec::with_capacity(16);
loop_body.push(Op::Inst(Span::new(span, Inst::SubImm(Felt::ONE.into()))));
loop_body.push(Op::Inst(Span::new(span, Inst::AdvPush(2.into()))));
loop_body
.push(Op::Inst(Span::new(span, Inst::Trace(TraceEvent::FrameStart.as_u32().into()))));
loop_body.push(Op::Inst(Span::new(span, Inst::Exec(pipe_words_to_memory))));
loop_body
.push(Op::Inst(Span::new(span, Inst::Trace(TraceEvent::FrameEnd.as_u32().into()))));
loop_body.push(Op::Inst(Span::new(span, Inst::DropW)));
loop_body.push(Op::Inst(Span::new(span, Inst::DropW)));
loop_body.push(Op::Inst(Span::new(span, Inst::DropW)));
loop_body.push(Op::Inst(Span::new(span, Inst::Drop)));
loop_body.push(Op::Inst(Span::new(span, Inst::Dup0)));
loop_body
.push(Op::Inst(Span::new(span, Inst::Push(PushValue::Int(IntValue::U8(0)).into()))));
loop_body.push(Op::Inst(Span::new(span, Inst::Gt)));
block.push(Op::While {
span,
body: masm::Block::new(span, loop_body),
});
block.push(Op::Inst(Span::new(span, Inst::Drop)));
}
}
fn recover_wasm_cm_interfaces(lib: &Library) -> BTreeMap<Arc<Path>, LibraryExport> {
use crate::intrinsics::INTRINSICS_MODULE_NAMES;
let mut exports = BTreeMap::new();
for export in lib.exports() {
let path = export.path();
let Some(proc_export) = export.as_procedure() else {
exports.insert(path, export.clone());
continue;
};
let Some(module) = proc_export.path.parent() else {
exports.insert(path, export.clone());
continue;
};
let Some(proc_name) = proc_export.path.last() else {
exports.insert(path, export.clone());
continue;
};
if INTRINSICS_MODULE_NAMES.contains(&module.as_str()) || proc_name.starts_with("cabi") {
exports.insert(path, export.clone());
continue;
}
if let Some((component, interface)) = proc_name.rsplit_once('/') {
let (interface, function) =
interface.rsplit_once('#').expect("invalid wasm component model identifier");
let mut module_path = component.to_string();
module_path.push_str("::");
module_path.push_str(interface);
let module_path = masm::LibraryPath::new(&module_path)
.expect("invalid wasm component model identifier");
let name = masm::ProcedureName::from_raw_parts(masm::Ident::from_raw_parts(
Span::unknown(Arc::from(function)),
));
let qualified = masm::QualifiedProcedureName::new(module_path.as_path(), name);
let qualified = qualified.into_inner();
let mut new_export = ProcedureExport::new(proc_export.node, qualified.clone())
.with_attributes(proc_export.attributes.clone());
if let Some(signature) = proc_export.signature.clone() {
new_export = new_export.with_signature(signature);
}
exports.insert(qualified, LibraryExport::Procedure(new_export));
} else {
exports.insert(path, export.clone());
}
}
exports
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::*;
fn validate_bytes_to_elements(bytes: &[u8]) {
let result = Rodata::bytes_to_elements(bytes);
let expected_felts = bytes.len().div_ceil(4);
let expected_total_felts = expected_felts.div_ceil(4) * 4;
assert_eq!(
result.len(),
expected_total_felts,
"For {} bytes, expected {} felts (padded from {} felts), but got {}",
bytes.len(),
expected_total_felts,
expected_felts,
result.len()
);
for (i, felt) in result.iter().enumerate().skip(expected_felts) {
assert_eq!(*felt, miden_processor::Felt::ZERO, "Padding at index {i} should be zero");
}
}
#[test]
fn test_bytes_to_elements_edge_cases() {
validate_bytes_to_elements(&[]);
validate_bytes_to_elements(&[1]);
validate_bytes_to_elements(&[0u8; 4]);
validate_bytes_to_elements(&[0u8; 15]);
validate_bytes_to_elements(&[0u8; 16]);
validate_bytes_to_elements(&[0u8; 17]);
validate_bytes_to_elements(&[0u8; 31]);
validate_bytes_to_elements(&[0u8; 32]);
validate_bytes_to_elements(&[0u8; 33]);
validate_bytes_to_elements(&[0u8; 64]);
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn proptest_bytes_to_elements(bytes in prop::collection::vec(any::<u8>(), 0..=1000)) {
validate_bytes_to_elements(&bytes);
}
#[test]
fn proptest_bytes_to_elements_word_boundaries(size_factor in 0u32..=100) {
let base_size = size_factor * 16;
for offset in -2i32..=2 {
let size = (base_size as i32 + offset).max(0) as usize;
let bytes = vec![0u8; size];
validate_bytes_to_elements(&bytes);
}
}
}
}