arcium-core-utils 0.4.1

Arcium core utils
Documentation
use std::{
    collections::HashMap,
    ops::{Add, AddAssign, Index, IndexMut},
};

use primitives::algebra::{elliptic_curve::Curve, BoxedUint};

use crate::circuit::{
    AlgebraicType,
    BitShareBinaryOp,
    Circuit,
    FieldShareBinaryOp,
    FieldShareUnaryOp,
    FieldType,
    Gate,
    GateExt,
    Input,
    PointShareBinaryOp,
    PointShareUnaryOp,
    ShareOrPlaintext,
};

impl<C: Curve> Circuit<C> {
    /// Counts all the preprocessing required by the circuit.
    /// This includes scalar triples, base field triples, scalar singlets, base field singlets,
    /// and base field pow pairs.
    pub fn required_preprocessing(&self) -> CircuitPreprocessing {
        let mut circuit_preprocessing = CircuitPreprocessing::default();
        for gate in self.iter_gates_ext() {
            self.add_to_required_preprocessing(gate, &mut circuit_preprocessing);
        }
        circuit_preprocessing
    }

    /// Updates the circuit preprocessing structure with the requirements of this gate.
    pub fn add_to_required_preprocessing(
        &self,
        gate: &GateExt<C>,
        circuit_preprocessing: &mut CircuitPreprocessing,
    ) {
        let batch_size = gate.output.get_batch_size() as usize;
        match &gate.gate {
            Gate::Input(Input::SecretPlaintext { algebraic_type, .. })
            | Gate::Random { algebraic_type, .. } => match algebraic_type {
                AlgebraicType::ScalarField | AlgebraicType::Point => {
                    circuit_preprocessing.scalar.singlets += batch_size;
                }
                AlgebraicType::BaseField => {
                    circuit_preprocessing.base_field.singlets += batch_size;
                }
                AlgebraicType::Bit => {
                    circuit_preprocessing.bit_singlets += batch_size;
                }
                AlgebraicType::Mersenne107 => {
                    circuit_preprocessing.mersenne107.singlets += batch_size;
                }
            },
            Gate::FieldShareUnaryOp { op, .. } => {
                let field_type = gate.output.get_field_type_unchecked();
                match op {
                    FieldShareUnaryOp::MulInverse | FieldShareUnaryOp::IsZero => {
                        circuit_preprocessing[field_type].triples += batch_size;
                        circuit_preprocessing[field_type].singlets += batch_size;
                    }
                    FieldShareUnaryOp::Open | FieldShareUnaryOp::Neg => (),
                }
            }
            Gate::FieldShareBinaryOp { op, y, .. } => match op {
                FieldShareBinaryOp::Mul => {
                    let field_type = gate.output.get_field_type_unchecked();
                    if self.gate_output_unchecked(*y).get_form() == ShareOrPlaintext::Share {
                        circuit_preprocessing[field_type].triples += batch_size;
                    }
                }
                FieldShareBinaryOp::Add => (),
            },
            Gate::PointShareUnaryOp { op, .. } => match op {
                PointShareUnaryOp::IsZero => {
                    circuit_preprocessing.scalar.triples += batch_size;
                    circuit_preprocessing.scalar.singlets += batch_size;
                }
                PointShareUnaryOp::Open | PointShareUnaryOp::Neg => (),
            },
            Gate::PointShareBinaryOp { op, p, y, .. } => match op {
                PointShareBinaryOp::ScalarMul => {
                    if self.gate_output_unchecked(*p).get_form() == ShareOrPlaintext::Share
                        && self.gate_output_unchecked(*y).get_form() == ShareOrPlaintext::Share
                    {
                        circuit_preprocessing.scalar.triples += batch_size;
                    }
                }
                PointShareBinaryOp::Add => (),
            },
            Gate::BitShareBinaryOp { op, y, .. } => match op {
                BitShareBinaryOp::And | BitShareBinaryOp::Or => {
                    if self.gate_output_unchecked(*y).get_form() == ShareOrPlaintext::Share {
                        circuit_preprocessing.bit_triples += batch_size;
                    }
                }
                BitShareBinaryOp::Xor => (),
            },
            Gate::BaseFieldPow { exp, .. } => {
                *circuit_preprocessing
                    .base_field_pow_pairs
                    .entry(exp.clone())
                    .or_insert(0) += batch_size;
                circuit_preprocessing.base_field.triples += batch_size;
            }
            Gate::DaBit { field_type, .. } => {
                circuit_preprocessing[*field_type].dabits += batch_size
            }

            Gate::Input(_)
            | Gate::Constant { .. }
            | Gate::BatchSummation { .. }
            | Gate::BitShareUnaryOp { .. }
            | Gate::FieldPlaintextUnaryOp { .. }
            | Gate::FieldPlaintextBinaryOp { .. }
            | Gate::BitPlaintextUnaryOp { .. }
            | Gate::BitPlaintextBinaryOp { .. }
            | Gate::PointPlaintextUnaryOp { .. }
            | Gate::PointPlaintextBinaryOp { .. }
            | Gate::GetDaBitFieldShare { .. }
            | Gate::GetDaBitSharedBit { .. }
            | Gate::BitPlaintextToField { .. }
            | Gate::FieldPlaintextToBit { .. }
            | Gate::ExtractFromBatch { .. }
            | Gate::CollectToBatch { .. }
            | Gate::PointFromPlaintextExtendedEdwards { .. }
            | Gate::PlaintextPointToExtendedEdwards { .. }
            | Gate::PlaintextKeccakF1600 { .. }
            | Gate::CompressPlaintextPoint { .. }
            | Gate::KeyRecoveryPlaintextComputeErrors { .. } => (),
        };
    }
}

