use alloc::{boxed::Box, vec::Vec};
use core::{
fmt,
iter::{Peekable, repeat_n},
slice::Iter,
};
use miden_crypto::{Felt, Word, ZERO};
use miden_formatting::prettier::PrettyPrint;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::{
DecoratorList, Operation,
chiplets::hasher,
mast::{DecoratedOpLink, DecoratorId, MastForest, MastForestError, MastNodeId, Remapping},
};
mod op_batch;
pub use op_batch::OpBatch;
use op_batch::OpBatchAccumulator;
use super::{MastNodeErrorContext, MastNodeExt};
#[cfg(any(test, feature = "arbitrary"))]
pub mod arbitrary;
#[cfg(test)]
mod tests;
pub const GROUP_SIZE: usize = 9;
pub const BATCH_SIZE: usize = 8;
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(all(feature = "arbitrary", test), miden_test_serde_macros::serde_test)]
pub struct BasicBlockNode {
op_batches: Vec<OpBatch>,
digest: Word,
#[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty"))]
decorators: DecoratorList,
#[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty"))]
before_enter: Vec<DecoratorId>,
#[cfg_attr(feature = "serde", serde(default, skip_serializing_if = "Vec::is_empty"))]
after_exit: Vec<DecoratorId>,
}
impl BasicBlockNode {
pub const DOMAIN: Felt = ZERO;
}
impl BasicBlockNode {
pub fn new(
operations: Vec<Operation>,
decorators: DecoratorList,
) -> Result<Self, MastForestError> {
if operations.is_empty() {
return Err(MastForestError::EmptyBasicBlock);
}
#[cfg(debug_assertions)]
validate_decorators(operations.len(), &decorators);
let (op_batches, digest) = batch_and_hash_ops(operations);
let reflowed_decorators = BasicBlockNode::adjust_decorators(decorators, &op_batches);
Ok(Self {
op_batches,
digest,
decorators: reflowed_decorators,
before_enter: Vec::new(),
after_exit: Vec::new(),
})
}
fn adjust_decorators(decorators: DecoratorList, op_batches: &[OpBatch]) -> DecoratorList {
let raw2pad = RawToPaddedPrefix::new(op_batches);
decorators
.into_iter()
.map(|(raw_idx, dec_id)| (raw_idx + raw2pad[raw_idx], dec_id))
.collect()
}
pub fn new_unsafe(operations: Vec<Operation>, decorators: DecoratorList, digest: Word) -> Self {
assert!(!operations.is_empty());
let op_batches = batch_ops(operations);
Self {
op_batches,
digest,
decorators,
before_enter: Vec::new(),
after_exit: Vec::new(),
}
}
#[cfg(test)]
pub fn new_with_raw_decorators(
operations: Vec<Operation>,
decorators: Vec<(usize, crate::Decorator)>,
mast_forest: &mut crate::mast::MastForest,
) -> Result<Self, MastForestError> {
let mut decorator_list = Vec::new();
for (idx, decorator) in decorators {
decorator_list.push((idx, mast_forest.add_decorator(decorator)?));
}
Self::new(operations, decorator_list)
}
}
impl BasicBlockNode {
pub fn op_batches(&self) -> &[OpBatch] {
&self.op_batches
}
pub fn num_op_batches(&self) -> usize {
self.op_batches.len()
}
pub fn num_op_groups(&self) -> usize {
let last_batch_num_groups = self.op_batches.last().expect("no last group").num_groups();
(self.op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two()
}
pub fn num_operations(&self) -> u32 {
let num_ops: usize = self.op_batches.iter().map(|batch| batch.ops().len()).sum();
num_ops.try_into().expect("basic block contains more than 2^32 operations")
}
pub fn indexed_decorator_iter(&self) -> DecoratorOpLinkIterator<'_> {
DecoratorOpLinkIterator::new(&[], &self.decorators, &[], self.num_operations() as usize)
}
pub fn raw_decorator_iter(&self) -> RawDecoratorOpLinkIterator<'_> {
RawDecoratorOpLinkIterator::new(
&self.before_enter,
&self.decorators,
&self.after_exit,
&self.op_batches,
)
}
pub fn operations(&self) -> impl Iterator<Item = &Operation> {
self.op_batches.iter().flat_map(|batch| batch.ops())
}
pub fn raw_operations(&self) -> impl Iterator<Item = &Operation> {
self.op_batches.iter().flat_map(|batch| batch.raw_ops())
}
pub fn num_operations_and_decorators(&self) -> u32 {
let num_ops: usize = self.num_operations() as usize;
let num_decorators = self.decorators.len();
(num_ops + num_decorators)
.try_into()
.expect("basic block contains more than 2^32 operations and decorators")
}
pub fn iter(&self) -> impl Iterator<Item = OperationOrDecorator<'_>> {
OperationOrDecoratorIterator::new(self)
}
}
impl BasicBlockNode {
pub fn set_decorators(&mut self, decorator_list: DecoratorList) {
self.decorators = decorator_list;
}
}
impl MastNodeErrorContext for BasicBlockNode {
fn decorators(&self) -> impl Iterator<Item = DecoratedOpLink> {
DecoratorOpLinkIterator::new(
&self.before_enter,
&self.decorators,
&self.after_exit,
self.num_operations() as usize,
)
}
}
impl BasicBlockNode {
pub(super) fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> impl fmt::Display + 'a {
BasicBlockNodePrettyPrint { block_node: self, mast_forest }
}
pub(super) fn to_pretty_print<'a>(
&'a self,
mast_forest: &'a MastForest,
) -> impl PrettyPrint + 'a {
BasicBlockNodePrettyPrint { block_node: self, mast_forest }
}
}
impl MastNodeExt for BasicBlockNode {
fn digest(&self) -> Word {
self.digest
}
fn before_enter(&self) -> &[DecoratorId] {
&self.before_enter
}
fn after_exit(&self) -> &[DecoratorId] {
&self.after_exit
}
fn append_before_enter(&mut self, decorator_ids: &[DecoratorId]) {
self.before_enter.extend_from_slice(decorator_ids);
}
fn append_after_exit(&mut self, decorator_ids: &[DecoratorId]) {
self.after_exit.extend_from_slice(decorator_ids);
}
fn remove_decorators(&mut self) {
self.decorators.truncate(0);
self.before_enter.truncate(0);
self.after_exit.truncate(0);
}
fn to_display<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn fmt::Display + 'a> {
Box::new(BasicBlockNode::to_display(self, mast_forest))
}
fn to_pretty_print<'a>(&'a self, mast_forest: &'a MastForest) -> Box<dyn PrettyPrint + 'a> {
Box::new(BasicBlockNode::to_pretty_print(self, mast_forest))
}
fn remap_children(&self, _remapping: &Remapping) -> Self {
self.clone()
}
fn has_children(&self) -> bool {
false
}
fn append_children_to(&self, _target: &mut Vec<MastNodeId>) {
}
fn for_each_child<F>(&self, _f: F)
where
F: FnMut(MastNodeId),
{
}
fn domain(&self) -> Felt {
Self::DOMAIN
}
}
struct BasicBlockNodePrettyPrint<'a> {
block_node: &'a BasicBlockNode,
mast_forest: &'a MastForest,
}
impl PrettyPrint for BasicBlockNodePrettyPrint<'_> {
#[rustfmt::skip]
fn render(&self) -> crate::prettier::Document {
use crate::prettier::*;
let single_line = const_text("basic_block")
+ const_text(" ")
+ self.
block_node
.iter()
.map(|op_or_dec| match op_or_dec {
OperationOrDecorator::Operation(op) => op.render(),
OperationOrDecorator::Decorator(&decorator_id) => self.mast_forest[decorator_id].render(),
})
.reduce(|acc, doc| acc + const_text(" ") + doc)
.unwrap_or_default()
+ const_text(" ")
+ const_text("end");
let multi_line = indent(
4,
const_text("basic_block")
+ nl()
+ self
.block_node
.iter()
.map(|op_or_dec| match op_or_dec {
OperationOrDecorator::Operation(op) => op.render(),
OperationOrDecorator::Decorator(&decorator_id) => self.mast_forest[decorator_id].render(),
})
.reduce(|acc, doc| acc + nl() + doc)
.unwrap_or_default(),
) + nl()
+ const_text("end");
single_line | multi_line
}
}
impl fmt::Display for BasicBlockNodePrettyPrint<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use crate::prettier::PrettyPrint;
self.pretty_print(f)
}
}
pub struct DecoratorOpLinkIterator<'a> {
before: Peekable<Iter<'a, DecoratorId>>,
middle: Peekable<Iter<'a, (usize, DecoratorId)>>,
after: Peekable<Iter<'a, DecoratorId>>,
total_ops: usize,
seg: Segment,
}
enum Segment {
Before,
Middle,
After,
Done,
}
impl<'a> DecoratorOpLinkIterator<'a> {
pub fn new(
before_enter: &'a [DecoratorId],
decorators: &'a DecoratorList,
after_exit: &'a [DecoratorId],
total_operations: usize,
) -> Self {
Self {
before: before_enter.iter().peekable(),
middle: decorators.iter().peekable(),
after: after_exit.iter().peekable(),
total_ops: total_operations,
seg: Segment::Before,
}
}
#[inline]
pub fn next_filtered(&mut self, pos: usize) -> Option<(usize, DecoratorId)> {
let should_yield: bool;
'segwalk: loop {
match self.seg {
Segment::Before => {
if self.before.peek().is_some() {
should_yield = pos == 0;
break 'segwalk;
}
self.seg = Segment::Middle;
},
Segment::Middle => {
if let Some(&(p, _)) = self.middle.peek() {
should_yield = pos == *p;
break 'segwalk;
}
self.seg = Segment::After;
},
Segment::After => {
if self.after.peek().is_some() {
should_yield = pos == self.total_ops;
break 'segwalk;
}
self.seg = Segment::Done;
},
Segment::Done => {
should_yield = false;
break 'segwalk;
},
}
}
if should_yield { self.next() } else { None }
}
}
impl<'a> Iterator for DecoratorOpLinkIterator<'a> {
type Item = (usize, DecoratorId);
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.seg {
Segment::Before => {
if let Some(&id) = self.before.next() {
return Some((0, id));
}
self.seg = Segment::Middle;
},
Segment::Middle => {
if let Some(&(pos, id)) = self.middle.next() {
return Some((pos, id));
}
self.seg = Segment::After;
},
Segment::After => {
if let Some(&id) = self.after.next() {
return Some((self.total_ops, id));
}
self.seg = Segment::Done;
},
Segment::Done => return None,
}
}
}
}
impl<'a> ExactSizeIterator for DecoratorOpLinkIterator<'a> {
#[inline]
fn len(&self) -> usize {
self.before.len() + self.middle.len() + self.after.len()
}
}
pub struct RawDecoratorOpLinkIterator<'a> {
before: core::slice::Iter<'a, DecoratorId>,
middle: core::slice::Iter<'a, (usize, DecoratorId)>, after: core::slice::Iter<'a, DecoratorId>,
pad2raw: PaddedToRawPrefix, total_raw_ops: usize, seg: Segment,
}
impl<'a> RawDecoratorOpLinkIterator<'a> {
pub fn new(
before_enter: &'a [DecoratorId],
decorators: &'a DecoratorList, after_exit: &'a [DecoratorId],
op_batches: &'a [OpBatch],
) -> Self {
let pad2raw = PaddedToRawPrefix::new(op_batches);
let raw2pad = RawToPaddedPrefix::new(op_batches);
let total_raw_ops = raw2pad.raw_ops();
Self {
before: before_enter.iter(),
middle: decorators.iter(),
after: after_exit.iter(),
pad2raw,
total_raw_ops,
seg: Segment::Before,
}
}
}
impl<'a> Iterator for RawDecoratorOpLinkIterator<'a> {
type Item = (usize, DecoratorId);
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.seg {
Segment::Before => {
if let Some(&id) = self.before.next() {
return Some((0, id));
}
self.seg = Segment::Middle;
},
Segment::Middle => {
if let Some(&(padded_idx, id)) = self.middle.next() {
let raw_idx = padded_idx - self.pad2raw[padded_idx];
return Some((raw_idx, id));
}
self.seg = Segment::After;
},
Segment::After => {
if let Some(&id) = self.after.next() {
return Some((self.total_raw_ops, id));
}
self.seg = Segment::Done;
},
Segment::Done => return None,
}
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum OperationOrDecorator<'a> {
Operation(&'a Operation),
Decorator(&'a DecoratorId),
}
struct OperationOrDecoratorIterator<'a> {
node: &'a BasicBlockNode,
before: core::slice::Iter<'a, DecoratorId>,
after: core::slice::Iter<'a, DecoratorId>,
batch_index: usize,
op_index_in_batch: usize,
op_index: usize,
decorator_list_next_index: usize,
seg: Segment,
}
impl<'a> OperationOrDecoratorIterator<'a> {
fn new(node: &'a BasicBlockNode) -> Self {
Self {
node,
before: node.before_enter().iter(),
after: node.after_exit().iter(),
batch_index: 0,
op_index_in_batch: 0,
op_index: 0,
decorator_list_next_index: 0,
seg: Segment::Before,
}
}
#[inline]
fn next_decorator_if_due(&mut self) -> Option<OperationOrDecorator<'a>> {
if let Some((op_idx, deco)) = self.node.decorators.get(self.decorator_list_next_index)
&& *op_idx == self.op_index
{
self.decorator_list_next_index += 1;
Some(OperationOrDecorator::Decorator(deco))
} else {
None
}
}
}
impl<'a> Iterator for OperationOrDecoratorIterator<'a> {
type Item = OperationOrDecorator<'a>;
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.seg {
Segment::Before => {
if let Some(id) = self.before.next() {
return Some(OperationOrDecorator::Decorator(id));
}
self.seg = Segment::Middle;
},
Segment::Middle => {
if let Some(d) = self.next_decorator_if_due() {
return Some(d);
}
if let Some(batch) = self.node.op_batches.get(self.batch_index) {
if let Some(op) = batch.ops.get(self.op_index_in_batch) {
self.op_index_in_batch += 1;
self.op_index += 1;
return Some(OperationOrDecorator::Operation(op));
} else {
self.batch_index += 1;
self.op_index_in_batch = 0;
continue;
}
} else {
self.seg = Segment::After;
}
},
Segment::After => {
if let Some(id) = self.after.next() {
return Some(OperationOrDecorator::Decorator(id));
}
self.seg = Segment::Done;
},
Segment::Done => return None,
}
}
}
}
#[cfg(debug_assertions)]
pub(crate) fn validate_decorators(operations_len: usize, decorators: &DecoratorList) {
if !decorators.is_empty() {
for i in 0..(decorators.len() - 1) {
debug_assert!(decorators[i + 1].0 >= decorators[i].0, "unsorted decorators list");
}
debug_assert!(
operations_len >= decorators.last().expect("empty decorators list").0,
"last op index in decorator list should be less than or equal to the number of ops"
);
}
}
#[derive(Debug, Clone)]
pub struct RawToPaddedPrefix(Vec<usize>);
impl RawToPaddedPrefix {
pub fn new(op_batches: &[OpBatch]) -> Self {
let mut v = Vec::new();
let mut pads_so_far = 0usize;
for b in op_batches {
let n = b.num_groups();
let indptr = b.indptr();
let padding = b.padding();
for g in 0..n {
let group_len = indptr[g + 1] - indptr[g];
let has_pad = padding[g] as usize;
let raw_in_g = group_len - has_pad;
v.extend(repeat_n(pads_so_far, raw_in_g));
pads_so_far += has_pad; }
}
v.push(pads_so_far);
RawToPaddedPrefix(v)
}
#[inline]
pub fn raw_ops(&self) -> usize {
self.0.len() - 1
}
}
impl core::ops::Index<usize> for RawToPaddedPrefix {
type Output = usize;
#[inline]
fn index(&self, idx: usize) -> &Self::Output {
&self.0[idx]
}
}
#[derive(Debug, Clone)]
pub struct PaddedToRawPrefix(Vec<usize>);
impl PaddedToRawPrefix {
pub fn new(op_batches: &[OpBatch]) -> Self {
let padded_ops = op_batches
.iter()
.map(|b| {
let n = b.num_groups();
let indptr = b.indptr();
indptr[1..=n]
.iter()
.zip(&indptr[..n])
.map(|(end, start)| end - start)
.sum::<usize>()
})
.sum::<usize>();
let mut v = Vec::with_capacity(padded_ops + 1);
let mut pads_so_far = 0usize;
for b in op_batches {
let n = b.num_groups();
let indptr = b.indptr();
let padding = b.padding();
for g in 0..n {
let group_len = indptr[g + 1] - indptr[g];
let has_pad = padding[g] as usize;
let raw_in_g = group_len - has_pad;
v.extend(repeat_n(pads_so_far, raw_in_g));
if has_pad == 1 {
v.push(pads_so_far);
pads_so_far += 1; }
}
}
v.push(pads_so_far);
PaddedToRawPrefix(v)
}
}
impl core::ops::Index<usize> for PaddedToRawPrefix {
type Output = usize;
#[inline]
fn index(&self, idx: usize) -> &Self::Output {
&self.0[idx]
}
}
fn batch_and_hash_ops(ops: Vec<Operation>) -> (Vec<OpBatch>, Word) {
let batches = batch_ops(ops);
let op_groups: Vec<Felt> = batches.iter().flat_map(|batch| batch.groups).collect();
let hash = hasher::hash_elements(&op_groups);
(batches, hash)
}
fn batch_ops(ops: Vec<Operation>) -> Vec<OpBatch> {
let mut batches = Vec::<OpBatch>::new();
let mut batch_acc = OpBatchAccumulator::new();
for op in ops {
if !batch_acc.can_accept_op(op) {
let batch = batch_acc.into_batch();
batch_acc = OpBatchAccumulator::new();
batches.push(batch);
}
batch_acc.add_op(op);
}
if !batch_acc.is_empty() {
let batch = batch_acc.into_batch();
batches.push(batch);
}
batches
}