primitives/sharing/authenticated/
curve_key.rs

1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3use serde::{Deserialize, Serialize};
4use subtle::{Choice, ConstantTimeEq};
5
6use super::{GlobalFieldKey, ScalarKey};
7use crate::{
8    algebra::elliptic_curve::{Curve, Point, Scalar, ScalarAsExtension, ScalarField},
9    errors::{PrimitiveError, VerificationError::InvalidMAC},
10    random::{CryptoRngCore, Random, RandomWith},
11    sharing::OpenPointShare,
12    types::ConditionallySelectable,
13};
14
15pub type GlobalCurveKey<C> = GlobalFieldKey<ScalarField<C>>;
16
17// α and β, such that MAC(x) = α · x + β
18// In the context of VOLE, this corresponds to w = Δ · u + v
19#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(bound = "C: Curve")]
21pub struct CurveKey<C: Curve> {
22    pub(crate) alpha: GlobalCurveKey<C>,
23    pub(crate) beta: Point<C>,
24}
25
26impl<C: Curve> CurveKey<C> {
27    pub fn new(alpha: GlobalCurveKey<C>, beta: Point<C>) -> Self {
28        CurveKey { alpha, beta }
29    }
30
31    pub fn compute_mac(&self, value: &Point<C>) -> Point<C> {
32        self.beta + value * *self.alpha
33    }
34
35    #[inline]
36    pub fn verify_mac(&self, open_share: &OpenPointShare<C>) -> Result<(), PrimitiveError> {
37        let expected_mac = self.compute_mac(&open_share.value);
38        if expected_mac == open_share.mac {
39            Ok(())
40        } else {
41            Err(InvalidMAC(
42                serde_json::to_string(&expected_mac).unwrap(),
43                serde_json::to_string(&open_share.mac).unwrap(),
44            )
45            .into())
46        }
47    }
48
49    pub fn get_alpha(&self) -> GlobalCurveKey<C> {
50        self.alpha.clone()
51    }
52
53    pub fn get_alpha_value(&self) -> ScalarAsExtension<C> {
54        *self.alpha
55    }
56
57    pub fn get_beta(&self) -> Point<C> {
58        self.beta
59    }
60
61    pub fn zero_batch(alphas: Vec<GlobalCurveKey<C>>) -> Vec<CurveKey<C>> {
62        alphas
63            .iter()
64            .map(|alpha| CurveKey::new(alpha.clone(), Point::<C>::identity()))
65            .collect()
66    }
67}
68
69impl<C: Curve> ConditionallySelectable for CurveKey<C> {
70    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
71        CurveKey {
72            alpha: GlobalCurveKey::<C>::conditional_select(&a.alpha, &b.alpha, choice),
73            beta: Point::conditional_select(&a.beta, &b.beta, choice),
74        }
75    }
76}
77
78// -------------------------
79// |   Random Generation   |
80// -------------------------
81
82impl<C: Curve> Random for CurveKey<C> {
83    fn random(mut rng: impl CryptoRngCore) -> Self {
84        let alpha = GlobalCurveKey::<C>::random(&mut rng);
85        let beta = Point::<C>::random(&mut rng);
86        CurveKey { alpha, beta }
87    }
88}
89
90impl<C: Curve> RandomWith<GlobalCurveKey<C>> for CurveKey<C> {
91    fn random_with(mut rng: impl CryptoRngCore, alpha: GlobalCurveKey<C>) -> Self {
92        CurveKey {
93            alpha,
94            beta: Point::random(&mut rng),
95        }
96    }
97}
98
99// ------------------------------------
100// | Curve Arithmetic Implementations |
101// ------------------------------------
102
103// === Addition === //
104
105impl<C: Curve> Add for CurveKey<C> {
106    type Output = CurveKey<C>;
107
108    #[inline]
109    fn add(self, other: Self) -> Self::Output {
110        debug_assert_eq!(self.alpha, other.alpha);
111        CurveKey {
112            beta: self.beta + other.beta,
113            ..self
114        }
115    }
116}
117
118impl<C: Curve> Add for &CurveKey<C> {
119    type Output = CurveKey<C>;
120
121    #[inline]
122    fn add(self, other: Self) -> Self::Output {
123        debug_assert_eq!(self.alpha, other.alpha);
124        CurveKey {
125            alpha: self.alpha.clone(),
126            beta: self.beta + other.beta,
127        }
128    }
129}
130
131impl<C: Curve> Add<CurveKey<C>> for &CurveKey<C> {
132    type Output = CurveKey<C>;
133
134    #[inline]
135    fn add(self, other: CurveKey<C>) -> Self::Output {
136        debug_assert_eq!(self.alpha, other.alpha);
137        CurveKey {
138            beta: self.beta + other.beta,
139            ..other
140        }
141    }
142}
143
144impl<'a, C: Curve> Add<&'a CurveKey<C>> for CurveKey<C> {
145    type Output = CurveKey<C>;
146
147    #[inline]
148    fn add(self, other: &'a CurveKey<C>) -> Self::Output {
149        debug_assert_eq!(self.alpha, other.alpha);
150        CurveKey {
151            beta: self.beta + other.beta,
152            ..self
153        }
154    }
155}
156
157// === AddAssign === //
158
159impl<C: Curve> AddAssign for CurveKey<C> {
160    #[inline]
161    fn add_assign(&mut self, rhs: Self) {
162        self.beta += rhs.beta;
163    }
164}
165
166impl<'a, C: Curve> AddAssign<&'a CurveKey<C>> for CurveKey<C> {
167    #[inline]
168    fn add_assign(&mut self, rhs: &'a CurveKey<C>) {
169        self.beta += rhs.beta;
170    }
171}
172
173// === Subtraction === //
174
175impl<C: Curve> Sub for CurveKey<C> {
176    type Output = CurveKey<C>;
177
178    #[inline]
179    fn sub(self, other: Self) -> Self::Output {
180        debug_assert_eq!(self.alpha, other.alpha);
181        CurveKey {
182            beta: self.beta - other.beta,
183            ..self
184        }
185    }
186}
187
188impl<C: Curve> Sub for &CurveKey<C> {
189    type Output = CurveKey<C>;
190
191    #[inline]
192    fn sub(self, other: Self) -> Self::Output {
193        debug_assert_eq!(self.alpha, other.alpha);
194        CurveKey {
195            alpha: self.alpha.clone(),
196            beta: self.beta - other.beta,
197        }
198    }
199}
200
201impl<C: Curve> Sub<CurveKey<C>> for &CurveKey<C> {
202    type Output = CurveKey<C>;
203
204    #[inline]
205    fn sub(self, other: CurveKey<C>) -> Self::Output {
206        debug_assert_eq!(self.alpha, other.alpha);
207        CurveKey {
208            beta: self.beta - other.beta,
209            ..other
210        }
211    }
212}
213
214impl<'a, C: Curve> Sub<&'a CurveKey<C>> for CurveKey<C> {
215    type Output = CurveKey<C>;
216
217    #[inline]
218    fn sub(self, other: &'a CurveKey<C>) -> Self::Output {
219        debug_assert_eq!(self.alpha, other.alpha);
220        CurveKey {
221            beta: self.beta - other.beta,
222            ..self
223        }
224    }
225}
226
227// === SubAssign === //
228
229impl<C: Curve> SubAssign for CurveKey<C> {
230    #[inline]
231    fn sub_assign(&mut self, rhs: Self) {
232        self.beta -= rhs.beta;
233    }
234}
235
236impl<'a, C: Curve> SubAssign<&'a CurveKey<C>> for CurveKey<C> {
237    #[inline]
238    fn sub_assign(&mut self, rhs: &'a CurveKey<C>) {
239        self.beta -= rhs.beta;
240    }
241}
242
243// === Multiplication === //
244
245impl<C: Curve> Mul<ScalarAsExtension<C>> for CurveKey<C> {
246    type Output = CurveKey<C>;
247
248    #[inline]
249    fn mul(self, other: ScalarAsExtension<C>) -> Self::Output {
250        CurveKey {
251            alpha: self.alpha,
252            beta: self.beta * other,
253        }
254    }
255}
256
257impl<'a, C: Curve> Mul<&'a ScalarAsExtension<C>> for &'a CurveKey<C> {
258    type Output = CurveKey<C>;
259
260    #[inline]
261    fn mul(self, other: &'a ScalarAsExtension<C>) -> Self::Output {
262        CurveKey {
263            alpha: self.alpha.clone(),
264            beta: self.beta * other,
265        }
266    }
267}
268
269impl<'a, C: Curve> Mul<&'a ScalarAsExtension<C>> for CurveKey<C> {
270    type Output = CurveKey<C>;
271
272    #[inline]
273    fn mul(self, other: &'a ScalarAsExtension<C>) -> Self::Output {
274        CurveKey {
275            beta: self.beta * other,
276            ..self
277        }
278    }
279}
280
281impl<C: Curve> Mul<ScalarAsExtension<C>> for &CurveKey<C> {
282    type Output = CurveKey<C>;
283
284    #[inline]
285    fn mul(self, other: ScalarAsExtension<C>) -> Self::Output {
286        CurveKey {
287            alpha: self.alpha.clone(),
288            beta: self.beta * other,
289        }
290    }
291}
292
293impl<C: Curve> Mul<Scalar<C>> for CurveKey<C> {
294    type Output = CurveKey<C>;
295
296    #[inline]
297    fn mul(self, other: Scalar<C>) -> Self::Output {
298        CurveKey {
299            alpha: self.alpha,
300            beta: self.beta * other,
301        }
302    }
303}
304
305impl<'a, C: Curve> Mul<&'a Scalar<C>> for &'a CurveKey<C> {
306    type Output = CurveKey<C>;
307
308    #[inline]
309    fn mul(self, other: &'a Scalar<C>) -> Self::Output {
310        CurveKey {
311            alpha: self.alpha.clone(),
312            beta: self.beta * other,
313        }
314    }
315}
316
317impl<'a, C: Curve> Mul<&'a Scalar<C>> for CurveKey<C> {
318    type Output = CurveKey<C>;
319
320    #[inline]
321    fn mul(self, other: &'a Scalar<C>) -> Self::Output {
322        CurveKey {
323            beta: self.beta * other,
324            ..self
325        }
326    }
327}
328
329impl<C: Curve> Mul<Scalar<C>> for &CurveKey<C> {
330    type Output = CurveKey<C>;
331
332    #[inline]
333    fn mul(self, other: Scalar<C>) -> Self::Output {
334        CurveKey {
335            alpha: self.alpha.clone(),
336            beta: self.beta * other,
337        }
338    }
339}
340
341// === MulAssign === //
342
343impl<C: Curve> MulAssign<ScalarAsExtension<C>> for CurveKey<C> {
344    #[inline]
345    fn mul_assign(&mut self, rhs: ScalarAsExtension<C>) {
346        self.beta *= rhs;
347    }
348}
349
350impl<'a, C: Curve> MulAssign<&'a ScalarAsExtension<C>> for CurveKey<C> {
351    #[inline]
352    fn mul_assign(&mut self, rhs: &'a ScalarAsExtension<C>) {
353        self.beta *= rhs;
354    }
355}
356
357impl<C: Curve> MulAssign<Scalar<C>> for CurveKey<C> {
358    #[inline]
359    fn mul_assign(&mut self, rhs: Scalar<C>) {
360        self.beta *= rhs;
361    }
362}
363
364impl<'a, C: Curve> MulAssign<&'a Scalar<C>> for CurveKey<C> {
365    #[inline]
366    fn mul_assign(&mut self, rhs: &'a Scalar<C>) {
367        self.beta *= rhs;
368    }
369}
370
371// === Negation === //
372
373impl<C: Curve> Neg for CurveKey<C> {
374    type Output = Self;
375
376    #[inline]
377    fn neg(self) -> Self::Output {
378        CurveKey {
379            alpha: self.alpha,
380            beta: -self.beta,
381        }
382    }
383}
384
385impl<C: Curve> Neg for &CurveKey<C> {
386    type Output = CurveKey<C>;
387
388    #[inline]
389    fn neg(self) -> Self::Output {
390        CurveKey {
391            alpha: self.alpha.clone(),
392            beta: -self.beta,
393        }
394    }
395}
396
397// === Equality === //
398
399impl<C: Curve> ConstantTimeEq for CurveKey<C> {
400    #[inline]
401    fn ct_eq(&self, other: &Self) -> Choice {
402        self.alpha.ct_eq(&other.alpha) & self.beta.ct_eq(&other.beta)
403    }
404}
405
406// === Type conversions === //
407
408impl<C: Curve> From<ScalarKey<C>> for CurveKey<C> {
409    #[inline]
410    fn from(scalar_key: ScalarKey<C>) -> Self {
411        CurveKey {
412            alpha: scalar_key.alpha,
413            beta: scalar_key.beta * Point::<C>::generator(),
414        }
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use crate::algebra::elliptic_curve::Curve25519Ristretto as C;
422
423    pub type FrExt = ScalarAsExtension<C>;
424    pub type P = Point<C>;
425
426    #[test]
427    fn test_addition() {
428        let mut rng = rand::thread_rng();
429        let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
430        let beta1 = P::random(&mut rng);
431        let beta2 = P::random(&mut rng);
432
433        let key1 = CurveKey {
434            alpha: alpha.clone(),
435            beta: beta1,
436        };
437        let key2 = CurveKey {
438            alpha: alpha.clone(),
439            beta: beta2,
440        };
441        let expected_result = CurveKey {
442            alpha,
443            beta: beta1 + beta2,
444        };
445        assert_eq!(key1 + key2, expected_result);
446    }
447
448    #[test]
449    fn test_subtraction() {
450        let mut rng = rand::thread_rng();
451        let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
452        let beta1 = P::random(&mut rng);
453        let beta2 = P::random(&mut rng);
454
455        let key1 = CurveKey {
456            alpha: alpha.clone(),
457            beta: beta1,
458        };
459        let key2 = CurveKey {
460            alpha: alpha.clone(),
461            beta: beta2,
462        };
463        let expected_result = CurveKey {
464            alpha,
465            beta: beta1 - beta2,
466        };
467        assert_eq!(key1 - key2, expected_result);
468    }
469
470    #[test]
471    fn test_multiplication() {
472        let mut rng = rand::thread_rng();
473        let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
474        let beta1 = P::random(&mut rng);
475
476        let key = CurveKey {
477            alpha: alpha.clone(),
478            beta: beta1,
479        };
480        let scalar = FrExt::from(3u32);
481        let expected_result = CurveKey {
482            alpha,
483            beta: beta1 * scalar,
484        };
485        assert_eq!(key * scalar, expected_result);
486    }
487
488    #[test]
489    fn test_negation() {
490        let mut rng = rand::thread_rng();
491        let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
492        let beta1 = P::random(&mut rng);
493
494        let key = CurveKey {
495            alpha: alpha.clone(),
496            beta: beta1,
497        };
498        let expected_result = CurveKey {
499            alpha,
500            beta: -beta1,
501        };
502        assert_eq!(-key, expected_result);
503    }
504}