1#![recursion_limit = "2048"]
4
5extern crate self as tasm_lib;
14
15use std::collections::HashMap;
16use std::io::Write;
17
18use itertools::Itertools;
19use memory::dyn_malloc;
20use num_traits::Zero;
21use triton_vm::isa::op_stack::NUM_OP_STACK_REGISTERS;
22use triton_vm::prelude::*;
23use web_time::SystemTime;
24
25pub mod arithmetic;
26pub mod array;
27pub mod data_type;
28pub mod exported_snippets;
29pub mod hashing;
30pub mod io;
31pub mod library;
32pub mod linker;
33pub mod list;
34pub mod memory;
35pub mod mmr;
36pub mod neptune;
37pub mod prelude;
38pub mod rust_shadowing_helper_functions;
39pub mod snippet_bencher;
40pub mod structure;
41pub mod test_helpers;
42pub mod traits;
43pub mod verifier;
44
45pub use triton_vm;
47use triton_vm::isa::instruction::AnInstruction;
48use triton_vm::prelude::TableId;
49pub use triton_vm::twenty_first;
50
51use crate::test_helpers::prepend_program_with_stack_setup;
52use crate::traits::rust_shadow::RustShadowError;
53
54pub(crate) const U32_TO_USIZE_ERR: &str =
55 "internal error: type `usize` should have at least 32 bits";
56pub(crate) const USIZE_TO_U64_ERR: &str =
57 "internal error: type `usize` should have at most 64 bits";
58
59#[derive(Clone, Debug, Default)]
60pub struct InitVmState {
61 pub stack: Vec<BFieldElement>,
62 pub public_input: Vec<BFieldElement>,
63 pub nondeterminism: NonDeterminism,
64 pub sponge: Option<Tip5>,
65}
66
67impl InitVmState {
68 pub fn with_stack(stack: Vec<BFieldElement>) -> Self {
69 InitVmState {
70 stack,
71 public_input: vec![],
72 nondeterminism: NonDeterminism::default(),
73 sponge: None,
74 }
75 }
76
77 pub fn with_stack_and_memory(
78 stack: Vec<BFieldElement>,
79 memory: HashMap<BFieldElement, BFieldElement>,
80 ) -> Self {
81 InitVmState {
82 stack,
83 public_input: vec![],
84 nondeterminism: NonDeterminism::default().with_ram(memory),
85 sponge: None,
86 }
87 }
88}
89
90#[derive(Clone, Debug)]
91pub struct RustShadowOutputState {
92 pub public_output: Vec<BFieldElement>,
93 pub stack: Vec<BFieldElement>,
94 pub ram: HashMap<BFieldElement, BFieldElement>,
95 pub sponge: Option<Tip5>,
96}
97
98pub fn empty_stack() -> Vec<BFieldElement> {
99 vec![BFieldElement::zero(); NUM_OP_STACK_REGISTERS]
100}
101
102pub fn push_encodable<T: BFieldCodec>(stack: &mut Vec<BFieldElement>, value: &T) {
103 stack.extend(value.encode().into_iter().rev());
104}
105
106pub fn pop_encodable<T: BFieldCodec>(stack: &mut Vec<BFieldElement>) -> Result<T, RustShadowError> {
116 let Some(len) = T::static_length() else {
117 return Err(RustShadowError::Other);
118 };
119 let limbs: Vec<_> = (0..len)
120 .map(|_| stack.pop().ok_or(RustShadowError::StackUnderflow))
121 .try_collect()?;
122
123 T::decode(&limbs)
124 .map(|t| *t)
125 .map_err(|_| RustShadowError::DecodingError)
126}
127
128pub(crate) fn execute_test(
131 code: &[LabelledInstruction],
132 stack: &mut Vec<BFieldElement>,
133 expected_stack_diff: isize,
134 std_in: Vec<BFieldElement>,
135 nondeterminism: NonDeterminism,
136 maybe_sponge: Option<Tip5>,
137) -> Result<VMState, RustShadowError> {
138 let init_stack = stack.to_owned();
139 let public_input = PublicInput::new(std_in.clone());
140 let program = Program::new(code);
141
142 let mut vm_state = VMState::new(
143 program.clone(),
144 public_input.clone(),
145 nondeterminism.clone(),
146 );
147 vm_state.op_stack.stack.clone_from(&init_stack);
148 vm_state.sponge = maybe_sponge;
149
150 maybe_write_debuggable_vm_state_to_disk(&vm_state);
151
152 if let Err(err) = vm_state.run() {
153 eprintln!("{err}\n\nFinal state was: {vm_state}");
154 return Err(RustShadowError::VmError);
155 }
156 let terminal_state = vm_state;
157
158 if !terminal_state.jump_stack.is_empty() {
159 eprintln!("Jump stack must be unchanged after code execution");
160 return Err(RustShadowError::VmError);
161 }
162
163 let final_stack_height = terminal_state.op_stack.stack.len() as isize;
164 let initial_stack_height = init_stack.len();
165 if expected_stack_diff != final_stack_height - initial_stack_height as isize {
166 eprintln!(
167 "Code must grow/shrink stack with expected number of elements.\n
168 init height: {initial_stack_height}\n
169 end height: {final_stack_height}\n
170 expected difference: {expected_stack_diff}\n\n
171 initial stack: {}\n
172 final stack: {}",
173 init_stack.iter().skip(NUM_OP_STACK_REGISTERS).join(","),
174 terminal_state
175 .op_stack
176 .stack
177 .iter()
178 .skip(NUM_OP_STACK_REGISTERS)
179 .join(","),
180 );
181 return Err(RustShadowError::VmError);
182 }
183
184 if std::env::var("DYING_TO_PROVE").is_ok() {
190 prove_and_verify(program, &std_in, &nondeterminism, Some(init_stack));
191 }
192
193 stack.clone_from(&terminal_state.op_stack.stack);
194 Ok(terminal_state)
195}
196
197pub fn maybe_write_debuggable_vm_state_to_disk(vm_state: &VMState) {
207 let Ok(_) = std::env::var("TASMLIB_TRITON_TUI") else {
208 return;
209 };
210
211 let mut state_file = std::fs::File::create("vm_state.json").unwrap();
212 let state = serde_json::to_string(&vm_state).unwrap();
213 write!(state_file, "{state}").unwrap();
214}
215
216pub(crate) fn execute_with_terminal_state(
218 program: Program,
219 std_in: &[BFieldElement],
220 stack: &[BFieldElement],
221 nondeterminism: &NonDeterminism,
222 maybe_sponge: Option<Tip5>,
223) -> Result<VMState, InstructionError> {
224 let public_input = PublicInput::new(std_in.into());
225 let mut vm_state = VMState::new(program, public_input, nondeterminism.to_owned());
226 stack.clone_into(&mut vm_state.op_stack.stack);
227 vm_state.sponge = maybe_sponge;
228
229 maybe_write_debuggable_vm_state_to_disk(&vm_state);
230 match vm_state.run() {
231 Ok(()) => {
232 println!("Triton VM execution successful.");
233 Ok(vm_state)
234 }
235 Err(err) => {
236 if let Some(ref sponge) = vm_state.sponge {
237 println!("tasm final sponge:");
238 println!("{}", sponge.state.iter().join(", "));
239 }
240 println!("Triton VM execution failed. Final state:\n{vm_state}");
241 Err(err)
242 }
243 }
244}
245
246pub fn prove_and_verify(
254 program: Program,
255 std_in: &[BFieldElement],
256 nondeterminism: &NonDeterminism,
257 init_stack: Option<Vec<BFieldElement>>,
258) {
259 let labelled_instructions = program.labelled_instructions();
260 let timing_report_label = match labelled_instructions.first().unwrap() {
261 LabelledInstruction::Instruction(AnInstruction::Call(func)) => func,
262 LabelledInstruction::Label(label) => label,
263 _ => "Some program",
264 };
265
266 let program = match &init_stack {
269 Some(init_stack) => prepend_program_with_stack_setup(init_stack, &program),
270 None => program,
271 };
272
273 let claim = Claim::about_program(&program).with_input(std_in.to_owned());
274 let (aet, public_output) = VM::trace_execution(
275 program.clone(),
276 PublicInput::new(std_in.to_owned()),
277 nondeterminism.clone(),
278 )
279 .unwrap();
280 let claim = claim.with_output(public_output);
281
282 let stark = Stark::default();
283 let tick = SystemTime::now();
284 triton_vm::profiler::start(timing_report_label);
285 let proof = stark.prove(&claim, &aet).unwrap();
286 let profile = triton_vm::profiler::finish();
287 let measured_time = tick.elapsed().expect("Don't mess with time");
288
289 let padded_height = proof.padded_height().unwrap();
290 let fri = stark.fri(padded_height).unwrap();
291 let report = profile
292 .with_cycle_count(aet.processor_trace.nrows())
293 .with_padded_height(padded_height)
294 .with_fri_domain_len(fri.domain.len());
295 println!("{report}");
296
297 println!("Done proving. Elapsed time: {measured_time:?}");
298 println!(
299 "Proof was generated from:
300 table lengths:
301 processor table: {}
302 hash table: {}
303 u32 table: {}
304 op-stack table: {}
305 RAM table: {}
306 Program table: {}
307 Cascade table: {}
308 Lookup table: {}",
309 aet.height_of_table(TableId::Processor),
310 aet.height_of_table(TableId::Hash),
311 aet.height_of_table(TableId::U32),
312 aet.height_of_table(TableId::OpStack),
313 aet.height_of_table(TableId::Ram),
314 aet.height_of_table(TableId::Program),
315 aet.height_of_table(TableId::Cascade),
316 aet.height_of_table(TableId::Lookup),
317 );
318
319 assert!(
320 triton_vm::verify(stark, &claim, &proof),
321 "Generated proof must verify for program:\n{program}",
322 );
323}
324
325pub fn generate_full_profile(
327 name: &str,
328 program: Program,
329 public_input: &PublicInput,
330 nondeterminism: &NonDeterminism,
331) -> String {
332 let (_output, profile) =
333 VM::profile(program, public_input.clone(), nondeterminism.clone()).unwrap();
334 format!("{name}:\n{profile}")
335}
336
337#[cfg(test)]
341pub mod test_prelude {
342 macro_rules! test {
357 ($item:item) => {
358 #[test]
359 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
360 $item
361 };
362 }
363 pub(crate) use test;
364
365 macro_rules! proptest {
392 ($item:item $(($($config:tt)*))?) => {
393 #[test_strategy::proptest $(($($config)*))?]
394 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
395 $item
396 };
397 }
398 pub(crate) use proptest;
399
400 pub use std::collections::HashMap;
401
402 pub use itertools::Itertools;
403 pub use proptest::prelude::Just;
404 pub use proptest::prelude::Strategy;
405 pub use proptest::prelude::TestCaseError;
406 pub use proptest::prelude::any;
407 pub use proptest::prelude::prop_assert;
408 pub use proptest::prelude::prop_assert_eq;
409 pub use proptest::prelude::prop_assume;
410 pub use proptest::test_runner::TestCaseResult;
411 pub use proptest_arbitrary_adapter::arb;
412 pub use rand::Rng;
413 pub use rand::RngCore;
414 pub use rand::SeedableRng;
415 pub use rand::prelude::IndexedMutRandom;
416 pub use rand::prelude::IndexedRandom;
417 pub use rand::prelude::IteratorRandom;
418 pub use rand::rngs::StdRng;
419 pub use test_strategy::Arbitrary;
420
421 pub use crate::InitVmState;
422 pub use crate::memory::encode_to_memory;
423 pub(crate) use crate::pop_encodable;
424 pub use crate::push_encodable;
425 pub use crate::snippet_bencher::BenchmarkCase;
426 pub use crate::test_helpers::test_assertion_failure;
427 pub use crate::test_helpers::test_rust_equivalence_given_complete_state;
428 pub use crate::traits::accessor::Accessor;
429 pub use crate::traits::accessor::AccessorInitialState;
430 pub use crate::traits::accessor::ShadowedAccessor;
431 pub use crate::traits::algorithm::Algorithm;
432 pub use crate::traits::algorithm::AlgorithmInitialState;
433 pub use crate::traits::algorithm::ShadowedAlgorithm;
434 pub use crate::traits::closure::Closure;
435 pub use crate::traits::closure::ShadowedClosure;
436 pub use crate::traits::function::Function;
437 pub use crate::traits::function::FunctionInitialState;
438 pub use crate::traits::function::ShadowedFunction;
439 pub use crate::traits::mem_preserver::MemPreserver;
440 pub use crate::traits::mem_preserver::MemPreserverInitialState;
441 pub use crate::traits::mem_preserver::ShadowedMemPreserver;
442 pub use crate::traits::procedure::Procedure;
443 pub use crate::traits::procedure::ProcedureInitialState;
444 pub use crate::traits::procedure::ShadowedProcedure;
445 pub use crate::traits::read_only_algorithm::ReadOnlyAlgorithm;
446 pub use crate::traits::read_only_algorithm::ReadOnlyAlgorithmInitialState;
447 pub use crate::traits::read_only_algorithm::ShadowedReadOnlyAlgorithm;
448 pub use crate::traits::rust_shadow::RustShadow;
449 pub use crate::traits::rust_shadow::RustShadowError;
450}