use std::{
collections::HashMap,
ops::{Add, AddAssign, Index, IndexMut},
};
use primitives::algebra::{elliptic_curve::Curve, BoxedUint};
use crate::circuit::{
AlgebraicType,
BitShareBinaryOp,
Circuit,
FieldShareBinaryOp,
FieldShareUnaryOp,
FieldType,
Gate,
GateExt,
Input,
PointShareBinaryOp,
PointShareUnaryOp,
ShareOrPlaintext,
};
impl<C: Curve> Circuit<C> {
pub fn required_preprocessing(&self) -> CircuitPreprocessing {
let mut circuit_preprocessing = CircuitPreprocessing::default();
for gate in self.iter_gates_ext() {
self.add_to_required_preprocessing(gate, &mut circuit_preprocessing);
}
circuit_preprocessing
}
pub fn add_to_required_preprocessing(
&self,
gate: &GateExt<C>,
circuit_preprocessing: &mut CircuitPreprocessing,
) {
let batch_size = gate.output.get_batch_size() as usize;
match &gate.gate {
Gate::Input(Input::SecretPlaintext { algebraic_type, .. })
| Gate::Random { algebraic_type, .. } => match algebraic_type {
AlgebraicType::ScalarField | AlgebraicType::Point => {
circuit_preprocessing.scalar.singlets += batch_size;
}
AlgebraicType::BaseField => {
circuit_preprocessing.base_field.singlets += batch_size;
}
AlgebraicType::Bit => {
circuit_preprocessing.bit_singlets += batch_size;
}
AlgebraicType::Mersenne107 => {
circuit_preprocessing.mersenne107.singlets += batch_size;
}
},
Gate::FieldShareUnaryOp { op, .. } => {
let field_type = gate.output.get_field_type_unchecked();
match op {
FieldShareUnaryOp::MulInverse | FieldShareUnaryOp::IsZero => {
circuit_preprocessing[field_type].triples += batch_size;
circuit_preprocessing[field_type].singlets += batch_size;
}
FieldShareUnaryOp::Open | FieldShareUnaryOp::Neg => (),
}
}
Gate::FieldShareBinaryOp { op, y, .. } => match op {
FieldShareBinaryOp::Mul => {
let field_type = gate.output.get_field_type_unchecked();
if self.gate_output_unchecked(*y).get_form() == ShareOrPlaintext::Share {
circuit_preprocessing[field_type].triples += batch_size;
}
}
FieldShareBinaryOp::Add => (),
},
Gate::PointShareUnaryOp { op, .. } => match op {
PointShareUnaryOp::IsZero => {
circuit_preprocessing.scalar.triples += batch_size;
circuit_preprocessing.scalar.singlets += batch_size;
}
PointShareUnaryOp::Open | PointShareUnaryOp::Neg => (),
},
Gate::PointShareBinaryOp { op, p, y, .. } => match op {
PointShareBinaryOp::ScalarMul => {
if self.gate_output_unchecked(*p).get_form() == ShareOrPlaintext::Share
&& self.gate_output_unchecked(*y).get_form() == ShareOrPlaintext::Share
{
circuit_preprocessing.scalar.triples += batch_size;
}
}
PointShareBinaryOp::Add => (),
},
Gate::BitShareBinaryOp { op, y, .. } => match op {
BitShareBinaryOp::And | BitShareBinaryOp::Or => {
if self.gate_output_unchecked(*y).get_form() == ShareOrPlaintext::Share {
circuit_preprocessing.bit_triples += batch_size;
}
}
BitShareBinaryOp::Xor => (),
},
Gate::BaseFieldPow { exp, .. } => {
*circuit_preprocessing
.base_field_pow_pairs
.entry(exp.clone())
.or_insert(0) += batch_size;
circuit_preprocessing.base_field.triples += batch_size;
}
Gate::DaBit { field_type, .. } => {
circuit_preprocessing[*field_type].dabits += batch_size
}
Gate::Input(_)
| Gate::Constant { .. }
| Gate::BatchSummation { .. }
| Gate::BitShareUnaryOp { .. }
| Gate::FieldPlaintextUnaryOp { .. }
| Gate::FieldPlaintextBinaryOp { .. }
| Gate::BitPlaintextUnaryOp { .. }
| Gate::BitPlaintextBinaryOp { .. }
| Gate::PointPlaintextUnaryOp { .. }
| Gate::PointPlaintextBinaryOp { .. }
| Gate::GetDaBitFieldShare { .. }
| Gate::GetDaBitSharedBit { .. }
| Gate::BitPlaintextToField { .. }
| Gate::FieldPlaintextToBit { .. }
| Gate::ExtractFromBatch { .. }
| Gate::CollectToBatch { .. }
| Gate::PointFromPlaintextExtendedEdwards { .. }
| Gate::PlaintextPointToExtendedEdwards { .. }
| Gate::PlaintextKeccakF1600 { .. }
| Gate::CompressPlaintextPoint { .. }
| Gate::KeyRecoveryPlaintextComputeErrors { .. } => (),
};
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct FieldCircuitPreprocessing {
pub singlets: usize,
pub triples: usize,
pub dabits: usize,
}
impl AddAssign for FieldCircuitPreprocessing {
fn add_assign(&mut self, rhs: Self) {
self.singlets += rhs.singlets;
self.triples += rhs.triples;
self.dabits += rhs.dabits;
}
}
impl Add for FieldCircuitPreprocessing {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
let mut res = self;
res += rhs;
res
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct CircuitPreprocessing {
pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
pub bit_singlets: usize,
pub bit_triples: usize,
pub base_field: FieldCircuitPreprocessing,
pub scalar: FieldCircuitPreprocessing,
pub mersenne107: FieldCircuitPreprocessing,
}
impl AddAssign for CircuitPreprocessing {
fn add_assign(&mut self, rhs: Self) {
self.bit_singlets += rhs.bit_singlets;
self.bit_triples += rhs.bit_triples;
self.base_field += rhs.base_field;
self.scalar += rhs.scalar;
self.mersenne107 += rhs.mersenne107;
for (k, v) in rhs.base_field_pow_pairs {
*self.base_field_pow_pairs.entry(k).or_insert(0) += v;
}
}
}
impl Add for CircuitPreprocessing {
type Output = Self;
fn add(self, other: Self) -> Self::Output {
let mut res = self;
res += other;
res
}
}
impl Index<FieldType> for CircuitPreprocessing {
type Output = FieldCircuitPreprocessing;
fn index(&self, index: FieldType) -> &Self::Output {
match index {
FieldType::BaseField => &self.base_field,
FieldType::ScalarField => &self.scalar,
FieldType::Mersenne107 => &self.mersenne107,
}
}
}
impl IndexMut<FieldType> for CircuitPreprocessing {
fn index_mut(&mut self, index: FieldType) -> &mut Self::Output {
match index {
FieldType::BaseField => &mut self.base_field,
FieldType::ScalarField => &mut self.scalar,
FieldType::Mersenne107 => &mut self.mersenne107,
}
}
}