use crate::{
utils::field::{BaseField, ScalarField},
AsyncMPCCircuit,
};
use core_utils::circuit::{
AlgebraicType,
BitShareBinaryOp,
BitShareUnaryOp,
FieldShareBinaryOp,
FieldShareUnaryOp,
FieldType,
Input,
PointShareBinaryOp,
PointShareUnaryOp,
ShareOrPlaintext,
};
use ff::PrimeField;
use primitives::algebra::elliptic_curve::Curve25519Ristretto;
use std::{
error::Error,
ops::{Index, IndexMut},
};
type Gate = core_utils::circuit::Gate<Curve25519Ristretto>;
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub struct NetworkContent {
pub bit: usize,
pub base: usize,
pub scalar: usize,
pub mersenne: usize,
pub point: usize,
}
impl Index<FieldType> for NetworkContent {
type Output = usize;
fn index(&self, index: FieldType) -> &Self::Output {
match index {
FieldType::BaseField => &self.base,
FieldType::ScalarField => &self.scalar,
FieldType::Mersenne107 => &self.mersenne,
}
}
}
impl IndexMut<FieldType> for NetworkContent {
fn index_mut(&mut self, index: FieldType) -> &mut Self::Output {
match index {
FieldType::BaseField => &mut self.base,
FieldType::ScalarField => &mut self.scalar,
FieldType::Mersenne107 => &mut self.mersenne,
}
}
}
impl Index<AlgebraicType> for NetworkContent {
type Output = usize;
fn index(&self, index: AlgebraicType) -> &Self::Output {
match index {
AlgebraicType::BaseField => &self.base,
AlgebraicType::ScalarField => &self.scalar,
AlgebraicType::Point => &self.point,
AlgebraicType::Bit => &self.bit,
AlgebraicType::Mersenne107 => &self.mersenne,
}
}
}
impl IndexMut<AlgebraicType> for NetworkContent {
fn index_mut(&mut self, index: AlgebraicType) -> &mut Self::Output {
match index {
AlgebraicType::BaseField => &mut self.base,
AlgebraicType::ScalarField => &mut self.scalar,
AlgebraicType::Point => &mut self.point,
AlgebraicType::Bit => &mut self.bit,
AlgebraicType::Mersenne107 => &mut self.mersenne,
}
}
}
impl TryFrom<&[usize]> for NetworkContent {
type Error = Box<dyn Error>;
fn try_from(value: &[usize]) -> Result<Self, Self::Error> {
let expected = 5;
if value.len() < expected {
return Err(format!("Expected {} records, got {}", expected, value.len()).into());
}
let res = Self {
bit: value[0],
base: value[1],
scalar: value[2],
mersenne: value[3],
point: value[4],
};
Ok(res)
}
}
impl NetworkContent {
pub fn add_gate(&mut self, gate: &Gate, batch_size: usize) {
match gate {
Gate::Input { input_type } => {
if let Input::SecretPlaintext { algebraic_type, .. } = input_type {
self[*algebraic_type] += batch_size;
}
}
Gate::FieldShareUnaryOp { op, field_type, .. } => match op {
FieldShareUnaryOp::Neg => {}
FieldShareUnaryOp::MulInverse | FieldShareUnaryOp::IsZero => {
self[*field_type] += 3 * batch_size;
}
FieldShareUnaryOp::Open => {
self[*field_type] += batch_size;
}
},
Gate::FieldShareBinaryOp {
y_form,
op,
field_type,
..
} => match op {
FieldShareBinaryOp::Add => {}
FieldShareBinaryOp::Mul => match y_form {
ShareOrPlaintext::Share => {
self[*field_type] += 2 * batch_size;
}
ShareOrPlaintext::Plaintext => {}
},
},
Gate::BatchSummation { .. } => {}
Gate::BitShareUnaryOp { op, .. } => match op {
BitShareUnaryOp::Not => {}
BitShareUnaryOp::Open => {
self.bit += batch_size;
}
},
Gate::BitShareBinaryOp { y_form, op, .. } => match op {
BitShareBinaryOp::Xor => {}
BitShareBinaryOp::Or | BitShareBinaryOp::And => match y_form {
ShareOrPlaintext::Share => {
self.bit += 2 * batch_size;
}
ShareOrPlaintext::Plaintext => {}
},
},
Gate::PointShareUnaryOp { op, .. } => match op {
PointShareUnaryOp::Neg => {}
PointShareUnaryOp::Open => {
self.point += batch_size;
}
PointShareUnaryOp::IsZero => {
self.scalar += batch_size;
self.point += 2 * batch_size;
}
},
Gate::PointShareBinaryOp {
p_form, y_form, op, ..
} => match op {
PointShareBinaryOp::Add => {}
PointShareBinaryOp::ScalarMul => {
if matches!(y_form, ShareOrPlaintext::Share)
&& matches!(p_form, ShareOrPlaintext::Share)
{
self.scalar += batch_size;
self.point += 2 * batch_size;
}
}
},
Gate::FieldPlaintextUnaryOp { .. } => {}
Gate::FieldPlaintextBinaryOp { .. } => {}
Gate::BitPlaintextUnaryOp { .. } => {}
Gate::BitPlaintextBinaryOp { .. } => {}
Gate::PointPlaintextUnaryOp { .. } => {}
Gate::PointPlaintextBinaryOp { .. } => {}
Gate::DaBit { .. } => {}
Gate::GetDaBitFieldShare { .. } => {}
Gate::GetDaBitSharedBit { .. } => {}
Gate::BaseFieldPow { .. } => {
self.base += 3 * batch_size;
}
Gate::BitPlaintextToField { .. } => {}
Gate::FieldPlaintextToBit { .. } => {}
Gate::BatchGetIndex { .. } => {}
Gate::CollectToBatch { .. } => {}
Gate::PointFromPlaintextExtendedEdwards { .. } => {}
Gate::PlaintextPointToExtendedEdwards { .. } => {}
Gate::PlaintextKeccakF1600 { .. } => {}
Gate::CompressPlaintextPoint { .. } => {}
Gate::KeyRecoveryPlaintextComputeErrors { .. } => {}
}
}
pub fn from_circuit(circuit: &AsyncMPCCircuit) -> NetworkContent {
let mut network_content = NetworkContent::default();
let batched = circuit.determine_batched_gates();
for (label, gate) in circuit.iter().enumerate() {
let batched = batched.get(&(label as u32)).copied().unwrap_or(1);
network_content.add_gate(gate, batched);
}
network_content
}
pub fn network_size(&self) -> usize {
2 * (self.base * BaseField::NUM_BITS.div_ceil(8) as usize
+ self.scalar * ScalarField::NUM_BITS.div_ceil(8) as usize
+ self.mersenne * 107usize.div_ceil(8)
+ self.point * 32)
+ self.bit * (1 + 16) }
}