Skip to main content

module_lattice/
algebra.rs

1use super::truncate::Truncate;
2
3use array::{Array, ArraySize, typenum::U256};
4use core::fmt::Debug;
5use core::ops::{Add, Mul, Neg, Sub};
6use num_traits::PrimInt;
7
8#[cfg(feature = "subtle")]
9use subtle::{Choice, ConstantTimeEq};
10#[cfg(feature = "zeroize")]
11use zeroize::Zeroize;
12
13/// Finite field with efficient modular reduction for lattice-based cryptography.
14pub trait Field: Copy + Default + Debug + PartialEq {
15    /// Base integer type used to represent field elements
16    type Int: PrimInt + Default + Debug + From<u8> + Into<u128> + Into<Self::Long> + Truncate<u128>;
17    /// Double-width integer type used for intermediate computations.
18    type Long: PrimInt + From<Self::Int>;
19    /// Quadruple-width integer type used for Barrett reduction.
20    type LongLong: PrimInt;
21
22    /// Field modulus.
23    const Q: Self::Int;
24    /// Field modulus as [`Self::Long`].
25    const QL: Self::Long;
26    /// Field modulus as [`Self::LongLong`].
27    const QLL: Self::LongLong;
28
29    /// Bit shift used in Barrett reduction.
30    const BARRETT_SHIFT: usize;
31    /// Precomputed multiplier for Barrett reduction.
32    const BARRETT_MULTIPLIER: Self::LongLong;
33
34    /// Reduce a value that's already close to the modulus range.
35    fn small_reduce(x: Self::Int) -> Self::Int;
36    /// Reduce a wider value to a field element using Barrett reduction.
37    fn barrett_reduce(x: Self::Long) -> Self::Int;
38}
39
40/// The `define_field` macro creates a zero-sized struct and an implementation of the [`Field`]
41/// trait for that struct.  The caller must specify:
42///
43/// * `$field`: The name of the zero-sized struct to be created
44/// * `$q`: The prime number that defines the field.
45/// * `$int`: The primitive integer type to be used to represent members of the field
46/// * `$long`: The primitive integer type to be used to represent products of two field members.
47///   This type should have roughly twice the bits of `$int`.
48/// * `$longlong`: The primitive integer type to be used to represent products of three field
49///   members. This type should have roughly four times the bits of `$int`.
50#[macro_export]
51macro_rules! define_field {
52    ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal) => {
53        $crate::define_field!($field, $int, $long, $longlong, $q, "Finite field");
54    };
55    ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal, $doc:expr) => {
56        #[doc = $doc]
57        #[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
58        pub struct $field;
59
60        impl $crate::Field for $field {
61            type Int = $int;
62            type Long = $long;
63            type LongLong = $longlong;
64
65            const Q: Self::Int = $q;
66            const QL: Self::Long = $q;
67            const QLL: Self::LongLong = $q;
68
69            #[allow(clippy::as_conversions)]
70            const BARRETT_SHIFT: usize = 2 * (Self::Q.ilog2() + 1) as usize;
71            #[allow(clippy::integer_division_remainder_used)]
72            const BARRETT_MULTIPLIER: Self::LongLong = (1 << Self::BARRETT_SHIFT) / Self::QLL;
73
74            fn small_reduce(x: Self::Int) -> Self::Int {
75                if x < Self::Q { x } else { x - Self::Q }
76            }
77
78            fn barrett_reduce(x: Self::Long) -> Self::Int {
79                let x: Self::LongLong = x.into();
80                let product = x * Self::BARRETT_MULTIPLIER;
81                let quotient = product >> Self::BARRETT_SHIFT;
82                let remainder = x - quotient * Self::QLL;
83                Self::small_reduce($crate::Truncate::truncate(remainder))
84            }
85        }
86    };
87}
88
89/// An [`Elem`] is a member of the specified prime-order field.
90///
91/// Elements can be added, subtracted, multiplied, and negated, and the overloaded operators will
92/// ensure both that the integer values remain in the field, and that the reductions are done
93/// efficiently.
94///
95/// For addition and subtraction, a simple conditional subtraction is used; for multiplication,
96/// Barrett reduction.
97#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
98pub struct Elem<F: Field>(pub F::Int);
99
100impl<F: Field> Elem<F> {
101    /// Create a new field element.
102    pub const fn new(x: F::Int) -> Self {
103        Self(x)
104    }
105}
106
107#[cfg(feature = "subtle")]
108impl<F: Field> ConstantTimeEq for Elem<F>
109where
110    F::Int: ConstantTimeEq,
111{
112    fn ct_eq(&self, other: &Self) -> Choice {
113        self.0.ct_eq(&other.0)
114    }
115}
116
117#[cfg(feature = "zeroize")]
118impl<F: Field> Zeroize for Elem<F>
119where
120    F::Int: Zeroize,
121{
122    fn zeroize(&mut self) {
123        self.0.zeroize();
124    }
125}
126
127impl<F: Field> Neg for Elem<F> {
128    type Output = Elem<F>;
129
130    fn neg(self) -> Elem<F> {
131        Elem(F::small_reduce(F::Q - self.0))
132    }
133}
134
135impl<F: Field> Add<Elem<F>> for Elem<F> {
136    type Output = Elem<F>;
137
138    fn add(self, rhs: Elem<F>) -> Elem<F> {
139        Elem(F::small_reduce(self.0 + rhs.0))
140    }
141}
142
143impl<F: Field> Sub<Elem<F>> for Elem<F> {
144    type Output = Elem<F>;
145
146    fn sub(self, rhs: Elem<F>) -> Elem<F> {
147        Elem(F::small_reduce(self.0 + F::Q - rhs.0))
148    }
149}
150
151impl<F: Field> Mul<Elem<F>> for Elem<F> {
152    type Output = Elem<F>;
153
154    fn mul(self, rhs: Elem<F>) -> Elem<F> {
155        let lhs: F::Long = self.0.into();
156        let rhs: F::Long = rhs.0.into();
157        let prod = lhs * rhs;
158        Elem(F::barrett_reduce(prod))
159    }
160}
161
162/// A `Polynomial` is a member of the ring `R_q = Z_q[X] / (X^256)` of degree-256 polynomials
163/// over the finite field with prime order `q`.
164///
165/// Polynomials can be added, subtracted, negated, and multiplied by field elements.
166#[derive(Clone, Copy, Default, Debug, PartialEq)]
167pub struct Polynomial<F: Field>(pub Array<Elem<F>, U256>);
168
169impl<F: Field> Polynomial<F> {
170    /// Create a new polynomial.
171    pub const fn new(x: Array<Elem<F>, U256>) -> Self {
172        Self(x)
173    }
174}
175
176#[cfg(feature = "zeroize")]
177impl<F: Field> Zeroize for Polynomial<F>
178where
179    F::Int: Zeroize,
180{
181    fn zeroize(&mut self) {
182        self.0.zeroize();
183    }
184}
185
186impl<F: Field> Add<&Polynomial<F>> for &Polynomial<F> {
187    type Output = Polynomial<F>;
188
189    fn add(self, rhs: &Polynomial<F>) -> Polynomial<F> {
190        Polynomial(
191            self.0
192                .iter()
193                .zip(rhs.0.iter())
194                .map(|(&x, &y)| x + y)
195                .collect(),
196        )
197    }
198}
199
200impl<F: Field> Sub<&Polynomial<F>> for &Polynomial<F> {
201    type Output = Polynomial<F>;
202
203    fn sub(self, rhs: &Polynomial<F>) -> Polynomial<F> {
204        Polynomial(
205            self.0
206                .iter()
207                .zip(rhs.0.iter())
208                .map(|(&x, &y)| x - y)
209                .collect(),
210        )
211    }
212}
213
214impl<F: Field> Mul<&Polynomial<F>> for Elem<F> {
215    type Output = Polynomial<F>;
216
217    fn mul(self, rhs: &Polynomial<F>) -> Polynomial<F> {
218        Polynomial(rhs.0.iter().map(|&x| self * x).collect())
219    }
220}
221
222impl<F: Field> Neg for &Polynomial<F> {
223    type Output = Polynomial<F>;
224
225    fn neg(self) -> Polynomial<F> {
226        Polynomial(self.0.iter().map(|&x| -x).collect())
227    }
228}
229
230/// A `Vector` is a vector of polynomials from `R_q` of length `K`.
231///
232/// Vectors can be added, subtracted, negated, and multiplied by field elements.
233#[derive(Clone, Default, Debug, PartialEq)]
234pub struct Vector<F: Field, K: ArraySize>(pub Array<Polynomial<F>, K>);
235
236impl<F: Field, K: ArraySize> Vector<F, K> {
237    /// Create a new vector.
238    pub const fn new(x: Array<Polynomial<F>, K>) -> Self {
239        Self(x)
240    }
241}
242
243#[cfg(feature = "zeroize")]
244impl<F: Field, K: ArraySize> Zeroize for Vector<F, K>
245where
246    F::Int: Zeroize,
247{
248    fn zeroize(&mut self) {
249        self.0.zeroize();
250    }
251}
252
253impl<F: Field, K: ArraySize> Add<Vector<F, K>> for Vector<F, K> {
254    type Output = Vector<F, K>;
255    fn add(self, rhs: Vector<F, K>) -> Vector<F, K> {
256        Add::add(&self, &rhs)
257    }
258}
259impl<F: Field, K: ArraySize> Add<&Vector<F, K>> for &Vector<F, K> {
260    type Output = Vector<F, K>;
261
262    fn add(self, rhs: &Vector<F, K>) -> Vector<F, K> {
263        Vector(
264            self.0
265                .iter()
266                .zip(rhs.0.iter())
267                .map(|(x, y)| x + y)
268                .collect(),
269        )
270    }
271}
272
273impl<F: Field, K: ArraySize> Sub<&Vector<F, K>> for &Vector<F, K> {
274    type Output = Vector<F, K>;
275
276    fn sub(self, rhs: &Vector<F, K>) -> Vector<F, K> {
277        Vector(
278            self.0
279                .iter()
280                .zip(rhs.0.iter())
281                .map(|(x, y)| x - y)
282                .collect(),
283        )
284    }
285}
286
287impl<F: Field, K: ArraySize> Mul<&Vector<F, K>> for Elem<F> {
288    type Output = Vector<F, K>;
289
290    fn mul(self, rhs: &Vector<F, K>) -> Vector<F, K> {
291        Vector(rhs.0.iter().map(|x| self * x).collect())
292    }
293}
294
295impl<F: Field, K: ArraySize> Neg for &Vector<F, K> {
296    type Output = Vector<F, K>;
297
298    fn neg(self) -> Vector<F, K> {
299        Vector(self.0.iter().map(|x| -x).collect())
300    }
301}
302
303/// An `NttPolynomial` is a member of the NTT algebra `T_q = Z_q[X]^256` of 256-tuples of field
304/// elements.
305///
306/// NTT polynomials can be added and subtracted, negated, and multiplied by scalars.
307/// We do not define multiplication of NTT polynomials here: that is defined by the downstream
308/// crate using the [`MultiplyNtt`] trait.
309///
310/// We also do not define the mappings between normal polynomials and NTT polynomials (i.e., between
311/// `R_q` and `T_q`).
312#[derive(Clone, Default, Debug, Eq, PartialEq)]
313pub struct NttPolynomial<F: Field>(pub Array<Elem<F>, U256>);
314
315impl<F: Field> NttPolynomial<F> {
316    /// Create a new NTT polynomial.
317    pub const fn new(x: Array<Elem<F>, U256>) -> Self {
318        Self(x)
319    }
320}
321
322impl<F: Field> Add<&NttPolynomial<F>> for &NttPolynomial<F> {
323    type Output = NttPolynomial<F>;
324
325    fn add(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
326        NttPolynomial(
327            self.0
328                .iter()
329                .zip(rhs.0.iter())
330                .map(|(&x, &y)| x + y)
331                .collect(),
332        )
333    }
334}
335
336impl<F: Field> Sub<&NttPolynomial<F>> for &NttPolynomial<F> {
337    type Output = NttPolynomial<F>;
338
339    fn sub(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
340        NttPolynomial(
341            self.0
342                .iter()
343                .zip(rhs.0.iter())
344                .map(|(&x, &y)| x - y)
345                .collect(),
346        )
347    }
348}
349
350impl<F: Field> Mul<&NttPolynomial<F>> for Elem<F> {
351    type Output = NttPolynomial<F>;
352
353    fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
354        NttPolynomial(rhs.0.iter().map(|&x| self * x).collect())
355    }
356}
357
358impl<F> Mul<&NttPolynomial<F>> for &NttPolynomial<F>
359where
360    F: Field + MultiplyNtt,
361{
362    type Output = NttPolynomial<F>;
363
364    fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
365        F::multiply_ntt(self, rhs)
366    }
367}
368
369/// Perform multiplication in the NTT domain.
370pub trait MultiplyNtt: Field {
371    /// Multiply two NTT polynomials.
372    fn multiply_ntt(lhs: &NttPolynomial<Self>, rhs: &NttPolynomial<Self>) -> NttPolynomial<Self>;
373}
374
375impl<F: Field> Neg for &NttPolynomial<F> {
376    type Output = NttPolynomial<F>;
377
378    fn neg(self) -> NttPolynomial<F> {
379        NttPolynomial(self.0.iter().map(|&x| -x).collect())
380    }
381}
382
383impl<F: Field> From<Array<Elem<F>, U256>> for NttPolynomial<F> {
384    fn from(f: Array<Elem<F>, U256>) -> NttPolynomial<F> {
385        NttPolynomial(f)
386    }
387}
388
389impl<F: Field> From<NttPolynomial<F>> for Array<Elem<F>, U256> {
390    fn from(f_hat: NttPolynomial<F>) -> Array<Elem<F>, U256> {
391        f_hat.0
392    }
393}
394
395#[cfg(feature = "subtle")]
396impl<F: Field> ConstantTimeEq for NttPolynomial<F>
397where
398    F::Int: ConstantTimeEq,
399{
400    fn ct_eq(&self, other: &Self) -> Choice {
401        self.0.ct_eq(&other.0)
402    }
403}
404
405#[cfg(feature = "zeroize")]
406impl<F: Field> Zeroize for NttPolynomial<F>
407where
408    F::Int: Zeroize,
409{
410    fn zeroize(&mut self) {
411        self.0.zeroize();
412    }
413}
414
415/// An [`NttVector`] is a vector of polynomials from `T_q` of length `K`.
416///
417/// NTT vectors can be added and subtracted.  If multiplication is defined for NTT polynomials, then
418/// NTT vectors can be multiplied by NTT polynomials, and "multiplied" with each other to produce a
419/// dot product.
420#[derive(Clone, Default, Debug, Eq, PartialEq)]
421pub struct NttVector<F: Field, K: ArraySize>(pub Array<NttPolynomial<F>, K>);
422
423impl<F: Field, K: ArraySize> NttVector<F, K> {
424    /// Create a new NTT vector.
425    pub const fn new(x: Array<NttPolynomial<F>, K>) -> Self {
426        Self(x)
427    }
428}
429
430#[cfg(feature = "subtle")]
431impl<F: Field, K: ArraySize> ConstantTimeEq for NttVector<F, K>
432where
433    F::Int: ConstantTimeEq,
434{
435    fn ct_eq(&self, other: &Self) -> Choice {
436        self.0.ct_eq(&other.0)
437    }
438}
439
440#[cfg(feature = "zeroize")]
441impl<F: Field, K: ArraySize> Zeroize for NttVector<F, K>
442where
443    F::Int: Zeroize,
444{
445    fn zeroize(&mut self) {
446        self.0.zeroize();
447    }
448}
449
450impl<F: Field, K: ArraySize> Add<&NttVector<F, K>> for &NttVector<F, K> {
451    type Output = NttVector<F, K>;
452
453    fn add(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
454        NttVector(
455            self.0
456                .iter()
457                .zip(rhs.0.iter())
458                .map(|(x, y)| x + y)
459                .collect(),
460        )
461    }
462}
463
464impl<F: Field, K: ArraySize> Sub<&NttVector<F, K>> for &NttVector<F, K> {
465    type Output = NttVector<F, K>;
466
467    fn sub(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
468        NttVector(
469            self.0
470                .iter()
471                .zip(rhs.0.iter())
472                .map(|(x, y)| x - y)
473                .collect(),
474        )
475    }
476}
477
478impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttPolynomial<F>
479where
480    for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
481{
482    type Output = NttVector<F, K>;
483
484    fn mul(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
485        NttVector(rhs.0.iter().map(|x| self * x).collect())
486    }
487}
488
489impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttVector<F, K>
490where
491    for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
492{
493    type Output = NttPolynomial<F>;
494
495    fn mul(self, rhs: &NttVector<F, K>) -> NttPolynomial<F> {
496        self.0
497            .iter()
498            .zip(rhs.0.iter())
499            .map(|(x, y)| x * y)
500            .fold(NttPolynomial::default(), |x, y| &x + &y)
501    }
502}
503
504/// A `K x L` matrix of NTT-domain polynomials.
505///
506/// Each vector represents a row of the matrix, so that multiplying on the right just requires
507/// iteration.
508///
509/// Multiplication on the right by vectors is the only defined operation, and is only defined when
510/// multiplication of NTT polynomials is defined.
511#[derive(Clone, Default, Debug, PartialEq)]
512pub struct NttMatrix<F: Field, K: ArraySize, L: ArraySize>(pub Array<NttVector<F, L>, K>);
513
514impl<F: Field, K: ArraySize, L: ArraySize> NttMatrix<F, K, L> {
515    /// Create a new NTT matrix.
516    pub const fn new(x: Array<NttVector<F, L>, K>) -> Self {
517        Self(x)
518    }
519}
520
521impl<F: Field, K: ArraySize, L: ArraySize> Mul<&NttVector<F, L>> for &NttMatrix<F, K, L>
522where
523    for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
524{
525    type Output = NttVector<F, K>;
526
527    fn mul(self, rhs: &NttVector<F, L>) -> NttVector<F, K> {
528        NttVector(self.0.iter().map(|x| x * rhs).collect())
529    }
530}