arcium-core-utils 0.4.5

Arcium core utils
Documentation
use std::{collections::HashMap, fmt::Debug, future::Future, pin::Pin, vec::IntoIter};

use futures::future::{BoxFuture, Shared};
use primitives::{
    algebra::{
        elliptic_curve::Curve,
        field::{binary::Gf2_128, mersenne::Mersenne107},
        BoxedUint,
    },
    correlated_randomness::{dabits::DaBit, pow::PowPair, singlets::Singlet, triples::Triple},
};

use crate::{
    circuit::{CircuitPreprocessing, FieldCircuitPreprocessing},
    errors::AbortError,
    preprocessing::errors::PreprocessingBundlerError,
};

/// Per-element lazy future returned by [`PreprocessingSource::request_n_elements`].
pub type NextElement<P> = Pin<Box<dyn Future<Output = Result<P, AbortError>> + Send + 'static>>;

/// Batch future returned by [`PreprocessingSource::request_n_elements_batch`].
///
/// Concretely typed so that:
/// - **`Clone`** is always available, enabling the default `request_n_elements` to share one drive
///   across all per-element futures.
/// - No associated type is needed on the trait; every implementor boxes its own async block into
///   this common alias.
pub type NextBatch<P> = Shared<BoxFuture<'static, Result<Vec<P>, AbortError>>>;

pub type NextSinglet<F> = NextElement<Singlet<F>>;
pub type NextSingletBatch<F> = NextBatch<Singlet<F>>;

pub type NextTriple<F> = NextElement<Triple<F>>;
pub type NextTripleBatch<F> = NextBatch<Triple<F>>;

pub type NextDaBit<F> = NextElement<DaBit<F>>;

pub type NextDaBitBatch<F> = NextBatch<DaBit<F>>;

pub type NextPowPair<F> = NextElement<PowPair<F>>;

pub type NextPowPairBatch<F> = NextBatch<PowPair<F>>;

/// An iterator over per-element preprocessing futures for every gate in a circuit.
///
/// Produced by [`crate::protocol::PreprocessingBundler::fetch_for`]. All network
/// requests have been **dispatched** (but not awaited) by the time this iterator is
/// constructed; consuming an item via `next_*` hands the caller a lazy
/// [`NextElement`] future that resolves when the preprocessing value is ready.
pub struct CerberusPreprocessingIterator<C: Curve> {
    // Base field iterators
    pub base_field_dabits: IntoIter<NextDaBit<C::BaseField>>,
    pub base_field_pow_preprocessing: HashMap<BoxedUint, IntoIter<NextPowPair<C::BaseField>>>,
    pub base_field_singlets: IntoIter<NextSinglet<C::BaseField>>,
    pub base_field_triples: IntoIter<NextTriple<C::BaseField>>,

    // Binary field iterators
    pub binary_singlets: IntoIter<NextSinglet<Gf2_128>>,
    pub binary_triples: IntoIter<NextTriple<Gf2_128>>,

    // Mersenne107 iterators
    pub mersenne107_dabits: IntoIter<NextDaBit<Mersenne107>>,
    pub mersenne107_singlets: IntoIter<NextSinglet<Mersenne107>>,
    pub mersenne107_triples: IntoIter<NextTriple<Mersenne107>>,

    // Scalar field iterators
    pub scalar_dabits: IntoIter<NextDaBit<C::Scalar>>,
    pub scalar_singlets: IntoIter<NextSinglet<C::Scalar>>,
    pub scalar_triples: IntoIter<NextTriple<C::Scalar>>,
}

impl<C: Curve> CerberusPreprocessingIterator<C> {
    pub fn len(&self) -> CircuitPreprocessing {
        CircuitPreprocessing {
            base_field_pow_pairs: self
                .base_field_pow_preprocessing
                .iter()
                .map(|(x, y)| (x.clone(), y.len()))
                .collect(),
            bit_singlets: self.binary_singlets.len(),
            bit_triples: self.binary_triples.len(),
            base_field: FieldCircuitPreprocessing {
                singlets: self.base_field_singlets.len(),
                triples: self.base_field_triples.len(),
                dabits: self.base_field_dabits.len(),
            },
            scalar: FieldCircuitPreprocessing {
                singlets: self.scalar_singlets.len(),
                triples: self.scalar_triples.len(),
                dabits: self.scalar_dabits.len(),
            },
            mersenne107: FieldCircuitPreprocessing {
                singlets: self.mersenne107_singlets.len(),
                triples: self.mersenne107_triples.len(),
                dabits: self.mersenne107_dabits.len(),
            },
        }
    }
    pub fn is_empty(&self) -> bool {
        self.base_field_pow_preprocessing
            .iter()
            .all(|(_, x)| x.len() == 0)
            && self.base_field_dabits.len() == 0
            && self.base_field_singlets.len() == 0
            && self.base_field_triples.len() == 0
            && self.binary_singlets.len() == 0
            && self.binary_triples.len() == 0
            && self.mersenne107_dabits.len() == 0
            && self.mersenne107_singlets.len() == 0
            && self.mersenne107_triples.len() == 0
            && self.scalar_dabits.len() == 0
            && self.scalar_singlets.len() == 0
            && self.scalar_triples.len() == 0
    }
    pub fn next_base_field_dabit(
        &mut self,
    ) -> Result<NextDaBit<C::BaseField>, PreprocessingBundlerError> {
        self.base_field_dabits.next().ok_or_else(|| {
            PreprocessingBundlerError::InsufficientDaBits(
                std::any::type_name::<C::BaseField>().to_string(),
            )
        })
    }

