use alloc::{string::String, sync::Arc};
use std::string::ToString;
use miden_core::{Kernel, Operation, Program, mast::MastForest};
use miden_utils_testing::get_column_name;
use pretty_assertions::assert_eq;
use rstest::{fixture, rstest};
use winter_prover::Trace;
use super::*;
use crate::{DefaultHost, HostLibrary, fast::FastProcessor};
const DEFAULT_STACK: &[Felt] = &[Felt::new(1), Felt::new(2), Felt::new(3)];
const DYN_TARGET_PROC_HASH: &[Felt] = &[
Felt::new(10995436151082118190),
Felt::new(776663942277617877),
Felt::new(3177713792132750309),
Felt::new(10407898805173442467),
];
const EXTERNAL_LIB_PROC_DIGEST: Word = Word::new([
Felt::new(9552974201798903089),
Felt::new(993192251238261044),
Felt::new(1885027269046469428),
Felt::new(8558115384207742312),
]);
#[rstest]
#[case(join_program(), 4, DEFAULT_STACK)]
#[case(join_program(), 11, DEFAULT_STACK)]
#[case(split_program(), 5, &[ONE])]
#[case(split_program(), 5, &[ZERO])]
#[case(split_program(), 9, &[ONE])]
#[case(split_program(), 9, &[ZERO])]
#[case(loop_program(), 5, &[ZERO])]
#[case(loop_program(), 6, &[ZERO])]
#[case(loop_program(), 10, &[ONE])]
#[case(loop_program(), 10, &[ONE, ONE])]
#[case(call_program(), 5, DEFAULT_STACK)]
#[case(call_program(), 10, DEFAULT_STACK)]
#[case(syscall_program(), 5, DEFAULT_STACK)]
#[case(syscall_program(), 10, DEFAULT_STACK)]
#[case(basic_block_program_small(), 1, DEFAULT_STACK)]
#[case(basic_block_program_small(), 2, DEFAULT_STACK)]
#[case(basic_block_program_small(), 3, DEFAULT_STACK)]
#[case(basic_block_program_small(), 4, DEFAULT_STACK)]
#[case(basic_block_program_small(), 5, DEFAULT_STACK)]
#[case(basic_block_program_multiple_batches(), 74, DEFAULT_STACK)]
#[case(basic_block_program_multiple_batches(), 76, DEFAULT_STACK)]
#[case(dyn_program(), 12, DYN_TARGET_PROC_HASH)]
#[case(dyn_program(), 16, DYN_TARGET_PROC_HASH)]
#[case(dyncall_program(), 12, DYN_TARGET_PROC_HASH)]
#[case(dyncall_program(), 16, DYN_TARGET_PROC_HASH)]
#[case(external_program(), 5, DEFAULT_STACK)]
#[case(dyn_program(), 12, EXTERNAL_LIB_PROC_DIGEST.as_elements())]
fn test_trace_generation_at_fragment_boundaries(
testname: String,
#[case] program: Program,
#[case] fragment_size: usize,
#[case] stack_inputs: &[Felt],
) {
const MAX_FRAGMENT_SIZE: usize = 1 << 29;
let trace_from_fragments = {
let processor = FastProcessor::new(stack_inputs);
let mut host = DefaultHost::default();
host.load_library(create_simple_library()).unwrap();
let (execution_output, trace_fragment_contexts) =
processor.execute_for_trace_sync(&program, &mut host, fragment_size).unwrap();
build_trace(
execution_output,
trace_fragment_contexts,
program.hash(),
program.kernel().clone(),
)
};
let trace_from_single_fragment = {
let processor = FastProcessor::new(stack_inputs);
let mut host = DefaultHost::default();
host.load_library(create_simple_library()).unwrap();
let (execution_output, trace_fragment_contexts) = processor
.execute_for_trace_sync(&program, &mut host, MAX_FRAGMENT_SIZE)
.unwrap();
assert!(trace_fragment_contexts.core_trace_contexts.len() == 1);
build_trace(
execution_output,
trace_fragment_contexts,
program.hash(),
program.kernel().clone(),
)
};
for (col_idx, (col_from_fragments, col_from_single_fragment)) in trace_from_fragments
.main_segment()
.columns()
.zip(trace_from_single_fragment.main_segment().columns())
.enumerate()
{
if col_from_fragments != col_from_single_fragment {
for (row_idx, (val_from_fragments, val_from_single_fragment)) in
col_from_fragments.iter().zip(col_from_single_fragment.iter()).enumerate()
{
if val_from_fragments != val_from_single_fragment {
panic!(
"Trace columns do not match between trace generated as multiple fragments vs a single fragment at column {} ({}) row {}: multiple={}, single={}",
col_idx,
get_column_name(col_idx),
row_idx,
val_from_fragments,
val_from_single_fragment
);
}
}
panic!(
"Trace columns do not match between trace generated as multiple fragments vs a single fragment at column {} ({}): different lengths (slow={}, parallel={})",
col_idx,
get_column_name(col_idx),
col_from_fragments.len(),
col_from_single_fragment.len()
);
}
}
assert_eq!(format!("{trace_from_fragments:?}"), format!("{trace_from_single_fragment:?}"));
insta::assert_compact_debug_snapshot!(testname, trace_from_fragments);
}
fn create_simple_library() -> HostLibrary {
let mut mast_forest = MastForest::new();
let swap_block = mast_forest
.add_block(vec![Operation::Swap, Operation::Swap], Vec::new())
.unwrap();
mast_forest.make_root(swap_block);
HostLibrary::from(Arc::new(mast_forest))
}
fn join_program() -> Program {
let mut program = MastForest::new();
let basic_block_mul = program.add_block(vec![Operation::Mul], Vec::new()).unwrap();
let basic_block_add = program.add_block(vec![Operation::Add], Vec::new()).unwrap();
let basic_block_swap = program.add_block(vec![Operation::Swap], Vec::new()).unwrap();
let target_join_node = program.add_join(basic_block_add, basic_block_swap).unwrap();
let root_join_node = program.add_join(basic_block_mul, target_join_node).unwrap();
program.make_root(root_join_node);
Program::new(Arc::new(program), root_join_node)
}
fn split_program() -> Program {
let mut program = MastForest::new();
let root_join_node = {
let basic_block_swap_swap =
program.add_block(vec![Operation::Swap, Operation::Swap], Vec::new()).unwrap();
let target_split_node = {
let basic_block_add = program.add_block(vec![Operation::Add], Vec::new()).unwrap();
let basic_block_swap = program.add_block(vec![Operation::Swap], Vec::new()).unwrap();
program.add_split(basic_block_add, basic_block_swap).unwrap()
};
program.add_join(basic_block_swap_swap, target_split_node).unwrap()
};
program.make_root(root_join_node);
Program::new(Arc::new(program), root_join_node)
}
fn loop_program() -> Program {
let mut program = MastForest::new();
let root_join_node = {
let basic_block_swap_swap =
program.add_block(vec![Operation::Swap, Operation::Swap], Vec::new()).unwrap();
let target_loop_node = {
let basic_block_pad_drop =
program.add_block(vec![Operation::Pad, Operation::Drop], Vec::new()).unwrap();
program.add_loop(basic_block_pad_drop).unwrap()
};
program.add_join(basic_block_swap_swap, target_loop_node).unwrap()
};
program.make_root(root_join_node);
Program::new(Arc::new(program), root_join_node)
}
fn call_program() -> Program {
let mut program = MastForest::new();
let root_join_node = {
let basic_block_swap_swap =
program.add_block(vec![Operation::Swap, Operation::Swap], Vec::new()).unwrap();
let target_call_node = program.add_call(basic_block_swap_swap).unwrap();
program.add_join(basic_block_swap_swap, target_call_node).unwrap()
};
program.make_root(root_join_node);
Program::new(Arc::new(program), root_join_node)
}
fn syscall_program() -> Program {
let mut program = MastForest::new();
let (root_join_node, kernel_proc_digest) = {
let basic_block_swap_swap =
program.add_block(vec![Operation::Swap, Operation::Swap], Vec::new()).unwrap();
let target_call_node = program.add_syscall(basic_block_swap_swap).unwrap();
let root_join_node = program.add_join(basic_block_swap_swap, target_call_node).unwrap();
(root_join_node, program[basic_block_swap_swap].digest())
};
program.make_root(root_join_node);
Program::with_kernel(
Arc::new(program),
root_join_node,
Kernel::new(&[kernel_proc_digest]).unwrap(),
)
}
fn basic_block_program_small() -> Program {
let mut program = MastForest::new();
let root_join_node = {
let target_basic_block = program
.add_block(vec![Operation::Swap, Operation::Push(42_u32.into())], Vec::new())
.unwrap();
let basic_block_drop = program.add_block(vec![Operation::Drop], Vec::new()).unwrap();
program.add_join(target_basic_block, basic_block_drop).unwrap()
};
program.make_root(root_join_node);
Program::new(Arc::new(program), root_join_node)
}
fn basic_block_program_multiple_batches() -> Program {
const NUM_SWAPS: usize = 80;
let mut program = MastForest::new();
let root_join_node = {
let target_basic_block =
program.add_block(vec![Operation::Swap; NUM_SWAPS], Vec::new()).unwrap();
let basic_block_drop = program.add_block(vec![Operation::Drop], Vec::new()).unwrap();
program.add_join(target_basic_block, basic_block_drop).unwrap()
};
program.make_root(root_join_node);
Program::new(Arc::new(program), root_join_node)
}
fn dyn_program() -> Program {
const HASH_ADDR: Felt = Felt::new(40);
let mut program = MastForest::new();
let root_join_node = {
let basic_block = program
.add_block(
vec![
Operation::Push(HASH_ADDR),
Operation::MStoreW,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Push(HASH_ADDR),
],
Vec::new(),
)
.unwrap();
let dyn_node = program.add_dyn().unwrap();
program.add_join(basic_block, dyn_node).unwrap()
};
program.make_root(root_join_node);
let target = program.add_block(vec![Operation::Swap], Vec::new()).unwrap();
program.make_root(target);
Program::new(Arc::new(program), root_join_node)
}
fn dyncall_program() -> Program {
const HASH_ADDR: Felt = Felt::new(40);
let mut program = MastForest::new();
let root_join_node = {
let basic_block = program
.add_block(
vec![
Operation::Push(HASH_ADDR),
Operation::MStoreW,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Push(HASH_ADDR),
],
Vec::new(),
)
.unwrap();
let dyncall_node = program.add_dyncall().unwrap();
program.add_join(basic_block, dyncall_node).unwrap()
};
program.make_root(root_join_node);
let target = program.add_block(vec![Operation::Swap], Vec::new()).unwrap();
program.make_root(target);
Program::new(Arc::new(program), root_join_node)
}
fn external_program() -> Program {
let mut program = MastForest::new();
let root_join_node = {
let basic_block_pad_drop =
program.add_block(vec![Operation::Pad, Operation::Drop], Vec::new()).unwrap();
let external_node = program.add_external(EXTERNAL_LIB_PROC_DIGEST).unwrap();
program.add_join(basic_block_pad_drop, external_node).unwrap()
};
program.make_root(root_join_node);
Program::new(Arc::new(program), root_join_node)
}
#[fixture]
fn testname() -> String {
std::thread::current().name().unwrap().to_string()
}