use core::ops::Range;
use super::{Felt, FieldElement, LookupTableRow, StarkField};
use crate::Matrix;
use vm_core::{
chiplets::{
hasher::{
CAPACITY_LEN, DIGEST_LEN, DIGEST_RANGE, LINEAR_HASH_LABEL, MP_VERIFY_LABEL,
MR_UPDATE_NEW_LABEL, MR_UPDATE_OLD_LABEL, RATE_LEN, RETURN_HASH_LABEL,
RETURN_STATE_LABEL, STATE_WIDTH,
},
HASHER_RATE_COL_RANGE, HASHER_STATE_COL_RANGE,
},
utils::collections::Vec,
};
const NUM_HEADER_ALPHAS: usize = 4;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum HasherLookupContext {
Start,
Absorb,
Return,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct HasherLookup {
label: u8,
addr: u32,
index: Felt,
context: HasherLookupContext,
}
impl HasherLookup {
pub(super) fn new(label: u8, addr: u32, index: Felt, context: HasherLookupContext) -> Self {
Self {
label,
addr,
index,
context,
}
}
pub fn cycle(&self) -> u32 {
self.addr - 1
}
fn get_header_value<E: FieldElement<BaseField = Felt>>(&self, alphas: &[E]) -> E {
let transition_label = match self.context {
HasherLookupContext::Start => E::from(self.label) + E::from(16_u8),
_ => E::from(self.label) + E::from(32_u8),
};
alphas[0]
+ alphas[1].mul(transition_label)
+ alphas[2].mul(E::from(self.addr))
+ alphas[3].mul_base(self.index)
}
}
impl LookupTableRow for HasherLookup {
fn to_value<E: FieldElement<BaseField = Felt>>(
&self,
main_trace: &Matrix<Felt>,
alphas: &[E],
) -> E {
let header = self.get_header_value(&alphas[..NUM_HEADER_ALPHAS]);
let alphas = &alphas[NUM_HEADER_ALPHAS..(NUM_HEADER_ALPHAS + STATE_WIDTH)];
match self.context {
HasherLookupContext::Start => {
if self.label == LINEAR_HASH_LABEL {
header
+ build_value(
alphas,
&get_hasher_state_at(self.addr, main_trace, 0..STATE_WIDTH),
)
} else {
let state =
&get_hasher_state_at(self.addr, main_trace, CAPACITY_LEN..STATE_WIDTH);
assert!(
self.label == MR_UPDATE_OLD_LABEL
|| self.label == MR_UPDATE_NEW_LABEL
|| self.label == MP_VERIFY_LABEL,
"unrecognized hash operation"
);
let bit = (self.index.as_int() >> 1) & 1;
let left_word = build_value(&alphas[DIGEST_RANGE], &state[..DIGEST_LEN]);
let right_word = build_value(&alphas[DIGEST_RANGE], &state[DIGEST_LEN..]);
header + E::from(1 - bit).mul(left_word) + E::from(bit).mul(right_word)
}
}
HasherLookupContext::Absorb => {
assert!(
self.label == LINEAR_HASH_LABEL,
"unrecognized hash operation"
);
let (curr_hasher_rate, next_hasher_rate) =
get_adjacent_hasher_rates(self.addr, main_trace);
let next_state_value = build_value(&alphas[CAPACITY_LEN..], &next_hasher_rate);
let state_value = build_value(&alphas[CAPACITY_LEN..], &curr_hasher_rate);
header + next_state_value - state_value
}
HasherLookupContext::Return => {
if self.label == RETURN_STATE_LABEL {
header
+ build_value(
alphas,
&get_hasher_state_at(self.addr, main_trace, 0..STATE_WIDTH),
)
} else {
assert!(
self.label == RETURN_HASH_LABEL,
"unrecognized hash operation"
);
header
+ build_value(
&alphas[DIGEST_RANGE],
&get_hasher_state_at(self.addr, main_trace, DIGEST_RANGE),
)
}
}
}
}
}
fn build_value<E: FieldElement<BaseField = Felt>>(alphas: &[E], elements: &[Felt]) -> E {
let mut value = E::ZERO;
for (&alpha, &element) in alphas.iter().zip(elements.iter()) {
value += alpha.mul_base(element);
}
value
}
fn get_hasher_state_at(addr: u32, main_trace: &Matrix<Felt>, col_range: Range<usize>) -> Vec<Felt> {
let row = get_row_from_addr(addr);
col_range
.map(|col| main_trace.get(HASHER_STATE_COL_RANGE.start + col, row))
.collect::<Vec<Felt>>()
}
fn get_adjacent_hasher_rates(
addr: u32,
main_trace: &Matrix<Felt>,
) -> ([Felt; RATE_LEN], [Felt; RATE_LEN]) {
let row = get_row_from_addr(addr);
let mut current = [Felt::ZERO; RATE_LEN];
let mut next = [Felt::ZERO; RATE_LEN];
for (idx, col_idx) in HASHER_RATE_COL_RANGE.enumerate() {
let column = main_trace.get_column(col_idx);
current[idx] = column[row];
next[idx] = column[row + 1];
}
(current, next)
}
fn get_row_from_addr(addr: u32) -> usize {
addr as usize - 1
}