    pub fn next_base_field_powpair(
        &mut self,
        exp: &BoxedUint,
    ) -> Result<NextPowPair<C::BaseField>, PreprocessingBundlerError> {
        self.base_field_pow_preprocessing
            .get_mut(exp)
            .ok_or_else(|| PreprocessingBundlerError::InsufficientPowPreprocessing(exp.clone()))?
            .next()
            .ok_or_else(|| PreprocessingBundlerError::InsufficientPowPreprocessing(exp.clone()))
    }

    pub fn next_base_field_singlet(
        &mut self,
    ) -> Result<NextSinglet<C::BaseField>, PreprocessingBundlerError> {
        self.base_field_singlets.next().ok_or_else(|| {
            PreprocessingBundlerError::InsufficientSinglets(
                std::any::type_name::<C::BaseField>().to_string(),
            )
        })
    }

    pub fn next_base_field_triple(
        &mut self,
    ) -> Result<NextTriple<C::BaseField>, PreprocessingBundlerError> {
        self.base_field_triples.next().ok_or_else(|| {
            PreprocessingBundlerError::InsufficientTriples(
                std::any::type_name::<C::BaseField>().to_string(),
            )
        })
    }

    pub fn next_bit_singlet(&mut self) -> Result<NextSinglet<Gf2_128>, PreprocessingBundlerError> {
        self.binary_singlets.next().ok_or_else(|| {
            PreprocessingBundlerError::InsufficientSinglets(
                std::any::type_name::<Gf2_128>().to_string(),
            )
        })
    }

    pub fn next_bit_triple(&mut self) -> Result<NextTriple<Gf2_128>, PreprocessingBundlerError> {
        self.binary_triples.next().ok_or_else(|| {
            PreprocessingBundlerError::InsufficientTriples(
                std::any::type_name::<Gf2_128>().to_string(),
            )
        })
    }

    pub fn next_mersenne107_dabit(
        &mut self,
    ) -> Result<NextDaBit<Mersenne107>, PreprocessingBundlerError> {
        self.mersenne107_dabits
            .next()
            .ok_or_else(|| PreprocessingBundlerError::InsufficientDaBits("Mersenne107".to_string()))
    }

    pub fn next_mersenne107_singlet(
        &mut self,
    ) -> Result<NextSinglet<Mersenne107>, PreprocessingBundlerError> {
        self.mersenne107_singlets.next().ok_or_else(|| {
            PreprocessingBundlerError::InsufficientSinglets("Mersenne107".to_string())
        })
    }

    pub fn next_mersenne107_triple(
        &mut self,
    ) -> Result<NextTriple<Mersenne107>, PreprocessingBundlerError> {
        self.mersenne107_triples.next().ok_or_else(|| {
            PreprocessingBundlerError::InsufficientTriples("Mersenne107".to_string())
        })
    }

    pub fn next_scalar_dabit(&mut self) -> Result<NextDaBit<C::Scalar>, PreprocessingBundlerError> {
        self.scalar_dabits.next().ok_or_else(|| {
            PreprocessingBundlerError::InsufficientDaBits(
                std::any::type_name::<C::Scalar>().to_string(),
            )
        })
    }

    pub fn next_scalar_singlet(
        &mut self,
    ) -> Result<NextSinglet<C::Scalar>, PreprocessingBundlerError> {
        self.scalar_singlets.next().ok_or_else(|| {
            PreprocessingBundlerError::InsufficientSinglets(
                std::any::type_name::<C::Scalar>().to_string(),
            )
        })
    }

    pub fn next_scalar_triple(
        &mut self,
    ) -> Result<NextTriple<C::Scalar>, PreprocessingBundlerError> {
        self.scalar_triples.next().ok_or_else(|| {
            PreprocessingBundlerError::InsufficientTriples(
                std::any::type_name::<C::Scalar>().to_string(),
            )
        })
    }
}

impl<C: Curve> Debug for CerberusPreprocessingIterator<C> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "CerberusPreprocessingIterator with len: ")?;
        self.len().fmt(f)
    }
}