Skip to main content

core_utils/preprocessing/
iterator.rs

1use std::{collections::HashMap, future::Future, pin::Pin, vec::IntoIter};
2
3use futures::future::{BoxFuture, Shared};
4use primitives::{
5    algebra::{
6        elliptic_curve::Curve,
7        field::{binary::Gf2_128, mersenne::Mersenne107},
8        BoxedUint,
9    },
10    correlated_randomness::{dabits::DaBit, pow::PowPair, singlets::Singlet, triples::Triple},
11};
12
13use crate::{errors::AbortError, preprocessing::errors::PreprocessingBundlerError};
14
15/// Per-element lazy future returned by [`PreprocessingSource::request_n_elements`].
16pub type NextElement<P> = Pin<Box<dyn Future<Output = Result<P, AbortError>> + Send + 'static>>;
17
18/// Batch future returned by [`PreprocessingSource::request_n_elements_batch`].
19///
20/// Concretely typed so that:
21/// - **`Clone`** is always available, enabling the default `request_n_elements` to share one drive
22///   across all per-element futures.
23/// - No associated type is needed on the trait; every implementor boxes its own async block into
24///   this common alias.
25pub type NextBatch<P> = Shared<BoxFuture<'static, Result<Vec<P>, AbortError>>>;
26
27pub type NextSinglet<F> = NextElement<Singlet<F>>;
28pub type NextSingletBatch<F> = NextBatch<Singlet<F>>;
29
30pub type NextTriple<F> = NextElement<Triple<F>>;
31pub type NextTripleBatch<F> = NextBatch<Triple<F>>;
32
33pub type NextDaBit<F> = NextElement<DaBit<F>>;
34
35pub type NextDaBitBatch<F> = NextBatch<DaBit<F>>;
36
37pub type NextPowPair<F> = NextElement<PowPair<F>>;
38
39pub type NextPowPairBatch<F> = NextBatch<PowPair<F>>;
40
41/// An iterator over per-element preprocessing futures for every gate in a circuit.
42///
43/// Produced by [`crate::protocol::PreprocessingBundler::fetch_for`]. All network
44/// requests have been **dispatched** (but not awaited) by the time this iterator is
45/// constructed; consuming an item via `next_*` hands the caller a lazy
46/// [`NextElement`] future that resolves when the preprocessing value is ready.
47pub struct CerberusPreprocessingIterator<C: Curve> {
48    // Base field iterators
49    pub base_field_dabits: IntoIter<NextDaBit<C::BaseField>>,
50    pub base_field_pow_preprocessing: HashMap<BoxedUint, IntoIter<NextPowPair<C::BaseField>>>,
51    pub base_field_singlets: IntoIter<NextSinglet<C::BaseField>>,
52    pub base_field_triples: IntoIter<NextTriple<C::BaseField>>,
53
54    // Binary field iterators
55    pub binary_singlets: IntoIter<NextSinglet<Gf2_128>>,
56    pub binary_triples: IntoIter<NextTriple<Gf2_128>>,
57
58    // Mersenne107 iterators
59    pub mersenne107_dabits: IntoIter<NextDaBit<Mersenne107>>,
60    pub mersenne107_singlets: IntoIter<NextSinglet<Mersenne107>>,
61    pub mersenne107_triples: IntoIter<NextTriple<Mersenne107>>,
62
63    // Scalar field iterators
64    pub scalar_dabits: IntoIter<NextDaBit<C::Scalar>>,
65    pub scalar_singlets: IntoIter<NextSinglet<C::Scalar>>,
66    pub scalar_triples: IntoIter<NextTriple<C::Scalar>>,
67}
68
69impl<C: Curve> CerberusPreprocessingIterator<C> {
70    pub fn next_base_field_dabit(
71        &mut self,
72    ) -> Result<NextDaBit<C::BaseField>, PreprocessingBundlerError> {
73        self.base_field_dabits.next().ok_or_else(|| {
74            PreprocessingBundlerError::InsufficientDaBits(
75                std::any::type_name::<C::BaseField>().to_string(),
76            )
77        })
78    }
79
80    pub fn next_base_field_powpair(
81        &mut self,
82        exp: &BoxedUint,
83    ) -> Result<NextPowPair<C::BaseField>, PreprocessingBundlerError> {
84        self.base_field_pow_preprocessing
85            .get_mut(exp)
86            .ok_or_else(|| PreprocessingBundlerError::InsufficientPowPreprocessing(exp.clone()))?
87            .next()
88            .ok_or_else(|| PreprocessingBundlerError::InsufficientPowPreprocessing(exp.clone()))
89    }
90
91    pub fn next_base_field_singlet(
92        &mut self,
93    ) -> Result<NextSinglet<C::BaseField>, PreprocessingBundlerError> {
94        self.base_field_singlets.next().ok_or_else(|| {
95            PreprocessingBundlerError::InsufficientSinglets(
96                std::any::type_name::<C::BaseField>().to_string(),
97            )
98        })
99    }
100
101    pub fn next_base_field_triple(
102        &mut self,
103    ) -> Result<NextTriple<C::BaseField>, PreprocessingBundlerError> {
104        self.base_field_triples.next().ok_or_else(|| {
105            PreprocessingBundlerError::InsufficientTriples(
106                std::any::type_name::<C::BaseField>().to_string(),
107            )
108        })
109    }
110
111    pub fn next_bit_singlet(&mut self) -> Result<NextSinglet<Gf2_128>, PreprocessingBundlerError> {
112        self.binary_singlets.next().ok_or_else(|| {
113            PreprocessingBundlerError::InsufficientSinglets(
114                std::any::type_name::<Gf2_128>().to_string(),
115            )
116        })
117    }
118
119    pub fn next_bit_triple(&mut self) -> Result<NextTriple<Gf2_128>, PreprocessingBundlerError> {
120        self.binary_triples.next().ok_or_else(|| {
121            PreprocessingBundlerError::InsufficientTriples(
122                std::any::type_name::<Gf2_128>().to_string(),
123            )
124        })
125    }
126
127    pub fn next_mersenne107_dabit(
128        &mut self,
129    ) -> Result<NextDaBit<Mersenne107>, PreprocessingBundlerError> {
130        self.mersenne107_dabits
131            .next()
132            .ok_or_else(|| PreprocessingBundlerError::InsufficientDaBits("Mersenne107".to_string()))
133    }
134
135    pub fn next_mersenne107_singlet(
136        &mut self,
137    ) -> Result<NextSinglet<Mersenne107>, PreprocessingBundlerError> {
138        self.mersenne107_singlets.next().ok_or_else(|| {
139            PreprocessingBundlerError::InsufficientSinglets("Mersenne107".to_string())
140        })
141    }
142
143    pub fn next_mersenne107_triple(
144        &mut self,
145    ) -> Result<NextTriple<Mersenne107>, PreprocessingBundlerError> {
146        self.mersenne107_triples.next().ok_or_else(|| {
147            PreprocessingBundlerError::InsufficientTriples("Mersenne107".to_string())
148        })
149    }
150
151    pub fn next_scalar_dabit(&mut self) -> Result<NextDaBit<C::Scalar>, PreprocessingBundlerError> {
152        self.scalar_dabits.next().ok_or_else(|| {
153            PreprocessingBundlerError::InsufficientDaBits(
154                std::any::type_name::<C::Scalar>().to_string(),
155            )
156        })
157    }
158
159    pub fn next_scalar_singlet(
160        &mut self,
161    ) -> Result<NextSinglet<C::Scalar>, PreprocessingBundlerError> {
162        self.scalar_singlets.next().ok_or_else(|| {
163            PreprocessingBundlerError::InsufficientSinglets(
164                std::any::type_name::<C::Scalar>().to_string(),
165            )
166        })
167    }
168
169    pub fn next_scalar_triple(
170        &mut self,
171    ) -> Result<NextTriple<C::Scalar>, PreprocessingBundlerError> {
172        self.scalar_triples.next().ok_or_else(|| {
173            PreprocessingBundlerError::InsufficientTriples(
174                std::any::type_name::<C::Scalar>().to_string(),
175            )
176        })
177    }
178}