use alloc::boxed::Box;
use miden_air::trace::{
chiplets::hasher::STATE_WIDTH,
log_precompile::{STATE_CAP_RANGE, STATE_RATE_0_RANGE, STATE_RATE_1_RANGE},
};
use super::{DOUBLE_WORD_SIZE, WORD_SIZE_FELT};
use crate::{
ContextId, Felt, MemoryError, ONE, RowIndex, Word, ZERO,
errors::{CryptoError, MerklePathVerificationFailedInner, OperationError},
field::{BasedVectorSpace, QuadFelt},
mast::MastForest,
processor::{
AdviceProviderInterface, HasherInterface, MemoryInterface, Processor, StackInterface,
SystemInterface,
},
tracer::{OperationHelperRegisters, Tracer},
};
#[cfg(test)]
mod tests;
#[inline(always)]
pub(super) fn op_hperm<P: Processor, T: Tracer>(
processor: &mut P,
tracer: &mut T,
) -> Result<OperationHelperRegisters, OperationError> {
let double_word: [Felt; 8] = processor.stack().get_double_word(0);
let word: Word = processor.stack().get_word(8);
let input_state: [Felt; STATE_WIDTH] = [
double_word[0],
double_word[1],
double_word[2],
double_word[3],
double_word[4],
double_word[5],
double_word[6],
double_word[7],
word[0],
word[1],
word[2],
word[3],
];
let (addr, output_state) = processor.hasher().permute(input_state)?;
let r0: Word = output_state[STATE_RATE_0_RANGE].try_into().expect("r0 slice has length 4");
let r1: Word = output_state[STATE_RATE_1_RANGE].try_into().expect("r1 slice has length 4");
let cap: Word = output_state[STATE_CAP_RANGE].try_into().expect("cap slice has length 4");
processor.stack_mut().set_word(0, &r0);
processor.stack_mut().set_word(4, &r1);
processor.stack_mut().set_word(8, &cap);
tracer.record_hasher_permute(input_state, output_state);
Ok(OperationHelperRegisters::HPerm { addr })
}
#[inline(always)]
pub(super) fn op_mpverify<P: Processor, T: Tracer>(
processor: &mut P,
err_code: Felt,
program: &MastForest,
tracer: &mut T,
) -> Result<OperationHelperRegisters, CryptoError> {
let node = processor.stack().get_word(0);
let depth = processor.stack().get(4);
let index = processor.stack().get(5);
let root = processor.stack().get_word(6);
let path = processor.advice_provider().get_merkle_path(root, depth, index)?;
tracer.record_hasher_build_merkle_root(node, path.as_ref(), index, root);
let addr = processor.hasher().verify_merkle_root(root, node, path.as_ref(), index, || {
let err_msg = program.resolve_error_message(err_code);
OperationError::MerklePathVerificationFailed {
inner: Box::new(MerklePathVerificationFailedInner {
value: node,
index,
root,
err_code,
err_msg,
}),
}
})?;
Ok(OperationHelperRegisters::MerklePath { addr })
}
#[inline(always)]
pub(super) fn op_mrupdate<P: Processor, T: Tracer>(
processor: &mut P,
tracer: &mut T,
) -> Result<OperationHelperRegisters, CryptoError> {
let old_value = processor.stack().get_word(0);
let depth = processor.stack().get(4);
let index = processor.stack().get(5);
let claimed_old_root = processor.stack().get_word(6);
let new_value = processor.stack().get_word(10);
let path = processor.advice_provider_mut().update_merkle_node(
claimed_old_root,
depth,
index,
new_value,
)?;
if let Some(path) = &path
&& path.len() != depth.as_canonical_u64() as usize
{
return Err(OperationError::InvalidMerklePathLength { path_len: path.len(), depth }.into());
}
let (addr, new_root) = processor.hasher().update_merkle_root(
claimed_old_root,
old_value,
new_value,
path.as_ref(),
index,
|| OperationError::MerklePathVerificationFailed {
inner: Box::new(MerklePathVerificationFailedInner {
value: old_value,
index,
root: claimed_old_root,
err_code: ZERO,
err_msg: None,
}),
},
)?;
tracer.record_hasher_update_merkle_root(
old_value,
new_value,
path.as_ref(),
index,
claimed_old_root,
new_root,
);
processor.stack_mut().set_word(0, &new_root);
Ok(OperationHelperRegisters::MerklePath { addr })
}
#[inline(always)]
pub(super) fn op_horner_eval_base<P: Processor, T: Tracer>(
processor: &mut P,
tracer: &mut T,
) -> Result<OperationHelperRegisters, crate::MemoryError> {
const ALPHA_ADDR_INDEX: usize = 13;
const ACC_LOW_INDEX: usize = 14;
const ACC_HIGH_INDEX: usize = 15;
let clk = processor.system().clock();
let ctx = processor.system().ctx();
let alpha = {
let addr = processor.stack().get(ALPHA_ADDR_INDEX);
let eval_point_0 = processor.memory_mut().read_element(ctx, addr)?;
let eval_point_1 = processor.memory_mut().read_element(ctx, addr + ONE)?;
tracer.record_memory_read_element_pair(
eval_point_0,
addr,
eval_point_1,
addr + ONE,
ctx,
clk,
);
QuadFelt::from_basis_coefficients_fn(|i: usize| [eval_point_0, eval_point_1][i])
};
let coef: [Felt; 8] = processor.stack().get_double_word(0);
let c0 = QuadFelt::from(coef[0]);
let c1 = QuadFelt::from(coef[1]);
let c2 = QuadFelt::from(coef[2]);
let c3 = QuadFelt::from(coef[3]);
let c4 = QuadFelt::from(coef[4]);
let c5 = QuadFelt::from(coef[5]);
let c6 = QuadFelt::from(coef[6]);
let c7 = QuadFelt::from(coef[7]);
let acc_low = processor.stack().get(ACC_LOW_INDEX);
let acc_high = processor.stack().get(ACC_HIGH_INDEX);
let acc = QuadFelt::from_basis_coefficients_fn(|i: usize| [acc_low, acc_high][i]);
let tmp0 = (acc * alpha + c0) * alpha + c1;
let tmp1 = ((tmp0 * alpha + c2) * alpha + c3) * alpha + c4;
let acc_new = ((tmp1 * alpha + c5) * alpha + c6) * alpha + c7;
let acc_new_base_elements = acc_new.as_basis_coefficients_slice();
processor.stack_mut().set(ACC_HIGH_INDEX, acc_new_base_elements[1]);
processor.stack_mut().set(ACC_LOW_INDEX, acc_new_base_elements[0]);
Ok(OperationHelperRegisters::HornerEvalBase { alpha, tmp0, tmp1 })
}
#[inline(always)]
pub(super) fn op_horner_eval_ext<P: Processor, T: Tracer>(
processor: &mut P,
tracer: &mut T,
) -> Result<OperationHelperRegisters, crate::MemoryError> {
const ALPHA_ADDR_INDEX: usize = 13;
const ACC_LOW_INDEX: usize = 14;
const ACC_HIGH_INDEX: usize = 15;
let clk = processor.system().clock();
let ctx = processor.system().ctx();
let coef: [QuadFelt; 4] = core::array::from_fn(|j| {
let lo = processor.stack().get(2 * j);
let hi = processor.stack().get(2 * j + 1);
QuadFelt::from_basis_coefficients_fn(|i: usize| [lo, hi][i])
});
let (alpha, k0, k1) = {
let addr = processor.stack().get(ALPHA_ADDR_INDEX);
let word = processor.memory_mut().read_word(ctx, addr, clk)?;
tracer.record_memory_read_word(
word,
addr,
processor.system().ctx(),
processor.system().clock(),
);
(
QuadFelt::from_basis_coefficients_fn(|i: usize| [word[0], word[1]][i]),
word[2],
word[3],
)
};
let acc_low = processor.stack().get(ACC_LOW_INDEX);
let acc_high = processor.stack().get(ACC_HIGH_INDEX);
let acc_old = QuadFelt::from_basis_coefficients_fn(|i: usize| [acc_low, acc_high][i]);
let acc_tmp = coef.iter().take(2).fold(acc_old, |acc, coef| *coef + alpha * acc);
let acc_new = coef.iter().skip(2).fold(acc_tmp, |acc, coef| *coef + alpha * acc);
let acc_new_base_elements = acc_new.as_basis_coefficients_slice();
processor.stack_mut().set(ACC_HIGH_INDEX, acc_new_base_elements[1]);
processor.stack_mut().set(ACC_LOW_INDEX, acc_new_base_elements[0]);
Ok(OperationHelperRegisters::HornerEvalExt { alpha, k0, k1, acc_tmp })
}
#[inline(always)]
pub(super) fn op_log_precompile<P: Processor, T: Tracer>(
processor: &mut P,
tracer: &mut T,
) -> Result<OperationHelperRegisters, OperationError> {
let comm: Word = processor.stack().get_word(0);
let tag: Word = processor.stack().get_word(4);
let cap_prev = processor.precompile_transcript_state();
let mut hasher_state: [Felt; STATE_WIDTH] = [ZERO; 12];
hasher_state[STATE_RATE_0_RANGE].copy_from_slice(comm.as_slice());
hasher_state[STATE_RATE_1_RANGE].copy_from_slice(tag.as_slice());
hasher_state[STATE_CAP_RANGE].copy_from_slice(cap_prev.as_slice());
let (addr, output_state) = processor.hasher().permute(hasher_state)?;
let r0: Word = output_state[STATE_RATE_0_RANGE.clone()]
.try_into()
.expect("r0 slice has length 4");
let r1: Word = output_state[STATE_RATE_1_RANGE.clone()]
.try_into()
.expect("r1 slice has length 4");
let cap_next: Word = output_state[STATE_CAP_RANGE.clone()]
.try_into()
.expect("cap_next slice has length 4");
processor.set_precompile_transcript_state(cap_next);
processor.stack_mut().set_word(0, &r0);
processor.stack_mut().set_word(4, &r1);
processor.stack_mut().set_word(8, &cap_next);
tracer.record_hasher_permute(hasher_state, output_state);
Ok(OperationHelperRegisters::LogPrecompile { addr, cap_prev })
}
#[inline(always)]
pub(super) fn op_crypto_stream<P: Processor, T: Tracer>(
processor: &mut P,
tracer: &mut T,
) -> Result<OperationHelperRegisters, crate::MemoryError> {
const SRC_PTR_IDX: usize = 12;
const DST_PTR_IDX: usize = 13;
let ctx = processor.system().ctx();
let clk = processor.system().clock();
let src_addr = processor.stack().get(SRC_PTR_IDX);
let dst_addr = processor.stack().get(DST_PTR_IDX);
validate_dual_word_stream_addrs(src_addr, dst_addr, ctx, clk)?;
let src_addr_word2 = src_addr + WORD_SIZE_FELT;
let plaintext_word1 = processor.memory_mut().read_word(ctx, src_addr, clk)?;
let plaintext_word2 = processor.memory_mut().read_word(ctx, src_addr_word2, clk)?;
let rate: [Felt; 8] = processor.stack().get_double_word(0);
let ciphertext_word1 = [
plaintext_word1[0] + rate[0],
plaintext_word1[1] + rate[1],
plaintext_word1[2] + rate[2],
plaintext_word1[3] + rate[3],
]
.into();
let ciphertext_word2 = [
plaintext_word2[0] + rate[4],
plaintext_word2[1] + rate[5],
plaintext_word2[2] + rate[6],
plaintext_word2[3] + rate[7],
]
.into();
let dst_addr_word2 = dst_addr + WORD_SIZE_FELT;
processor.memory_mut().write_word(ctx, dst_addr, clk, ciphertext_word1)?;
processor.memory_mut().write_word(ctx, dst_addr_word2, clk, ciphertext_word2)?;
tracer.record_crypto_stream(
[plaintext_word1, plaintext_word2],
src_addr,
[ciphertext_word1, ciphertext_word2],
dst_addr,
ctx,
clk,
);
processor.stack_mut().set_word(0, &ciphertext_word1);
processor.stack_mut().set_word(4, &ciphertext_word2);
processor.stack_mut().set(SRC_PTR_IDX, src_addr + DOUBLE_WORD_SIZE);
processor.stack_mut().set(DST_PTR_IDX, dst_addr + DOUBLE_WORD_SIZE);
Ok(OperationHelperRegisters::Empty)
}
#[inline(always)]
fn validate_dual_word_stream_addrs(
src_addr: Felt,
dst_addr: Felt,
ctx: ContextId,
clk: RowIndex,
) -> Result<(), MemoryError> {
let src_addr_u64 = src_addr.as_canonical_u64();
let dst_addr_u64 = dst_addr.as_canonical_u64();
let src_addr_u32 = u32::try_from(src_addr_u64)
.map_err(|_| MemoryError::AddressOutOfBounds { addr: src_addr_u64 })?;
let src_end = src_addr_u32
.checked_add(8)
.ok_or(MemoryError::AddressOutOfBounds { addr: src_addr_u64 })?;
let dst_addr_u32 = u32::try_from(dst_addr_u64)
.map_err(|_| MemoryError::AddressOutOfBounds { addr: dst_addr_u64 })?;
let dst_end = dst_addr_u32
.checked_add(8)
.ok_or(MemoryError::AddressOutOfBounds { addr: dst_addr_u64 })?;
if src_addr_u32 < dst_end && dst_addr_u32 < src_end {
let dst_word2 = dst_addr_u32 + 4; let overlap_first = (dst_addr_u32 >= src_addr_u32) && (dst_addr_u32 < src_end);
let offending_addr = if overlap_first { dst_addr_u32 } else { dst_word2 };
return Err(MemoryError::IllegalMemoryAccess {
ctx,
addr: offending_addr,
clk: Felt::from(clk),
});
}
Ok(())
}