triton_vm/
aet.rs

1use std::ops::AddAssign;
2
3use air::AIR;
4use air::table::TableId;
5use air::table::hash::HashTable;
6use air::table::hash::PermutationTrace;
7use air::table::op_stack;
8use air::table::processor;
9use air::table::ram;
10use air::table_column::HashMainColumn::CI;
11use air::table_column::MasterMainColumn;
12use arbitrary::Arbitrary;
13use indexmap::IndexMap;
14use indexmap::map::Entry::Occupied;
15use indexmap::map::Entry::Vacant;
16use isa::error::InstructionError;
17use isa::error::InstructionError::InstructionPointerOverflow;
18use isa::instruction::Instruction;
19use isa::program::Program;
20use itertools::Itertools;
21use ndarray::Array2;
22use ndarray::s;
23use strum::EnumCount;
24use strum::IntoEnumIterator;
25use twenty_first::prelude::*;
26
27use crate::ndarray_helper::ROW_AXIS;
28use crate::table;
29use crate::table::op_stack::OpStackTableEntry;
30use crate::table::ram::RamTableCall;
31use crate::table::u32::U32TableEntry;
32use crate::vm::CoProcessorCall;
33use crate::vm::VMState;
34
35/// An Algebraic Execution Trace (AET) is the primary witness required for proof
36/// generation. It holds every intermediate state of the processor and all
37/// co-processors, alongside additional witness information, such as the number
38/// of times each instruction has been looked up (equivalently, how often each
39/// instruction has been executed).
40#[derive(Debug, Clone)]
41pub struct AlgebraicExecutionTrace {
42    /// The program that was executed in order to generate the trace.
43    pub program: Program,
44
45    /// The number of times each instruction has been executed.
46    ///
47    /// Each instruction in the `program` has one associated entry in
48    /// `instruction_multiplicities`, counting the number of times this
49    /// specific instruction at that location in the program memory has been
50    /// executed.
51    pub instruction_multiplicities: Vec<u32>,
52
53    /// Records the state of the processor after each instruction.
54    pub processor_trace: Array2<BFieldElement>,
55
56    pub op_stack_underflow_trace: Array2<BFieldElement>,
57
58    pub ram_trace: Array2<BFieldElement>,
59
60    /// The trace of hashing the program whose execution generated this
61    /// `AlgebraicExecutionTrace`. The resulting digest
62    /// 1. ties a [`Proof`](crate::proof::Proof) to the program it was produced
63    ///    from, and
64    /// 1. is accessible to the program being executed.
65    pub program_hash_trace: Array2<BFieldElement>,
66
67    /// For the `hash` instruction, the hash trace records the internal state of
68    /// the Tip5 permutation for each round.
69    pub hash_trace: Array2<BFieldElement>,
70
71    /// For the Sponge instructions, i.e., `sponge_init`, `sponge_absorb`,
72    /// `sponge_absorb_mem`, and `sponge_squeeze`, the Sponge trace records the
73    /// internal state of the Tip5 permutation for each round.
74    pub sponge_trace: Array2<BFieldElement>,
75
76    /// The u32 entries hold all pairs of BFieldElements that were written to
77    /// the U32 Table, alongside the u32 instruction that was executed at
78    /// the time. Additionally, it records how often the instruction was
79    /// executed with these arguments.
80    //
81    // `IndexMap` over `HashMap` for deterministic iteration order. This is not
82    // needed for correctness of the STARK.
83    pub u32_entries: IndexMap<U32TableEntry, u64>,
84
85    /// Records how often each entry in the cascade table was looked up.
86    //
87    // `IndexMap` over `HashMap` for the same reasons as for field `u32_entries`
88    pub cascade_table_lookup_multiplicities: IndexMap<u16, u64>,
89
90    /// Records how often each entry in the lookup table was looked up.
91    pub lookup_table_lookup_multiplicities: [u64; AlgebraicExecutionTrace::LOOKUP_TABLE_HEIGHT],
92}
93
94#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
95pub struct TableHeight {
96    pub table: TableId,
97    pub height: usize,
98}
99
100impl AlgebraicExecutionTrace {
101    pub(crate) const LOOKUP_TABLE_HEIGHT: usize = 1 << 8;
102
103    pub fn new(program: Program) -> Self {
104        const PROCESSOR_WIDTH: usize = <processor::ProcessorTable as AIR>::MainColumn::COUNT;
105        const OP_STACK_WIDTH: usize = <op_stack::OpStackTable as AIR>::MainColumn::COUNT;
106        const RAM_WIDTH: usize = <ram::RamTable as AIR>::MainColumn::COUNT;
107        const HASH_WIDTH: usize = <HashTable as AIR>::MainColumn::COUNT;
108
109        let program_len = program.len_bwords();
110
111        let mut aet = Self {
112            program,
113            instruction_multiplicities: vec![0_u32; program_len],
114            processor_trace: Array2::default([0, PROCESSOR_WIDTH]),
115            op_stack_underflow_trace: Array2::default([0, OP_STACK_WIDTH]),
116            ram_trace: Array2::default([0, RAM_WIDTH]),
117            program_hash_trace: Array2::default([0, HASH_WIDTH]),
118            hash_trace: Array2::default([0, HASH_WIDTH]),
119            sponge_trace: Array2::default([0, HASH_WIDTH]),
120            u32_entries: IndexMap::new(),
121            cascade_table_lookup_multiplicities: IndexMap::new(),
122            lookup_table_lookup_multiplicities: [0; Self::LOOKUP_TABLE_HEIGHT],
123        };
124        aet.fill_program_hash_trace();
125        aet
126    }
127
128    /// The height of the [AET](AlgebraicExecutionTrace) after [padding][pad].
129    ///
130    /// Guaranteed to be a power of two.
131    ///
132    /// [pad]: table::master_table::MasterMainTable::pad
133    pub fn padded_height(&self) -> usize {
134        self.height().height.next_power_of_two()
135    }
136
137    /// The height of the [AET](AlgebraicExecutionTrace) before [padding][pad].
138    /// Corresponds to the height of the longest table.
139    ///
140    /// [pad]: table::master_table::MasterMainTable::pad
141    pub fn height(&self) -> TableHeight {
142        TableId::iter()
143            .map(|t| TableHeight::new(t, self.height_of_table(t)))
144            .max()
145            .unwrap()
146    }
147
148    pub fn height_of_table(&self, table: TableId) -> usize {
149        let hash_table_height = || {
150            self.sponge_trace.nrows() + self.hash_trace.nrows() + self.program_hash_trace.nrows()
151        };
152
153        match table {
154            TableId::Program => Self::padded_program_length(&self.program),
155            TableId::Processor => self.processor_trace.nrows(),
156            TableId::OpStack => self.op_stack_underflow_trace.nrows(),
157            TableId::Ram => self.ram_trace.nrows(),
158            TableId::JumpStack => self.processor_trace.nrows(),
159            TableId::Hash => hash_table_height(),
160            TableId::Cascade => self.cascade_table_lookup_multiplicities.len(),
161            TableId::Lookup => Self::LOOKUP_TABLE_HEIGHT,
162            TableId::U32 => self.u32_table_height(),
163        }
164    }
165
166    /// # Panics
167    ///
168    /// - if the table height exceeds [`u32::MAX`]
169    /// - if the table height exceeds [`usize::MAX`]
170    fn u32_table_height(&self) -> usize {
171        let entry_len = U32TableEntry::table_height_contribution;
172        let height = self.u32_entries.keys().map(entry_len).sum::<u32>();
173        height.try_into().unwrap()
174    }
175
176    fn padded_program_length(program: &Program) -> usize {
177        // Padding is at least one 1.
178        // Also note that the Program Table's side of the instruction lookup
179        // argument requires at least one padding row to account for the
180        // processor's “next instruction or argument.” Both of these are
181        // captured by the “+ 1” in the following line.
182        (program.len_bwords() + 1).next_multiple_of(Tip5::RATE)
183    }
184
185    /// Hash the program and record the entire Sponge's trace for program
186    /// attestation.
187    fn fill_program_hash_trace(&mut self) {
188        let padded_program = Self::hash_input_pad_program(&self.program);
189        let mut program_sponge = Tip5::init();
190        for chunk_to_absorb in padded_program.chunks(Tip5::RATE) {
191            program_sponge.state[..Tip5::RATE]
192                .iter_mut()
193                .zip_eq(chunk_to_absorb)
194                .for_each(|(sponge_state_elem, &absorb_elem)| *sponge_state_elem = absorb_elem);
195            let hash_trace = program_sponge.trace();
196            let trace_addendum = table::hash::trace_to_table_rows(hash_trace);
197
198            self.increase_lookup_multiplicities(hash_trace);
199            self.program_hash_trace
200                .append(ROW_AXIS, trace_addendum.view())
201                .expect("shapes must be identical");
202        }
203
204        let instruction_column_index = CI.main_index();
205        let mut instruction_column = self.program_hash_trace.column_mut(instruction_column_index);
206        instruction_column.fill(Instruction::Hash.opcode_b());
207
208        // consistency check
209        let program_digest = program_sponge.state[..Digest::LEN].try_into().unwrap();
210        let program_digest = Digest::new(program_digest);
211        let expected_digest = self.program.hash();
212        assert_eq!(expected_digest, program_digest);
213    }
214
215    fn hash_input_pad_program(program: &Program) -> Vec<BFieldElement> {
216        let padded_program_length = Self::padded_program_length(program);
217
218        // padding is one 1, then as many zeros as necessary: [1, 0, 0, …]
219        let program_iter = program.to_bwords().into_iter();
220        let one = bfe_array![1];
221        let zeros = bfe_array![0; tip5::RATE];
222        program_iter
223            .chain(one)
224            .chain(zeros)
225            .take(padded_program_length)
226            .collect()
227    }
228
229    pub(crate) fn record_state(&mut self, state: &VMState) -> Result<(), InstructionError> {
230        self.record_instruction_lookup(state.instruction_pointer)?;
231        self.append_state_to_processor_trace(state);
232        Ok(())
233    }
234
235    fn record_instruction_lookup(
236        &mut self,
237        instruction_pointer: usize,
238    ) -> Result<(), InstructionError> {
239        if instruction_pointer >= self.instruction_multiplicities.len() {
240            return Err(InstructionPointerOverflow);
241        }
242        self.instruction_multiplicities[instruction_pointer] += 1;
243        Ok(())
244    }
245
246    fn append_state_to_processor_trace(&mut self, state: &VMState) {
247        self.processor_trace
248            .push_row(state.to_processor_row().view())
249            .unwrap()
250    }
251
252    pub(crate) fn record_co_processor_call(&mut self, co_processor_call: CoProcessorCall) {
253        match co_processor_call {
254            CoProcessorCall::Tip5Trace(Instruction::Hash, trace) => self.append_hash_trace(*trace),
255            CoProcessorCall::SpongeStateReset => self.append_initial_sponge_state(),
256            CoProcessorCall::Tip5Trace(instruction, trace) => {
257                self.append_sponge_trace(instruction, *trace)
258            }
259            CoProcessorCall::U32(u32_entry) => self.record_u32_table_entry(u32_entry),
260            CoProcessorCall::OpStack(op_stack_entry) => self.record_op_stack_entry(op_stack_entry),
261            CoProcessorCall::Ram(ram_call) => self.record_ram_call(ram_call),
262        }
263    }
264
265    fn append_hash_trace(&mut self, trace: PermutationTrace) {
266        self.increase_lookup_multiplicities(trace);
267        let mut hash_trace_addendum = table::hash::trace_to_table_rows(trace);
268        hash_trace_addendum
269            .slice_mut(s![.., CI.main_index()])
270            .fill(Instruction::Hash.opcode_b());
271        self.hash_trace
272            .append(ROW_AXIS, hash_trace_addendum.view())
273            .expect("shapes must be identical");
274    }
275
276    fn append_initial_sponge_state(&mut self) {
277        let round_number = 0;
278        let initial_state = Tip5::init().state;
279        let mut hash_table_row = table::hash::trace_row_to_table_row(initial_state, round_number);
280        hash_table_row[CI.main_index()] = Instruction::SpongeInit.opcode_b();
281        self.sponge_trace.push_row(hash_table_row.view()).unwrap();
282    }
283
284    fn append_sponge_trace(&mut self, instruction: Instruction, trace: PermutationTrace) {
285        assert!(matches!(
286            instruction,
287            Instruction::SpongeAbsorb | Instruction::SpongeSqueeze
288        ));
289        self.increase_lookup_multiplicities(trace);
290        let mut sponge_trace_addendum = table::hash::trace_to_table_rows(trace);
291        sponge_trace_addendum
292            .slice_mut(s![.., CI.main_index()])
293            .fill(instruction.opcode_b());
294        self.sponge_trace
295            .append(ROW_AXIS, sponge_trace_addendum.view())
296            .expect("shapes must be identical");
297    }
298
299    /// Given a trace of the hash function's permutation, determines how often
300    /// each entry in the
301    /// - cascade table was looked up, and
302    /// - lookup table was looked up;
303    ///
304    /// and increases the multiplicities accordingly
305    fn increase_lookup_multiplicities(&mut self, trace: PermutationTrace) {
306        // The last row in the trace is the permutation's result: no lookups are
307        // performed for it.
308        let rows_for_which_lookups_are_performed = trace.iter().dropping_back(1);
309        for row in rows_for_which_lookups_are_performed {
310            self.increase_lookup_multiplicities_for_row(row);
311        }
312    }
313
314    /// Given one row of the hash function's permutation trace, increase the
315    /// multiplicities of the relevant entries in the cascade table and/or
316    /// the lookup table.
317    fn increase_lookup_multiplicities_for_row(&mut self, row: &[BFieldElement; tip5::STATE_SIZE]) {
318        for &state_element in &row[0..tip5::NUM_SPLIT_AND_LOOKUP] {
319            self.increase_lookup_multiplicities_for_state_element(state_element);
320        }
321    }
322
323    /// Given one state element, increase the multiplicities of the
324    /// corresponding entries in the cascade table and/or the lookup table.
325    fn increase_lookup_multiplicities_for_state_element(&mut self, state_element: BFieldElement) {
326        for limb in table::hash::base_field_element_into_16_bit_limbs(state_element) {
327            match self.cascade_table_lookup_multiplicities.entry(limb) {
328                Occupied(mut cascade_table_entry) => *cascade_table_entry.get_mut() += 1,
329                Vacant(cascade_table_entry) => {
330                    cascade_table_entry.insert(1);
331                    self.increase_lookup_table_multiplicities_for_limb(limb);
332                }
333            }
334        }
335    }
336
337    /// Given one 16-bit limb, increase the multiplicities of the corresponding
338    /// entries in the lookup table.
339    fn increase_lookup_table_multiplicities_for_limb(&mut self, limb: u16) {
340        let limb_lo = limb & 0xff;
341        let limb_hi = (limb >> 8) & 0xff;
342        self.lookup_table_lookup_multiplicities[limb_lo as usize] += 1;
343        self.lookup_table_lookup_multiplicities[limb_hi as usize] += 1;
344    }
345
346    fn record_u32_table_entry(&mut self, u32_entry: U32TableEntry) {
347        self.u32_entries.entry(u32_entry).or_insert(0).add_assign(1)
348    }
349
350    fn record_op_stack_entry(&mut self, op_stack_entry: OpStackTableEntry) {
351        let op_stack_table_row = op_stack_entry.to_main_table_row();
352        self.op_stack_underflow_trace
353            .push_row(op_stack_table_row.view())
354            .unwrap();
355    }
356
357    fn record_ram_call(&mut self, ram_call: RamTableCall) {
358        self.ram_trace
359            .push_row(ram_call.to_table_row().view())
360            .unwrap();
361    }
362}
363
364impl TableHeight {
365    fn new(table: TableId, height: usize) -> Self {
366        Self { table, height }
367    }
368}
369
370impl PartialOrd for TableHeight {
371    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
372        Some(self.cmp(other))
373    }
374}
375
376impl Ord for TableHeight {
377    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
378        self.height.cmp(&other.height)
379    }
380}
381
382#[cfg(test)]
383#[cfg_attr(coverage_nightly, coverage(off))]
384mod tests {
385    use assert2::assert;
386    use isa::triton_asm;
387    use isa::triton_program;
388
389    use super::*;
390    use crate::prelude::*;
391
392    #[test]
393    fn pad_program_requiring_no_padding_zeros() {
394        let eight_nops = triton_asm![nop; 8];
395        let program = triton_program!({&eight_nops} halt);
396        let padded_program = AlgebraicExecutionTrace::hash_input_pad_program(&program);
397
398        let expected = [program.to_bwords(), bfe_vec![1]].concat();
399        assert!(expected == padded_program);
400    }
401
402    #[test]
403    fn height_of_any_table_can_be_computed() {
404        let program = triton_program!(halt);
405        let (aet, _) =
406            VM::trace_execution(program, PublicInput::default(), NonDeterminism::default())
407                .unwrap();
408
409        let _ = aet.height();
410        for table in TableId::iter() {
411            let _ = aet.height_of_table(table);
412        }
413    }
414}