use super::{
super::StackTopState, Felt, OverflowTableRow, Stack, StackInputs, ONE, STACK_TOP_SIZE, ZERO,
};
use alloc::vec::Vec;
use miden_air::trace::{
stack::{B0_COL_IDX, B1_COL_IDX, H0_COL_IDX, NUM_STACK_HELPER_COLS},
STACK_TRACE_WIDTH,
};
use vm_core::{FieldElement, StarkField};
type StackHelpersState = [Felt; NUM_STACK_HELPER_COLS];
#[test]
fn initialize() {
let mut stack_inputs = [1, 2, 3, 4];
let stack = StackInputs::try_from_ints(stack_inputs).unwrap();
let stack = Stack::new(&stack, 4, false);
stack_inputs.reverse();
let expected_stack = build_stack(&stack_inputs);
let expected_helpers = [Felt::new(STACK_TOP_SIZE as u64), ZERO, ZERO];
assert_eq!(stack.trace_state(), expected_stack);
assert_eq!(stack.helpers_state(), expected_helpers);
}
#[test]
fn initialize_overflow() {
let mut stack_inputs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19];
let stack = StackInputs::try_from_ints(stack_inputs).unwrap();
let stack = Stack::new(&stack, 4, false);
stack_inputs.reverse();
let expected_stack = build_stack(&stack_inputs[..STACK_TOP_SIZE]);
let expected_depth = stack_inputs.len() as u64;
let expected_helpers = [
Felt::new(expected_depth),
-ONE,
Felt::new(expected_depth - STACK_TOP_SIZE as u64),
];
let init_addr = Felt::MODULUS - 3;
let expected_overflow_rows = vec![
OverflowTableRow::new(Felt::new(init_addr), ONE, ZERO),
OverflowTableRow::new(Felt::new(init_addr + 1), Felt::new(2), Felt::new(init_addr)),
OverflowTableRow::new(Felt::new(init_addr + 2), Felt::new(3), Felt::new(init_addr + 1)),
];
let expected_overflow_active_rows = vec![0, 1, 2];
assert_eq!(stack.trace_state(), expected_stack);
assert_eq!(stack.helpers_state(), expected_helpers);
assert_eq!(stack.overflow.active_rows(), expected_overflow_active_rows);
assert_eq!(stack.overflow.all_rows(), expected_overflow_rows);
}
#[test]
fn shift_left() {
let stack_inputs = [1, 2, 3, 4];
let stack_inputs = StackInputs::try_from_ints(stack_inputs).unwrap();
let mut stack = Stack::new(&stack_inputs, 4, false);
stack.shift_left(1);
stack.advance_clock();
assert_eq!(stack.trace_state(), build_stack(&[3, 2, 1]));
assert_eq!(stack.helpers_state(), build_helpers_partial(0, 0));
let mut stack = Stack::new(&stack_inputs, 4, false);
stack.copy_state(0);
stack.advance_clock();
stack.shift_right(0);
let prev_overflow_addr = stack.current_clk() as usize;
stack.advance_clock();
stack.shift_right(0);
stack.advance_clock();
stack.ensure_trace_capacity();
stack.shift_left(1);
stack.advance_clock();
assert_eq!(stack.trace_state(), build_stack(&[0, 4, 3, 2, 1]));
assert_eq!(stack.helpers_state(), build_helpers_partial(1, prev_overflow_addr));
stack.shift_left(1);
stack.advance_clock();
assert_eq!(stack.trace_state(), build_stack(&[4, 3, 2, 1]));
assert_eq!(stack.helpers_state(), build_helpers_partial(0, 0));
}
#[test]
fn shift_right() {
let stack_inputs = [1, 2, 3, 4];
let stack_inputs = StackInputs::try_from_ints(stack_inputs).unwrap();
let mut stack = Stack::new(&stack_inputs, 4, false);
stack.copy_state(0);
stack.advance_clock();
let expected_stack = build_stack(&[0, 4, 3, 2, 1]);
let expected_helpers = build_helpers_partial(1, stack.current_clk() as usize);
stack.shift_right(0);
stack.advance_clock();
assert_eq!(stack.trace_state(), expected_stack);
assert_eq!(stack.helpers_state(), expected_helpers);
let expected_stack = build_stack(&[0, 0, 4, 3, 2, 1]);
let expected_helpers = build_helpers_partial(2, stack.current_clk() as usize);
stack.shift_right(0);
stack.advance_clock();
assert_eq!(stack.trace_state(), expected_stack);
assert_eq!(stack.helpers_state(), expected_helpers);
}
#[test]
fn start_restore_context() {
let stack_init = (0..16).map(|v| v as u64 + 1);
let stack = StackInputs::try_from_ints(stack_init).unwrap();
let mut stack = Stack::new(&stack, 8, false);
stack.copy_state(0);
stack.advance_clock();
stack.start_context();
stack.copy_state(0);
stack.advance_clock();
assert_eq!(16, stack.depth());
stack.shift_left(1);
stack.advance_clock();
assert_eq!(16, stack.depth());
stack.shift_right(0);
stack.advance_clock();
assert_eq!(17, stack.depth());
stack.shift_left(1);
stack.advance_clock();
assert_eq!(16, stack.depth());
stack.restore_context(16, ZERO);
stack.copy_state(0);
stack.advance_clock();
assert_eq!(16, stack.depth());
let stack_init = (0..16).map(|v| v as u64 + 1);
let stack = StackInputs::try_from_ints(stack_init.clone()).unwrap();
let mut stack = Stack::new(&stack, 8, false);
let mut stack_state = stack_init.collect::<Vec<_>>();
stack_state.reverse();
stack.copy_state(0);
stack.advance_clock();
stack.shift_right(0);
stack.advance_clock();
assert_eq!(17, stack.depth());
stack_state.insert(0, 0);
assert_eq!(stack.trace_state(), build_stack(&stack_state[..16]));
assert_eq!(stack.helpers_state(), build_helpers_partial(1, 1));
let (ctx0_depth, ctx0_next_overflow_addr) = stack.start_context();
stack.copy_state(0);
stack.advance_clock();
assert_eq!(16, stack.depth());
assert_eq!(stack.trace_state(), build_stack(&stack_state[..16]));
assert_eq!(stack.helpers_state(), build_helpers_partial(0, 0));
stack.shift_right(0);
stack.advance_clock();
assert_eq!(17, stack.depth());
stack_state.insert(0, 0);
assert_eq!(stack.trace_state(), build_stack(&stack_state[..16]));
assert_eq!(stack.helpers_state(), build_helpers_partial(1, 3));
stack.shift_left(1);
stack.advance_clock();
assert_eq!(16, stack.depth());
stack_state.remove(0);
assert_eq!(stack.trace_state(), build_stack(&stack_state[..16]));
assert_eq!(stack.helpers_state(), build_helpers_partial(0, 0));
stack.restore_context(17, ctx0_next_overflow_addr);
stack.copy_state(0);
stack.advance_clock();
assert_eq!(ctx0_depth, stack.depth());
assert_eq!(stack.trace_state(), build_stack(&stack_state[..16]));
assert_eq!(
stack.helpers_state(),
build_helpers_partial(ctx0_depth - 16, ctx0_next_overflow_addr.as_int() as usize)
);
stack.shift_left(1);
stack.advance_clock();
assert_eq!(16, stack.depth());
stack_state.remove(0);
assert_eq!(stack.trace_state(), build_stack(&stack_state[..16]));
assert_eq!(stack.helpers_state(), build_helpers_partial(0, 0));
}
#[test]
fn generate_trace() {
let stack_inputs = [1, 2, 3, 4];
let stack_inputs = StackInputs::try_from_ints(stack_inputs).unwrap();
let mut stack = Stack::new(&stack_inputs, 16, false);
stack.copy_state(0);
stack.advance_clock();
stack.shift_right(0);
stack.advance_clock();
stack.shift_right(0);
stack.advance_clock();
let (c0_depth, c0_overflow_addr) = stack.start_context();
stack.copy_state(0);
stack.advance_clock();
stack.shift_right(0);
stack.advance_clock();
stack.copy_state(0);
stack.advance_clock();
stack.shift_left(1);
stack.advance_clock();
stack.restore_context(c0_depth, c0_overflow_addr);
stack.copy_state(0);
stack.advance_clock();
stack.shift_right(0);
stack.advance_clock();
stack.copy_state(0);
stack.advance_clock();
stack.shift_left(1);
stack.advance_clock();
stack.shift_left(1);
stack.advance_clock();
stack.shift_left(1);
stack.advance_clock();
let trace = stack.into_trace(16, 1);
let trace = trace.trace;
assert_eq!(read_stack_top(&trace, 0), build_stack(&[4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 1), build_stack(&[4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 2), build_stack(&[0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 3), build_stack(&[0, 0, 4, 3, 2, 1])); assert_eq!(read_stack_top(&trace, 4), build_stack(&[0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 5), build_stack(&[0, 0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 6), build_stack(&[0, 0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 7), build_stack(&[0, 0, 4, 3, 2, 1])); assert_eq!(read_stack_top(&trace, 8), build_stack(&[0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 9), build_stack(&[0, 0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 10), build_stack(&[0, 0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 11), build_stack(&[0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 12), build_stack(&[0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 13), build_stack(&[4, 3, 2, 1]));
assert_eq!(read_helpers(&trace, 0), build_helpers(16, 0));
assert_eq!(read_helpers(&trace, 1), build_helpers(16, 0));
assert_eq!(read_helpers(&trace, 2), build_helpers(17, 1));
assert_eq!(read_helpers(&trace, 3), build_helpers(18, 2)); assert_eq!(read_helpers(&trace, 4), build_helpers(16, 0));
assert_eq!(read_helpers(&trace, 5), build_helpers(17, 4));
assert_eq!(read_helpers(&trace, 6), build_helpers(17, 4));
assert_eq!(read_helpers(&trace, 7), build_helpers(16, 0)); assert_eq!(read_helpers(&trace, 8), build_helpers(18, 2));
assert_eq!(read_helpers(&trace, 9), build_helpers(19, 8));
assert_eq!(read_helpers(&trace, 10), build_helpers(19, 8));
assert_eq!(read_helpers(&trace, 11), build_helpers(18, 2));
assert_eq!(read_helpers(&trace, 12), build_helpers(17, 1));
assert_eq!(read_helpers(&trace, 13), build_helpers(16, 0));
}
fn build_stack(stack_inputs: &[u64]) -> StackTopState {
let mut result = [ZERO; STACK_TOP_SIZE];
for (idx, &input) in stack_inputs.iter().enumerate() {
result[idx] = Felt::new(input);
}
result
}
fn build_helpers(stack_depth: u64, next_overflow_addr: u64) -> StackHelpersState {
let b0 = Felt::new(stack_depth);
let b1 = Felt::new(next_overflow_addr);
let h0 = (b0 - Felt::new(STACK_TOP_SIZE as u64)).inv();
[b0, b1, h0]
}
fn build_helpers_partial(num_overflow: usize, next_overflow_addr: usize) -> StackHelpersState {
let depth = STACK_TOP_SIZE + num_overflow;
let b0 = Felt::new(depth as u64);
let b1 = Felt::new(next_overflow_addr as u64);
let h0 = b0 - Felt::new(STACK_TOP_SIZE as u64);
[b0, b1, h0]
}
fn read_stack_top(trace: &[Vec<Felt>; STACK_TRACE_WIDTH], row: usize) -> StackTopState {
let mut result = [ZERO; STACK_TOP_SIZE];
for (value, column) in result.iter_mut().zip(trace) {
*value = column[row];
}
result
}
fn read_helpers(trace: &[Vec<Felt>; STACK_TRACE_WIDTH], row: usize) -> StackHelpersState {
[trace[B0_COL_IDX][row], trace[B1_COL_IDX][row], trace[H0_COL_IDX][row]]
}