Skip to main content

miden_test_utils/
lib.rs

1#![no_std]
2
3extern crate alloc;
4
5#[cfg(feature = "std")]
6extern crate std;
7
8#[cfg(not(target_family = "wasm"))]
9use alloc::format;
10use alloc::{
11    string::{String, ToString},
12    sync::Arc,
13    vec::Vec,
14};
15
16use assembly::{KernelLibrary, Library};
17pub use assembly::{LibraryPath, SourceFile, SourceManager, diagnostics::Report};
18pub use pretty_assertions::{assert_eq, assert_ne, assert_str_eq};
19pub use processor::{
20    AdviceInputs, AdviceProvider, ContextId, ExecutionError, ExecutionOptions, ExecutionTrace,
21    Process, ProcessState, VmStateIterator,
22};
23use processor::{Program, fast::FastProcessor};
24#[cfg(not(target_family = "wasm"))]
25use proptest::prelude::{Arbitrary, Strategy};
26use prover::utils::range;
27pub use prover::{MemAdviceProvider, MerkleTreeVC, ProvingOptions, prove};
28pub use test_case::test_case;
29pub use verifier::{AcceptableOptions, VerifierError, verify};
30pub use vm_core::{
31    EMPTY_WORD, Felt, FieldElement, ONE, StackInputs, StackOutputs, StarkField, WORD_SIZE, Word,
32    ZERO,
33    chiplets::hasher::{STATE_WIDTH, hash_elements},
34    stack::MIN_STACK_DEPTH,
35    utils::{IntoBytes, ToElements, collections, group_slice_elements},
36};
37use vm_core::{ProgramInfo, chiplets::hasher::apply_permutation};
38
39pub mod math {
40    pub use winter_prover::math::{
41        ExtensionOf, FieldElement, StarkField, ToElements, fft, fields::QuadExtension, polynom,
42    };
43}
44
45pub mod serde {
46    pub use vm_core::utils::{
47        ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, SliceReader,
48    };
49}
50
51pub mod crypto;
52
53pub mod host;
54use host::TestHost;
55
56#[cfg(not(target_family = "wasm"))]
57pub mod rand;
58
59mod test_builders;
60
61#[cfg(not(target_family = "wasm"))]
62pub use proptest;
63
64// TYPE ALIASES
65// ================================================================================================
66
67pub type QuadFelt = vm_core::QuadExtension<Felt>;
68
69// CONSTANTS
70// ================================================================================================
71
72/// A value just over what a [u32] integer can hold.
73pub const U32_BOUND: u64 = u32::MAX as u64 + 1;
74
75/// A source code of the `truncate_stack` procedure.
76pub const TRUNCATE_STACK_PROC: &str = "
77proc.truncate_stack.4
78    loc_storew.0 dropw movupw.3
79    sdepth neq.16
80    while.true
81        dropw movupw.3
82        sdepth neq.16
83    end
84    loc_loadw.0
85end
86";
87
88// TEST HANDLER
89// ================================================================================================
90
91/// Asserts that running the given assembler test will result in the expected error.
92#[cfg(all(feature = "std", not(target_family = "wasm")))]
93#[macro_export]
94macro_rules! expect_assembly_error {
95    ($test:expr, $(|)? $( $pattern:pat_param )|+ $( if $guard: expr )? $(,)?) => {
96        let error = $test.compile().expect_err("expected assembly to fail");
97        match error.downcast::<assembly::AssemblyError>() {
98            Ok(error) => {
99                ::vm_core::assert_matches!(error, $( $pattern )|+ $( if $guard )?);
100            }
101            Err(report) => {
102                panic!(r#"
103assertion failed (expected assembly error, but got a different type):
104    left: `{:?}`,
105    right: `{}`"#, report, stringify!($($pattern)|+ $(if $guard)?));
106            }
107        }
108    };
109}
110
111/// Asserts that running the given execution test will result in the expected error.
112#[cfg(all(feature = "std", not(target_family = "wasm")))]
113#[macro_export]
114macro_rules! expect_exec_error_matches {
115    ($test:expr, $(|)? $( $pattern:pat_param )|+ $( if $guard: expr )? $(,)?) => {
116        match $test.execute() {
117            Ok(_) => panic!("expected execution to fail @ {}:{}", file!(), line!()),
118            Err(error) => ::vm_core::assert_matches!(error, $( $pattern )|+ $( if $guard )?),
119        }
120    };
121}
122
123/// Like [assembly::assert_diagnostic], but matches each non-empty line of the rendered output to a
124/// corresponding pattern.
125///
126/// So if the output has 3 lines, the second of which is empty, and you provide 2 patterns, the
127/// assertion passes if the first line matches the first pattern, and the third line matches the
128/// second pattern - the second line is ignored because it is empty.
129#[cfg(not(target_family = "wasm"))]
130#[macro_export]
131macro_rules! assert_diagnostic_lines {
132    ($diagnostic:expr, $($expected:expr),+) => {{
133        use assembly::testing::Pattern;
134        let actual = format!("{}", assembly::diagnostics::reporting::PrintDiagnostic::new_without_color($diagnostic));
135        let lines = actual.lines().filter(|l| !l.trim().is_empty()).zip([$(Pattern::from($expected)),*].into_iter());
136        for (actual_line, expected) in lines {
137            expected.assert_match_with_context(actual_line, &actual);
138        }
139    }};
140}
141
142#[cfg(not(target_family = "wasm"))]
143#[macro_export]
144macro_rules! assert_assembler_diagnostic {
145    ($test:ident, $($expected:literal),+) => {{
146        let error = $test
147            .compile()
148            .expect_err("expected diagnostic to be raised, but compilation succeeded");
149        assert_diagnostic_lines!(error, $($expected),*);
150    }};
151
152    ($test:ident, $($expected:expr),+) => {{
153        let error = $test
154            .compile()
155            .expect_err("expected diagnostic to be raised, but compilation succeeded");
156        assert_diagnostic_lines!(error, $($expected),*);
157    }};
158}
159
160/// This is a container for the data required to run tests, which allows for running several
161/// different types of tests.
162///
163/// Types of valid result tests:
164/// - Execution test: check that running a program compiled from the given source has the specified
165///   results for the given (optional) inputs.
166/// - Proptest: run an execution test inside a proptest.
167///
168/// Types of failure tests:
169/// - Assembly error test: check that attempting to compile the given source causes an AssemblyError
170///   which contains the specified substring.
171/// - Execution error test: check that running a program compiled from the given source causes an
172///   ExecutionError which contains the specified substring.
173pub struct Test {
174    pub source_manager: Arc<dyn SourceManager>,
175    pub source: Arc<SourceFile>,
176    pub kernel_source: Option<Arc<SourceFile>>,
177    pub stack_inputs: StackInputs,
178    pub advice_inputs: AdviceInputs,
179    pub in_debug_mode: bool,
180    pub libraries: Vec<Library>,
181    pub add_modules: Vec<(LibraryPath, String)>,
182}
183
184impl Test {
185    // CONSTRUCTOR
186    // --------------------------------------------------------------------------------------------
187
188    /// Creates the simplest possible new test, with only a source string and no inputs.
189    pub fn new(name: &str, source: &str, in_debug_mode: bool) -> Self {
190        let source_manager = Arc::new(assembly::DefaultSourceManager::default());
191        let source = source_manager.load(name, source.to_string());
192        Test {
193            source_manager,
194            source,
195            kernel_source: None,
196            stack_inputs: StackInputs::default(),
197            advice_inputs: AdviceInputs::default(),
198            in_debug_mode,
199            libraries: Vec::default(),
200            add_modules: Vec::default(),
201        }
202    }
203
204    /// Add an extra module to link in during assembly
205    pub fn add_module(&mut self, path: assembly::LibraryPath, source: impl ToString) {
206        self.add_modules.push((path, source.to_string()));
207    }
208
209    // TEST METHODS
210    // --------------------------------------------------------------------------------------------
211
212    /// Builds a final stack from the provided stack-ordered array and asserts that executing the
213    /// test will result in the expected final stack state.
214    #[track_caller]
215    pub fn expect_stack(&self, final_stack: &[u64]) {
216        let result = self.get_last_stack_state().as_int_vec();
217        let expected = resize_to_min_stack_depth(final_stack);
218        assert_eq!(expected, result, "Expected stack to be {:?}, found {:?}", expected, result);
219    }
220
221    /// Executes the test and validates that the process memory has the elements of `expected_mem`
222    /// at address `mem_start_addr` and that the end of the stack execution trace matches the
223    /// `final_stack`.
224    #[track_caller]
225    pub fn expect_stack_and_memory(
226        &self,
227        final_stack: &[u64],
228        mem_start_addr: u32,
229        expected_mem: &[u64],
230    ) {
231        // compile the program
232        let (program, kernel) = self.compile().expect("Failed to compile test source.");
233        let mut host = TestHost::new(MemAdviceProvider::from(self.advice_inputs.clone()));
234        if let Some(kernel) = kernel {
235            host.load_mast_forest(kernel.mast_forest().clone()).unwrap();
236        }
237        for library in &self.libraries {
238            host.load_mast_forest(library.mast_forest().clone()).unwrap();
239        }
240
241        // execute the test
242        let mut process = Process::new(
243            program.kernel().clone(),
244            self.stack_inputs.clone(),
245            ExecutionOptions::default().with_debugging(self.in_debug_mode),
246        )
247        .with_source_manager(self.source_manager.clone());
248        process.execute(&program, &mut host).unwrap();
249
250        // validate the memory state
251        for (addr, mem_value) in
252            (range(mem_start_addr as usize, expected_mem.len())).zip(expected_mem.iter())
253        {
254            let mem_state = process
255                .chiplets
256                .memory
257                .get_value(ContextId::root(), addr as u32)
258                .unwrap_or(ZERO);
259            assert_eq!(
260                *mem_value,
261                mem_state.as_int(),
262                "Expected memory [{}] => {:?}, found {:?}",
263                addr,
264                mem_value,
265                mem_state
266            );
267        }
268
269        // validate the stack states
270        self.expect_stack(final_stack);
271    }
272
273    /// Asserts that executing the test inside a proptest results in the expected final stack state.
274    /// The proptest will return a test failure instead of panicking if the assertion condition
275    /// fails.
276    #[cfg(not(target_family = "wasm"))]
277    pub fn prop_expect_stack(
278        &self,
279        final_stack: &[u64],
280    ) -> Result<(), proptest::prelude::TestCaseError> {
281        let result = self.get_last_stack_state().as_int_vec();
282        proptest::prop_assert_eq!(resize_to_min_stack_depth(final_stack), result);
283
284        Ok(())
285    }
286
287    // UTILITY METHODS
288    // --------------------------------------------------------------------------------------------
289
290    /// Compiles a test's source and returns the resulting Program together with the associated
291    /// kernel library (when specified).
292    ///
293    /// # Errors
294    /// Returns an error if compilation of the program source or the kernel fails.
295    pub fn compile(&self) -> Result<(Program, Option<KernelLibrary>), Report> {
296        use assembly::{Assembler, CompileOptions, ast::ModuleKind};
297
298        let (assembler, kernel_lib) = if let Some(kernel) = self.kernel_source.clone() {
299            let kernel_lib =
300                Assembler::new(self.source_manager.clone()).assemble_kernel(kernel).unwrap();
301
302            (
303                Assembler::with_kernel(self.source_manager.clone(), kernel_lib.clone()),
304                Some(kernel_lib),
305            )
306        } else {
307            (Assembler::new(self.source_manager.clone()), None)
308        };
309
310        let mut assembler = self
311            .add_modules
312            .iter()
313            .fold(assembler, |assembler, (path, source)| {
314                assembler
315                    .with_module_and_options(
316                        source,
317                        CompileOptions::new(ModuleKind::Library, path.clone()).unwrap(),
318                    )
319                    .expect("invalid masm source code")
320            })
321            .with_debug_mode(self.in_debug_mode);
322        for library in &self.libraries {
323            assembler.add_library(library).unwrap();
324        }
325
326        Ok((assembler.assemble_program(self.source.clone())?, kernel_lib))
327    }
328
329    /// Compiles the test's source to a Program and executes it with the tests inputs. Returns a
330    /// resulting execution trace or error.
331    ///
332    /// Internally, this also checks that the slow and fast processors agree on the stack
333    /// outputs.
334    #[track_caller]
335    pub fn execute(&self) -> Result<ExecutionTrace, ExecutionError> {
336        let (program, mut host) = self.get_program_and_host();
337
338        // slow processor
339        let mut process = Process::new(
340            program.kernel().clone(),
341            self.stack_inputs.clone(),
342            ExecutionOptions::default().with_debugging(self.in_debug_mode),
343        )
344        .with_source_manager(self.source_manager.clone());
345        let slow_stack_outputs = process.execute(&program, &mut host)?;
346
347        let trace = ExecutionTrace::new(process, slow_stack_outputs.clone());
348        assert_eq!(&program.hash(), trace.program_hash(), "inconsistent program hash");
349
350        // compare fast and slow processors' stack outputs
351        self.assert_outputs_with_fast_processor(slow_stack_outputs);
352
353        Ok(trace)
354    }
355
356    /// Compiles the test's source to a Program and executes it with the tests inputs. Returns the
357    /// process once execution is finished.
358    pub fn execute_process(&self) -> Result<(Process, TestHost), ExecutionError> {
359        let (program, mut host) = self.get_program_and_host();
360
361        let mut process = Process::new(
362            program.kernel().clone(),
363            self.stack_inputs.clone(),
364            ExecutionOptions::default().with_debugging(self.in_debug_mode),
365        )
366        .with_source_manager(self.source_manager.clone());
367
368        let stack_outputs = process.execute(&program, &mut host)?;
369        self.assert_outputs_with_fast_processor(stack_outputs);
370
371        Ok((process, host))
372    }
373
374    /// Compiles the test's code into a program, then generates and verifies a proof of execution
375    /// using the given public inputs and the specified number of stack outputs. When `test_fail`
376    /// is true, this function will force a failure by modifying the first output.
377    pub fn prove_and_verify(&self, pub_inputs: Vec<u64>, test_fail: bool) {
378        let (program, mut host) = self.get_program_and_host();
379        let stack_inputs = StackInputs::try_from_ints(pub_inputs).unwrap();
380        let (mut stack_outputs, proof) = prover::prove(
381            &program,
382            stack_inputs.clone(),
383            &mut host,
384            ProvingOptions::default(),
385            self.source_manager.clone(),
386        )
387        .unwrap();
388
389        self.assert_outputs_with_fast_processor(stack_outputs.clone());
390
391        let program_info = ProgramInfo::from(program);
392        if test_fail {
393            stack_outputs.stack_mut()[0] += ONE;
394            assert!(verifier::verify(program_info, stack_inputs, stack_outputs, proof).is_err());
395        } else {
396            let result = verifier::verify(program_info, stack_inputs, stack_outputs, proof);
397            assert!(result.is_ok(), "error: {result:?}");
398        }
399    }
400
401    /// Compiles the test's source to a Program and executes it with the tests inputs. Returns a
402    /// VmStateIterator that allows us to iterate through each clock cycle and inspect the process
403    /// state.
404    pub fn execute_iter(&self) -> VmStateIterator {
405        let (program, mut host) = self.get_program_and_host();
406
407        let mut process = Process::new(
408            program.kernel().clone(),
409            self.stack_inputs.clone(),
410            ExecutionOptions::default().with_debugging(self.in_debug_mode),
411        )
412        .with_source_manager(self.source_manager.clone());
413        let result = process.execute(&program, &mut host);
414
415        if let Ok(stack_outputs) = &result {
416            assert_eq!(
417                program.hash(),
418                process.decoder.program_hash().into(),
419                "inconsistent program hash"
420            );
421            self.assert_outputs_with_fast_processor(stack_outputs.clone());
422        }
423        VmStateIterator::new(process, result)
424    }
425
426    /// Returns the last state of the stack after executing a test.
427    #[track_caller]
428    pub fn get_last_stack_state(&self) -> StackOutputs {
429        let trace = self.execute().unwrap();
430
431        trace.last_stack_state()
432    }
433
434    // HELPERS
435    // ------------------------------------------------------------------------------------------
436
437    /// Returns the program and host for the test.
438    ///
439    /// The host is initialized with the advice inputs provided in the test, as well as the kernel
440    /// and library MAST forests.
441    fn get_program_and_host(&self) -> (Program, TestHost) {
442        let (program, kernel) = self.compile().expect("Failed to compile test source.");
443        let mut host = TestHost::new(MemAdviceProvider::from(self.advice_inputs.clone()));
444        if let Some(kernel) = kernel {
445            host.load_mast_forest(kernel.mast_forest().clone()).unwrap();
446        }
447        for library in &self.libraries {
448            host.load_mast_forest(library.mast_forest().clone()).unwrap();
449        }
450
451        (program, host)
452    }
453
454    /// Runs the program on the fast processor, and asserts that the stack outputs match the slow
455    /// processor's stack outputs.
456    fn assert_outputs_with_fast_processor(&self, slow_stack_outputs: StackOutputs) {
457        let (program, mut host) = self.get_program_and_host();
458        let stack_inputs: Vec<Felt> = self.stack_inputs.clone().into_iter().rev().collect();
459        let fast_process = FastProcessor::new(&stack_inputs);
460        let fast_stack_outputs = fast_process.execute(&program, &mut host).unwrap();
461
462        assert_eq!(
463            slow_stack_outputs, fast_stack_outputs,
464            "stack outputs do not match between slow and fast processors"
465        );
466    }
467}
468
469// HELPER FUNCTIONS
470// ================================================================================================
471
472/// Converts a slice of Felts into a vector of u64 values.
473pub fn felt_slice_to_ints(values: &[Felt]) -> Vec<u64> {
474    values.iter().map(|e| (*e).as_int()).collect()
475}
476
477pub fn resize_to_min_stack_depth(values: &[u64]) -> Vec<u64> {
478    let mut result: Vec<u64> = values.to_vec();
479    result.resize(MIN_STACK_DEPTH, 0);
480    result
481}
482
483/// A proptest strategy for generating a random word with 4 values of type T.
484#[cfg(not(target_family = "wasm"))]
485pub fn prop_randw<T: Arbitrary>() -> impl Strategy<Value = Vec<T>> {
486    use proptest::prelude::{any, prop};
487    prop::collection::vec(any::<T>(), 4)
488}
489
490/// Given a hasher state, perform one permutation.
491///
492/// The values of `values` should be:
493/// - 0..4 the capacity
494/// - 4..12 the rate
495///
496/// Return the result of the permutation in stack order.
497pub fn build_expected_perm(values: &[u64]) -> [Felt; STATE_WIDTH] {
498    let mut expected = [ZERO; STATE_WIDTH];
499    for (&value, result) in values.iter().zip(expected.iter_mut()) {
500        *result = Felt::new(value);
501    }
502    apply_permutation(&mut expected);
503    expected.reverse();
504
505    expected
506}
507
508pub fn build_expected_hash(values: &[u64]) -> [Felt; 4] {
509    let digest = hash_elements(&values.iter().map(|&v| Felt::new(v)).collect::<Vec<_>>());
510    let mut expected: [Felt; 4] = digest.into();
511    expected.reverse();
512
513    expected
514}
515
516// Generates the MASM code which pushes the input values during the execution of the program.
517#[cfg(all(feature = "std", not(target_family = "wasm")))]
518pub fn push_inputs(inputs: &[u64]) -> String {
519    let mut result = String::new();
520
521    inputs.iter().for_each(|v| result.push_str(&format!("push.{v}\n")));
522    result
523}