use core::{
borrow::{Borrow, BorrowMut},
mem::size_of,
};
use super::{
chiplets::columns::{
AceCols, AceEvalCols, AceReadCols, BitwiseCols, ControllerCols, KernelRomCols, MemoryCols,
PermutationCols, borrow_chiplet,
},
decoder::columns::DecoderCols,
range::columns::RangeCols,
stack::columns::StackCols,
system::columns::SystemCols,
};
use crate::trace::{CHIPLETS_WIDTH, TRACE_WIDTH};
#[repr(C)]
pub struct MainCols<T> {
pub system: SystemCols<T>,
pub decoder: DecoderCols<T>,
pub stack: StackCols<T>,
pub range: RangeCols<T>,
pub(crate) chiplets: [T; CHIPLETS_WIDTH - 1],
pub s_perm: T,
}
impl<T> MainCols<T> {
pub fn chiplet_selectors(&self) -> [T; 6]
where
T: Copy,
{
[
self.chiplets[0],
self.s_perm,
self.chiplets[1],
self.chiplets[2],
self.chiplets[3],
self.chiplets[4],
]
}
pub fn bitwise(&self) -> &BitwiseCols<T> {
borrow_chiplet(&self.chiplets[2..15])
}
pub fn memory(&self) -> &MemoryCols<T> {
borrow_chiplet(&self.chiplets[3..18])
}
pub fn ace(&self) -> &AceCols<T> {
borrow_chiplet(&self.chiplets[4..])
}
pub fn kernel_rom(&self) -> &KernelRomCols<T> {
borrow_chiplet(&self.chiplets[5..10])
}
pub fn permutation(&self) -> &PermutationCols<T> {
borrow_chiplet(&self.chiplets[1..])
}
pub fn controller(&self) -> &ControllerCols<T> {
borrow_chiplet(&self.chiplets[1..])
}
}
impl<T> Borrow<MainCols<T>> for [T] {
fn borrow(&self) -> &MainCols<T> {
debug_assert_eq!(self.len(), TRACE_WIDTH);
let (prefix, shorts, suffix) = unsafe { self.align_to::<MainCols<T>>() };
debug_assert!(prefix.is_empty() && suffix.is_empty() && shorts.len() == 1);
&shorts[0]
}
}
impl<T> BorrowMut<MainCols<T>> for [T] {
fn borrow_mut(&mut self) -> &mut MainCols<T> {
debug_assert_eq!(self.len(), TRACE_WIDTH);
let (prefix, shorts, suffix) = unsafe { self.align_to_mut::<MainCols<T>>() };
debug_assert!(prefix.is_empty() && suffix.is_empty() && shorts.len() == 1);
&mut shorts[0]
}
}
pub const fn indices_arr<const N: usize>() -> [usize; N] {
let mut arr = [0; N];
let mut i = 0;
while i < N {
arr[i] = i;
i += 1;
}
arr
}
pub const NUM_MAIN_COLS: usize = size_of::<MainCols<u8>>();
#[allow(dead_code)]
pub const MAIN_COL_MAP: MainCols<usize> = {
assert!(NUM_MAIN_COLS == TRACE_WIDTH);
unsafe { core::mem::transmute(indices_arr::<NUM_MAIN_COLS>()) }
};
pub const NUM_SYSTEM_COLS: usize = size_of::<SystemCols<u8>>();
pub const NUM_DECODER_COLS: usize = size_of::<DecoderCols<u8>>();
pub const NUM_STACK_COLS: usize = size_of::<StackCols<u8>>();
pub const NUM_RANGE_COLS: usize = size_of::<RangeCols<u8>>();
pub const NUM_BITWISE_COLS: usize = size_of::<BitwiseCols<u8>>();
pub const NUM_MEMORY_COLS: usize = size_of::<MemoryCols<u8>>();
pub const NUM_ACE_COLS: usize = size_of::<AceCols<u8>>();
pub const NUM_ACE_READ_COLS: usize = size_of::<AceReadCols<u8>>();
pub const NUM_ACE_EVAL_COLS: usize = size_of::<AceEvalCols<u8>>();
pub const NUM_KERNEL_ROM_COLS: usize = size_of::<KernelRomCols<u8>>();
const _: () = assert!(NUM_MAIN_COLS == TRACE_WIDTH);
const _: () = assert!(NUM_SYSTEM_COLS == 6);
const _: () = assert!(NUM_DECODER_COLS == 24);
const _: () = assert!(NUM_STACK_COLS == 19);
const _: () = assert!(NUM_RANGE_COLS == 2);
const _: () = assert!(NUM_BITWISE_COLS == 13);
const _: () = assert!(NUM_MEMORY_COLS == 15);
const _: () = assert!(NUM_ACE_COLS == 16);
const _: () = assert!(NUM_ACE_READ_COLS == 4);
const _: () = assert!(NUM_ACE_EVAL_COLS == 4);
const _: () = assert!(NUM_KERNEL_ROM_COLS == 5);
#[cfg(test)]
mod tests {
use super::*;
use crate::trace::{
CHIPLETS_OFFSET, CLK_COL_IDX, CTX_COL_IDX, DECODER_TRACE_OFFSET, FN_HASH_OFFSET,
STACK_TRACE_OFFSET, decoder, range, stack,
};
#[test]
fn col_map_system() {
assert_eq!(MAIN_COL_MAP.system.clk, CLK_COL_IDX);
assert_eq!(MAIN_COL_MAP.system.ctx, CTX_COL_IDX);
assert_eq!(MAIN_COL_MAP.system.fn_hash[0], FN_HASH_OFFSET);
assert_eq!(MAIN_COL_MAP.system.fn_hash[3], FN_HASH_OFFSET + 3);
}
#[test]
fn col_map_decoder() {
assert_eq!(MAIN_COL_MAP.decoder.addr, DECODER_TRACE_OFFSET + decoder::ADDR_COL_IDX);
assert_eq!(MAIN_COL_MAP.decoder.op_bits[0], DECODER_TRACE_OFFSET + decoder::OP_BITS_OFFSET);
assert_eq!(
MAIN_COL_MAP.decoder.op_bits[6],
DECODER_TRACE_OFFSET + decoder::OP_BITS_OFFSET + 6
);
assert_eq!(
MAIN_COL_MAP.decoder.hasher_state[0],
DECODER_TRACE_OFFSET + decoder::HASHER_STATE_OFFSET
);
assert_eq!(MAIN_COL_MAP.decoder.in_span, DECODER_TRACE_OFFSET + decoder::IN_SPAN_COL_IDX);
assert_eq!(
MAIN_COL_MAP.decoder.group_count,
DECODER_TRACE_OFFSET + decoder::GROUP_COUNT_COL_IDX
);
assert_eq!(MAIN_COL_MAP.decoder.op_index, DECODER_TRACE_OFFSET + decoder::OP_INDEX_COL_IDX);
assert_eq!(
MAIN_COL_MAP.decoder.batch_flags[0],
DECODER_TRACE_OFFSET + decoder::OP_BATCH_FLAGS_OFFSET
);
assert_eq!(
MAIN_COL_MAP.decoder.extra[0],
DECODER_TRACE_OFFSET + decoder::OP_BITS_EXTRA_COLS_OFFSET
);
}
#[test]
fn col_map_stack() {
assert_eq!(MAIN_COL_MAP.stack.top[0], STACK_TRACE_OFFSET + stack::STACK_TOP_OFFSET);
assert_eq!(MAIN_COL_MAP.stack.top[15], STACK_TRACE_OFFSET + 15);
assert_eq!(MAIN_COL_MAP.stack.b0, STACK_TRACE_OFFSET + stack::B0_COL_IDX);
assert_eq!(MAIN_COL_MAP.stack.b1, STACK_TRACE_OFFSET + stack::B1_COL_IDX);
assert_eq!(MAIN_COL_MAP.stack.h0, STACK_TRACE_OFFSET + stack::H0_COL_IDX);
}
#[test]
fn col_map_range() {
assert_eq!(MAIN_COL_MAP.range.multiplicity, range::M_COL_IDX);
assert_eq!(MAIN_COL_MAP.range.value, range::V_COL_IDX);
}
#[test]
fn col_map_chiplets() {
assert_eq!(MAIN_COL_MAP.chiplets[0], CHIPLETS_OFFSET);
assert_eq!(MAIN_COL_MAP.chiplets[19], CHIPLETS_OFFSET + 19);
assert_eq!(MAIN_COL_MAP.s_perm, CHIPLETS_OFFSET + 20);
}
}