use crate::system::ContextId;
use super::{
crypto::MerklePath, utils, ChipletsTrace, ExecutionError, Felt, FieldElement, RangeChecker,
TraceFragment, Word, CHIPLETS_WIDTH, EMPTY_WORD, ONE, ZERO,
};
use alloc::vec::Vec;
use miden_air::trace::chiplets::hasher::{Digest, HasherState};
use vm_core::{code_blocks::OpBatch, Kernel};
mod bitwise;
use bitwise::Bitwise;
mod hasher;
#[cfg(test)]
pub(crate) use hasher::init_state_from_words;
use hasher::Hasher;
mod memory;
use memory::Memory;
mod kernel_rom;
use kernel_rom::KernelRom;
mod aux_trace;
pub(crate) use aux_trace::AuxTraceBuilder;
#[cfg(test)]
mod tests;
pub struct Chiplets {
clk: u32,
hasher: Hasher,
bitwise: Bitwise,
memory: Memory,
kernel_rom: KernelRom,
}
impl Chiplets {
pub fn new(kernel: Kernel) -> Self {
Self {
clk: 0,
hasher: Hasher::default(),
bitwise: Bitwise::default(),
memory: Memory::default(),
kernel_rom: KernelRom::new(kernel),
}
}
pub fn trace_len(&self) -> usize {
self.hasher.trace_len()
+ self.bitwise.trace_len()
+ self.memory.trace_len()
+ self.kernel_rom.trace_len()
+ 1
}
pub fn bitwise_start(&self) -> usize {
self.hasher.trace_len()
}
pub fn memory_start(&self) -> usize {
self.bitwise_start() + self.bitwise.trace_len()
}
pub fn kernel_rom_start(&self) -> usize {
self.memory_start() + self.memory.trace_len()
}
pub fn padding_start(&self) -> usize {
self.kernel_rom_start() + self.kernel_rom.trace_len()
}
pub const fn kernel(&self) -> &Kernel {
self.kernel_rom.kernel()
}
pub fn permute(&mut self, state: HasherState) -> (Felt, HasherState) {
let (addr, return_state) = self.hasher.permute(state);
(addr, return_state)
}
pub fn build_merkle_root(
&mut self,
value: Word,
path: &MerklePath,
index: Felt,
) -> (Felt, Word) {
let (addr, root) = self.hasher.build_merkle_root(value, path, index);
(addr, root)
}
pub fn update_merkle_root(
&mut self,
old_value: Word,
new_value: Word,
path: &MerklePath,
index: Felt,
) -> MerkleRootUpdate {
self.hasher.update_merkle_root(old_value, new_value, path, index)
}
pub fn hash_control_block(
&mut self,
h1: Word,
h2: Word,
domain: Felt,
expected_hash: Digest,
) -> Felt {
let (addr, result) = self.hasher.hash_control_block(h1, h2, domain, expected_hash);
debug_assert_eq!(expected_hash, result.into());
addr
}
pub fn hash_span_block(&mut self, op_batches: &[OpBatch], expected_hash: Digest) -> Felt {
let (addr, result) = self.hasher.hash_span_block(op_batches, expected_hash);
debug_assert_eq!(expected_hash, result.into());
addr
}
pub fn u32and(&mut self, a: Felt, b: Felt) -> Result<Felt, ExecutionError> {
let result = self.bitwise.u32and(a, b)?;
Ok(result)
}
pub fn u32xor(&mut self, a: Felt, b: Felt) -> Result<Felt, ExecutionError> {
let result = self.bitwise.u32xor(a, b)?;
Ok(result)
}
pub fn read_mem(&mut self, ctx: ContextId, addr: u32) -> Word {
self.memory.read(ctx, addr, self.clk)
}
pub fn read_mem_double(&mut self, ctx: ContextId, addr: u32) -> [Word; 2] {
let addr2 = addr + 1;
[self.memory.read(ctx, addr, self.clk), self.memory.read(ctx, addr2, self.clk)]
}
pub fn write_mem(&mut self, ctx: ContextId, addr: u32, word: Word) {
self.memory.write(ctx, addr, self.clk, word);
}
pub fn write_mem_element(&mut self, ctx: ContextId, addr: u32, value: Felt) -> Word {
let old_word = self.memory.get_old_value(ctx, addr);
let new_word = [value, old_word[1], old_word[2], old_word[3]];
self.memory.write(ctx, addr, self.clk, new_word);
old_word
}
pub fn write_mem_double(&mut self, ctx: ContextId, addr: u32, words: [Word; 2]) {
let addr2 = addr + 1;
self.memory.write(ctx, addr, self.clk, words[0]);
self.memory.write(ctx, addr2, self.clk, words[1]);
}
pub fn get_mem_value(&self, ctx: ContextId, addr: u32) -> Option<Word> {
self.memory.get_value(ctx, addr)
}
pub fn get_mem_state_at(&self, ctx: ContextId, clk: u32) -> Vec<(u64, Word)> {
self.memory.get_state_at(ctx, clk)
}
#[cfg(test)]
pub fn get_mem_size(&self) -> usize {
self.memory.size()
}
pub fn access_kernel_proc(&mut self, proc_hash: Digest) -> Result<(), ExecutionError> {
self.kernel_rom.access_proc(proc_hash)?;
Ok(())
}
pub fn advance_clock(&mut self) {
self.clk += 1;
}
pub fn append_range_checks(&self, range_checker: &mut RangeChecker) {
self.memory.append_range_checks(self.memory_start(), range_checker);
}
pub fn into_trace(self, trace_len: usize, num_rand_rows: usize) -> ChipletsTrace {
assert!(self.trace_len() + num_rand_rows <= trace_len, "target trace length too small");
let mut trace = (0..CHIPLETS_WIDTH)
.map(|_| Felt::zeroed_vector(trace_len))
.collect::<Vec<_>>()
.try_into()
.expect("failed to convert vector to array");
self.fill_trace(&mut trace);
ChipletsTrace {
trace,
aux_builder: AuxTraceBuilder::default(),
}
}
fn fill_trace(self, trace: &mut [Vec<Felt>; CHIPLETS_WIDTH]) {
let bitwise_start = self.bitwise_start();
let memory_start = self.memory_start();
let kernel_rom_start = self.kernel_rom_start();
let padding_start = self.padding_start();
let Chiplets {
clk: _,
hasher,
bitwise,
memory,
kernel_rom,
} = self;
trace[0][bitwise_start..].fill(ONE);
trace[1][memory_start..].fill(ONE);
trace[2][kernel_rom_start..].fill(ONE);
trace[3][padding_start..].fill(ONE);
let mut hasher_fragment = TraceFragment::new(CHIPLETS_WIDTH);
let mut bitwise_fragment = TraceFragment::new(CHIPLETS_WIDTH);
let mut memory_fragment = TraceFragment::new(CHIPLETS_WIDTH);
let mut kernel_rom_fragment = TraceFragment::new(CHIPLETS_WIDTH);
for (column_num, column) in trace.iter_mut().enumerate().skip(1) {
match column_num {
1 | 15..=17 => {
hasher_fragment.push_column_slice(column, hasher.trace_len());
}
2 => {
let rest = hasher_fragment.push_column_slice(column, hasher.trace_len());
bitwise_fragment.push_column_slice(rest, bitwise.trace_len());
}
3 | 10..=14 => {
let rest = hasher_fragment.push_column_slice(column, hasher.trace_len());
let rest = bitwise_fragment.push_column_slice(rest, bitwise.trace_len());
memory_fragment.push_column_slice(rest, memory.trace_len());
}
4..=9 => {
let rest = hasher_fragment.push_column_slice(column, hasher.trace_len());
let rest = bitwise_fragment.push_column_slice(rest, bitwise.trace_len());
let rest = memory_fragment.push_column_slice(rest, memory.trace_len());
kernel_rom_fragment.push_column_slice(rest, kernel_rom.trace_len());
}
_ => panic!("invalid column index"),
}
}
hasher.fill_trace(&mut hasher_fragment);
bitwise.fill_trace(&mut bitwise_fragment);
memory.fill_trace(&mut memory_fragment);
kernel_rom.fill_trace(&mut kernel_rom_fragment);
}
}
#[derive(Debug, Copy, Clone)]
pub struct MerkleRootUpdate {
address: Felt,
old_root: Word,
new_root: Word,
}
impl MerkleRootUpdate {
pub fn get_address(&self) -> Felt {
self.address
}
pub fn get_old_root(&self) -> Word {
self.old_root
}
pub fn get_new_root(&self) -> Word {
self.new_root
}
}