use alloc::vec::Vec;
use miden_air::trace::chiplets::bitwise::{
A_COL_IDX, A_COL_RANGE, B_COL_IDX, B_COL_RANGE, BITWISE_AND, BITWISE_XOR, OUTPUT_COL_IDX,
PREV_OUTPUT_COL_IDX, TRACE_WIDTH,
};
use crate::{Felt, operation::OperationError, trace::TraceFragment};
#[cfg(test)]
mod tests;
const INIT_TRACE_CAPACITY: usize = 128;
#[derive(Debug)]
pub struct Bitwise {
trace: [Vec<Felt>; TRACE_WIDTH],
}
impl Bitwise {
pub fn new() -> Self {
let trace = (0..TRACE_WIDTH)
.map(|_| Vec::with_capacity(INIT_TRACE_CAPACITY))
.collect::<Vec<_>>()
.try_into()
.expect("failed to convert vector to array");
Self { trace }
}
pub fn trace_len(&self) -> usize {
self.trace[0].len()
}
pub fn u32and(&mut self, a: Felt, b: Felt) -> Result<Felt, OperationError> {
let a = assert_u32(a)? as u64;
let b = assert_u32(b)? as u64;
let mut result = 0u64;
for bit_offset in (0..32).step_by(4).rev() {
self.trace[PREV_OUTPUT_COL_IDX].push(Felt::new(result));
let a = a >> bit_offset;
let b = b >> bit_offset;
self.add_bitwise_trace_row(BITWISE_AND, a, b);
let result_4_bit = (a & b) & 0xf;
result = (result << 4) | result_4_bit;
self.trace[OUTPUT_COL_IDX].push(Felt::new(result));
}
Ok(Felt::new(result))
}
pub fn u32xor(&mut self, a: Felt, b: Felt) -> Result<Felt, OperationError> {
let a = assert_u32(a)? as u64;
let b = assert_u32(b)? as u64;
let mut result = 0u64;
for bit_offset in (0..32).step_by(4).rev() {
self.trace[PREV_OUTPUT_COL_IDX].push(Felt::new(result));
let a = a >> bit_offset;
let b = b >> bit_offset;
self.add_bitwise_trace_row(BITWISE_XOR, a, b);
let result_4_bit = (a ^ b) & 0xf;
result = (result << 4) | result_4_bit;
self.trace[OUTPUT_COL_IDX].push(Felt::new(result));
}
Ok(Felt::new(result))
}
pub fn fill_trace(self, trace: &mut TraceFragment) {
debug_assert_eq!(self.trace_len(), trace.len(), "inconsistent trace lengths");
debug_assert_eq!(TRACE_WIDTH, trace.width(), "inconsistent trace widths");
for (out_column, column) in trace.columns().zip(self.trace) {
out_column.copy_from_slice(&column);
}
}
fn add_bitwise_trace_row(&mut self, selector: Felt, a: u64, b: u64) {
self.trace[0].push(selector);
self.trace[A_COL_IDX].push(Felt::new(a));
self.trace[B_COL_IDX].push(Felt::new(b));
self.trace[A_COL_RANGE.start].push(Felt::new(a & 1));
self.trace[A_COL_RANGE.start + 1].push(Felt::new((a >> 1) & 1));
self.trace[A_COL_RANGE.start + 2].push(Felt::new((a >> 2) & 1));
self.trace[A_COL_RANGE.start + 3].push(Felt::new((a >> 3) & 1));
self.trace[B_COL_RANGE.start].push(Felt::new(b & 1));
self.trace[B_COL_RANGE.start + 1].push(Felt::new((b >> 1) & 1));
self.trace[B_COL_RANGE.start + 2].push(Felt::new((b >> 2) & 1));
self.trace[B_COL_RANGE.start + 3].push(Felt::new((b >> 3) & 1));
}
}
impl Default for Bitwise {
fn default() -> Self {
Self::new()
}
}
pub fn assert_u32(value: Felt) -> Result<u32, OperationError> {
u32::try_from(value.as_canonical_u64())
.map_err(|_| OperationError::NotU32Values { values: vec![value] })
}