use alloc::vec::Vec;
use miden_air::trace::{
STACK_TRACE_WIDTH,
stack::{B0_COL_IDX, B1_COL_IDX, H0_COL_IDX, NUM_STACK_HELPER_COLS},
};
use miden_core::FieldElement;
use super::*;
use crate::stack::OverflowTableRow;
type StackHelpersState = [Felt; NUM_STACK_HELPER_COLS];
#[test]
fn initialize() {
let mut stack_inputs = [1, 2, 3, 4];
let stack = StackInputs::try_from_ints(stack_inputs).unwrap();
let stack = Stack::new(&stack, 4, false);
stack_inputs.reverse();
let expected_stack = build_stack(&stack_inputs);
let expected_helpers = [Felt::new(MIN_STACK_DEPTH as u64), ZERO, ZERO];
assert_eq!(stack.trace_state(), expected_stack);
assert_eq!(stack.helpers_state(), expected_helpers);
}
#[test]
fn stack_overflow() {
let mut stack_values_holder: [u64; 19] =
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19];
let stack = StackInputs::try_from_ints(stack_values_holder[0..16].to_vec()).unwrap();
let mut stack = Stack::new(&stack, 5, false);
stack.copy_state(0);
stack.advance_clock();
stack.shift_right(0);
stack.set(0, Felt::from(17u8));
stack.advance_clock();
stack.shift_right(0);
stack.set(0, Felt::from(18u8));
stack.advance_clock();
stack.shift_right(0);
stack.set(0, Felt::from(19u8));
stack.advance_clock();
stack_values_holder.reverse();
let expected_stack = build_stack(&stack_values_holder[0..16]);
let expected_depth = stack_values_holder.len() as u64;
let expected_helpers = [
Felt::new(expected_depth),
Felt::new(3u64),
Felt::new(expected_depth - MIN_STACK_DEPTH as u64),
];
let init_addr = 1;
let expected_overflow_rows = [
OverflowTableRow::new(Felt::new(init_addr), ONE, ZERO),
OverflowTableRow::new(Felt::new(init_addr + 1), Felt::new(2), Felt::new(init_addr)),
OverflowTableRow::new(Felt::new(init_addr + 2), Felt::new(3), Felt::new(init_addr + 1)),
];
assert_eq!(stack.trace_state(), expected_stack);
assert_eq!(stack.helpers_state(), expected_helpers);
assert_eq!(stack.overflow.total_num_elements(), expected_overflow_rows.len());
}
#[test]
fn shift_left() {
let stack_inputs = [1, 2, 3, 4];
let stack_inputs = StackInputs::try_from_ints(stack_inputs).unwrap();
let mut stack = Stack::new(&stack_inputs, 4, false);
stack.shift_left(1);
stack.advance_clock();
assert_eq!(stack.trace_state(), build_stack(&[3, 2, 1]));
assert_eq!(stack.helpers_state(), build_helpers_partial(0, 0));
let mut stack = Stack::new(&stack_inputs, 4, false);
stack.copy_state(0);
stack.advance_clock();
stack.shift_right(0);
let prev_overflow_addr: usize = stack.current_clk().into();
stack.advance_clock();
stack.shift_right(0);
stack.advance_clock();
stack.ensure_trace_capacity();
stack.shift_left(1);
stack.advance_clock();
assert_eq!(stack.trace_state(), build_stack(&[0, 4, 3, 2, 1]));
assert_eq!(stack.helpers_state(), build_helpers_partial(1, prev_overflow_addr));
stack.shift_left(1);
stack.advance_clock();
assert_eq!(stack.trace_state(), build_stack(&[4, 3, 2, 1]));
assert_eq!(stack.helpers_state(), build_helpers_partial(0, 0));
}
#[test]
fn shift_right() {
let stack_inputs = [1, 2, 3, 4];
let stack_inputs = StackInputs::try_from_ints(stack_inputs).unwrap();
let mut stack = Stack::new(&stack_inputs, 4, false);
stack.copy_state(0);
stack.advance_clock();
let expected_stack = build_stack(&[0, 4, 3, 2, 1]);
let expected_helpers = build_helpers_partial(1, stack.current_clk().into());
stack.shift_right(0);
stack.advance_clock();
assert_eq!(stack.trace_state(), expected_stack);
assert_eq!(stack.helpers_state(), expected_helpers);
let expected_stack = build_stack(&[0, 0, 4, 3, 2, 1]);
let expected_helpers = build_helpers_partial(2, stack.current_clk().into());
stack.shift_right(0);
stack.advance_clock();
assert_eq!(stack.trace_state(), expected_stack);
assert_eq!(stack.helpers_state(), expected_helpers);
}
#[test]
fn start_restore_context() {
let stack = StackInputs::try_from_ints(1..17).unwrap();
let mut stack = Stack::new(&stack, 8, false);
stack.copy_state(0);
stack.advance_clock();
stack.start_context();
stack.copy_state(0);
stack.advance_clock();
assert_eq!(16, stack.depth());
stack.shift_left(1);
stack.advance_clock();
assert_eq!(16, stack.depth());
stack.shift_right(0);
stack.advance_clock();
assert_eq!(17, stack.depth());
stack.shift_left(1);
stack.advance_clock();
assert_eq!(16, stack.depth());
stack.restore_context(16);
stack.copy_state(0);
stack.advance_clock();
assert_eq!(16, stack.depth());
let stack_init = (0..16).map(|v| v as u64 + 1);
let stack = StackInputs::try_from_ints(stack_init.clone()).unwrap();
let mut stack = Stack::new(&stack, 8, false);
let mut stack_state = stack_init.collect::<Vec<_>>();
stack_state.reverse();
stack.copy_state(0);
stack.advance_clock();
stack.shift_right(0);
stack.advance_clock();
assert_eq!(17, stack.depth());
stack_state.insert(0, 0);
assert_eq!(stack.trace_state(), build_stack(&stack_state[..16]));
assert_eq!(stack.helpers_state(), build_helpers_partial(1, 1));
let (ctx0_depth, ctx0_next_overflow_addr) = stack.start_context();
stack.copy_state(0);
stack.advance_clock();
assert_eq!(16, stack.depth());
assert_eq!(stack.trace_state(), build_stack(&stack_state[..16]));
assert_eq!(stack.helpers_state(), build_helpers_partial(0, 0));
stack.shift_right(0);
stack.advance_clock();
assert_eq!(17, stack.depth());
stack_state.insert(0, 0);
assert_eq!(stack.trace_state(), build_stack(&stack_state[..16]));
assert_eq!(stack.helpers_state(), build_helpers_partial(1, 3));
stack.shift_left(1);
stack.advance_clock();
assert_eq!(16, stack.depth());
stack_state.remove(0);
assert_eq!(stack.trace_state(), build_stack(&stack_state[..16]));
assert_eq!(stack.helpers_state(), build_helpers_partial(0, 0));
stack.restore_context(17);
stack.copy_state(0);
stack.advance_clock();
assert_eq!(ctx0_depth, stack.depth());
assert_eq!(stack.trace_state(), build_stack(&stack_state[..16]));
assert_eq!(
stack.helpers_state(),
build_helpers_partial(ctx0_depth - 16, ctx0_next_overflow_addr.as_int() as usize)
);
stack.shift_left(1);
stack.advance_clock();
assert_eq!(16, stack.depth());
stack_state.remove(0);
assert_eq!(stack.trace_state(), build_stack(&stack_state[..16]));
assert_eq!(stack.helpers_state(), build_helpers_partial(0, 0));
}
#[test]
fn root_context_separate_overflows() {
const SENTINEL_VALUE: Felt = Felt::new(100);
let mut overflow_stack = OverflowTable::new(true);
overflow_stack.advance_clock();
overflow_stack.push(SENTINEL_VALUE);
overflow_stack.advance_clock();
overflow_stack.start_context();
overflow_stack.advance_clock();
overflow_stack.start_context();
overflow_stack.advance_clock();
let popped_value = overflow_stack.pop();
overflow_stack.advance_clock();
assert!(popped_value.is_none());
overflow_stack.restore_context();
overflow_stack.advance_clock();
overflow_stack.restore_context();
overflow_stack.advance_clock();
let popped_value = overflow_stack.pop();
overflow_stack.advance_clock();
assert_eq!(popped_value, Some(SENTINEL_VALUE));
let mut overflow_stack_at_clk = Vec::new();
overflow_stack.append_from_history_at(0_u32.into(), &mut overflow_stack_at_clk);
assert!(overflow_stack_at_clk.is_empty());
overflow_stack_at_clk.clear();
overflow_stack.append_from_history_at(1_u32.into(), &mut overflow_stack_at_clk);
assert_eq!(overflow_stack_at_clk, vec![SENTINEL_VALUE]);
overflow_stack_at_clk.clear();
overflow_stack.append_from_history_at(2_u32.into(), &mut overflow_stack_at_clk);
assert!(overflow_stack_at_clk.is_empty());
overflow_stack_at_clk.clear();
overflow_stack.append_from_history_at(3_u32.into(), &mut overflow_stack_at_clk);
assert!(overflow_stack_at_clk.is_empty());
overflow_stack_at_clk.clear();
overflow_stack.append_from_history_at(4_u32.into(), &mut overflow_stack_at_clk);
assert!(overflow_stack_at_clk.is_empty());
overflow_stack_at_clk.clear();
overflow_stack.append_from_history_at(5_u32.into(), &mut overflow_stack_at_clk);
assert!(overflow_stack_at_clk.is_empty());
overflow_stack_at_clk.clear();
overflow_stack.append_from_history_at(6_u32.into(), &mut overflow_stack_at_clk);
assert_eq!(overflow_stack_at_clk, vec![SENTINEL_VALUE]);
overflow_stack_at_clk.clear();
overflow_stack.append_from_history_at(7_u32.into(), &mut overflow_stack_at_clk);
assert!(overflow_stack_at_clk.is_empty());
}
#[test]
fn generate_trace() {
let stack_inputs = [1, 2, 3, 4];
let stack_inputs = StackInputs::try_from_ints(stack_inputs).unwrap();
let mut stack = Stack::new(&stack_inputs, 16, false);
stack.copy_state(0);
stack.advance_clock();
stack.shift_right(0);
stack.advance_clock();
stack.shift_right(0);
stack.advance_clock();
let (c0_depth, _c0_overflow_addr) = stack.start_context();
stack.copy_state(0);
stack.advance_clock();
stack.shift_right(0);
stack.advance_clock();
stack.copy_state(0);
stack.advance_clock();
stack.shift_left(1);
stack.advance_clock();
stack.restore_context(c0_depth);
stack.copy_state(0);
stack.advance_clock();
stack.shift_right(0);
stack.advance_clock();
stack.copy_state(0);
stack.advance_clock();
stack.shift_left(1);
stack.advance_clock();
stack.shift_left(1);
stack.advance_clock();
stack.shift_left(1);
stack.advance_clock();
let trace = stack.into_trace(16, 1);
let trace = trace.trace;
assert_eq!(read_stack_top(&trace, 0), build_stack(&[4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 1), build_stack(&[4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 2), build_stack(&[0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 3), build_stack(&[0, 0, 4, 3, 2, 1])); assert_eq!(read_stack_top(&trace, 4), build_stack(&[0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 5), build_stack(&[0, 0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 6), build_stack(&[0, 0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 7), build_stack(&[0, 0, 4, 3, 2, 1])); assert_eq!(read_stack_top(&trace, 8), build_stack(&[0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 9), build_stack(&[0, 0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 10), build_stack(&[0, 0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 11), build_stack(&[0, 0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 12), build_stack(&[0, 4, 3, 2, 1]));
assert_eq!(read_stack_top(&trace, 13), build_stack(&[4, 3, 2, 1]));
assert_eq!(read_helpers(&trace, 0), build_helpers(16, 0));
assert_eq!(read_helpers(&trace, 1), build_helpers(16, 0));
assert_eq!(read_helpers(&trace, 2), build_helpers(17, 1));
assert_eq!(read_helpers(&trace, 3), build_helpers(18, 2)); assert_eq!(read_helpers(&trace, 4), build_helpers(16, 0));
assert_eq!(read_helpers(&trace, 5), build_helpers(17, 4));
assert_eq!(read_helpers(&trace, 6), build_helpers(17, 4));
assert_eq!(read_helpers(&trace, 7), build_helpers(16, 0)); assert_eq!(read_helpers(&trace, 8), build_helpers(18, 2));
assert_eq!(read_helpers(&trace, 9), build_helpers(19, 8));
assert_eq!(read_helpers(&trace, 10), build_helpers(19, 8));
assert_eq!(read_helpers(&trace, 11), build_helpers(18, 2));
assert_eq!(read_helpers(&trace, 12), build_helpers(17, 1));
assert_eq!(read_helpers(&trace, 13), build_helpers(16, 0));
}
fn build_stack(stack_inputs: &[u64]) -> [Felt; MIN_STACK_DEPTH] {
let mut result = [ZERO; MIN_STACK_DEPTH];
for (idx, &input) in stack_inputs.iter().enumerate() {
result[idx] = Felt::new(input);
}
result
}
fn build_helpers(stack_depth: u64, next_overflow_addr: u64) -> StackHelpersState {
let b0 = Felt::new(stack_depth);
let b1 = Felt::new(next_overflow_addr);
let h0 = (b0 - Felt::new(MIN_STACK_DEPTH as u64)).inv();
[b0, b1, h0]
}
fn build_helpers_partial(num_overflow: usize, next_overflow_addr: usize) -> StackHelpersState {
let depth = MIN_STACK_DEPTH + num_overflow;
let b0 = Felt::new(depth as u64);
let b1 = Felt::new(next_overflow_addr as u64);
let h0 = b0 - Felt::new(MIN_STACK_DEPTH as u64);
[b0, b1, h0]
}
fn read_stack_top(trace: &[Vec<Felt>; STACK_TRACE_WIDTH], row: usize) -> [Felt; MIN_STACK_DEPTH] {
let mut result = [ZERO; MIN_STACK_DEPTH];
for (value, column) in result.iter_mut().zip(trace) {
*value = column[row];
}
result
}
fn read_helpers(trace: &[Vec<Felt>; STACK_TRACE_WIDTH], row: usize) -> StackHelpersState {
[trace[B0_COL_IDX][row], trace[B1_COL_IDX][row], trace[H0_COL_IDX][row]]
}