use super::{
ChipletsBus, ExecutionError, Felt, FieldElement, LookupTableRow, StarkField, TraceFragment,
Vec, BITWISE_AND_LABEL, BITWISE_XOR_LABEL,
};
use crate::{utils::get_trace_len, Matrix};
use vm_core::chiplets::bitwise::{
A_COL_IDX, A_COL_RANGE, BITWISE_AND, BITWISE_XOR, B_COL_IDX, B_COL_RANGE, OP_CYCLE_LEN,
OUTPUT_COL_IDX, PREV_OUTPUT_COL_IDX, TRACE_WIDTH,
};
#[cfg(test)]
mod tests;
const INIT_TRACE_CAPACITY: usize = 128;
pub struct Bitwise {
trace: [Vec<Felt>; TRACE_WIDTH],
}
impl Bitwise {
pub fn new() -> Self {
let trace = (0..TRACE_WIDTH)
.map(|_| Vec::with_capacity(INIT_TRACE_CAPACITY))
.collect::<Vec<_>>()
.try_into()
.expect("failed to convert vector to array");
Self { trace }
}
pub fn trace_len(&self) -> usize {
get_trace_len(&self.trace)
}
pub fn u32and(&mut self, a: Felt, b: Felt) -> Result<Felt, ExecutionError> {
let a = assert_u32(a)?.as_int();
let b = assert_u32(b)?.as_int();
let mut result = 0u64;
for bit_offset in (0..32).step_by(4).rev() {
self.trace[PREV_OUTPUT_COL_IDX].push(Felt::new(result));
let a = a >> bit_offset;
let b = b >> bit_offset;
self.add_bitwise_trace_row(BITWISE_AND, a, b);
let result_4_bit = (a & b) & 0xF;
result = (result << 4) | result_4_bit;
self.trace[OUTPUT_COL_IDX].push(Felt::new(result));
}
Ok(Felt::new(result))
}
pub fn u32xor(&mut self, a: Felt, b: Felt) -> Result<Felt, ExecutionError> {
let a = assert_u32(a)?.as_int();
let b = assert_u32(b)?.as_int();
let mut result = 0u64;
for bit_offset in (0..32).step_by(4).rev() {
self.trace[PREV_OUTPUT_COL_IDX].push(Felt::new(result));
let a = a >> bit_offset;
let b = b >> bit_offset;
self.add_bitwise_trace_row(BITWISE_XOR, a, b);
let result_4_bit = (a ^ b) & 0xF;
result = (result << 4) | result_4_bit;
self.trace[OUTPUT_COL_IDX].push(Felt::new(result));
}
Ok(Felt::new(result))
}
pub fn fill_trace(
self,
trace: &mut TraceFragment,
chiplets_bus: &mut ChipletsBus,
bitwise_start_row: usize,
) {
debug_assert_eq!(self.trace_len(), trace.len(), "inconsistent trace lengths");
debug_assert_eq!(TRACE_WIDTH, trace.width(), "inconsistent trace widths");
for row in ((OP_CYCLE_LEN - 1)..self.trace_len()).step_by(OP_CYCLE_LEN) {
let a = self.trace[A_COL_IDX][row];
let b = self.trace[B_COL_IDX][row];
let z = self.trace[OUTPUT_COL_IDX][row];
let op_selector: Felt = self.trace[0][row];
let label = if op_selector == BITWISE_AND {
BITWISE_AND_LABEL
} else {
assert!(
op_selector == BITWISE_XOR,
"Unrecognized operation selectors in Bitwise chiplet"
);
BITWISE_XOR_LABEL
};
let lookup = BitwiseLookup::new(label, a, b, z);
chiplets_bus.provide_bitwise_operation(lookup, (bitwise_start_row + row) as u32);
}
for (out_column, column) in trace.columns().zip(self.trace) {
out_column.copy_from_slice(&column);
}
}
fn add_bitwise_trace_row(&mut self, selector: Felt, a: u64, b: u64) {
self.trace[0].push(selector);
self.trace[A_COL_IDX].push(Felt::new(a));
self.trace[B_COL_IDX].push(Felt::new(b));
self.trace[A_COL_RANGE.start].push(Felt::new(a & 1));
self.trace[A_COL_RANGE.start + 1].push(Felt::new((a >> 1) & 1));
self.trace[A_COL_RANGE.start + 2].push(Felt::new((a >> 2) & 1));
self.trace[A_COL_RANGE.start + 3].push(Felt::new((a >> 3) & 1));
self.trace[B_COL_RANGE.start].push(Felt::new(b & 1));
self.trace[B_COL_RANGE.start + 1].push(Felt::new((b >> 1) & 1));
self.trace[B_COL_RANGE.start + 2].push(Felt::new((b >> 2) & 1));
self.trace[B_COL_RANGE.start + 3].push(Felt::new((b >> 3) & 1));
}
}
impl Default for Bitwise {
fn default() -> Self {
Self::new()
}
}
pub fn assert_u32(value: Felt) -> Result<Felt, ExecutionError> {
let val_u64 = value.as_int();
if val_u64 > u32::MAX.into() {
Err(ExecutionError::NotU32Value(value))
} else {
Ok(value)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct BitwiseLookup {
op_id: Felt,
a: Felt,
b: Felt,
z: Felt,
}
impl BitwiseLookup {
pub fn new(op_id: Felt, a: Felt, b: Felt, z: Felt) -> Self {
Self { op_id, a, b, z }
}
}
impl LookupTableRow for BitwiseLookup {
fn to_value<E: FieldElement<BaseField = Felt>>(
&self,
_main_trace: &Matrix<Felt>,
alphas: &[E],
) -> E {
alphas[0]
+ alphas[1].mul_base(self.op_id)
+ alphas[2].mul_base(self.a)
+ alphas[3].mul_base(self.b)
+ alphas[4].mul_base(self.z)
}
}