use alloc::{boxed::Box, string::String, vec::Vec};
use core::{fmt, iter::repeat_n};
use crate::{
Felt, Word, ZERO,
chiplets::hasher,
mast::{MastForest, MastForestError, MastNode, MastNodeId},
operations::Operation,
prettier::PrettyPrint,
serde::Serializable,
utils::{LookupByIdx, bytes_to_packed_u32_elements},
};
mod op_batch;
pub use op_batch::OpBatch;
use op_batch::OpBatchAccumulator;
pub(crate) use op_batch::collect_immediate_placements;
use super::{MastForestContributor, MastNodeExt};
#[cfg(any(test, feature = "arbitrary"))]
pub mod arbitrary;
#[cfg(test)]
mod tests;
pub const GROUP_SIZE: usize = 9;
pub const BATCH_SIZE: usize = 8;
const _: [(); 1] = [(); ((BATCH_SIZE & (BATCH_SIZE - 1)) == 0) as usize];
const ERROR_CODE_FINGERPRINT_DOMAIN: Felt = Felt::new_unchecked(0x2473_0001);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BasicBlockNode {
op_batches: Vec<OpBatch>,
digest: Word,
}
impl BasicBlockNode {
pub const DOMAIN: Felt = ZERO;
}
impl BasicBlockNode {
#[cfg(any(test, feature = "arbitrary"))]
pub(crate) fn new(operations: Vec<Operation>) -> Result<Self, MastForestError> {
if operations.is_empty() {
return Err(MastForestError::EmptyBasicBlock);
}
let (op_batches, digest) = batch_and_hash_ops(&operations);
Ok(Self { op_batches, digest })
}
pub fn adjust_asm_op_indices<T: Copy>(
asm_ops: Vec<(usize, T)>,
op_batches: &[OpBatch],
) -> Vec<(usize, T)> {
let raw2pad = RawToPaddedPrefix::new(op_batches);
asm_ops
.into_iter()
.map(|(raw_idx, id)| {
let padded = raw_idx + raw2pad[raw_idx];
(padded, id)
})
.collect()
}
pub fn unadjust_asm_op_indices<T: Copy>(
asm_ops: Vec<(usize, T)>,
op_batches: &[OpBatch],
) -> Vec<(usize, T)> {
let pad2raw = PaddedToRawPrefix::new(op_batches);
asm_ops
.into_iter()
.map(|(padded_idx, id)| {
let raw = padded_idx - pad2raw[padded_idx];
(raw, id)
})
.collect()
}
}
impl BasicBlockNode {
pub fn op_batches(&self) -> &[OpBatch] {
&self.op_batches
}
pub fn num_op_batches(&self) -> usize {
self.op_batches.len()
}
pub fn num_op_groups(&self) -> usize {
let last_batch_num_groups = self.op_batches.last().expect("no last group").num_groups();
(self.op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two()
}
pub fn num_operations(&self) -> u32 {
let num_ops: usize = self.op_batches.iter().map(|batch| batch.ops().len()).sum();
num_ops.try_into().expect("basic block contains more than 2^32 operations")
}
pub fn operations(&self) -> impl Iterator<Item = &Operation> {
self.op_batches.iter().flat_map(OpBatch::ops)
}
pub fn raw_operations(&self) -> impl Iterator<Item = &Operation> {
self.op_batches.iter().flat_map(OpBatch::raw_ops)
}
#[cfg(test)]
pub fn semantic_eq(&self, other: &BasicBlockNode) -> bool {
let self_ops: Vec<_> = self.operations().collect();
let other_ops: Vec<_> = other.operations().collect();
self_ops == other_ops
}
}
impl BasicBlockNode {
pub fn validate_batch_invariants(&self) -> Result<(), String> {
self.validate_power_of_two_groups()?;
self.validate_batch_structure()?;
self.validate_no_immediate_endings()?;
self.validate_immediate_commitment()?;
self.validate_padding_semantics()?;
Ok(())
}
fn validate_power_of_two_groups(&self) -> Result<(), String> {
for (batch_idx, batch) in self.op_batches.iter().enumerate() {
let num_groups = batch.num_groups();
if batch_idx + 1 < self.op_batches.len() {
if num_groups != BATCH_SIZE {
return Err(format!(
"Batch {batch_idx}: {num_groups} groups is not full batch size {BATCH_SIZE}"
));
}
} else if !num_groups.is_power_of_two() {
return Err(format!("Batch {batch_idx}: {num_groups} groups is not power of two"));
}
}
Ok(())
}
fn validate_no_immediate_endings(&self) -> Result<(), String> {
for (batch_idx, batch) in self.op_batches.iter().enumerate() {
let num_groups = batch.num_groups();
let indptr = batch.indptr();
let ops = batch.ops();
for group_idx in 0..num_groups {
let group_start = indptr[group_idx];
let group_end = indptr[group_idx + 1];
if group_start == group_end {
continue;
}
let group_ops = &ops[group_start..group_end];
let is_last_group = group_idx == num_groups - 1;
if is_last_group {
for (op_idx, op) in group_ops.iter().enumerate() {
if op.imm_value().is_some() {
return Err(format!(
"Batch {batch_idx}, group {group_idx}: operation at index {op_idx} requires immediate value, but this is the last group in batch"
));
}
}
} else {
if let Some(last_op) = group_ops.last()
&& last_op.imm_value().is_some()
{
return Err(format!(
"Batch {batch_idx}, group {group_idx}: ends with operation requiring immediate value"
));
}
}
}
}
Ok(())
}
fn validate_batch_structure(&self) -> Result<(), String> {
for (batch_idx, batch) in self.op_batches.iter().enumerate() {
if batch.num_groups() > BATCH_SIZE {
return Err(format!(
"Batch {}: num_groups {} exceeds maximum {}",
batch_idx,
batch.num_groups(),
BATCH_SIZE
));
}
let indptr = batch.indptr();
let ops = batch.ops();
for i in 0..indptr.len() - 1 {
if indptr[i] > indptr[i + 1] {
return Err(format!(
"Batch {}: indptr[{}] {} > indptr[{}] {} - full array not monotonic (required for serialization)",
batch_idx,
i,
indptr[i],
i + 1,
indptr[i + 1]
));
}
}
let ops_len = ops.len();
if indptr[indptr.len() - 1] != ops_len {
return Err(format!(
"Batch {}: final indptr value {} doesn't match ops.len() {}",
batch_idx,
indptr[indptr.len() - 1],
ops_len
));
}
for group_idx in 0..batch.num_groups() {
let group_start = indptr[group_idx];
let group_end = indptr[group_idx + 1];
let group_size = group_end - group_start;
if group_size > GROUP_SIZE {
return Err(format!(
"Batch {batch_idx}, group {group_idx}: contains {group_size} operations, exceeds maximum {GROUP_SIZE}"
));
}
}
}
Ok(())
}
fn validate_immediate_commitment(&self) -> Result<(), String> {
for (batch_idx, batch) in self.op_batches.iter().enumerate() {
let num_groups = batch.num_groups();
let indptr = batch.indptr();
let ops = batch.ops();
let groups = batch.groups();
let mut immediate_slots = [false; BATCH_SIZE];
for group_idx in 0..num_groups {
let group_start = indptr[group_idx];
let group_end = indptr[group_idx + 1];
if group_start == group_end {
continue;
}
let mut group_value: u64 = 0;
for (local_op_idx, op) in ops[group_start..group_end].iter().enumerate() {
let opcode = op.op_code() as u64;
group_value |= opcode << (Operation::OP_BITS * local_op_idx);
}
if groups[group_idx] != Felt::new_unchecked(group_value) {
return Err(format!(
"Batch {batch_idx}, group {group_idx}: committed opcode group does not match operations"
));
}
let (placements, _next_group_idx) = collect_immediate_placements(
ops,
indptr,
group_idx,
BATCH_SIZE,
Some(num_groups),
)
.map_err(|err| format!("Batch {batch_idx}: {err}"))?;
for (imm_group_idx, imm_value) in placements {
if groups[imm_group_idx] != imm_value {
return Err(format!(
"Batch {batch_idx}: push immediate value mismatch at index {imm_group_idx}"
));
}
immediate_slots[imm_group_idx] = true;
}
}
for group_idx in 0..num_groups {
if indptr[group_idx] == indptr[group_idx + 1]
&& !immediate_slots[group_idx]
&& groups[group_idx] != ZERO
{
return Err(format!(
"Batch {batch_idx}, group {group_idx}: empty group must be zero"
));
}
}
}
Ok(())
}
fn validate_padding_semantics(&self) -> Result<(), String> {
for (batch_idx, batch) in self.op_batches.iter().enumerate() {
batch
.validate_padding_semantics()
.map_err(|err| format!("Batch {batch_idx}: {err}"))?;
}
Ok(())
}
}
impl BasicBlockNode {
pub(super) fn to_display<'a>(&'a self, _mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
self.clone()
}
pub(super) fn to_pretty_print<'a>(
&'a self,
_mast_forest: &'a MastForest,
) -> impl PrettyPrint + 'a {
self.clone()
}
}
impl MastNodeExt for BasicBlockNode {
fn digest(&self) -> Word {
self.digest
}
fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
Box::new(BasicBlockNode::to_display(self, mast_forest))
}
fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
Box::new(BasicBlockNode::to_pretty_print(self, mast_forest))
}
fn has_children(&self) -> bool {
false
}
fn append_children_to(&self, _target: &mut Vec<MastNodeId>) {
}
fn for_each_child<F>(&self, _f: F)
where
F: FnMut(MastNodeId),
{
}
fn domain(&self) -> Felt {
Self::DOMAIN
}
type Builder = BasicBlockNodeBuilder;
fn to_builder(self, _forest: &MastForest) -> Self::Builder {
BasicBlockNodeBuilder::from_op_batches(self.op_batches, self.digest)
}
}
impl PrettyPrint for BasicBlockNode {
#[rustfmt::skip]
fn render(&self) -> crate::prettier::Document {
use crate::prettier::*;
let single_line = const_text("basic_block")
+ const_text(" ")
+ self
.operations()
.map(PrettyPrint::render)
.reduce(|acc, doc| acc + const_text(" ") + doc)
.unwrap_or_default()
+ const_text(" ")
+ const_text("end");
let multi_line = indent(
4,
const_text("basic_block")
+ nl()
+ self
.operations()
.map(PrettyPrint::render)
.reduce(|acc, doc| acc + nl() + doc)
.unwrap_or_default(),
) + nl()
+ const_text("end");
single_line | multi_line
}
}
impl fmt::Display for BasicBlockNode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.pretty_print(f)
}
}
#[derive(Debug, Clone)]
pub struct RawToPaddedPrefix(Vec<usize>);
impl RawToPaddedPrefix {
pub fn new(op_batches: &[OpBatch]) -> Self {
let mut v = Vec::new();
let mut pads_so_far = 0usize;
for b in op_batches {
let n = b.num_groups();
let indptr = b.indptr();
let padding = b.padding();
for g in 0..n {
let group_len = indptr[g + 1] - indptr[g];
let has_pad = padding[g] as usize;
let raw_in_g = group_len - has_pad;
v.extend(repeat_n(pads_so_far, raw_in_g));
pads_so_far += has_pad; }
}
v.push(pads_so_far);
RawToPaddedPrefix(v)
}
}
impl core::ops::Index<usize> for RawToPaddedPrefix {
type Output = usize;
#[inline]
fn index(&self, idx: usize) -> &Self::Output {
&self.0[idx]
}
}
#[derive(Debug, Clone)]
pub struct PaddedToRawPrefix(Vec<usize>);
impl PaddedToRawPrefix {
pub fn new(op_batches: &[OpBatch]) -> Self {
let padded_ops = op_batches
.iter()
.map(|b| {
let n = b.num_groups();
let indptr = b.indptr();
indptr[1..=n]
.iter()
.zip(&indptr[..n])
.map(|(end, start)| end - start)
.sum::<usize>()
})
.sum::<usize>();
let mut v = Vec::with_capacity(padded_ops + 1);
let mut pads_so_far = 0usize;
for b in op_batches {
let n = b.num_groups();
let indptr = b.indptr();
let padding = b.padding();
for g in 0..n {
let group_len = indptr[g + 1] - indptr[g];
let has_pad = padding[g] as usize;
let raw_in_g = group_len - has_pad;
v.extend(repeat_n(pads_so_far, raw_in_g));
if has_pad == 1 {
v.push(pads_so_far);
pads_so_far += 1; }
}
}
v.push(pads_so_far);
PaddedToRawPrefix(v)
}
}
impl core::ops::Index<usize> for PaddedToRawPrefix {
type Output = usize;
#[inline]
fn index(&self, idx: usize) -> &Self::Output {
&self.0[idx]
}
}
fn batch_and_hash_ops(ops: &[Operation]) -> (Vec<OpBatch>, Word) {
let batches = batch_ops(ops);
let op_groups: Vec<Felt> = batches.iter().flat_map(|batch| batch.groups).collect();
let hash = hasher::hash_elements(&op_groups);
(batches, hash)
}
fn fingerprint_basic_block_error_codes(block_digest: Word, op_batches: &[OpBatch]) -> Word {
let error_code_data = serialize_basic_block_error_codes(op_batches);
if error_code_data.is_empty() {
return block_digest;
}
let data_len = error_code_data.len() as u64;
let mut elements = Vec::with_capacity(7 + error_code_data.len().div_ceil(4));
elements.push(ERROR_CODE_FINGERPRINT_DOMAIN);
elements.extend_from_slice(block_digest.as_elements());
elements.push(Felt::from_u32(data_len as u32));
elements.push(Felt::from_u32((data_len >> 32) as u32));
elements.extend(bytes_to_packed_u32_elements(&error_code_data));
hasher::hash_elements(&elements)
}
fn serialize_basic_block_error_codes(op_batches: &[OpBatch]) -> Vec<u8> {
let mut data = Vec::new();
for (raw_op_idx, op) in op_batches.iter().flat_map(OpBatch::raw_ops).enumerate() {
if matches!(op, Operation::Assert(_) | Operation::U32assert2(_) | Operation::MpVerify(_)) {
data.extend_from_slice(&(raw_op_idx as u64).to_le_bytes());
op.write_into(&mut data);
}
}
data
}
fn batch_ops(ops: &[Operation]) -> Vec<OpBatch> {
let mut batches = Vec::<OpBatch>::new();
let mut batch_acc = OpBatchAccumulator::new();
for op in ops.iter().copied() {
if !batch_acc.can_accept_op(op) {
let batch = batch_acc.into_batch();
batch_acc = OpBatchAccumulator::new();
batches.push(batch);
}
batch_acc.add_op(op);
}
if !batch_acc.is_empty() {
let batch = batch_acc.into_batch();
batches.push(batch);
}
batches
}
#[derive(Debug)]
enum OperationData {
Raw { operations: Vec<Operation> },
Batched { op_batches: Vec<OpBatch> },
}
#[derive(Debug)]
pub struct BasicBlockNodeBuilder {
operation_data: OperationData,
digest: Option<Word>,
}
impl BasicBlockNodeBuilder {
pub fn new(operations: Vec<Operation>) -> Self {
Self {
operation_data: OperationData::Raw { operations },
digest: None,
}
}
pub(crate) fn from_op_batches(op_batches: Vec<OpBatch>, digest: Word) -> Self {
Self {
operation_data: OperationData::Batched { op_batches },
digest: Some(digest),
}
}
#[doc(hidden)]
pub fn from_op_batches_preserving_digest(op_batches: Vec<OpBatch>, digest: Word) -> Self {
Self::from_op_batches(op_batches, digest)
}
pub fn build(self) -> Result<BasicBlockNode, MastForestError> {
let (op_batches, digest) = match self.operation_data {
OperationData::Raw { operations } => {
if operations.is_empty() {
return Err(MastForestError::EmptyBasicBlock);
}
let (op_batches, computed_digest) = batch_and_hash_ops(&operations);
let digest = self.digest.unwrap_or(computed_digest);
(op_batches, digest)
},
OperationData::Batched { op_batches } => {
if op_batches.is_empty() {
return Err(MastForestError::EmptyBasicBlock);
}
let digest = self.digest.expect("digest must be set for batched operations");
(op_batches, digest)
},
};
Ok(BasicBlockNode { op_batches, digest })
}
pub(in crate::mast) fn add_to_forest_relaxed(
self,
forest: &mut MastForest,
) -> Result<MastNodeId, MastForestError> {
let (op_batches, digest) = match self.operation_data {
OperationData::Raw { operations } => {
if operations.is_empty() {
return Err(MastForestError::EmptyBasicBlock);
}
let (op_batches, computed_digest) = batch_and_hash_ops(&operations);
let digest = self.digest.unwrap_or(computed_digest);
(op_batches, digest)
},
OperationData::Batched { op_batches } => {
if op_batches.is_empty() {
return Err(MastForestError::EmptyBasicBlock);
}
let digest = self.digest.expect("digest must be set for batched operations");
(op_batches, digest)
},
};
let node_id = forest
.nodes
.push(MastNode::Block(BasicBlockNode { op_batches, digest }))
.map_err(|_| MastForestError::TooManyNodes)?;
Ok(node_id)
}
}
impl MastForestContributor for BasicBlockNodeBuilder {
fn add_to_forest(self, forest: &mut MastForest) -> Result<MastNodeId, MastForestError> {
let (op_batches, digest) = match self.operation_data {
OperationData::Raw { operations } => {
if operations.is_empty() {
return Err(MastForestError::EmptyBasicBlock);
}
let (op_batches, computed_digest) = batch_and_hash_ops(&operations);
let digest = self.digest.unwrap_or(computed_digest);
(op_batches, digest)
},
OperationData::Batched { op_batches } => {
if op_batches.is_empty() {
return Err(MastForestError::EmptyBasicBlock);
}
let digest = self.digest.expect("digest must be set for batched operations");
(op_batches, digest)
},
};
let node_id = forest
.nodes
.push(MastNode::Block(BasicBlockNode { op_batches, digest }))
.map_err(|_| MastForestError::TooManyNodes)?;
Ok(node_id)
}
fn fingerprint_for_node(
&self,
_forest: &MastForest,
_hash_by_node_id: &impl LookupByIdx<MastNodeId, Word>,
) -> Result<Word, MastForestError> {
let (op_batches, digest) = match &self.operation_data {
OperationData::Raw { operations } => {
let (op_batches, computed_digest) = batch_and_hash_ops(operations);
(op_batches, self.digest.unwrap_or(computed_digest))
},
OperationData::Batched { op_batches } => {
let digest = self.digest.expect("digest must be set for batched operations");
(op_batches.clone(), digest)
},
};
Ok(fingerprint_basic_block_error_codes(digest, &op_batches))
}
fn remap_children(self, _remapping: &impl LookupByIdx<MastNodeId, MastNodeId>) -> Self {
self
}
fn with_digest(mut self, digest: Word) -> Self {
self.digest = Some(digest);
self
}
}
#[cfg(any(test, feature = "arbitrary"))]
impl proptest::prelude::Arbitrary for BasicBlockNodeBuilder {
type Parameters = arbitrary::BasicBlockNodeParams;
type Strategy = proptest::strategy::BoxedStrategy<Self>;
fn arbitrary_with(params: Self::Parameters) -> Self::Strategy {
use proptest::prelude::*;
use super::arbitrary::op_non_control_sequence_strategy;
op_non_control_sequence_strategy(params.max_ops_len).prop_map(Self::new).boxed()
}
}