Skip to main content

module_lattice/
algebra.rs

1use super::truncate::Truncate;
2
3use array::{Array, ArraySize, typenum::U256};
4use core::fmt::Debug;
5use core::ops::{Add, Mul, Neg, Sub};
6use num_traits::PrimInt;
7
8#[cfg(feature = "ctutils")]
9use ctutils::{Choice, CtEq, CtEqSlice};
10#[cfg(feature = "zeroize")]
11use zeroize::Zeroize;
12
13/// Finite field with efficient modular reduction for lattice-based cryptography.
14pub trait Field: Copy + Default + Debug + PartialEq {
15    /// Base integer type used to represent field elements
16    type Int: PrimInt + Default + Debug + From<u8> + Into<u128> + Into<Self::Long> + Truncate<u128>;
17    /// Double-width integer type used for intermediate computations.
18    type Long: PrimInt + From<Self::Int>;
19    /// Quadruple-width integer type used for Barrett reduction.
20    type LongLong: PrimInt;
21
22    /// Field modulus.
23    const Q: Self::Int;
24    /// Field modulus as [`Self::Long`].
25    const QL: Self::Long;
26    /// Field modulus as [`Self::LongLong`].
27    const QLL: Self::LongLong;
28
29    /// Bit shift used in Barrett reduction.
30    const BARRETT_SHIFT: usize;
31    /// Precomputed multiplier for Barrett reduction.
32    const BARRETT_MULTIPLIER: Self::LongLong;
33
34    /// Reduce a value that's already close to the modulus range.
35    fn small_reduce(x: Self::Int) -> Self::Int;
36    /// Reduce a wider value to a field element using Barrett reduction.
37    fn barrett_reduce(x: Self::Long) -> Self::Int;
38}
39
40/// The `define_field` macro creates a zero-sized struct and an implementation of the [`Field`]
41/// trait for that struct.  The caller must specify:
42///
43/// * `$field`: The name of the zero-sized struct to be created
44/// * `$q`: The prime number that defines the field.
45/// * `$int`: The primitive integer type to be used to represent members of the field
46/// * `$long`: The primitive integer type to be used to represent products of two field members.
47///   This type should have roughly twice the bits of `$int`.
48/// * `$longlong`: The primitive integer type to be used to represent products of three field
49///   members. This type should have roughly four times the bits of `$int`.
50#[macro_export]
51macro_rules! define_field {
52    ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal) => {
53        $crate::define_field!($field, $int, $long, $longlong, $q, "Finite field");
54    };
55    ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal, $doc:expr) => {
56        #[doc = $doc]
57        #[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
58        pub struct $field;
59
60        impl $crate::Field for $field {
61            type Int = $int;
62            type Long = $long;
63            type LongLong = $longlong;
64
65            const Q: Self::Int = $q;
66            const QL: Self::Long = $q;
67            const QLL: Self::LongLong = $q;
68
69            #[allow(clippy::as_conversions)]
70            const BARRETT_SHIFT: usize = 2 * (Self::Q.ilog2() + 1) as usize;
71            #[allow(clippy::integer_division_remainder_used)]
72            const BARRETT_MULTIPLIER: Self::LongLong = (1 << Self::BARRETT_SHIFT) / Self::QLL;
73
74            fn small_reduce(x: Self::Int) -> Self::Int {
75                if x < Self::Q { x } else { x - Self::Q }
76            }
77
78            fn barrett_reduce(x: Self::Long) -> Self::Int {
79                let x: Self::LongLong = x.into();
80                let product = x * Self::BARRETT_MULTIPLIER;
81                let quotient = product >> Self::BARRETT_SHIFT;
82                let remainder = x - quotient * Self::QLL;
83                Self::small_reduce($crate::Truncate::truncate(remainder))
84            }
85        }
86    };
87}
88
89/// An [`Elem`] is a member of the specified prime-order field.
90///
91/// Elements can be added, subtracted, multiplied, and negated, and the overloaded operators will
92/// ensure both that the integer values remain in the field, and that the reductions are done
93/// efficiently.
94///
95/// For addition and subtraction, a simple conditional subtraction is used; for multiplication,
96/// Barrett reduction.
97#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
98pub struct Elem<F: Field>(pub F::Int);
99
100impl<F: Field> Elem<F> {
101    /// Create a new field element.
102    pub const fn new(x: F::Int) -> Self {
103        Self(x)
104    }
105}
106
107#[cfg(feature = "ctutils")]
108impl<F: Field> CtEq for Elem<F>
109where
110    F::Int: CtEq,
111{
112    fn ct_eq(&self, other: &Self) -> Choice {
113        self.0.ct_eq(&other.0)
114    }
115}
116
117#[cfg(feature = "ctutils")]
118impl<F: Field<Int: CtEq>> CtEqSlice for Elem<F> {}
119
120#[cfg(feature = "zeroize")]
121impl<F: Field> Zeroize for Elem<F>
122where
123    F::Int: Zeroize,
124{
125    fn zeroize(&mut self) {
126        self.0.zeroize();
127    }
128}
129
130impl<F: Field> Neg for Elem<F> {
131    type Output = Elem<F>;
132
133    fn neg(self) -> Elem<F> {
134        Elem(F::small_reduce(F::Q - self.0))
135    }
136}
137
138impl<F: Field> Add<Elem<F>> for Elem<F> {
139    type Output = Elem<F>;
140
141    fn add(self, rhs: Elem<F>) -> Elem<F> {
142        Elem(F::small_reduce(self.0 + rhs.0))
143    }
144}
145
146impl<F: Field> Sub<Elem<F>> for Elem<F> {
147    type Output = Elem<F>;
148
149    fn sub(self, rhs: Elem<F>) -> Elem<F> {
150        Elem(F::small_reduce(self.0 + F::Q - rhs.0))
151    }
152}
153
154impl<F: Field> Mul<Elem<F>> for Elem<F> {
155    type Output = Elem<F>;
156
157    fn mul(self, rhs: Elem<F>) -> Elem<F> {
158        let lhs: F::Long = self.0.into();
159        let rhs: F::Long = rhs.0.into();
160        let prod = lhs * rhs;
161        Elem(F::barrett_reduce(prod))
162    }
163}
164
165/// A `Polynomial` is a member of the ring `R_q = Z_q[X] / (X^256)` of degree-256 polynomials
166/// over the finite field with prime order `q`.
167///
168/// Polynomials can be added, subtracted, negated, and multiplied by field elements.
169#[derive(Clone, Copy, Default, Debug, PartialEq)]
170pub struct Polynomial<F: Field>(pub Array<Elem<F>, U256>);
171
172impl<F: Field> Polynomial<F> {
173    /// Create a new polynomial.
174    pub const fn new(x: Array<Elem<F>, U256>) -> Self {
175        Self(x)
176    }
177}
178
179#[cfg(feature = "zeroize")]
180impl<F: Field> Zeroize for Polynomial<F>
181where
182    F::Int: Zeroize,
183{
184    fn zeroize(&mut self) {
185        self.0.zeroize();
186    }
187}
188
189impl<F: Field> Add<&Polynomial<F>> for &Polynomial<F> {
190    type Output = Polynomial<F>;
191
192    fn add(self, rhs: &Polynomial<F>) -> Polynomial<F> {
193        Polynomial(
194            self.0
195                .iter()
196                .zip(rhs.0.iter())
197                .map(|(&x, &y)| x + y)
198                .collect(),
199        )
200    }
201}
202
203impl<F: Field> Sub<&Polynomial<F>> for &Polynomial<F> {
204    type Output = Polynomial<F>;
205
206    fn sub(self, rhs: &Polynomial<F>) -> Polynomial<F> {
207        Polynomial(
208            self.0
209                .iter()
210                .zip(rhs.0.iter())
211                .map(|(&x, &y)| x - y)
212                .collect(),
213        )
214    }
215}
216
217impl<F: Field> Mul<&Polynomial<F>> for Elem<F> {
218    type Output = Polynomial<F>;
219
220    fn mul(self, rhs: &Polynomial<F>) -> Polynomial<F> {
221        Polynomial(rhs.0.iter().map(|&x| self * x).collect())
222    }
223}
224
225impl<F: Field> Neg for &Polynomial<F> {
226    type Output = Polynomial<F>;
227
228    fn neg(self) -> Polynomial<F> {
229        Polynomial(self.0.iter().map(|&x| -x).collect())
230    }
231}
232
233#[cfg(feature = "ctutils")]
234impl<F: Field> CtEq for Polynomial<F>
235where
236    F::Int: CtEq,
237{
238    fn ct_eq(&self, other: &Self) -> Choice {
239        self.0.ct_eq(&other.0)
240    }
241}
242
243#[cfg(feature = "ctutils")]
244impl<F: Field<Int: CtEq>> CtEqSlice for Polynomial<F> {}
245
246/// A `Vector` is a vector of polynomials from `R_q` of length `K`.
247///
248/// Vectors can be added, subtracted, negated, and multiplied by field elements.
249#[derive(Clone, Default, Debug, PartialEq)]
250pub struct Vector<F: Field, K: ArraySize>(pub Array<Polynomial<F>, K>);
251
252impl<F: Field, K: ArraySize> Vector<F, K> {
253    /// Create a new vector.
254    pub const fn new(x: Array<Polynomial<F>, K>) -> Self {
255        Self(x)
256    }
257}
258
259#[cfg(feature = "zeroize")]
260impl<F: Field, K: ArraySize> Zeroize for Vector<F, K>
261where
262    F::Int: Zeroize,
263{
264    fn zeroize(&mut self) {
265        self.0.zeroize();
266    }
267}
268
269impl<F: Field, K: ArraySize> Add<Vector<F, K>> for Vector<F, K> {
270    type Output = Vector<F, K>;
271    fn add(self, rhs: Vector<F, K>) -> Vector<F, K> {
272        Add::add(&self, &rhs)
273    }
274}
275impl<F: Field, K: ArraySize> Add<&Vector<F, K>> for &Vector<F, K> {
276    type Output = Vector<F, K>;
277
278    fn add(self, rhs: &Vector<F, K>) -> Vector<F, K> {
279        Vector(
280            self.0
281                .iter()
282                .zip(rhs.0.iter())
283                .map(|(x, y)| x + y)
284                .collect(),
285        )
286    }
287}
288
289impl<F: Field, K: ArraySize> Sub<&Vector<F, K>> for &Vector<F, K> {
290    type Output = Vector<F, K>;
291
292    fn sub(self, rhs: &Vector<F, K>) -> Vector<F, K> {
293        Vector(
294            self.0
295                .iter()
296                .zip(rhs.0.iter())
297                .map(|(x, y)| x - y)
298                .collect(),
299        )
300    }
301}
302
303impl<F: Field, K: ArraySize> Mul<&Vector<F, K>> for Elem<F> {
304    type Output = Vector<F, K>;
305
306    fn mul(self, rhs: &Vector<F, K>) -> Vector<F, K> {
307        Vector(rhs.0.iter().map(|x| self * x).collect())
308    }
309}
310
311impl<F: Field, K: ArraySize> Neg for &Vector<F, K> {
312    type Output = Vector<F, K>;
313
314    fn neg(self) -> Vector<F, K> {
315        Vector(self.0.iter().map(|x| -x).collect())
316    }
317}
318
319#[cfg(feature = "ctutils")]
320impl<F: Field, K: ArraySize> CtEq for Vector<F, K>
321where
322    F::Int: CtEq,
323{
324    fn ct_eq(&self, other: &Self) -> Choice {
325        self.0.ct_eq(&other.0)
326    }
327}
328
329#[cfg(feature = "ctutils")]
330impl<F: Field<Int: CtEq>, K: ArraySize> CtEqSlice for Vector<F, K> {}
331
332/// An `NttPolynomial` is a member of the NTT algebra `T_q = Z_q[X]^256` of 256-tuples of field
333/// elements.
334///
335/// NTT polynomials can be added and subtracted, negated, and multiplied by scalars.
336/// We do not define multiplication of NTT polynomials here: that is defined by the downstream
337/// crate using the [`MultiplyNtt`] trait.
338///
339/// We also do not define the mappings between normal polynomials and NTT polynomials (i.e., between
340/// `R_q` and `T_q`).
341#[derive(Clone, Default, Debug, Eq, PartialEq)]
342pub struct NttPolynomial<F: Field>(pub Array<Elem<F>, U256>);
343
344impl<F: Field> NttPolynomial<F> {
345    /// Create a new NTT polynomial.
346    pub const fn new(x: Array<Elem<F>, U256>) -> Self {
347        Self(x)
348    }
349}
350
351impl<F: Field> Add<&NttPolynomial<F>> for &NttPolynomial<F> {
352    type Output = NttPolynomial<F>;
353
354    fn add(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
355        NttPolynomial(
356            self.0
357                .iter()
358                .zip(rhs.0.iter())
359                .map(|(&x, &y)| x + y)
360                .collect(),
361        )
362    }
363}
364
365impl<F: Field> Sub<&NttPolynomial<F>> for &NttPolynomial<F> {
366    type Output = NttPolynomial<F>;
367
368    fn sub(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
369        NttPolynomial(
370            self.0
371                .iter()
372                .zip(rhs.0.iter())
373                .map(|(&x, &y)| x - y)
374                .collect(),
375        )
376    }
377}
378
379impl<F: Field> Mul<&NttPolynomial<F>> for Elem<F> {
380    type Output = NttPolynomial<F>;
381
382    fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
383        NttPolynomial(rhs.0.iter().map(|&x| self * x).collect())
384    }
385}
386
387impl<F> Mul<&NttPolynomial<F>> for &NttPolynomial<F>
388where
389    F: Field + MultiplyNtt,
390{
391    type Output = NttPolynomial<F>;
392
393    fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
394        F::multiply_ntt(self, rhs)
395    }
396}
397
398/// Perform multiplication in the NTT domain.
399pub trait MultiplyNtt: Field {
400    /// Multiply two NTT polynomials.
401    fn multiply_ntt(lhs: &NttPolynomial<Self>, rhs: &NttPolynomial<Self>) -> NttPolynomial<Self>;
402}
403
404impl<F: Field> Neg for &NttPolynomial<F> {
405    type Output = NttPolynomial<F>;
406
407    fn neg(self) -> NttPolynomial<F> {
408        NttPolynomial(self.0.iter().map(|&x| -x).collect())
409    }
410}
411
412impl<F: Field> From<Array<Elem<F>, U256>> for NttPolynomial<F> {
413    fn from(f: Array<Elem<F>, U256>) -> NttPolynomial<F> {
414        NttPolynomial(f)
415    }
416}
417
418impl<F: Field> From<NttPolynomial<F>> for Array<Elem<F>, U256> {
419    fn from(f_hat: NttPolynomial<F>) -> Array<Elem<F>, U256> {
420        f_hat.0
421    }
422}
423
424#[cfg(feature = "ctutils")]
425impl<F: Field> CtEq for NttPolynomial<F>
426where
427    F::Int: CtEq,
428{
429    fn ct_eq(&self, other: &Self) -> Choice {
430        self.0.ct_eq(&other.0)
431    }
432}
433
434#[cfg(feature = "ctutils")]
435impl<F: Field<Int: CtEq>> CtEqSlice for NttPolynomial<F> {}
436
437#[cfg(feature = "zeroize")]
438impl<F: Field> Zeroize for NttPolynomial<F>
439where
440    F::Int: Zeroize,
441{
442    fn zeroize(&mut self) {
443        self.0.zeroize();
444    }
445}
446
447/// An [`NttVector`] is a vector of polynomials from `T_q` of length `K`.
448///
449/// NTT vectors can be added and subtracted.  If multiplication is defined for NTT polynomials, then
450/// NTT vectors can be multiplied by NTT polynomials, and "multiplied" with each other to produce a
451/// dot product.
452#[derive(Clone, Default, Debug, Eq, PartialEq)]
453pub struct NttVector<F: Field, K: ArraySize>(pub Array<NttPolynomial<F>, K>);
454
455impl<F: Field, K: ArraySize> NttVector<F, K> {
456    /// Create a new NTT vector.
457    pub const fn new(x: Array<NttPolynomial<F>, K>) -> Self {
458        Self(x)
459    }
460}
461
462#[cfg(feature = "ctutils")]
463impl<F: Field, K: ArraySize> CtEq for NttVector<F, K>
464where
465    F::Int: CtEq,
466{
467    fn ct_eq(&self, other: &Self) -> Choice {
468        self.0.ct_eq(&other.0)
469    }
470}
471
472#[cfg(feature = "ctutils")]
473impl<F: Field<Int: CtEq>, K: ArraySize> CtEqSlice for NttVector<F, K> {}
474
475#[cfg(feature = "zeroize")]
476impl<F: Field, K: ArraySize> Zeroize for NttVector<F, K>
477where
478    F::Int: Zeroize,
479{
480    fn zeroize(&mut self) {
481        self.0.zeroize();
482    }
483}
484
485impl<F: Field, K: ArraySize> Add<&NttVector<F, K>> for &NttVector<F, K> {
486    type Output = NttVector<F, K>;
487
488    fn add(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
489        NttVector(
490            self.0
491                .iter()
492                .zip(rhs.0.iter())
493                .map(|(x, y)| x + y)
494                .collect(),
495        )
496    }
497}
498
499impl<F: Field, K: ArraySize> Sub<&NttVector<F, K>> for &NttVector<F, K> {
500    type Output = NttVector<F, K>;
501
502    fn sub(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
503        NttVector(
504            self.0
505                .iter()
506                .zip(rhs.0.iter())
507                .map(|(x, y)| x - y)
508                .collect(),
509        )
510    }
511}
512
513impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttPolynomial<F>
514where
515    for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
516{
517    type Output = NttVector<F, K>;
518
519    fn mul(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
520        NttVector(rhs.0.iter().map(|x| self * x).collect())
521    }
522}
523
524impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttVector<F, K>
525where
526    for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
527{
528    type Output = NttPolynomial<F>;
529
530    fn mul(self, rhs: &NttVector<F, K>) -> NttPolynomial<F> {
531        self.0
532            .iter()
533            .zip(rhs.0.iter())
534            .map(|(x, y)| x * y)
535            .fold(NttPolynomial::default(), |x, y| &x + &y)
536    }
537}
538
539/// A `K x L` matrix of NTT-domain polynomials.
540///
541/// Each vector represents a row of the matrix, so that multiplying on the right just requires
542/// iteration.
543///
544/// Multiplication on the right by vectors is the only defined operation, and is only defined when
545/// multiplication of NTT polynomials is defined.
546#[derive(Clone, Default, Debug, PartialEq)]
547pub struct NttMatrix<F: Field, K: ArraySize, L: ArraySize>(pub Array<NttVector<F, L>, K>);
548
549impl<F: Field, K: ArraySize, L: ArraySize> NttMatrix<F, K, L> {
550    /// Create a new NTT matrix.
551    pub const fn new(x: Array<NttVector<F, L>, K>) -> Self {
552        Self(x)
553    }
554}
555
556impl<F: Field, K: ArraySize, L: ArraySize> Mul<&NttVector<F, L>> for &NttMatrix<F, K, L>
557where
558    for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
559{
560    type Output = NttVector<F, K>;
561
562    fn mul(self, rhs: &NttVector<F, L>) -> NttVector<F, K> {
563        NttVector(self.0.iter().map(|x| x * rhs).collect())
564    }
565}
566
567#[cfg(feature = "ctutils")]
568impl<F: Field, K: ArraySize, L: ArraySize> CtEq for NttMatrix<F, K, L>
569where
570    F::Int: CtEq,
571{
572    fn ct_eq(&self, other: &Self) -> Choice {
573        self.0.ct_eq(&other.0)
574    }
575}
576
577#[cfg(feature = "ctutils")]
578impl<F: Field<Int: CtEq>, K: ArraySize, L: ArraySize> CtEqSlice for NttMatrix<F, K, L> {}