Skip to main content

core_utils/preprocessing/
iterator.rs

1use std::{collections::HashMap, fmt::Debug, future::Future, pin::Pin, vec::IntoIter};
2
3use futures::future::{BoxFuture, Shared};
4use primitives::{
5    algebra::{
6        elliptic_curve::Curve,
7        field::{binary::Gf2_128, mersenne::Mersenne107},
8        BoxedUint,
9    },
10    correlated_randomness::{dabits::DaBit, pow::PowPair, singlets::Singlet, triples::Triple},
11};
12
13use crate::{
14    circuit::{CircuitPreprocessing, FieldCircuitPreprocessing},
15    errors::AbortError,
16    preprocessing::errors::PreprocessingBundlerError,
17};
18
19/// Per-element lazy future returned by [`PreprocessingSource::request_n_elements`].
20pub type NextElement<P> = Pin<Box<dyn Future<Output = Result<P, AbortError>> + Send + 'static>>;
21
22/// Batch future returned by [`PreprocessingSource::request_n_elements_batch`].
23///
24/// Concretely typed so that:
25/// - **`Clone`** is always available, enabling the default `request_n_elements` to share one drive
26///   across all per-element futures.
27/// - No associated type is needed on the trait; every implementor boxes its own async block into
28///   this common alias.
29pub type NextBatch<P> = Shared<BoxFuture<'static, Result<Vec<P>, AbortError>>>;
30
31pub type NextSinglet<F> = NextElement<Singlet<F>>;
32pub type NextSingletBatch<F> = NextBatch<Singlet<F>>;
33
34pub type NextTriple<F> = NextElement<Triple<F>>;
35pub type NextTripleBatch<F> = NextBatch<Triple<F>>;
36
37pub type NextDaBit<F> = NextElement<DaBit<F>>;
38
39pub type NextDaBitBatch<F> = NextBatch<DaBit<F>>;
40
41pub type NextPowPair<F> = NextElement<PowPair<F>>;
42
43pub type NextPowPairBatch<F> = NextBatch<PowPair<F>>;
44
45/// An iterator over per-element preprocessing futures for every gate in a circuit.
46///
47/// Produced by [`crate::protocol::PreprocessingBundler::fetch_for`]. All network
48/// requests have been **dispatched** (but not awaited) by the time this iterator is
49/// constructed; consuming an item via `next_*` hands the caller a lazy
50/// [`NextElement`] future that resolves when the preprocessing value is ready.
51pub struct CerberusPreprocessingIterator<C: Curve> {
52    // Base field iterators
53    pub base_field_dabits: IntoIter<NextDaBit<C::BaseField>>,
54    pub base_field_pow_preprocessing: HashMap<BoxedUint, IntoIter<NextPowPair<C::BaseField>>>,
55    pub base_field_singlets: IntoIter<NextSinglet<C::BaseField>>,
56    pub base_field_triples: IntoIter<NextTriple<C::BaseField>>,
57
58    // Binary field iterators
59    pub binary_singlets: IntoIter<NextSinglet<Gf2_128>>,
60    pub binary_triples: IntoIter<NextTriple<Gf2_128>>,
61
62    // Mersenne107 iterators
63    pub mersenne107_dabits: IntoIter<NextDaBit<Mersenne107>>,
64    pub mersenne107_singlets: IntoIter<NextSinglet<Mersenne107>>,
65    pub mersenne107_triples: IntoIter<NextTriple<Mersenne107>>,
66
67    // Scalar field iterators
68    pub scalar_dabits: IntoIter<NextDaBit<C::Scalar>>,
69    pub scalar_singlets: IntoIter<NextSinglet<C::Scalar>>,
70    pub scalar_triples: IntoIter<NextTriple<C::Scalar>>,
71}
72
73impl<C: Curve> CerberusPreprocessingIterator<C> {
74    pub fn len(&self) -> CircuitPreprocessing {
75        CircuitPreprocessing {
76            base_field_pow_pairs: self
77                .base_field_pow_preprocessing
78                .iter()
79                .map(|(x, y)| (x.clone(), y.len()))
80                .collect(),
81            bit_singlets: self.binary_singlets.len(),
82            bit_triples: self.binary_triples.len(),
83            base_field: FieldCircuitPreprocessing {
84                singlets: self.base_field_singlets.len(),
85                triples: self.base_field_triples.len(),
86                dabits: self.base_field_dabits.len(),
87            },
88            scalar: FieldCircuitPreprocessing {
89                singlets: self.scalar_singlets.len(),
90                triples: self.scalar_triples.len(),
91                dabits: self.scalar_dabits.len(),
92            },
93            mersenne107: FieldCircuitPreprocessing {
94                singlets: self.mersenne107_singlets.len(),
95                triples: self.mersenne107_triples.len(),
96                dabits: self.mersenne107_dabits.len(),
97            },
98        }
99    }
100    pub fn is_empty(&self) -> bool {
101        self.base_field_pow_preprocessing
102            .iter()
103            .all(|(_, x)| x.len() == 0)
104            && self.base_field_dabits.len() == 0
105            && self.base_field_singlets.len() == 0
106            && self.base_field_triples.len() == 0
107            && self.binary_singlets.len() == 0
108            && self.binary_triples.len() == 0
109            && self.mersenne107_dabits.len() == 0
110            && self.mersenne107_singlets.len() == 0
111            && self.mersenne107_triples.len() == 0
112            && self.scalar_dabits.len() == 0
113            && self.scalar_singlets.len() == 0
114            && self.scalar_triples.len() == 0
115    }
116    pub fn next_base_field_dabit(
117        &mut self,
118    ) -> Result<NextDaBit<C::BaseField>, PreprocessingBundlerError> {
119        self.base_field_dabits.next().ok_or_else(|| {
120            PreprocessingBundlerError::InsufficientDaBits(
121                std::any::type_name::<C::BaseField>().to_string(),
122            )
123        })
124    }
125
126    pub fn next_base_field_powpair(
127        &mut self,
128        exp: &BoxedUint,
129    ) -> Result<NextPowPair<C::BaseField>, PreprocessingBundlerError> {
130        self.base_field_pow_preprocessing
131            .get_mut(exp)
132            .ok_or_else(|| PreprocessingBundlerError::InsufficientPowPreprocessing(exp.clone()))?
133            .next()
134            .ok_or_else(|| PreprocessingBundlerError::InsufficientPowPreprocessing(exp.clone()))
135    }
136
137    pub fn next_base_field_singlet(
138        &mut self,
139    ) -> Result<NextSinglet<C::BaseField>, PreprocessingBundlerError> {
140        self.base_field_singlets.next().ok_or_else(|| {
141            PreprocessingBundlerError::InsufficientSinglets(
142                std::any::type_name::<C::BaseField>().to_string(),
143            )
144        })
145    }
146
147    pub fn next_base_field_triple(
148        &mut self,
149    ) -> Result<NextTriple<C::BaseField>, PreprocessingBundlerError> {
150        self.base_field_triples.next().ok_or_else(|| {
151            PreprocessingBundlerError::InsufficientTriples(
152                std::any::type_name::<C::BaseField>().to_string(),
153            )
154        })
155    }
156
157    pub fn next_bit_singlet(&mut self) -> Result<NextSinglet<Gf2_128>, PreprocessingBundlerError> {
158        self.binary_singlets.next().ok_or_else(|| {
159            PreprocessingBundlerError::InsufficientSinglets(
160                std::any::type_name::<Gf2_128>().to_string(),
161            )
162        })
163    }
164
165    pub fn next_bit_triple(&mut self) -> Result<NextTriple<Gf2_128>, PreprocessingBundlerError> {
166        self.binary_triples.next().ok_or_else(|| {
167            PreprocessingBundlerError::InsufficientTriples(
168                std::any::type_name::<Gf2_128>().to_string(),
169            )
170        })
171    }
172
173    pub fn next_mersenne107_dabit(
174        &mut self,
175    ) -> Result<NextDaBit<Mersenne107>, PreprocessingBundlerError> {
176        self.mersenne107_dabits
177            .next()
178            .ok_or_else(|| PreprocessingBundlerError::InsufficientDaBits("Mersenne107".to_string()))
179    }
180
181    pub fn next_mersenne107_singlet(
182        &mut self,
183    ) -> Result<NextSinglet<Mersenne107>, PreprocessingBundlerError> {
184        self.mersenne107_singlets.next().ok_or_else(|| {
185            PreprocessingBundlerError::InsufficientSinglets("Mersenne107".to_string())
186        })
187    }
188
189    pub fn next_mersenne107_triple(
190        &mut self,
191    ) -> Result<NextTriple<Mersenne107>, PreprocessingBundlerError> {
192        self.mersenne107_triples.next().ok_or_else(|| {
193            PreprocessingBundlerError::InsufficientTriples("Mersenne107".to_string())
194        })
195    }
196
197    pub fn next_scalar_dabit(&mut self) -> Result<NextDaBit<C::Scalar>, PreprocessingBundlerError> {
198        self.scalar_dabits.next().ok_or_else(|| {
199            PreprocessingBundlerError::InsufficientDaBits(
200                std::any::type_name::<C::Scalar>().to_string(),
201            )
202        })
203    }
204
205    pub fn next_scalar_singlet(
206        &mut self,
207    ) -> Result<NextSinglet<C::Scalar>, PreprocessingBundlerError> {
208        self.scalar_singlets.next().ok_or_else(|| {
209            PreprocessingBundlerError::InsufficientSinglets(
210                std::any::type_name::<C::Scalar>().to_string(),
211            )
212        })
213    }
214
215    pub fn next_scalar_triple(
216        &mut self,
217    ) -> Result<NextTriple<C::Scalar>, PreprocessingBundlerError> {
218        self.scalar_triples.next().ok_or_else(|| {
219            PreprocessingBundlerError::InsufficientTriples(
220                std::any::type_name::<C::Scalar>().to_string(),
221            )
222        })
223    }
224}
225
226impl<C: Curve> Debug for CerberusPreprocessingIterator<C> {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        write!(f, "CerberusPreprocessingIterator with len: ")?;
229        self.len().fmt(f)
230    }
231}