Skip to main content

primitives/sharing/authenticated/pairwise/
keys.rs

1use std::{
2    iter::Sum as IterSum,
3    mem::MaybeUninit,
4    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
5    sync::Arc,
6};
7
8use serde::{Deserialize, Serialize};
9use subtle::{Choice, ConstantTimeEq};
10use typenum::{Prod, Sum, U1, U2, U3};
11use wincode::{
12    io::{Reader, Writer},
13    ReadResult,
14    SchemaRead,
15    SchemaWrite,
16    WriteResult,
17};
18
19use crate::{
20    algebra::{
21        elliptic_curve::{BaseField, Curve, Point, ScalarAsExtension, ScalarField},
22        field::{FieldElement, FieldExtension},
23    },
24    random::{CryptoRngCore, Random, RandomWith},
25    types::{
26        heap_array::{CurvePoints, FieldElements},
27        ConditionallySelectable,
28        HeapArray,
29        Positive,
30    },
31};
32
33// ============================================================
34// |                     GlobalKey                             |
35// ============================================================
36
37/// α, a global authentication key for field shares. Each party holds a
38/// global key α for each peer, and uses that α to authenticate all its field shares
39/// (alongside a local key β).
40pub type GlobalKey<A> = Arc<A>;
41
42/// Global authentication key for field shares. Alias for [`GlobalKey<FieldElement<F>>`].
43pub type GlobalFieldKey<F> = Arc<FieldElement<F>>;
44/// [`GlobalFieldKey`] for scalar field shares. Alias for [`GlobalFieldKey<ScalarField<C>>`].
45pub type GlobalScalarKey<C> = GlobalFieldKey<ScalarField<C>>;
46/// [`GlobalFieldKey`] for base field shares. Alias for [`GlobalFieldKey<BaseField<C>>`].
47pub type GlobalBaseKey<C> = GlobalFieldKey<BaseField<C>>;
48/// [`GlobalFieldKey`] for curve-point shares (MAC is computed over the scalar field).
49/// Alias for [`GlobalFieldKey<ScalarField<C>>`].
50pub type GlobalCurveKey<C> = GlobalFieldKey<ScalarField<C>>;
51
52// ============================================================
53// |               FieldShareKeyBase<A, B>                    |
54// ============================================================
55
56/// Generic authenticated key: a global key `alpha: A` and a local key `beta: B`,
57/// satisfying `MAC(x) = α · x + β`.
58///
59/// This is the base type for both:
60/// - [`FieldShareKey<F>`]    — single-value authenticated key
61/// - [`FieldShareKeys<F,M>`] — batched authenticated key (M values, shared α)
62// α and β, such that MAC(x) = α · x + β
63// In the context of VOLE, this corresponds to w = Δ · u + v
64#[derive(Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
65#[repr(C)]
66pub struct PairwiseAuthKey<A, B> {
67    pub alpha: GlobalKey<A>, // α, global key
68    pub beta: B,             // β, local key (single value or batch)
69}
70
71impl<A: std::fmt::Debug, B: std::fmt::Debug> std::fmt::Debug for PairwiseAuthKey<A, B> {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("PairwiseAuthKey")
74            .field("alpha", &self.alpha)
75            .field("beta", &self.beta)
76            .finish()
77    }
78}
79
80impl<A, B> PairwiseAuthKey<A, B> {
81    /// Create a new [`PairwiseAuthKey`] with the given global key `alpha` and local key `beta`.
82    #[inline]
83    pub fn new(alpha: Arc<A>, beta: B) -> Self {
84        PairwiseAuthKey { alpha, beta }
85    }
86
87    /// Return a reference to the local key `beta`.
88    #[inline]
89    pub fn get_beta(&self) -> &B {
90        &self.beta
91    }
92
93    /// Return a reference to the global key `alpha`.
94    #[inline]
95    pub fn get_alpha(&self) -> &A {
96        &self.alpha
97    }
98
99    /// Return the value of the global key `alpha`.
100    #[inline]
101    pub fn alpha(&self) -> Arc<A> {
102        self.alpha.clone()
103    }
104}
105
106// --- Aliases --- //
107
108/// [`PairwiseAuthKey`] for a single field value.
109///
110/// See also [`FieldShareKeys<F, M>`] for the batched variant.
111///
112/// α and β, such that MAC(x) = α · x + β
113pub type FieldKey<F> = PairwiseAuthKey<FieldElement<F>, FieldElement<F>>;
114
115/// [`PairwiseAuthKey`] for a single scalar field element. Alias for
116/// [`FieldKey<ScalarField<C>>`].
117pub type ScalarKey<C> = FieldKey<ScalarField<C>>;
118/// [`PairwiseAuthKey`] for a single base field element. Alias for [`FieldKey<BaseField<C>>`].
119pub type BaseFieldKey<C> = FieldKey<BaseField<C>>;
120
121/// [`PairwiseAuthKey`] for a batch of `M` field values sharing a single global key.
122///
123/// See also [`FieldKey<F>`] for the single-value variant.
124///
125/// α and β₁…βₘ, such that MAC(xₜ) = α · xₜ + βₜ
126pub type FieldKeys<F, M> = PairwiseAuthKey<FieldElement<F>, FieldElements<F, M>>;
127
128/// [`PairwiseAuthKey`] for a batch of scalar field elements.
129pub type ScalarKeys<C, M> = FieldKeys<ScalarField<C>, M>;
130/// [`PairwiseAuthKey`] for a batch of base field elements.
131pub type BaseFieldKeys<C, M> = FieldKeys<BaseField<C>, M>;
132
133/// [`PairwiseAuthKey`] for a single curve point.
134pub type CurveKey<C> = PairwiseAuthKey<ScalarAsExtension<C>, Point<C>>;
135/// [`PairwiseAuthKey`] for `M` curve points.
136pub type CurveKeys<C, M> = PairwiseAuthKey<ScalarAsExtension<C>, CurvePoints<C, M>>;
137
138// ========================
139// |       Addition       |
140// ========================
141
142#[macros::op_variants(owned, borrowed, flipped)]
143impl<A: PartialEq, B> Add<&PairwiseAuthKey<A, B>> for PairwiseAuthKey<A, B>
144where
145    for<'a> B: Add<&'a B, Output = B>,
146{
147    type Output = PairwiseAuthKey<A, B>;
148
149    #[inline]
150    fn add(mut self, other: &PairwiseAuthKey<A, B>) -> Self::Output {
151        assert!(self.alpha == other.alpha, "alpha mismatch");
152        self.beta = self.beta + &other.beta;
153        self
154    }
155}
156
157#[macros::op_variants(owned)]
158impl<'a, A: PartialEq, B> AddAssign<&'a PairwiseAuthKey<A, B>> for PairwiseAuthKey<A, B>
159where
160    for<'b> B: AddAssign<&'b B>,
161{
162    #[inline]
163    fn add_assign(&mut self, other: &'a PairwiseAuthKey<A, B>) {
164        assert!(self.alpha == other.alpha, "alpha mismatch");
165        self.beta += &other.beta;
166    }
167}
168
169#[macros::op_variants(owned, borrowed, flipped)]
170impl<A: PartialEq, B> Sub<&PairwiseAuthKey<A, B>> for PairwiseAuthKey<A, B>
171where
172    for<'a> B: Sub<&'a B, Output = B>,
173{
174    type Output = PairwiseAuthKey<A, B>;
175
176    #[inline]
177    fn sub(self, other: &PairwiseAuthKey<A, B>) -> Self::Output {
178        assert!(self.alpha == other.alpha, "alpha mismatch");
179        PairwiseAuthKey {
180            alpha: self.alpha,
181            beta: self.beta - &other.beta,
182        }
183    }
184}
185
186#[macros::op_variants(owned)]
187impl<'a, A: PartialEq, B> SubAssign<&'a PairwiseAuthKey<A, B>> for PairwiseAuthKey<A, B>
188where
189    for<'b> B: SubAssign<&'b B>,
190{
191    #[inline]
192    fn sub_assign(&mut self, other: &'a PairwiseAuthKey<A, B>) {
193        assert!(self.alpha == other.alpha, "alpha mismatch");
194        self.beta -= &other.beta;
195    }
196}
197
198#[macros::op_variants(borrowed)]
199impl<A, B: Neg<Output = B>> Neg for PairwiseAuthKey<A, B> {
200    type Output = PairwiseAuthKey<A, B>;
201
202    #[inline]
203    fn neg(mut self) -> Self::Output {
204        self.beta = -self.beta;
205        self
206    }
207}
208
209impl<A: Default + PartialEq, B: Default> IterSum for PairwiseAuthKey<A, B>
210where
211    for<'a> B: Add<&'a B, Output = B>,
212{
213    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
214        let first = iter.next().unwrap_or_default();
215        iter.fold(first, |acc, x| acc + &x)
216    }
217}
218
219// ========================
220// |   Multiplication     |
221// ========================
222
223impl<A, B, RHS, BPrime> Mul<RHS> for PairwiseAuthKey<A, B>
224where
225    B: Mul<RHS, Output = BPrime>,
226{
227    type Output = PairwiseAuthKey<A, BPrime>;
228    #[inline]
229    fn mul(self, other: RHS) -> Self::Output {
230        PairwiseAuthKey {
231            alpha: self.alpha,
232            beta: self.beta * other,
233        }
234    }
235}
236
237impl<A: Clone, B, RHS, BPrime> Mul<RHS> for &PairwiseAuthKey<A, B>
238where
239    for<'b> &'b B: Mul<RHS, Output = BPrime>,
240{
241    type Output = PairwiseAuthKey<A, BPrime>;
242    #[inline]
243    fn mul(self, other: RHS) -> Self::Output {
244        PairwiseAuthKey {
245            alpha: self.alpha.clone(),
246            beta: &self.beta * other,
247        }
248    }
249}
250
251impl<A, B, RHS> MulAssign<RHS> for PairwiseAuthKey<A, B>
252where
253    B: MulAssign<RHS>,
254{
255    #[inline]
256    fn mul_assign(&mut self, other: RHS) {
257        self.beta *= other;
258    }
259}
260
261// ========================
262// |   Random Generation  |
263// ========================
264
265impl<A: Random, B: Random> Random for PairwiseAuthKey<A, B> {
266    fn random(mut rng: impl CryptoRngCore) -> Self {
267        PairwiseAuthKey {
268            alpha: A::random(&mut rng).into(),
269            beta: B::random(&mut rng),
270        }
271    }
272}
273
274impl<A: Clone, B: Random> RandomWith<GlobalKey<A>> for PairwiseAuthKey<A, B> {
275    fn random_with(mut rng: impl CryptoRngCore, alpha: GlobalKey<A>) -> Self {
276        PairwiseAuthKey {
277            alpha,
278            beta: B::random(&mut rng),
279        }
280    }
281}
282
283// ---------------------------------------
284// |  Constant time Selection / Equality |
285// ---------------------------------------
286
287impl<A: ConstantTimeEq, B: ConstantTimeEq> ConstantTimeEq for PairwiseAuthKey<A, B> {
288    #[inline]
289    fn ct_eq(&self, other: &Self) -> Choice {
290        self.alpha.ct_eq(&other.alpha) & self.beta.ct_eq(&other.beta)
291    }
292}
293
294impl<A: ConditionallySelectable, B: ConditionallySelectable> ConditionallySelectable
295    for PairwiseAuthKey<A, B>
296{
297    #[inline]
298    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
299        PairwiseAuthKey {
300            alpha: A::conditional_select(&a.alpha, &b.alpha, choice).into(),
301            beta: B::conditional_select(&a.beta, &b.beta, choice),
302        }
303    }
304}
305
306// ========================
307// |  SchemaWrite / Read  |
308// ========================
309
310impl<A: SchemaWrite<Src = A>, B: SchemaWrite<Src = B>> SchemaWrite for PairwiseAuthKey<A, B> {
311    type Src = Self;
312
313    fn size_of(src: &Self::Src) -> WriteResult<usize> {
314        Ok(A::size_of(&src.alpha)? + B::size_of(&src.beta)?)
315    }
316
317    fn write(writer: &mut impl Writer, src: &Self::Src) -> WriteResult<()> {
318        A::write(writer, &src.alpha)?;
319        B::write(writer, &src.beta)
320    }
321}
322
323impl<'de, A: SchemaRead<'de, Dst = A>, B: SchemaRead<'de, Dst = B>> SchemaRead<'de>
324    for PairwiseAuthKey<A, B>
325{
326    type Dst = Self;
327
328    fn read(reader: &mut impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
329        let mut alpha_uninit = MaybeUninit::<A>::uninit();
330        let mut beta_uninit = MaybeUninit::<B>::uninit();
331        A::read(reader, &mut alpha_uninit)?;
332        B::read(reader, &mut beta_uninit)?;
333        // SAFETY: both fields were initialized by the reads above
334        let alpha = unsafe { alpha_uninit.assume_init() }.into();
335        let beta = unsafe { beta_uninit.assume_init() };
336        dst.write(PairwiseAuthKey { alpha, beta });
337        Ok(())
338    }
339}
340
341// -----------------------
342// |   Split and Merge   |
343// -----------------------
344
345/// Generic split and merge operations on any batched key `PairwiseAuthKey<A, HeapArray<B, M>>`.
346/// Covers both `FieldShareKeys<F, M>` and `CurveKeys<C, M>`.
347impl<A: Clone, B: Copy, M: Positive> PairwiseAuthKey<A, HeapArray<B, M>> {
348    #[allow(clippy::type_complexity)]
349    /// Split a batched key into two smaller batched keys with the same global key `alpha` and
350    /// disjoint local keys `beta` (first `M1` values in the first key, next `M2` values in the
351    /// second key).
352    pub fn split<M1, M2>(
353        self,
354    ) -> (
355        PairwiseAuthKey<A, HeapArray<B, M1>>,
356        PairwiseAuthKey<A, HeapArray<B, M2>>,
357    )
358    where
359        M1: Positive,
360        M2: Positive + Add<M1, Output = M>,
361    {
362        let PairwiseAuthKey { alpha, beta } = self;
363        let (betas1, betas2) = beta.split::<M1, M2>();
364        (
365            PairwiseAuthKey::new(alpha.clone(), betas1),
366            PairwiseAuthKey::new(alpha, betas2),
367        )
368    }
369    #[allow(clippy::type_complexity)]
370    /// Split a batched key into two smaller batched keys with the same global key `alpha` and
371    /// disjoint local keys `beta` (first `M/2` values in the first key, next `M/2` values in the
372    /// second key).
373    pub fn split_halves<MDiv2>(
374        self,
375    ) -> (
376        PairwiseAuthKey<A, HeapArray<B, MDiv2>>,
377        PairwiseAuthKey<A, HeapArray<B, MDiv2>>,
378    )
379    where
380        MDiv2: Positive + Mul<U2, Output = M>,
381    {
382        let PairwiseAuthKey { alpha, beta } = self;
383        let (betas1, betas2) = beta.split_halves::<MDiv2>();
384        (
385            PairwiseAuthKey::new(alpha.clone(), betas1),
386            PairwiseAuthKey::new(alpha, betas2),
387        )
388    }
389
390    /// Merge two batched keys with the same global key `alpha` and disjoint local keys `beta`
391    /// (first `M/2` values in the first key, next `M/2` values in the second key) into a single
392    /// batched key with `M` values in the local key.
393    pub fn merge_halves(this: Self, other: Self) -> PairwiseAuthKey<A, HeapArray<B, Prod<M, U2>>>
394    where
395        M: Positive + Mul<U2, Output: Positive>,
396        A: PartialEq,
397    {
398        assert!(this.alpha == other.alpha, "alpha mismatch in merge_halves");
399        PairwiseAuthKey {
400            alpha: this.alpha,
401            beta: HeapArray::merge_halves(this.beta, other.beta),
402        }
403    }
404
405    #[allow(clippy::type_complexity)]
406    /// Split a batched key into three smaller batched keys with the same global key `alpha` and
407    /// disjoint local keys `beta` (first `M1` values in the first key, next `M2` values in the
408    /// second key, last `M3` values in the third key).
409    pub fn split3<M1, M2, M3>(
410        self,
411    ) -> (
412        PairwiseAuthKey<A, HeapArray<B, M1>>,
413        PairwiseAuthKey<A, HeapArray<B, M2>>,
414        PairwiseAuthKey<A, HeapArray<B, M3>>,
415    )
416    where
417        M1: Positive,
418        M2: Positive + Add<M1>,
419        M3: Positive + Add<Sum<M2, M1>, Output = M>,
420    {
421        let PairwiseAuthKey { alpha, beta } = self;
422        let (betas1, betas2, betas3) = beta.split3::<M1, M2, M3>();
423        (
424            PairwiseAuthKey::new(alpha.clone(), betas1),
425            PairwiseAuthKey::new(alpha.clone(), betas2),
426            PairwiseAuthKey::new(alpha, betas3),
427        )
428    }
429
430    #[allow(clippy::type_complexity)]
431    /// Split a batched key into three smaller batched keys with the same global key `alpha` and
432    /// disjoint local keys `beta` (first `M/3` values in the first key, next `M/3` values in the
433    /// second key, last `M/3` values in the third key).
434    pub fn split_thirds<MDiv3>(
435        self,
436    ) -> (
437        PairwiseAuthKey<A, HeapArray<B, MDiv3>>,
438        PairwiseAuthKey<A, HeapArray<B, MDiv3>>,
439        PairwiseAuthKey<A, HeapArray<B, MDiv3>>,
440    )
441    where
442        MDiv3: Positive + Mul<U3, Output = M>,
443    {
444        let PairwiseAuthKey { alpha, beta } = self;
445        let (betas1, betas2, betas3) = beta.split_thirds::<MDiv3>();
446        (
447            PairwiseAuthKey::new(alpha.clone(), betas1),
448            PairwiseAuthKey::new(alpha.clone(), betas2),
449            PairwiseAuthKey::new(alpha, betas3),
450        )
451    }
452
453    /// Merge three batched keys with the same global key `alpha` and disjoint local keys `beta`
454    /// (first `M/3` values in the first key, next `M/3` values in the second key, last `M/3` values
455    /// in the third key) into a single batched key with `M` values in the local key.
456    pub fn merge_thirds(
457        first: Self,
458        second: Self,
459        third: Self,
460    ) -> PairwiseAuthKey<A, HeapArray<B, Prod<M, U3>>>
461    where
462        M: Positive + Mul<U3, Output: Positive>,
463        A: PartialEq,
464    {
465        assert!(
466            first.alpha == second.alpha,
467            "alpha mismatch in merge_thirds"
468        );
469        assert!(first.alpha == third.alpha, "alpha mismatch in merge_thirds");
470        PairwiseAuthKey {
471            alpha: first.alpha,
472            beta: HeapArray::merge_thirds(first.beta, second.beta, third.beta),
473        }
474    }
475}
476
477// ------------------------
478// |   Iterate and Cast   |
479// ------------------------
480
481/// Convert a single-element key into a one-element batched key.
482/// Covers `From<FieldShareKey<F>> for FieldShareKeys<F, U1>` and
483/// `From<CurveKey<C>> for CurveKeys<C, U1>`.
484impl<A, T> From<PairwiseAuthKey<A, T>> for PairwiseAuthKey<A, HeapArray<T, U1>> {
485    fn from(key: PairwiseAuthKey<A, T>) -> Self {
486        PairwiseAuthKey {
487            alpha: key.alpha,
488            beta: HeapArray::from(key.beta),
489        }
490    }
491}
492
493pub struct FieldShareKeysIterator<F: FieldExtension, M: Positive> {
494    keys: FieldKeys<F, M>,
495    index: usize,
496}
497
498impl<F: FieldExtension, M: Positive> Iterator for FieldShareKeysIterator<F, M> {
499    type Item = FieldKey<F>;
500
501    fn next(&mut self) -> Option<Self::Item> {
502        if self.index < M::to_usize() {
503            let key = PairwiseAuthKey {
504                alpha: self.keys.alpha.clone(),
505                beta: self.keys.beta[self.index],
506            };
507            self.index += 1;
508            Some(key)
509        } else {
510            None
511        }
512    }
513}
514
515impl<F: FieldExtension, M: Positive> ExactSizeIterator for FieldShareKeysIterator<F, M> {
516    fn len(&self) -> usize {
517        M::to_usize()
518    }
519}
520
521impl<F: FieldExtension, M: Positive> IntoIterator for FieldKeys<F, M> {
522    type Item = FieldKey<F>;
523    type IntoIter = FieldShareKeysIterator<F, M>;
524
525    fn into_iter(self) -> Self::IntoIter {
526        FieldShareKeysIterator::<F, M> {
527            keys: self,
528            index: 0,
529        }
530    }
531}
532
533pub struct FieldShareKeyRef<'a, F>
534where
535    F: FieldExtension,
536{
537    pub alpha: GlobalFieldKey<F>,
538    pub beta: &'a FieldElement<F>,
539}
540
541impl<'a, F: FieldExtension> From<FieldShareKeyRef<'a, F>> for FieldKey<F> {
542    fn from(val: FieldShareKeyRef<'a, F>) -> Self {
543        PairwiseAuthKey {
544            alpha: val.alpha,
545            beta: *val.beta,
546        }
547    }
548}
549
550pub struct FieldShareKeysRefIterator<'a, F, M>
551where
552    F: FieldExtension,
553    M: Positive,
554{
555    keys: &'a FieldKeys<F, M>,
556    index: usize,
557}
558
559impl<F: FieldExtension, M: Positive> ExactSizeIterator for FieldShareKeysRefIterator<'_, F, M> {
560    fn len(&self) -> usize {
561        M::to_usize()
562    }
563}
564
565impl<'a, F: FieldExtension, M: Positive> Iterator for FieldShareKeysRefIterator<'a, F, M> {
566    type Item = FieldShareKeyRef<'a, F>;
567
568    fn next(&mut self) -> Option<Self::Item> {
569        if self.index < M::to_usize() {
570            let key = FieldShareKeyRef {
571                alpha: self.keys.alpha.clone(),
572                beta: &self.keys.beta[self.index],
573            };
574            self.index += 1;
575            Some(key)
576        } else {
577            None
578        }
579    }
580}
581
582impl<'a, F: FieldExtension, M: Positive> IntoIterator for &'a FieldKeys<F, M> {
583    type Item = FieldShareKeyRef<'a, F>;
584    type IntoIter = FieldShareKeysRefIterator<'a, F, M>;
585
586    fn into_iter(self) -> Self::IntoIter {
587        FieldShareKeysRefIterator {
588            keys: self,
589            index: 0,
590        }
591    }
592}
593
594// --- Type conversions --- //
595
596impl<C: Curve> From<ScalarKey<C>> for CurveKey<C> {
597    #[inline]
598    fn from(scalar_key: ScalarKey<C>) -> Self {
599        PairwiseAuthKey {
600            alpha: scalar_key.alpha,
601            beta: scalar_key.beta * Point::<C>::generator(),
602        }
603    }
604}
605
606impl<C: Curve, M: Positive> From<ScalarKeys<C, M>> for CurveKeys<C, M> {
607    #[inline]
608    fn from(scalar_key: ScalarKeys<C, M>) -> Self {
609        PairwiseAuthKey {
610            alpha: scalar_key.alpha,
611            beta: scalar_key.beta * &Point::<C>::generator(),
612        }
613    }
614}
615
616#[cfg(test)]
617mod tests {
618    use typenum::{U10, U8};
619
620    use super::*;
621    use crate::algebra::elliptic_curve::{Curve25519Ristretto as C, ScalarAsExtension};
622
623    type Fq = ScalarAsExtension<C>;
624
625    #[test]
626    fn test_addition() {
627        let alpha = GlobalFieldKey::new(Fq::from(3u32));
628        let key1 = PairwiseAuthKey {
629            alpha: alpha.clone(),
630            beta: Fq::from(10u32),
631        };
632        let key2 = PairwiseAuthKey {
633            alpha: alpha.clone(),
634            beta: Fq::from(7u32),
635        };
636        let expected_result = PairwiseAuthKey {
637            alpha,
638            beta: Fq::from(17u32),
639        };
640        assert_eq!(key1 + key2, expected_result);
641    }
642
643    #[test]
644    fn test_subtraction() {
645        let alpha = GlobalFieldKey::new(Fq::from(3u32));
646        let key1 = PairwiseAuthKey {
647            alpha: alpha.clone(),
648            beta: Fq::from(10u32),
649        };
650        let key2 = PairwiseAuthKey {
651            alpha: alpha.clone(),
652            beta: Fq::from(7u32),
653        };
654        let expected_result = PairwiseAuthKey {
655            alpha,
656            beta: Fq::from(3u32),
657        };
658        assert_eq!(key1 - key2, expected_result);
659    }
660
661    #[test]
662    fn test_multiplication() {
663        let alpha = GlobalFieldKey::new(Fq::from(3u32));
664        let key = PairwiseAuthKey {
665            alpha: alpha.clone(),
666            beta: Fq::from(10u32),
667        };
668        let scalar = Fq::from(3u32);
669        let expected_result = PairwiseAuthKey {
670            alpha,
671            beta: Fq::from(30u32),
672        };
673        assert_eq!(key * scalar, expected_result);
674    }
675
676    #[test]
677    fn test_negation() {
678        let key = PairwiseAuthKey {
679            alpha: GlobalFieldKey::new(Fq::from(5u32)),
680            beta: Fq::from(10u32),
681        };
682        let expected_result = PairwiseAuthKey {
683            alpha: GlobalFieldKey::new(Fq::from(5u32)),
684            beta: -Fq::from(10u32),
685        };
686        assert_eq!(-key, expected_result);
687    }
688
689    #[test]
690    fn test_batched_addition() {
691        let alpha = GlobalFieldKey::new(Fq::from(3u32));
692        let key1 = PairwiseAuthKey {
693            alpha: alpha.clone(),
694            beta: HeapArray::<_, U10>::from_fn(|_| Fq::from(10u32)),
695        };
696        let key2 = PairwiseAuthKey {
697            alpha: alpha.clone(),
698            beta: HeapArray::<_, U10>::from_fn(|_| Fq::from(7u32)),
699        };
700        let expected_result = PairwiseAuthKey {
701            alpha,
702            beta: HeapArray::<_, U10>::from_fn(|_| Fq::from(17u32)),
703        };
704        assert_eq!(key1 + key2, expected_result);
705    }
706
707    #[test]
708    fn test_batched_subtraction() {
709        let alpha = GlobalFieldKey::new(Fq::from(3u32));
710        let key1 = PairwiseAuthKey {
711            alpha: alpha.clone(),
712            beta: HeapArray::<_, U10>::from_fn(|_| Fq::from(10u32)),
713        };
714        let key2 = PairwiseAuthKey {
715            alpha: alpha.clone(),
716            beta: HeapArray::<_, U10>::from_fn(|_| Fq::from(7u32)),
717        };
718        let expected_result = PairwiseAuthKey {
719            alpha,
720            beta: HeapArray::<_, U10>::from_fn(|_| Fq::from(3u32)),
721        };
722        assert_eq!(key1 - key2, expected_result);
723    }
724
725    #[test]
726    fn test_batched_multiplication() {
727        let alpha = GlobalFieldKey::new(Fq::from(3u32));
728        let key = PairwiseAuthKey {
729            alpha: alpha.clone(),
730            beta: HeapArray::<_, U10>::from_fn(|_| Fq::from(10u32)),
731        };
732        let scalar = Fq::from(3u32);
733        let expected_result = PairwiseAuthKey {
734            alpha,
735            beta: HeapArray::<_, U10>::from_fn(|_| Fq::from(30u32)),
736        };
737        assert_eq!(key * &scalar, expected_result);
738    }
739
740    #[test]
741    fn test_batched_negation() {
742        let key = PairwiseAuthKey {
743            alpha: GlobalFieldKey::new(Fq::from(5u32)),
744            beta: HeapArray::<_, U8>::from_fn(|_| Fq::from(10u32)),
745        };
746        let expected_result = PairwiseAuthKey {
747            alpha: GlobalFieldKey::new(Fq::from(5u32)),
748            beta: -HeapArray::<_, U8>::from_fn(|_| Fq::from(10u32)),
749        };
750        assert_eq!(-key, expected_result);
751    }
752
753    // --- CurveKey (single point) tests --- //
754
755    mod curve_single {
756        use super::*;
757        use crate::algebra::elliptic_curve::Curve25519Ristretto as C;
758
759        type P = Point<C>;
760        type FrExt = ScalarAsExtension<C>;
761
762        #[test]
763        fn test_addition() {
764            let mut rng = rand::thread_rng();
765            let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
766            let beta1 = P::random(&mut rng);
767            let beta2 = P::random(&mut rng);
768            let key1 = CurveKey {
769                alpha: alpha.clone(),
770                beta: beta1,
771            };
772            let key2 = CurveKey {
773                alpha: alpha.clone(),
774                beta: beta2,
775            };
776            let expected = CurveKey {
777                alpha,
778                beta: beta1 + beta2,
779            };
780            assert_eq!(key1 + key2, expected);
781        }
782
783        #[test]
784        fn test_subtraction() {
785            let mut rng = rand::thread_rng();
786            let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
787            let beta1 = P::random(&mut rng);
788            let beta2 = P::random(&mut rng);
789            let key1 = CurveKey {
790                alpha: alpha.clone(),
791                beta: beta1,
792            };
793            let key2 = CurveKey {
794                alpha: alpha.clone(),
795                beta: beta2,
796            };
797            let expected = CurveKey {
798                alpha,
799                beta: beta1 - beta2,
800            };
801            assert_eq!(key1 - key2, expected);
802        }
803
804        #[test]
805        fn test_multiplication() {
806            let mut rng = rand::thread_rng();
807            let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
808            let beta1 = P::random(&mut rng);
809            let key = CurveKey {
810                alpha: alpha.clone(),
811                beta: beta1,
812            };
813            let scalar = FrExt::from(3u32);
814            let expected = CurveKey {
815                alpha,
816                beta: beta1 * scalar,
817            };
818            assert_eq!(key * scalar, expected);
819        }
820
821        #[test]
822        fn test_negation() {
823            let mut rng = rand::thread_rng();
824            let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
825            let beta1 = P::random(&mut rng);
826            let key = CurveKey {
827                alpha: alpha.clone(),
828                beta: beta1,
829            };
830            let expected = CurveKey {
831                alpha,
832                beta: -beta1,
833            };
834            assert_eq!(-key, expected);
835        }
836    }
837
838    // --- CurveKeys (batched) tests --- //
839
840    mod curve_batched {
841        use typenum::U8;
842
843        use super::*;
844        use crate::{
845            algebra::elliptic_curve::Curve25519Ristretto as C,
846            random::{self, Random},
847            types::heap_array::CurvePoints,
848        };
849
850        type FrExt = ScalarAsExtension<C>;
851        type Ps = CurvePoints<C, U8>;
852
853        #[test]
854        fn test_addition() {
855            let mut rng = random::test_rng();
856            let alpha = GlobalCurveKey::<C>::random(&mut rng);
857            let beta1 = Ps::random(&mut rng);
858            let beta2 = Ps::random(&mut rng);
859            let key1 = CurveKeys {
860                alpha: alpha.clone(),
861                beta: beta1.clone(),
862            };
863            let key2 = CurveKeys {
864                alpha: alpha.clone(),
865                beta: beta2.clone(),
866            };
867            let expected = CurveKeys {
868                alpha,
869                beta: beta1 + beta2,
870            };
871            assert_eq!(key1 + key2, expected);
872        }
873
874        #[test]
875        fn test_subtraction() {
876            let mut rng = random::test_rng();
877            let alpha = GlobalCurveKey::<C>::random(&mut rng);
878            let beta1 = Ps::random(&mut rng);
879            let beta2 = Ps::random(&mut rng);
880            let key1 = CurveKeys {
881                alpha: alpha.clone(),
882                beta: beta1.clone(),
883            };
884            let key2 = CurveKeys {
885                alpha: alpha.clone(),
886                beta: beta2.clone(),
887            };
888            let expected = CurveKeys {
889                alpha,
890                beta: beta1 - beta2,
891            };
892            assert_eq!(key1 - key2, expected);
893        }
894
895        #[test]
896        fn test_multiplication() {
897            let mut rng = random::test_rng();
898            let alpha = GlobalCurveKey::<C>::random(&mut rng);
899            let beta1 = Ps::random(&mut rng);
900            let key = CurveKeys {
901                alpha: alpha.clone(),
902                beta: beta1.clone(),
903            };
904            let scalar = FrExt::from(3u32);
905            let expected = CurveKeys {
906                alpha,
907                beta: beta1 * &scalar,
908            };
909            assert_eq!(key * &scalar, expected);
910        }
911
912        #[test]
913        fn test_negation() {
914            let mut rng = random::test_rng();
915            let alpha = GlobalCurveKey::<C>::random(&mut rng);
916            let beta1 = Ps::random(&mut rng);
917            let key = CurveKeys {
918                alpha: alpha.clone(),
919                beta: beta1.clone(),
920            };
921            let expected = CurveKeys {
922                alpha,
923                beta: -beta1,
924            };
925            assert_eq!(-key, expected);
926        }
927    }
928}