use super::{fmt, hasher, Digest, Felt, FieldElement, Operation, Vec};
use crate::{DecoratorIterator, DecoratorList};
use winter_utils::flatten_slice_elements;
pub const GROUP_SIZE: usize = 9;
pub const BATCH_SIZE: usize = 8;
const MAX_OPS_PER_BATCH: usize = GROUP_SIZE * BATCH_SIZE;
#[derive(Clone, Debug)]
pub struct Span {
op_batches: Vec<OpBatch>,
hash: Digest,
decorators: DecoratorList,
}
impl Span {
pub fn new(operations: Vec<Operation>) -> Self {
assert!(!operations.is_empty()); Self::with_decorators(operations, DecoratorList::new())
}
pub fn with_decorators(operations: Vec<Operation>, decorators: DecoratorList) -> Self {
assert!(!operations.is_empty());
#[cfg(debug_assertions)]
validate_decorators(&operations, &decorators);
let (op_batches, hash) = batch_ops(operations);
Self {
op_batches,
hash,
decorators,
}
}
pub fn hash(&self) -> Digest {
self.hash
}
pub fn op_batches(&self) -> &[OpBatch] {
&self.op_batches
}
#[must_use]
pub fn replicate(&self, num_copies: usize) -> Self {
let own_ops = self.get_ops();
let own_decorators = &self.decorators;
let mut ops = Vec::with_capacity(own_ops.len() * num_copies);
let mut decorators = DecoratorList::new();
for i in 0..num_copies {
for decorator in own_decorators {
decorators.push((own_ops.len() * i + decorator.0, decorator.1.clone()))
}
ops.extend_from_slice(&own_ops);
}
Self::with_decorators(ops, decorators)
}
pub fn decorators(&self) -> &DecoratorList {
&self.decorators
}
pub fn decorator_iter(&self) -> DecoratorIterator {
DecoratorIterator::new(&self.decorators)
}
fn get_ops(&self) -> Vec<Operation> {
let mut ops = Vec::with_capacity(self.op_batches.len() * MAX_OPS_PER_BATCH);
for batch in self.op_batches.iter() {
ops.extend_from_slice(&batch.ops);
}
ops
}
}
impl fmt::Display for Span {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "span")?;
for batch in self.op_batches.iter() {
for op in batch.ops.iter() {
write!(f, " {op}")?;
}
}
write!(f, " end")
}
}
#[derive(Clone, Debug)]
pub struct OpBatch {
ops: Vec<Operation>,
groups: [Felt; BATCH_SIZE],
op_counts: [usize; BATCH_SIZE],
num_groups: usize,
}
impl OpBatch {
pub fn ops(&self) -> &[Operation] {
&self.ops
}
pub fn groups(&self) -> &[Felt; BATCH_SIZE] {
&self.groups
}
pub fn op_counts(&self) -> &[usize; BATCH_SIZE] {
&self.op_counts
}
pub fn num_groups(&self) -> usize {
self.num_groups
}
}
struct OpBatchAccumulator {
ops: Vec<Operation>,
groups: [Felt; BATCH_SIZE],
op_counts: [usize; BATCH_SIZE],
group: u64,
op_idx: usize,
group_idx: usize,
next_group_idx: usize,
}
impl OpBatchAccumulator {
pub fn new() -> Self {
Self {
ops: Vec::new(),
groups: [Felt::ZERO; BATCH_SIZE],
op_counts: [0; BATCH_SIZE],
group: 0,
op_idx: 0,
group_idx: 0,
next_group_idx: 1,
}
}
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
pub fn can_accept_op(&self, op: Operation) -> bool {
if op.imm_value().is_some() {
if self.op_idx < GROUP_SIZE - 1 {
self.next_group_idx < BATCH_SIZE
} else {
self.next_group_idx + 1 < BATCH_SIZE
}
} else {
self.op_idx < GROUP_SIZE || self.next_group_idx < BATCH_SIZE
}
}
pub fn add_op(&mut self, op: Operation) {
if self.op_idx == GROUP_SIZE {
self.finalize_op_group();
}
if let Some(imm) = op.imm_value() {
if self.op_idx == GROUP_SIZE - 1 {
self.finalize_op_group();
}
self.groups[self.next_group_idx] = imm;
self.next_group_idx += 1;
}
let opcode = op.op_code() as u64;
self.group |= opcode << (Operation::OP_BITS * self.op_idx);
self.ops.push(op);
self.op_idx += 1;
}
pub fn into_batch(mut self) -> OpBatch {
if self.group != 0 || self.op_idx != 0 {
self.groups[self.group_idx] = Felt::new(self.group);
self.op_counts[self.group_idx] = self.op_idx;
}
OpBatch {
ops: self.ops,
groups: self.groups,
op_counts: self.op_counts,
num_groups: self.next_group_idx,
}
}
fn finalize_op_group(&mut self) {
self.groups[self.group_idx] = Felt::new(self.group);
self.op_counts[self.group_idx] = self.op_idx;
self.group_idx = self.next_group_idx;
self.next_group_idx = self.group_idx + 1;
self.op_idx = 0;
self.group = 0;
}
}
fn batch_ops(ops: Vec<Operation>) -> (Vec<OpBatch>, Digest) {
let mut batch_acc = OpBatchAccumulator::new();
let mut batches = Vec::<OpBatch>::new();
let mut batch_groups = Vec::<[Felt; BATCH_SIZE]>::new();
for op in ops {
if !batch_acc.can_accept_op(op) {
let batch = batch_acc.into_batch();
batch_acc = OpBatchAccumulator::new();
batch_groups.push(*batch.groups());
batches.push(batch);
}
batch_acc.add_op(op);
}
if !batch_acc.is_empty() {
let batch = batch_acc.into_batch();
batch_groups.push(*batch.groups());
batches.push(batch);
}
let num_op_groups = get_span_op_group_count(&batches);
let op_groups = &flatten_slice_elements(&batch_groups)[..num_op_groups];
let hash = hasher::hash_elements(op_groups);
(batches, hash)
}
pub fn get_span_op_group_count(op_batches: &[OpBatch]) -> usize {
let last_batch_num_groups = op_batches.last().expect("no last group").num_groups();
(op_batches.len() - 1) * BATCH_SIZE + last_batch_num_groups.next_power_of_two()
}
#[cfg(debug_assertions)]
fn validate_decorators(operations: &[Operation], decorators: &DecoratorList) {
if !decorators.is_empty() {
for i in 0..(decorators.len() - 1) {
debug_assert!(
decorators[i + 1].0 >= decorators[i].0,
"unsorted decorators list"
);
}
debug_assert!(
operations.len() > decorators.last().expect("empty decorators list").0,
"last op index in decorator list should be less than number of ops"
);
}
}
#[cfg(test)]
mod tests {
use super::{hasher, Felt, FieldElement, Operation, BATCH_SIZE};
#[test]
fn batch_ops() {
let ops = vec![Operation::Add];
let (batches, hash) = super::batch_ops(ops.clone());
assert_eq!(1, batches.len());
let batch = &batches[0];
assert_eq!(ops, batch.ops);
assert_eq!(1, batch.num_groups());
let mut batch_groups = [Felt::ZERO; BATCH_SIZE];
batch_groups[0] = build_group(&ops);
assert_eq!(batch_groups, batch.groups);
assert_eq!([1_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts);
assert_eq!(hasher::hash_elements(&batch_groups[..1]), hash);
let ops = vec![Operation::Add, Operation::Mul];
let (batches, hash) = super::batch_ops(ops.clone());
assert_eq!(1, batches.len());
let batch = &batches[0];
assert_eq!(ops, batch.ops);
assert_eq!(1, batch.num_groups());
let mut batch_groups = [Felt::ZERO; BATCH_SIZE];
batch_groups[0] = build_group(&ops);
assert_eq!(batch_groups, batch.groups);
assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts);
assert_eq!(hasher::hash_elements(&batch_groups[..1]), hash);
let ops = vec![Operation::Add, Operation::Push(Felt::new(12345678))];
let (batches, hash) = super::batch_ops(ops.clone());
assert_eq!(1, batches.len());
let batch = &batches[0];
assert_eq!(ops, batch.ops);
assert_eq!(2, batch.num_groups());
let mut batch_groups = [Felt::ZERO; BATCH_SIZE];
batch_groups[0] = build_group(&ops);
batch_groups[1] = Felt::new(12345678);
assert_eq!(batch_groups, batch.groups);
assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts);
assert_eq!(hasher::hash_elements(&batch_groups[..2]), hash);
let ops = vec![
Operation::Push(Felt::new(1)),
Operation::Push(Felt::new(2)),
Operation::Push(Felt::new(3)),
Operation::Push(Felt::new(4)),
Operation::Push(Felt::new(5)),
Operation::Push(Felt::new(6)),
Operation::Push(Felt::new(7)),
Operation::Add,
];
let (batches, hash) = super::batch_ops(ops.clone());
assert_eq!(1, batches.len());
let batch = &batches[0];
assert_eq!(ops, batch.ops);
assert_eq!(8, batch.num_groups());
let batch_groups = [
build_group(&ops),
Felt::new(1),
Felt::new(2),
Felt::new(3),
Felt::new(4),
Felt::new(5),
Felt::new(6),
Felt::new(7),
];
assert_eq!(batch_groups, batch.groups);
assert_eq!([8_usize, 0, 0, 0, 0, 0, 0, 0], batch.op_counts);
assert_eq!(hasher::hash_elements(&batch_groups), hash);
let ops = vec![
Operation::Add,
Operation::Mul,
Operation::Push(Felt::new(1)),
Operation::Push(Felt::new(2)),
Operation::Push(Felt::new(3)),
Operation::Push(Felt::new(4)),
Operation::Push(Felt::new(5)),
Operation::Push(Felt::new(6)),
Operation::Add,
Operation::Push(Felt::new(7)),
];
let (batches, hash) = super::batch_ops(ops.clone());
assert_eq!(2, batches.len());
let batch0 = &batches[0];
assert_eq!(ops[..9], batch0.ops);
assert_eq!(7, batch0.num_groups());
let batch0_groups = [
build_group(&ops[..9]),
Felt::new(1),
Felt::new(2),
Felt::new(3),
Felt::new(4),
Felt::new(5),
Felt::new(6),
Felt::ZERO,
];
assert_eq!(batch0_groups, batch0.groups);
assert_eq!([9_usize, 0, 0, 0, 0, 0, 0, 0], batch0.op_counts);
let batch1 = &batches[1];
assert_eq!(vec![ops[9]], batch1.ops);
assert_eq!(2, batch1.num_groups());
let mut batch1_groups = [Felt::ZERO; BATCH_SIZE];
batch1_groups[0] = build_group(&[ops[9]]);
batch1_groups[1] = Felt::new(7);
assert_eq!([1_usize, 0, 0, 0, 0, 0, 0, 0], batch1.op_counts);
assert_eq!(batch1_groups, batch1.groups);
let all_groups = [batch0_groups, batch1_groups].concat();
assert_eq!(hasher::hash_elements(&all_groups[..10]), hash);
let ops = vec![
Operation::Add,
Operation::Mul,
Operation::Add,
Operation::Push(Felt::new(7)),
Operation::Add,
Operation::Add,
Operation::Push(Felt::new(11)),
Operation::Mul,
Operation::Mul,
Operation::Add,
];
let (batches, hash) = super::batch_ops(ops.clone());
assert_eq!(1, batches.len());
let batch = &batches[0];
assert_eq!(ops, batch.ops);
assert_eq!(4, batch.num_groups());
let batch_groups = [
build_group(&ops[..9]),
Felt::new(7),
Felt::new(11),
build_group(&ops[9..]),
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
];
assert_eq!([9_usize, 0, 0, 1, 0, 0, 0, 0], batch.op_counts);
assert_eq!(batch_groups, batch.groups);
assert_eq!(hasher::hash_elements(&batch_groups[..4]), hash);
let ops = vec![
Operation::Add,
Operation::Mul,
Operation::Add,
Operation::Add,
Operation::Add,
Operation::Mul,
Operation::Mul,
Operation::Add,
Operation::Push(Felt::new(11)),
];
let (batches, hash) = super::batch_ops(ops.clone());
assert_eq!(1, batches.len());
let batch = &batches[0];
assert_eq!(ops, batch.ops);
assert_eq!(3, batch.num_groups());
let batch_groups = [
build_group(&ops[..8]),
build_group(&[ops[8]]),
Felt::new(11),
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
];
assert_eq!(batch_groups, batch.groups);
assert_eq!([8_usize, 1, 0, 0, 0, 0, 0, 0], batch.op_counts);
assert_eq!(hasher::hash_elements(&batch_groups[..4]), hash);
let ops = vec![
Operation::Add,
Operation::Mul,
Operation::Add,
Operation::Add,
Operation::Add,
Operation::Mul,
Operation::Mul,
Operation::Push(Felt::new(1)),
Operation::Push(Felt::new(2)),
];
let (batches, hash) = super::batch_ops(ops.clone());
assert_eq!(1, batches.len());
let batch = &batches[0];
assert_eq!(ops, batch.ops);
assert_eq!(4, batch.num_groups());
let batch_groups = [
build_group(&ops[..8]),
Felt::new(1),
build_group(&[ops[8]]),
Felt::new(2),
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
];
assert_eq!(batch_groups, batch.groups);
assert_eq!([8_usize, 0, 1, 0, 0, 0, 0, 0], batch.op_counts);
assert_eq!(hasher::hash_elements(&batch_groups[..4]), hash);
let ops = vec![
Operation::Add,
Operation::Mul,
Operation::Push(Felt::new(1)),
Operation::Push(Felt::new(2)),
Operation::Push(Felt::new(3)),
Operation::Push(Felt::new(4)),
Operation::Push(Felt::new(5)),
Operation::Add,
Operation::Mul,
Operation::Add,
Operation::Mul,
Operation::Add,
Operation::Mul,
Operation::Add,
Operation::Mul,
Operation::Add,
Operation::Mul,
Operation::Push(Felt::new(6)),
Operation::Pad,
];
let (batches, hash) = super::batch_ops(ops.clone());
assert_eq!(2, batches.len());
let batch0 = &batches[0];
assert_eq!(ops[..17], batch0.ops);
assert_eq!(7, batch0.num_groups());
let batch0_groups = [
build_group(&ops[..9]),
Felt::new(1),
Felt::new(2),
Felt::new(3),
Felt::new(4),
Felt::new(5),
build_group(&ops[9..17]),
Felt::ZERO,
];
assert_eq!(batch0_groups, batch0.groups);
assert_eq!([9_usize, 0, 0, 0, 0, 0, 8, 0], batch0.op_counts);
let batch1 = &batches[1];
assert_eq!(ops[17..], batch1.ops);
assert_eq!(2, batch1.num_groups());
let batch1_groups = [
build_group(&ops[17..]),
Felt::new(6),
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
Felt::ZERO,
];
assert_eq!(batch1_groups, batch1.groups);
assert_eq!([2_usize, 0, 0, 0, 0, 0, 0, 0], batch1.op_counts);
let all_groups = [batch0_groups, batch1_groups].concat();
assert_eq!(hasher::hash_elements(&all_groups[..10]), hash);
}
fn build_group(ops: &[Operation]) -> Felt {
let mut group = 0u64;
for (i, op) in ops.iter().enumerate() {
group |= (op.op_code() as u64) << (Operation::OP_BITS * i);
}
Felt::new(group)
}
}