1use std::collections::HashMap;
2use std::fmt::Display;
3
4use itertools::Itertools;
5use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
6use triton_vm::prelude::*;
7
8use crate::InitVmState;
9use crate::RustShadowOutputState;
10use crate::dyn_malloc::DYN_MALLOC_ADDRESS;
11use crate::execute_test;
12use crate::execute_with_terminal_state;
13use crate::prelude::Tip5;
14use crate::traits::basic_snippet::SignedOffSnippet;
15use crate::traits::rust_shadow::RustShadow;
16use crate::traits::rust_shadow::RustShadowError;
17
18pub fn rust_final_state<T: RustShadow>(
19 shadowed_snippet: &T,
20 stack: &[BFieldElement],
21 stdin: &[BFieldElement],
22 nondeterminism: &NonDeterminism,
23 sponge: &Option<Tip5>,
24) -> RustShadowOutputState {
25 let mut rust_memory = nondeterminism.ram.clone();
26 let mut rust_stack = stack.to_vec();
27 let mut rust_sponge = sponge.clone();
28
29 let output = shadowed_snippet
31 .rust_shadow_wrapper(
32 stdin,
33 nondeterminism,
34 &mut rust_stack,
35 &mut rust_memory,
36 &mut rust_sponge,
37 )
38 .unwrap();
39
40 RustShadowOutputState {
41 public_output: output,
42 stack: rust_stack,
43 ram: rust_memory,
44 sponge: rust_sponge,
45 }
46}
47
48pub fn tasm_final_state<T: RustShadow>(
49 shadowed_snippet: &T,
50 stack: &[BFieldElement],
51 stdin: &[BFieldElement],
52 nondeterminism: NonDeterminism,
53 sponge: &Option<Tip5>,
54) -> Result<VMState, RustShadowError> {
55 link_and_run_tasm_for_test(
57 shadowed_snippet,
58 &mut stack.to_vec(),
59 stdin.to_vec(),
60 nondeterminism,
61 sponge.to_owned(),
62 )
63}
64
65pub fn verify_stack_equivalence(
67 stack_a_name: &str,
68 stack_a: &[BFieldElement],
69 stack_b_name: &str,
70 stack_b: &[BFieldElement],
71) {
72 let stack_a_name = format!("{stack_a_name}:");
73 let stack_b_name = format!("{stack_b_name}:");
74 let max_stack_name_len = stack_a_name.len().max(stack_b_name.len());
75
76 let stack_a = &stack_a[Digest::LEN..];
77 let stack_b = &stack_b[Digest::LEN..];
78 let display = |stack: &[BFieldElement]| stack.iter().join(",");
79 assert_eq!(
80 stack_a,
81 stack_b,
82 "{stack_a_name} stack must match {stack_b_name} stack\n\n\
83 {stack_a_name:<max_stack_name_len$} {}\n\n\
84 {stack_b_name:<max_stack_name_len$} {}",
85 display(stack_a),
86 display(stack_b),
87 );
88}
89
90pub(crate) fn verify_memory_equivalence(
92 a_name: &str,
93 a_memory: &HashMap<BFieldElement, BFieldElement>,
94 b_name: &str,
95 b_memory: &HashMap<BFieldElement, BFieldElement>,
96) {
97 let memory_without_dyn_malloc = |mem: HashMap<_, _>| -> HashMap<_, _> {
98 mem.into_iter()
99 .filter(|&(k, _)| k != DYN_MALLOC_ADDRESS)
100 .collect()
101 };
102 let a_memory = memory_without_dyn_malloc(a_memory.clone());
103 let b_memory = memory_without_dyn_malloc(b_memory.clone());
104 if a_memory == b_memory {
105 return;
106 }
107
108 fn format_hash_map_iterator<K, V>(map: impl Iterator<Item = (K, V)>) -> String
109 where
110 u64: From<K>,
111 K: Copy + Display,
112 V: Display,
113 {
114 map.sorted_by_key(|(k, _)| u64::from(*k))
115 .map(|(k, v)| format!("({k} => {v})"))
116 .join(", ")
117 }
118
119 let in_a_and_different_in_b = a_memory
120 .iter()
121 .filter(|&(k, v)| b_memory.get(k).map(|b| b != v).unwrap_or(true));
122 let in_b_and_different_in_a = b_memory
123 .iter()
124 .filter(|&(k, v)| a_memory.get(k).map(|b| b != v).unwrap_or(true));
125
126 let in_a_and_different_in_b = format_hash_map_iterator(in_a_and_different_in_b);
127 let in_b_and_different_in_a = format_hash_map_iterator(in_b_and_different_in_a);
128
129 panic!(
130 "Memory for both implementations must match after execution.\n\n\
131 In {b_name}, different in {a_name}: {in_b_and_different_in_a}\n\n\
132 In {a_name}, different in {b_name}: {in_a_and_different_in_b}"
133 );
134}
135
136pub fn verify_stack_growth<T: RustShadow>(
137 shadowed_snippet: &T,
138 initial_stack: &[BFieldElement],
139 final_stack: &[BFieldElement],
140) {
141 let observed_stack_growth: isize = final_stack.len() as isize - initial_stack.len() as isize;
142 let expected_stack_growth: isize = shadowed_snippet.inner().stack_diff();
143 assert_eq!(
144 expected_stack_growth,
145 observed_stack_growth,
146 "Stack must pop and push expected number of elements. Got input: {}\nGot output: {}",
147 initial_stack.iter().map(|x| x.to_string()).join(","),
148 final_stack.iter().map(|x| x.to_string()).join(",")
149 );
150}
151
152pub fn verify_sponge_equivalence(a: &Option<Tip5>, b: &Option<Tip5>) {
153 match (a, b) {
154 (Some(state_a), Some(state_b)) => assert_eq!(state_a.state, state_b.state),
155 (None, None) => (),
156 _ => panic!("{a:?} != {b:?}"),
157 };
158}
159
160pub fn test_rust_equivalence_given_complete_state<T: RustShadow>(
161 shadowed_snippet: &T,
162 stack: &[BFieldElement],
163 stdin: &[BFieldElement],
164 nondeterminism: &NonDeterminism,
165 sponge: &Option<Tip5>,
166 expected_final_stack: Option<&[BFieldElement]>,
167) -> VMState {
168 shadowed_snippet
169 .inner()
170 .assert_all_sign_offs_are_up_to_date();
171
172 let init_stack = stack.to_vec();
173
174 let rust = rust_final_state(shadowed_snippet, stack, stdin, nondeterminism, sponge);
175
176 let tasm = tasm_final_state(
178 shadowed_snippet,
179 stack,
180 stdin,
181 nondeterminism.clone(),
182 sponge,
183 )
184 .unwrap();
185
186 assert_eq!(
187 rust.public_output, tasm.public_output,
188 "Rust shadowing and VM std out must agree"
189 );
190
191 verify_stack_equivalence(
192 "rust-shadow final stack",
193 &rust.stack,
194 "TASM final stack",
195 &tasm.op_stack.stack,
196 );
197 if let Some(expected) = expected_final_stack {
198 verify_stack_equivalence("expected", expected, "actual", &rust.stack);
199 }
200 verify_memory_equivalence("Rust-shadow", &rust.ram, "TVM", &tasm.ram);
201 verify_stack_growth(shadowed_snippet, &init_stack, &tasm.op_stack.stack);
202
203 tasm
204}
205
206pub fn link_and_run_tasm_for_test<T: RustShadow>(
207 snippet_struct: &T,
208 stack: &mut Vec<BFieldElement>,
209 std_in: Vec<BFieldElement>,
210 nondeterminism: NonDeterminism,
211 maybe_sponge: Option<Tip5>,
212) -> Result<VMState, RustShadowError> {
213 let code = snippet_struct.inner().link_for_isolated_run();
214
215 execute_test(
216 &code,
217 stack,
218 snippet_struct.inner().stack_diff(),
219 std_in,
220 nondeterminism,
221 maybe_sponge,
222 )
223}
224
225pub fn test_rust_equivalence_given_execution_state<S: RustShadow>(
226 snippet_struct: &S,
227 execution_state: InitVmState,
228) -> VMState {
229 test_rust_equivalence_given_complete_state::<S>(
230 snippet_struct,
231 &execution_state.stack,
232 &execution_state.public_input,
233 &execution_state.nondeterminism,
234 &execution_state.sponge,
235 None,
236 )
237}
238
239pub fn negative_test<T: RustShadow>(
240 snippet: &T,
241 initial_state: InitVmState,
242 allowed_errors: &[InstructionError],
243) {
244 let err = instruction_error_from_failing_code(snippet, initial_state);
245 assert!(
246 allowed_errors.contains(&err),
247 "Triton VM execution must fail with one of the expected errors:\n- {}\n\n Got:\n{err}",
248 allowed_errors.iter().join("\n- ")
249 );
250}
251
252pub fn test_assertion_failure<S: RustShadow>(
253 snippet_struct: &S,
254 initial_state: InitVmState,
255 expected_error_ids: &[i128],
256) {
257 let err = instruction_error_from_failing_code(snippet_struct, initial_state);
258 let maybe_error_id = match err {
259 InstructionError::AssertionFailed(err)
260 | InstructionError::VectorAssertionFailed(_, err) => err.id,
261 _ => panic!("Triton VM execution failed, but not due to an assertion. Instead, got: {err}"),
262 };
263 let error_id = maybe_error_id.expect(
264 "Triton VM execution failed due to unfulfilled assertion, but that assertion has no \
265 error ID. See `tasm-lib/src/assertion_error_ids.md` to grab a unique ID.",
266 );
267 let expected_error_ids_str = expected_error_ids.iter().join(", ");
268 assert!(
269 expected_error_ids.contains(&error_id),
270 "error ID {error_id} ∉ {{{expected_error_ids_str}}}\nTriton VM execution failed due to \
271 unfulfilled assertion with error ID {error_id}, but expected one of the following IDs: \
272 {{{expected_error_ids_str}}}",
273 );
274}
275
276fn instruction_error_from_failing_code<S: RustShadow>(
277 snippet: &S,
278 init_state: InitVmState,
279) -> InstructionError {
280 let mut rust_stack = init_state.stack.clone();
281 let mut rust_memory = init_state.nondeterminism.ram.clone();
282 let mut rust_sponge = init_state.sponge.clone();
283 let rust_result = snippet.rust_shadow_wrapper(
284 &init_state.public_input,
285 &init_state.nondeterminism,
286 &mut rust_stack,
287 &mut rust_memory,
288 &mut rust_sponge,
289 );
290 rust_result.expect_err("Failed to fail: Rust-shadowing must panic in negative test case");
291
292 let code = snippet.inner().link_for_isolated_run();
293 let tvm_result = execute_with_terminal_state(
294 Program::new(&code),
295 &init_state.public_input,
296 &init_state.stack,
297 &init_state.nondeterminism,
298 init_state.sponge,
299 );
300 tvm_result.expect_err("Failed to fail: Triton VM execution must crash in negative test case")
301}
302
303pub fn prepend_program_with_stack_setup(
304 init_stack: &[BFieldElement],
305 program: &Program,
306) -> Program {
307 let stack_initialization_code = init_stack
308 .iter()
309 .skip(NUM_OP_STACK_REGISTERS)
310 .map(|&word| triton_instr!(push word))
311 .collect_vec();
312
313 Program::new(&[stack_initialization_code, program.labelled_instructions()].concat())
314}
315
316pub fn prepend_program_with_sponge_init(program: &Program) -> Program {
317 Program::new(&[triton_asm!(sponge_init), program.labelled_instructions()].concat())
318}
319
320pub fn maybe_write_tvm_output_to_disk(stark: &Stark, claim: &Claim, proof: &Proof) {
323 use std::io::Write;
324 let Ok(_) = std::env::var("TASMLIB_STORE") else {
325 return;
326 };
327
328 let mut stark_file = std::fs::File::create("stark.json").unwrap();
329 let state = serde_json::to_string(stark).unwrap();
330 write!(stark_file, "{state}").unwrap();
331 let mut claim_file = std::fs::File::create("claim.json").unwrap();
332 let claim = serde_json::to_string(claim).unwrap();
333 write!(claim_file, "{claim}").unwrap();
334 let mut proof_file = std::fs::File::create("proof.json").unwrap();
335 let proof = serde_json::to_string(proof).unwrap();
336 write!(proof_file, "{proof}").unwrap();
337}