use alloc::{vec, vec::Vec};
use core::{borrow::Borrow, mem::size_of};
use miden_core::{Felt, WORD_SIZE, chiplets::hasher::Hasher, field::PrimeCharacteristicRing};
use super::super::{columns::indices_arr, ext_field::QuadFeltExpr};
use crate::trace::chiplets::{
bitwise::NUM_DECOMP_BITS,
hasher::{CAPACITY_LEN, DIGEST_LEN, HASH_CYCLE_LEN, NUM_SELECTORS, RATE_LEN, STATE_WIDTH},
};
pub fn borrow_chiplet<T, S>(slice: &[T]) -> &S {
let (prefix, cols, suffix) = unsafe { slice.align_to::<S>() };
debug_assert!(prefix.is_empty() && suffix.is_empty() && cols.len() == 1);
&cols[0]
}
#[repr(C)]
pub struct PermutationCols<T> {
pub witnesses: [T; NUM_SELECTORS],
pub state: [T; STATE_WIDTH],
pub multiplicity: T,
_unused: [T; 3],
}
impl<T: Copy> PermutationCols<T> {
pub fn rate(&self) -> [T; RATE_LEN] {
[
self.state[0],
self.state[1],
self.state[2],
self.state[3],
self.state[4],
self.state[5],
self.state[6],
self.state[7],
]
}
pub fn capacity(&self) -> [T; CAPACITY_LEN] {
[self.state[8], self.state[9], self.state[10], self.state[11]]
}
pub fn digest(&self) -> [T; DIGEST_LEN] {
[self.state[0], self.state[1], self.state[2], self.state[3]]
}
pub fn rate0(&self) -> [T; DIGEST_LEN] {
[self.state[0], self.state[1], self.state[2], self.state[3]]
}
pub fn rate1(&self) -> [T; DIGEST_LEN] {
[self.state[4], self.state[5], self.state[6], self.state[7]]
}
pub fn unused_padding(&self) -> [T; 3] {
self._unused
}
}
#[repr(C)]
pub struct ControllerCols<T> {
pub s0: T,
pub s1: T,
pub s2: T,
pub state: [T; STATE_WIDTH],
pub node_index: T,
pub mrupdate_id: T,
pub is_boundary: T,
pub direction_bit: T,
}
impl<T: Copy> ControllerCols<T> {
pub fn rate(&self) -> [T; RATE_LEN] {
[
self.state[0],
self.state[1],
self.state[2],
self.state[3],
self.state[4],
self.state[5],
self.state[6],
self.state[7],
]
}
pub fn capacity(&self) -> [T; CAPACITY_LEN] {
[self.state[8], self.state[9], self.state[10], self.state[11]]
}
pub fn digest(&self) -> [T; DIGEST_LEN] {
[self.state[0], self.state[1], self.state[2], self.state[3]]
}
pub fn rate0(&self) -> [T; DIGEST_LEN] {
[self.state[0], self.state[1], self.state[2], self.state[3]]
}
pub fn rate1(&self) -> [T; DIGEST_LEN] {
[self.state[4], self.state[5], self.state[6], self.state[7]]
}
pub fn f_mu<E: PrimeCharacteristicRing>(&self) -> E
where
T: Into<E>,
{
self.s0.into() * self.s1.into() * self.s2.into()
}
pub fn f_mv<E: PrimeCharacteristicRing>(&self) -> E
where
T: Into<E>,
{
self.s0.into() * self.s1.into() * (E::ONE - self.s2.into())
}
}
#[repr(C)]
pub struct BitwiseCols<T> {
pub op_flag: T,
pub a: T,
pub b: T,
pub a_bits: [T; NUM_DECOMP_BITS],
pub b_bits: [T; NUM_DECOMP_BITS],
pub prev_output: T,
pub output: T,
}
#[repr(C)]
pub struct MemoryCols<T> {
pub is_read: T,
pub is_word: T,
pub ctx: T,
pub word_addr: T,
pub idx0: T,
pub idx1: T,
pub clk: T,
pub values: [T; WORD_SIZE],
pub d0: T,
pub d1: T,
pub d_inv: T,
pub is_same_ctx_and_addr: T,
}
#[repr(C)]
pub struct AceCols<T> {
pub s_start: T,
pub s_block: T,
pub ctx: T,
pub ptr: T,
pub clk: T,
pub eval_op: T,
pub id_0: T,
pub v_0: QuadFeltExpr<T>,
pub id_1: T,
pub v_1: QuadFeltExpr<T>,
mode: [T; 4],
}
impl<T> AceCols<T> {
pub fn read(&self) -> &AceReadCols<T> {
borrow_chiplet(&self.mode)
}
pub fn eval(&self) -> &AceEvalCols<T> {
borrow_chiplet(&self.mode)
}
}
impl<T: Copy> AceCols<T> {
pub fn f_read<E: PrimeCharacteristicRing>(&self) -> E
where
T: Into<E>,
{
E::ONE - self.s_block.into()
}
pub fn f_eval<E: PrimeCharacteristicRing>(&self) -> E
where
T: Into<E>,
{
self.s_block.into()
}
}
#[repr(C)]
pub struct AceReadCols<T> {
pub num_eval: T,
pub unused: T,
pub m_1: T,
pub m_0: T,
}
#[repr(C)]
pub struct AceEvalCols<T> {
pub id_2: T,
pub v_2: QuadFeltExpr<T>,
pub m_0: T,
}
#[allow(dead_code)]
pub const ACE_COL_MAP: AceCols<usize> = {
assert!(size_of::<AceCols<u8>>() == 16);
unsafe { core::mem::transmute(indices_arr::<{ size_of::<AceCols<u8>>() }>()) }
};
pub const ACE_READ_COL_MAP: AceReadCols<usize> = {
assert!(size_of::<AceReadCols<u8>>() == 4);
unsafe { core::mem::transmute(indices_arr::<{ size_of::<AceReadCols<u8>>() }>()) }
};
pub const ACE_EVAL_COL_MAP: AceEvalCols<usize> = {
assert!(size_of::<AceEvalCols<u8>>() == 4);
unsafe { core::mem::transmute(indices_arr::<{ size_of::<AceEvalCols<u8>>() }>()) }
};
#[allow(dead_code)]
pub const MODE_OFFSET: usize = ACE_COL_MAP.mode[0];
const _: () = {
assert!(size_of::<AceCols<u8>>() == 16);
assert!(size_of::<AceReadCols<u8>>() == 4);
assert!(size_of::<AceEvalCols<u8>>() == 4);
assert!(ACE_READ_COL_MAP.m_0 == ACE_EVAL_COL_MAP.m_0);
assert!(ACE_READ_COL_MAP.num_eval == ACE_EVAL_COL_MAP.id_2);
assert!(ACE_READ_COL_MAP.m_1 == ACE_EVAL_COL_MAP.v_2.1);
};
#[repr(C)]
pub struct KernelRomCols<T> {
pub multiplicity: T,
pub root: [T; WORD_SIZE],
}
#[derive(Clone, Copy)]
#[repr(C)]
pub struct PeriodicCols<T> {
pub hasher: HasherPeriodicCols<T>,
pub bitwise: BitwisePeriodicCols<T>,
}
#[derive(Clone, Copy)]
#[repr(C)]
pub struct HasherPeriodicCols<T> {
pub is_init_ext: T,
pub is_ext: T,
pub is_packed_int: T,
pub is_int_ext: T,
pub ark: [T; STATE_WIDTH],
}
#[derive(Clone, Copy)]
#[repr(C)]
pub struct BitwisePeriodicCols<T> {
pub k_first: T,
pub k_transition: T,
}
#[allow(clippy::new_without_default)]
impl HasherPeriodicCols<Vec<Felt>> {
#[allow(clippy::needless_range_loop)]
pub fn new() -> Self {
let mut is_init_ext = vec![Felt::ZERO; HASH_CYCLE_LEN];
let mut is_ext = vec![Felt::ZERO; HASH_CYCLE_LEN];
let mut is_packed_int = vec![Felt::ZERO; HASH_CYCLE_LEN];
let mut is_int_ext = vec![Felt::ZERO; HASH_CYCLE_LEN];
is_init_ext[0] = Felt::ONE;
for r in [1, 2, 3, 12, 13, 14] {
is_ext[r] = Felt::ONE;
}
for r in 4..=10 {
is_packed_int[r] = Felt::ONE;
}
is_int_ext[11] = Felt::ONE;
let ark = core::array::from_fn(|lane| {
let mut col = vec![Felt::ZERO; HASH_CYCLE_LEN];
col[0] = Hasher::ARK_EXT_INITIAL[0][lane];
for r in 1..=3 {
col[r] = Hasher::ARK_EXT_INITIAL[r][lane];
}
if lane < 3 {
for triple in 0..7_usize {
let row = 4 + triple;
let ark_idx = triple * 3 + lane;
col[row] = Hasher::ARK_INT[ark_idx];
}
}
col[11] = Hasher::ARK_EXT_TERMINAL[0][lane];
for r in 12..=14 {
col[r] = Hasher::ARK_EXT_TERMINAL[r - 11][lane];
}
col
});
Self {
is_init_ext,
is_ext,
is_packed_int,
is_int_ext,
ark,
}
}
}
#[allow(clippy::new_without_default)]
impl BitwisePeriodicCols<Vec<Felt>> {
pub fn new() -> Self {
let k_first = vec![
Felt::ONE,
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
];
let k_transition = vec![
Felt::ONE,
Felt::ONE,
Felt::ONE,
Felt::ONE,
Felt::ONE,
Felt::ONE,
Felt::ONE,
Felt::ZERO,
];
Self { k_first, k_transition }
}
}
impl PeriodicCols<Vec<Felt>> {
pub fn periodic_columns() -> Vec<Vec<Felt>> {
let HasherPeriodicCols {
is_init_ext,
is_ext,
is_packed_int,
is_int_ext,
ark: [a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11],
} = HasherPeriodicCols::new();
let BitwisePeriodicCols { k_first, k_transition } = BitwisePeriodicCols::new();
vec![
is_init_ext,
is_ext,
is_packed_int,
is_int_ext,
a0,
a1,
a2,
a3,
a4,
a5,
a6,
a7,
a8,
a9,
a10,
a11,
k_first,
k_transition,
]
}
}
pub const NUM_PERIODIC_COLUMNS: usize = size_of::<PeriodicCols<u8>>();
impl<T> Borrow<PeriodicCols<T>> for [T] {
fn borrow(&self) -> &PeriodicCols<T> {
debug_assert_eq!(self.len(), NUM_PERIODIC_COLUMNS);
let (prefix, cols, suffix) = unsafe { self.align_to::<PeriodicCols<T>>() };
debug_assert!(prefix.is_empty() && suffix.is_empty() && cols.len() == 1);
&cols[0]
}
}
const _: () = {
assert!(size_of::<PeriodicCols<u8>>() == 18);
assert!(size_of::<HasherPeriodicCols<u8>>() == 16);
assert!(size_of::<BitwisePeriodicCols<u8>>() == 2);
assert!(size_of::<PermutationCols<u8>>() == 19);
assert!(size_of::<ControllerCols<u8>>() == 19);
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn periodic_columns_dimensions() {
let cols = PeriodicCols::periodic_columns();
assert_eq!(cols.len(), NUM_PERIODIC_COLUMNS);
let (hasher_cols, bitwise_cols) = cols.split_at(size_of::<HasherPeriodicCols<u8>>());
for col in hasher_cols {
assert_eq!(col.len(), HASH_CYCLE_LEN);
}
for col in bitwise_cols {
assert_eq!(col.len(), 8);
}
}
#[test]
fn hasher_step_selectors_are_exclusive() {
let h = HasherPeriodicCols::new();
for row in 0..HASH_CYCLE_LEN {
let init_ext = h.is_init_ext[row];
let ext = h.is_ext[row];
let packed_int = h.is_packed_int[row];
let int_ext = h.is_int_ext[row];
assert_eq!(init_ext * (init_ext - Felt::ONE), Felt::ZERO);
assert_eq!(ext * (ext - Felt::ONE), Felt::ZERO);
assert_eq!(packed_int * (packed_int - Felt::ONE), Felt::ZERO);
assert_eq!(int_ext * (int_ext - Felt::ONE), Felt::ZERO);
let sum = init_ext + ext + packed_int + int_ext;
assert!(sum == Felt::ZERO || sum == Felt::ONE, "row {row}: sum = {sum}");
}
}
#[test]
fn external_round_constants_correct() {
let h = HasherPeriodicCols::new();
for lane in 0..STATE_WIDTH {
assert_eq!(h.ark[lane][0], Hasher::ARK_EXT_INITIAL[0][lane]);
}
for r in 1..=3 {
for lane in 0..STATE_WIDTH {
assert_eq!(h.ark[lane][r], Hasher::ARK_EXT_INITIAL[r][lane]);
}
}
for lane in 0..STATE_WIDTH {
assert_eq!(h.ark[lane][11], Hasher::ARK_EXT_TERMINAL[0][lane]);
}
for r in 12..=14 {
for lane in 0..STATE_WIDTH {
assert_eq!(h.ark[lane][r], Hasher::ARK_EXT_TERMINAL[r - 11][lane]);
}
}
}
#[test]
fn internal_round_constants_correct() {
let h = HasherPeriodicCols::new();
for triple in 0..7_usize {
let row = 4 + triple;
for k in 0..3 {
let ark_idx = triple * 3 + k;
assert_eq!(
h.ark[k][row],
Hasher::ARK_INT[ark_idx],
"mismatch at row {row}, int constant {k} (ARK_INT[{ark_idx}])"
);
}
for lane in 3..STATE_WIDTH {
assert_eq!(
h.ark[lane][row],
Felt::ZERO,
"ark[{lane}] nonzero at packed-int row {row}"
);
}
}
}
#[test]
fn boundary_row_all_zero() {
let h = HasherPeriodicCols::new();
for (lane, col) in h.ark.iter().enumerate() {
assert_eq!(col[15], Felt::ZERO, "ark column {lane} nonzero at row 15");
}
}
}