use p3_air::AirBuilder;
use p3_field::AbstractField;
use sp1_stark::air::BaseAirBuilder;
use crate::{
air::SP1RecursionAirBuilder,
memory::MemoryCols,
poseidon2_wide::{
columns::{
control_flow::ControlFlow, memory::Memory, opcode_workspace::OpcodeWorkspace,
syscall_params::SyscallParams,
},
Poseidon2WideChip, WIDTH,
},
};
impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
#[allow(clippy::too_many_arguments)]
pub(crate) fn eval_mem<AB: SP1RecursionAirBuilder>(
&self,
builder: &mut AB,
syscall_params: &SyscallParams<AB::Var>,
local_memory: &Memory<AB::Var>,
next_memory: &Memory<AB::Var>,
opcode_workspace: &OpcodeWorkspace<AB::Var>,
control_flow: &ControlFlow<AB::Var>,
first_half_memory_access: [AB::Var; WIDTH / 2],
second_half_memory_access: AB::Var,
) {
let clk = syscall_params.get_raw_params()[0];
let is_real = control_flow.is_compress + control_flow.is_absorb + control_flow.is_finalize;
for i in 0..WIDTH / 2 {
builder.assert_bool(local_memory.memory_slot_used[i]);
builder.when_not(is_real.clone()).assert_zero(local_memory.memory_slot_used[i]);
builder
.when(control_flow.is_compress + control_flow.is_finalize)
.assert_one(local_memory.memory_slot_used[i]);
self.eval_absorb_memory_slots(builder, control_flow, local_memory, opcode_workspace);
}
{
builder
.when(control_flow.is_compress * control_flow.is_syscall_row)
.assert_eq(syscall_params.compress().left_ptr, local_memory.start_addr);
builder
.when(control_flow.is_compress_output)
.assert_eq(syscall_params.compress().dst_ptr, local_memory.start_addr);
builder
.when(control_flow.is_absorb)
.when(control_flow.is_syscall_row)
.assert_eq(syscall_params.absorb().input_ptr, local_memory.start_addr);
builder.when(control_flow.is_absorb_not_last_row).assert_eq(
next_memory.start_addr,
local_memory.start_addr + opcode_workspace.absorb().num_consumed::<AB>(),
);
builder
.when(control_flow.is_finalize)
.assert_eq(syscall_params.finalize().output_ptr, local_memory.start_addr);
}
{
let mut addr: AB::Expr = local_memory.start_addr.into();
for i in 0..WIDTH / 2 {
builder.recursion_eval_memory_access_single(
clk + control_flow.is_compress_output,
addr.clone(),
&local_memory.memory_accesses[i],
first_half_memory_access[i],
);
let compress_syscall_row = control_flow.is_compress * control_flow.is_syscall_row;
builder.when(compress_syscall_row + control_flow.is_absorb).assert_eq(
*local_memory.memory_accesses[i].prev_value(),
*local_memory.memory_accesses[i].value(),
);
addr = addr.clone() + local_memory.memory_slot_used[i].into();
}
}
{
let compress_workspace = opcode_workspace.compress();
let is_compress_syscall = control_flow.is_compress * control_flow.is_syscall_row;
builder
.when(is_compress_syscall.clone())
.assert_eq(compress_workspace.start_addr, syscall_params.compress().right_ptr);
builder.when(control_flow.is_compress_output).assert_eq(
compress_workspace.start_addr,
syscall_params.compress().dst_ptr + AB::Expr::from_canonical_usize(WIDTH / 2),
);
let mut addr: AB::Expr = compress_workspace.start_addr.into();
for i in 0..WIDTH / 2 {
builder.recursion_eval_memory_access_single(
clk + control_flow.is_compress_output,
addr.clone(),
&compress_workspace.memory_accesses[i],
second_half_memory_access,
);
builder.when(is_compress_syscall.clone()).assert_eq(
*compress_workspace.memory_accesses[i].prev_value(),
*compress_workspace.memory_accesses[i].value(),
);
addr = addr.clone() + AB::Expr::one();
}
}
}
fn eval_absorb_memory_slots<AB: SP1RecursionAirBuilder>(
&self,
builder: &mut AB,
control_flow: &ControlFlow<AB::Var>,
local_memory: &Memory<AB::Var>,
opcode_workspace: &OpcodeWorkspace<AB::Var>,
) {
let mut absorb_builder = builder.when(control_flow.is_absorb);
let start_mem_idx_bitmap = opcode_workspace.absorb().start_mem_idx_bitmap;
let end_mem_idx_bitmap = opcode_workspace.absorb().end_mem_idx_bitmap;
for i in 0..WIDTH / 2 {
let derivative: AB::Expr = if i == 0 {
local_memory.memory_slot_used[i].into()
} else {
local_memory.memory_slot_used[i] - local_memory.memory_slot_used[i - 1]
};
let is_start_mem_idx = start_mem_idx_bitmap[i].into();
let is_previous_end_mem_idx =
if i == 0 { AB::Expr::zero() } else { end_mem_idx_bitmap[i - 1].into() };
absorb_builder.when(is_start_mem_idx.clone()).assert_one(derivative.clone());
absorb_builder
.when(is_previous_end_mem_idx.clone())
.assert_zero(derivative.clone() + AB::Expr::one());
absorb_builder
.when_not(is_start_mem_idx + is_previous_end_mem_idx)
.assert_zero(derivative);
}
let mut start_mem_idx_bitmap_sum = AB::Expr::zero();
start_mem_idx_bitmap.iter().for_each(|bit| {
absorb_builder.assert_bool(*bit);
start_mem_idx_bitmap_sum += (*bit).into();
});
absorb_builder.assert_one(start_mem_idx_bitmap_sum);
let mut end_mem_idx_bitmap_sum = AB::Expr::zero();
end_mem_idx_bitmap.iter().for_each(|bit| {
absorb_builder.assert_bool(*bit);
end_mem_idx_bitmap_sum += (*bit).into();
});
absorb_builder.assert_one(end_mem_idx_bitmap_sum);
let start_mem_idx: AB::Expr = start_mem_idx_bitmap
.iter()
.enumerate()
.map(|(i, bit)| AB::Expr::from_canonical_usize(i) * *bit)
.sum();
absorb_builder.assert_eq(start_mem_idx, opcode_workspace.absorb().state_cursor);
let end_mem_idx: AB::Expr = end_mem_idx_bitmap
.iter()
.enumerate()
.map(|(i, bit)| AB::Expr::from_canonical_usize(i) * *bit)
.sum();
absorb_builder
.when_not(opcode_workspace.absorb().is_last_row::<AB>())
.assert_zero(end_mem_idx.clone() - AB::Expr::from_canonical_usize(7));
absorb_builder
.when(opcode_workspace.absorb().is_last_row::<AB>())
.assert_eq(end_mem_idx, opcode_workspace.absorb().last_row_ending_cursor);
}
}