/// Field specific preprocessing requirements for a circuit.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct FieldCircuitPreprocessing {
    pub singlets: usize,
    pub triples: usize,
    pub dabits: usize,
}

impl AddAssign for FieldCircuitPreprocessing {
    fn add_assign(&mut self, rhs: Self) {
        self.singlets += rhs.singlets;
        self.triples += rhs.triples;
        self.dabits += rhs.dabits;
    }
}

impl Add for FieldCircuitPreprocessing {
    type Output = Self;
    fn add(self, rhs: Self) -> Self::Output {
        let mut res = self;
        res += rhs;
        res
    }
}

/// Preprocessing requirements for a circuit.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct CircuitPreprocessing {
    pub base_field_pow_pairs: HashMap<BoxedUint, usize>,
    pub bit_singlets: usize,
    pub bit_triples: usize,
    pub base_field: FieldCircuitPreprocessing,
    pub scalar: FieldCircuitPreprocessing,
    pub mersenne107: FieldCircuitPreprocessing,
}

impl AddAssign for CircuitPreprocessing {
    fn add_assign(&mut self, rhs: Self) {
        self.bit_singlets += rhs.bit_singlets;
        self.bit_triples += rhs.bit_triples;
        self.base_field += rhs.base_field;
        self.scalar += rhs.scalar;
        self.mersenne107 += rhs.mersenne107;
        for (k, v) in rhs.base_field_pow_pairs {
            *self.base_field_pow_pairs.entry(k).or_insert(0) += v;
        }
    }
}

impl Add for CircuitPreprocessing {
    type Output = Self;

    fn add(self, other: Self) -> Self::Output {
        let mut res = self;
        res += other;
        res
    }
}

impl Index<FieldType> for CircuitPreprocessing {
    type Output = FieldCircuitPreprocessing;

    fn index(&self, index: FieldType) -> &Self::Output {
        match index {
            FieldType::BaseField => &self.base_field,
            FieldType::ScalarField => &self.scalar,
            FieldType::Mersenne107 => &self.mersenne107,
        }
    }
}

impl IndexMut<FieldType> for CircuitPreprocessing {
    fn index_mut(&mut self, index: FieldType) -> &mut Self::Output {
        match index {
            FieldType::BaseField => &mut self.base_field,
            FieldType::ScalarField => &mut self.scalar,
            FieldType::Mersenne107 => &mut self.mersenne107,
        }
    }
}