use core::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
use zkboo::{
backend::{Backend, Frontend},
crypto::{Digest, Seed},
memory::{FlexibleMemoryManager, MemoryManager, RefCount},
word::{ByWordType, CompositeWord, Shape, Word, WordIdx},
};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct MemoryUsage {
pub stack: usize,
pub heap: usize,
}
impl MemoryUsage {
pub fn total(&self) -> usize {
return self.stack + self.heap;
}
pub fn from_num_words(num_words: Shape) -> MemoryUsage {
let usize_num_bytes = core::mem::size_of::<usize>();
let mut usage = MemoryUsage { heap: 0, stack: 0 };
let num_word_types = num_words.map(|_| 1).sum();
usage.stack += 3 * usize_num_bytes * num_word_types;
usage.heap += num_words.map_with_width(|w, n| w / 8 * n).sum();
return usage;
}
}
impl Add<Self> for MemoryUsage {
type Output = Self;
fn add(self, other: Self) -> Self {
return Self {
stack: self.stack + other.stack,
heap: self.heap + other.heap,
};
}
}
impl AddAssign<Self> for MemoryUsage {
fn add_assign(&mut self, other: Self) {
self.stack += other.stack;
self.heap += other.heap;
}
}
impl Sub<Self> for MemoryUsage {
type Output = Self;
fn sub(self, other: Self) -> Self {
return Self {
stack: self.stack - other.stack,
heap: self.heap - other.heap,
};
}
}
impl SubAssign<Self> for MemoryUsage {
fn sub_assign(&mut self, other: Self) {
self.stack -= other.stack;
self.heap -= other.heap;
}
}
impl Mul<usize> for MemoryUsage {
type Output = Self;
fn mul(self, rhs: usize) -> Self {
return Self {
stack: self.stack * rhs,
heap: self.heap * rhs,
};
}
}
impl MulAssign<usize> for MemoryUsage {
fn mul_assign(&mut self, rhs: usize) {
self.stack *= rhs;
self.heap *= rhs;
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct GateCounts {
pub input: Shape,
pub alloc: Shape,
pub constant: Shape,
pub from_le_words: Shape,
pub to_le_words: Shape,
pub output: Shape,
pub not: Shape,
pub bitxor: Shape,
pub bitand: Shape,
pub bitxor_const: Shape,
pub bitand_const: Shape,
pub unbounded_shl: Shape,
pub unbounded_shr: Shape,
pub rotate_left: Shape,
pub rotate_right: Shape,
pub reverse_bits: Shape,
pub swap_bytes: Shape,
pub cast: ByWordType<Shape>,
pub carry: Shape,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ProfilingData {
gate_counts: GateCounts,
state_size: Shape,
max_live_wordrefs: Shape,
max_cumulative_refcount: Shape,
max_refcount: Shape,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ResponseData {
and_msg_size: Shape,
input_share_size: Shape,
}
impl ResponseData {
pub fn and_msg_size(&self) -> Shape {
return self.and_msg_size;
}
pub fn input_share_size(&self) -> Shape {
return self.input_share_size;
}
pub fn mem_usage<D: Digest, S: Seed>(&self) -> MemoryUsage {
let mut usage = MemoryUsage { heap: 0, stack: 0 };
usage.stack += core::mem::size_of::<u8>(); usage.stack += 2 * core::mem::size_of::<S>(); usage.stack += core::mem::size_of::<D>(); usage.heap += self.and_msg_size.map_with_width(|w, n| w / 8 * n).sum();
usage.heap += self.input_share_size.map_with_width(|w, n| w / 8 * n).sum();
return usage;
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ViewsData {
pub and_msgs_size: Shape,
pub input_share2_size: Shape,
pub output_shares_size: Shape,
}
impl ViewsData {
pub fn and_msgs_size(&self) -> Shape {
return self.and_msgs_size;
}
pub fn input_share2_size(&self) -> Shape {
return self.input_share2_size;
}
pub fn output_shares_size(&self) -> Shape {
return self.output_shares_size;
}
}
impl ProfilingData {
pub fn state_size(&self) -> Shape {
return self.state_size;
}
pub fn gate_counts(&self) -> &GateCounts {
return &self.gate_counts;
}
pub fn max_live_wordrefs(&self) -> Shape {
return self.max_live_wordrefs;
}
pub fn max_cumulative_refcount(&self) -> Shape {
return self.max_cumulative_refcount;
}
pub fn max_refcount(&self) -> Shape {
return self.max_refcount;
}
pub fn and_msg_size(&self) -> Shape {
return self
.gate_counts
.bitand
.zip(&self.gate_counts.carry, |bitand_count, carry_count| {
bitand_count + carry_count
});
}
pub fn response_data(&self) -> ResponseData {
return ResponseData {
and_msg_size: self.and_msg_size(),
input_share_size: self.gate_counts.input,
};
}
pub fn views_data(&self) -> ViewsData {
return ViewsData {
and_msgs_size: self.and_msg_size().map(|n| n * 3),
input_share2_size: self.gate_counts.input,
output_shares_size: self.gate_counts.output.map(|n| n * 3),
};
}
pub fn state_mem_usage(&self) -> MemoryUsage {
return MemoryUsage::from_num_words(self.state_size);
}
pub fn output_mem_usage(&self) -> MemoryUsage {
return MemoryUsage::from_num_words(self.gate_counts.output);
}
pub fn wordrefs_mem_usage(&self) -> MemoryUsage {
let usize_num_bytes = core::mem::size_of::<usize>();
let mut usage = MemoryUsage { heap: 0, stack: 0 };
let num_wordrefs = self.max_live_wordrefs.sum();
let num_idxs = self.max_cumulative_refcount.sum();
usage.stack += usize_num_bytes * num_wordrefs; usage.stack += usize_num_bytes * num_idxs; return usage;
}
pub fn memory_manager_mem_usage<RC: RefCount>(&self) -> MemoryUsage {
let usize_num_bytes = core::mem::size_of::<usize>();
let rc_num_bytes = core::mem::size_of::<RC>();
let mut usage = MemoryUsage { heap: 0, stack: 0 };
let state_size = self.state_size;
let num_word_types = state_size.map(|_| 1).sum();
usage.stack += 3 * usize_num_bytes * num_word_types; usage.stack += 6 * usize_num_bytes * num_word_types; usage.heap += rc_num_bytes * state_size.sum(); usage.heap += rc_num_bytes * state_size.map(|v| (v + 63) / 64).sum(); return usage;
}
pub fn executor_mem_usage<RC: RefCount>(&self) -> MemoryUsage {
let mut usage = MemoryUsage { heap: 0, stack: 0 };
usage += self.state_mem_usage();
usage += self.wordrefs_mem_usage();
usage += self.memory_manager_mem_usage::<RC>();
usage += self.output_mem_usage();
return usage;
}
pub fn prover_mem_usage<RC: RefCount>(&self) -> MemoryUsage {
let mut usage = MemoryUsage { heap: 0, stack: 0 };
usage += self.state_mem_usage() * 3;
usage += self.wordrefs_mem_usage();
usage += self.memory_manager_mem_usage::<RC>();
return usage;
}
pub fn verifier_mem_usage<RC: RefCount>(&self) -> MemoryUsage {
let mut usage = MemoryUsage { heap: 0, stack: 0 };
usage += self.state_mem_usage() * 2;
usage += self.wordrefs_mem_usage();
usage += self.memory_manager_mem_usage::<RC>();
return usage;
}
}
#[derive(Debug)]
pub struct ProfilingBackend {
data: ProfilingData,
memory_manager: FlexibleMemoryManager<usize>,
live_wordrefs: Shape,
cumulative_refcount: Shape,
}
impl ProfilingBackend {
pub fn new() -> Self {
return Self {
data: ProfilingData::default(),
memory_manager: FlexibleMemoryManager::new(),
live_wordrefs: Shape::zero(),
cumulative_refcount: Shape::zero(),
};
}
pub fn into_profiler(self) -> Frontend<Self> {
return self.into_frontend();
}
}
impl Backend for ProfilingBackend {
type FinalizeArg = ();
type FinalizeResult = ProfilingData;
fn finalize(self, _arg: Self::FinalizeArg) -> Self::FinalizeResult {
return self.data;
}
fn input<W: Word, const N: usize>(&mut self, _word: CompositeWord<W, N>) -> WordIdx<W, N> {
let (idx, size) = self.memory_manager.alloc::<W, N>();
*self.data.gate_counts.input.as_value_mut::<W>() += N;
*self.data.state_size.as_value_mut::<W>() = size;
return idx;
}
fn alloc<W: Word, const N: usize>(&mut self) -> WordIdx<W, N> {
let (idx, size) = self.memory_manager.alloc::<W, N>();
*self.data.gate_counts.alloc.as_value_mut::<W>() += N;
*self.data.state_size.as_value_mut::<W>() = size;
return idx;
}
fn constant<W: Word, const N: usize>(
&mut self,
_word: CompositeWord<W, N>,
_out: WordIdx<W, N>,
) {
*self.data.gate_counts.constant.as_value_mut::<W>() += N;
}
fn from_le_words<W: Word, const N: usize>(
&mut self,
_ins: [WordIdx<W, 1>; N],
_out: WordIdx<W, N>,
) {
*self.data.gate_counts.from_le_words.as_value_mut::<W>() += N;
}
fn to_le_words<W: Word, const N: usize>(
&mut self,
_in_: WordIdx<W, N>,
_outs: [WordIdx<W, 1>; N],
) {
*self.data.gate_counts.to_le_words.as_value_mut::<W>() += N;
}
fn output<W: Word, const N: usize>(&mut self, _out: WordIdx<W, N>) {
*self.data.gate_counts.output.as_value_mut::<W>() += N;
}
fn increase_refcount<W: Word, const N: usize>(&mut self, idx: WordIdx<W, N>) {
self.memory_manager.increase_refcount(idx);
let cumulative_refcount = self.cumulative_refcount.as_value_mut::<W>();
let max_cumulative_refcount = self.data.max_cumulative_refcount.as_value_mut::<W>();
*cumulative_refcount += N;
if cumulative_refcount > max_cumulative_refcount {
*max_cumulative_refcount = *cumulative_refcount;
}
let live_wordrefs = self.live_wordrefs.as_value_mut::<W>();
let max_live_wordrefs = self.data.max_live_wordrefs.as_value_mut::<W>();
*live_wordrefs += 1;
if live_wordrefs > max_live_wordrefs {
*max_live_wordrefs = *live_wordrefs;
}
let refcount = self.memory_manager.refcounts().as_vec::<W>()[idx.into_array()[0]];
let max_refcount = self.data.max_refcount.as_value_mut::<W>();
if refcount > *max_refcount {
*max_refcount = refcount;
}
}
fn decrease_refcount<W: Word, const N: usize>(&mut self, idx: WordIdx<W, N>) {
self.memory_manager.decrease_refcount(idx);
*self.cumulative_refcount.as_value_mut::<W>() -= N;
*self.live_wordrefs.as_value_mut::<W>() -= 1;
}
fn not<W: Word, const N: usize>(&mut self, _in_: WordIdx<W, N>, _out: WordIdx<W, N>) {
*self.data.gate_counts.not.as_value_mut::<W>() += N;
}
fn bitxor<W: Word, const N: usize>(
&mut self,
_inl: WordIdx<W, N>,
_inr: WordIdx<W, N>,
_out: WordIdx<W, N>,
) {
*self.data.gate_counts.bitxor.as_value_mut::<W>() += N;
}
fn bitand<W: Word, const N: usize>(
&mut self,
_inl: WordIdx<W, N>,
_inr: WordIdx<W, N>,
_out: WordIdx<W, N>,
) {
*self.data.gate_counts.bitand.as_value_mut::<W>() += N;
}
fn bitxor_const<W: Word, const N: usize>(
&mut self,
_inl: WordIdx<W, N>,
_inr: CompositeWord<W, N>,
_out: WordIdx<W, N>,
) {
*self.data.gate_counts.bitxor_const.as_value_mut::<W>() += N;
}
fn bitand_const<W: Word, const N: usize>(
&mut self,
_inl: WordIdx<W, N>,
_inr: CompositeWord<W, N>,
_out: WordIdx<W, N>,
) {
*self.data.gate_counts.bitand_const.as_value_mut::<W>() += N;
}
fn unbounded_shl<W: Word, const N: usize>(
&mut self,
_in_: WordIdx<W, N>,
_shift: usize,
_out: WordIdx<W, N>,
) {
if N == 1 {
*self.data.gate_counts.unbounded_shl.as_value_mut::<W>() += 1;
} else {
*self.data.gate_counts.rotate_left.as_value_mut::<W>() += N;
*self.data.gate_counts.bitand_const.as_value_mut::<W>() += 2 * N;
*self.data.gate_counts.bitxor.as_value_mut::<W>() += N;
}
}
fn unbounded_shr<W: Word, const N: usize>(
&mut self,
_in_: WordIdx<W, N>,
_shift: usize,
_out: WordIdx<W, N>,
) {
if N == 1 {
*self.data.gate_counts.unbounded_shr.as_value_mut::<W>() += 1;
} else {
*self.data.gate_counts.rotate_right.as_value_mut::<W>() += N;
*self.data.gate_counts.bitand_const.as_value_mut::<W>() += 2 * N;
*self.data.gate_counts.bitxor.as_value_mut::<W>() += N;
}
}
fn rotate_left<W: Word, const N: usize>(
&mut self,
_in_: WordIdx<W, N>,
_shift: usize,
_out: WordIdx<W, N>,
) {
if N == 1 {
*self.data.gate_counts.rotate_left.as_value_mut::<W>() += 1;
} else {
*self.data.gate_counts.rotate_left.as_value_mut::<W>() += N;
*self.data.gate_counts.bitand_const.as_value_mut::<W>() += 2 * N;
*self.data.gate_counts.bitxor.as_value_mut::<W>() += N;
}
}
fn rotate_right<W: Word, const N: usize>(
&mut self,
_in_: WordIdx<W, N>,
_shift: usize,
_out: WordIdx<W, N>,
) {
if N == 1 {
*self.data.gate_counts.rotate_right.as_value_mut::<W>() += 1;
} else {
*self.data.gate_counts.rotate_right.as_value_mut::<W>() += N;
*self.data.gate_counts.bitand_const.as_value_mut::<W>() += 2 * N;
*self.data.gate_counts.bitxor.as_value_mut::<W>() += N;
}
}
fn reverse_bits<W: Word, const N: usize>(&mut self, _in_: WordIdx<W, N>, _out: WordIdx<W, N>) {
*self.data.gate_counts.reverse_bits.as_value_mut::<W>() += N;
}
fn swap_bytes<W: Word, const N: usize>(&mut self, _in_: WordIdx<W, N>, _out: WordIdx<W, N>) {
*self.data.gate_counts.swap_bytes.as_value_mut::<W>() += N;
}
fn cast<W: Word, T: Word>(&mut self, _in_: WordIdx<W, 1>, _out: WordIdx<T, 1>) {
*self
.data
.gate_counts
.cast
.as_value_mut::<W>()
.as_value_mut::<T>() += 1;
}
fn carry<W: Word, const N: usize>(
&mut self,
_p: WordIdx<W, N>,
_g: WordIdx<W, N>,
_carry_in: bool,
_out: WordIdx<W, N>,
) {
*self.data.gate_counts.carry.as_value_mut::<W>() += N;
}
}