Skip to main content

module_lattice/
algebra.rs

1use super::truncate::Truncate;
2
3use array::{Array, ArraySize, typenum::U256};
4use core::fmt::Debug;
5use core::ops::{Add, Mul, Neg, Sub};
6use num_traits::PrimInt;
7
8#[cfg(feature = "ctutils")]
9use ctutils::{Choice, CtEq, CtEqSlice};
10#[cfg(feature = "zeroize")]
11use zeroize::Zeroize;
12
13/// Finite field with efficient modular reduction for lattice-based cryptography.
14pub trait Field: Copy + Default + Debug + PartialEq {
15    /// Base integer type used to represent field elements
16    type Int: PrimInt + Default + Debug + From<u8> + Into<u128> + Into<Self::Long> + Truncate<u128>;
17    /// Double-width integer type used for intermediate computations.
18    type Long: PrimInt + From<Self::Int>;
19    /// Quadruple-width integer type used for Barrett reduction.
20    type LongLong: PrimInt;
21
22    /// Field modulus.
23    const Q: Self::Int;
24    /// Field modulus as [`Self::Long`].
25    const QL: Self::Long;
26    /// Field modulus as [`Self::LongLong`].
27    const QLL: Self::LongLong;
28
29    /// Bit shift used in Barrett reduction.
30    const BARRETT_SHIFT: usize;
31    /// Precomputed multiplier for Barrett reduction.
32    const BARRETT_MULTIPLIER: Self::LongLong;
33
34    /// Reduce a value that's already close to the modulus range.
35    fn small_reduce(x: Self::Int) -> Self::Int;
36    /// Reduce a wider value to a field element using Barrett reduction.
37    fn barrett_reduce(x: Self::Long) -> Self::Int;
38}
39
40/// The `define_field` macro creates a zero-sized struct and an implementation of the [`Field`]
41/// trait for that struct.  The caller must specify:
42///
43/// * `$field`: The name of the zero-sized struct to be created
44/// * `$q`: The prime number that defines the field.
45/// * `$int`: The primitive integer type to be used to represent members of the field
46/// * `$long`: The primitive integer type to be used to represent products of two field members.
47///   This type should have roughly twice the bits of `$int`.
48/// * `$longlong`: The primitive integer type to be used to represent products of three field
49///   members. This type should have roughly four times the bits of `$int`.
50#[macro_export]
51macro_rules! define_field {
52    ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal) => {
53        $crate::define_field!($field, $int, $long, $longlong, $q, "Finite field");
54    };
55    ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal, $doc:expr) => {
56        #[doc = $doc]
57        #[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
58        pub struct $field;
59
60        impl $crate::Field for $field {
61            type Int = $int;
62            type Long = $long;
63            type LongLong = $longlong;
64
65            const Q: Self::Int = $q;
66            const QL: Self::Long = $q;
67            const QLL: Self::LongLong = $q;
68
69            #[allow(clippy::as_conversions)]
70            const BARRETT_SHIFT: usize = 2 * (Self::Q.ilog2() + 1) as usize;
71            #[allow(clippy::integer_division_remainder_used)]
72            const BARRETT_MULTIPLIER: Self::LongLong = (1 << Self::BARRETT_SHIFT) / Self::QLL;
73
74            fn small_reduce(x: Self::Int) -> Self::Int {
75                // Branchless conditional subtraction: if x >= Q, subtract Q; else
76                // leave x alone. Compilers already emit `csel` here at O2, but the
77                // explicit mask form removes the dependency on optimizer choices
78                // and keeps the generated assembly free of secret-dependent control
79                // flow at every optimization level.
80                let mask = ((x >= Self::Q) as $int).wrapping_neg();
81                x - (Self::Q & mask)
82            }
83
84            fn barrett_reduce(x: Self::Long) -> Self::Int {
85                let x: Self::LongLong = x.into();
86                let product = x * Self::BARRETT_MULTIPLIER;
87                let quotient = product >> Self::BARRETT_SHIFT;
88                let remainder = x - quotient * Self::QLL;
89                Self::small_reduce($crate::Truncate::truncate(remainder))
90            }
91        }
92    };
93}
94
95/// An [`Elem`] is a member of the specified prime-order field.
96///
97/// Elements can be added, subtracted, multiplied, and negated, and the overloaded operators will
98/// ensure both that the integer values remain in the field, and that the reductions are done
99/// efficiently.
100///
101/// For addition and subtraction, a simple conditional subtraction is used; for multiplication,
102/// Barrett reduction.
103#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
104pub struct Elem<F: Field>(pub F::Int);
105
106impl<F: Field> Elem<F> {
107    /// Create a new field element.
108    pub const fn new(x: F::Int) -> Self {
109        Self(x)
110    }
111}
112
113#[cfg(feature = "ctutils")]
114impl<F: Field> CtEq for Elem<F>
115where
116    F::Int: CtEq,
117{
118    fn ct_eq(&self, other: &Self) -> Choice {
119        self.0.ct_eq(&other.0)
120    }
121}
122
123#[cfg(feature = "ctutils")]
124impl<F: Field<Int: CtEq>> CtEqSlice for Elem<F> {}
125
126#[cfg(feature = "zeroize")]
127impl<F: Field> Zeroize for Elem<F>
128where
129    F::Int: Zeroize,
130{
131    fn zeroize(&mut self) {
132        self.0.zeroize();
133    }
134}
135
136impl<F: Field> Neg for Elem<F> {
137    type Output = Elem<F>;
138
139    fn neg(self) -> Elem<F> {
140        Elem(F::small_reduce(F::Q - self.0))
141    }
142}
143
144impl<F: Field> Add<Elem<F>> for Elem<F> {
145    type Output = Elem<F>;
146
147    fn add(self, rhs: Elem<F>) -> Elem<F> {
148        Elem(F::small_reduce(self.0 + rhs.0))
149    }
150}
151
152impl<F: Field> Sub<Elem<F>> for Elem<F> {
153    type Output = Elem<F>;
154
155    fn sub(self, rhs: Elem<F>) -> Elem<F> {
156        Elem(F::small_reduce(self.0 + F::Q - rhs.0))
157    }
158}
159
160impl<F: Field> Mul<Elem<F>> for Elem<F> {
161    type Output = Elem<F>;
162
163    fn mul(self, rhs: Elem<F>) -> Elem<F> {
164        let lhs: F::Long = self.0.into();
165        let rhs: F::Long = rhs.0.into();
166        let prod = lhs * rhs;
167        Elem(F::barrett_reduce(prod))
168    }
169}
170
171/// A `Polynomial` is a member of the ring `R_q = Z_q[X] / (X^256)` of degree-256 polynomials
172/// over the finite field with prime order `q`.
173///
174/// Polynomials can be added, subtracted, negated, and multiplied by field elements.
175#[derive(Clone, Copy, Default, Debug, PartialEq)]
176pub struct Polynomial<F: Field>(pub Array<Elem<F>, U256>);
177
178impl<F: Field> Polynomial<F> {
179    /// Create a new polynomial.
180    pub const fn new(x: Array<Elem<F>, U256>) -> Self {
181        Self(x)
182    }
183}
184
185#[cfg(feature = "zeroize")]
186impl<F: Field> Zeroize for Polynomial<F>
187where
188    F::Int: Zeroize,
189{
190    fn zeroize(&mut self) {
191        self.0.zeroize();
192    }
193}
194
195impl<F: Field> Add<&Polynomial<F>> for &Polynomial<F> {
196    type Output = Polynomial<F>;
197
198    fn add(self, rhs: &Polynomial<F>) -> Polynomial<F> {
199        Polynomial(
200            self.0
201                .iter()
202                .zip(rhs.0.iter())
203                .map(|(&x, &y)| x + y)
204                .collect(),
205        )
206    }
207}
208
209impl<F: Field> Sub<&Polynomial<F>> for &Polynomial<F> {
210    type Output = Polynomial<F>;
211
212    fn sub(self, rhs: &Polynomial<F>) -> Polynomial<F> {
213        Polynomial(
214            self.0
215                .iter()
216                .zip(rhs.0.iter())
217                .map(|(&x, &y)| x - y)
218                .collect(),
219        )
220    }
221}
222
223impl<F: Field> Mul<&Polynomial<F>> for Elem<F> {
224    type Output = Polynomial<F>;
225
226    fn mul(self, rhs: &Polynomial<F>) -> Polynomial<F> {
227        Polynomial(rhs.0.iter().map(|&x| self * x).collect())
228    }
229}
230
231impl<F: Field> Neg for &Polynomial<F> {
232    type Output = Polynomial<F>;
233
234    fn neg(self) -> Polynomial<F> {
235        Polynomial(self.0.iter().map(|&x| -x).collect())
236    }
237}
238
239#[cfg(feature = "ctutils")]
240impl<F: Field> CtEq for Polynomial<F>
241where
242    F::Int: CtEq,
243{
244    fn ct_eq(&self, other: &Self) -> Choice {
245        self.0.ct_eq(&other.0)
246    }
247}
248
249#[cfg(feature = "ctutils")]
250impl<F: Field<Int: CtEq>> CtEqSlice for Polynomial<F> {}
251
252/// A `Vector` is a vector of polynomials from `R_q` of length `K`.
253///
254/// Vectors can be added, subtracted, negated, and multiplied by field elements.
255#[derive(Clone, Default, Debug, PartialEq)]
256pub struct Vector<F: Field, K: ArraySize>(pub Array<Polynomial<F>, K>);
257
258impl<F: Field, K: ArraySize> Vector<F, K> {
259    /// Create a new vector.
260    pub const fn new(x: Array<Polynomial<F>, K>) -> Self {
261        Self(x)
262    }
263}
264
265#[cfg(feature = "zeroize")]
266impl<F: Field, K: ArraySize> Zeroize for Vector<F, K>
267where
268    F::Int: Zeroize,
269{
270    fn zeroize(&mut self) {
271        self.0.zeroize();
272    }
273}
274
275impl<F: Field, K: ArraySize> Add<Vector<F, K>> for Vector<F, K> {
276    type Output = Vector<F, K>;
277    fn add(self, rhs: Vector<F, K>) -> Vector<F, K> {
278        Add::add(&self, &rhs)
279    }
280}
281impl<F: Field, K: ArraySize> Add<&Vector<F, K>> for &Vector<F, K> {
282    type Output = Vector<F, K>;
283
284    fn add(self, rhs: &Vector<F, K>) -> Vector<F, K> {
285        Vector(
286            self.0
287                .iter()
288                .zip(rhs.0.iter())
289                .map(|(x, y)| x + y)
290                .collect(),
291        )
292    }
293}
294
295impl<F: Field, K: ArraySize> Sub<&Vector<F, K>> for &Vector<F, K> {
296    type Output = Vector<F, K>;
297
298    fn sub(self, rhs: &Vector<F, K>) -> Vector<F, K> {
299        Vector(
300            self.0
301                .iter()
302                .zip(rhs.0.iter())
303                .map(|(x, y)| x - y)
304                .collect(),
305        )
306    }
307}
308
309impl<F: Field, K: ArraySize> Mul<&Vector<F, K>> for Elem<F> {
310    type Output = Vector<F, K>;
311
312    fn mul(self, rhs: &Vector<F, K>) -> Vector<F, K> {
313        Vector(rhs.0.iter().map(|x| self * x).collect())
314    }
315}
316
317impl<F: Field, K: ArraySize> Neg for &Vector<F, K> {
318    type Output = Vector<F, K>;
319
320    fn neg(self) -> Vector<F, K> {
321        Vector(self.0.iter().map(|x| -x).collect())
322    }
323}
324
325#[cfg(feature = "ctutils")]
326impl<F: Field, K: ArraySize> CtEq for Vector<F, K>
327where
328    F::Int: CtEq,
329{
330    fn ct_eq(&self, other: &Self) -> Choice {
331        self.0.ct_eq(&other.0)
332    }
333}
334
335#[cfg(feature = "ctutils")]
336impl<F: Field<Int: CtEq>, K: ArraySize> CtEqSlice for Vector<F, K> {}
337
338/// An `NttPolynomial` is a member of the NTT algebra `T_q = Z_q[X]^256` of 256-tuples of field
339/// elements.
340///
341/// NTT polynomials can be added and subtracted, negated, and multiplied by scalars.
342/// We do not define multiplication of NTT polynomials here: that is defined by the downstream
343/// crate using the [`MultiplyNtt`] trait.
344///
345/// We also do not define the mappings between normal polynomials and NTT polynomials (i.e., between
346/// `R_q` and `T_q`).
347#[derive(Clone, Default, Debug, Eq, PartialEq)]
348pub struct NttPolynomial<F: Field>(pub Array<Elem<F>, U256>);
349
350impl<F: Field> NttPolynomial<F> {
351    /// Create a new NTT polynomial.
352    pub const fn new(x: Array<Elem<F>, U256>) -> Self {
353        Self(x)
354    }
355}
356
357impl<F: Field> Add<&NttPolynomial<F>> for &NttPolynomial<F> {
358    type Output = NttPolynomial<F>;
359
360    fn add(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
361        NttPolynomial(
362            self.0
363                .iter()
364                .zip(rhs.0.iter())
365                .map(|(&x, &y)| x + y)
366                .collect(),
367        )
368    }
369}
370
371impl<F: Field> Sub<&NttPolynomial<F>> for &NttPolynomial<F> {
372    type Output = NttPolynomial<F>;
373
374    fn sub(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
375        NttPolynomial(
376            self.0
377                .iter()
378                .zip(rhs.0.iter())
379                .map(|(&x, &y)| x - y)
380                .collect(),
381        )
382    }
383}
384
385impl<F: Field> Mul<&NttPolynomial<F>> for Elem<F> {
386    type Output = NttPolynomial<F>;
387
388    fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
389        NttPolynomial(rhs.0.iter().map(|&x| self * x).collect())
390    }
391}
392
393impl<F> Mul<&NttPolynomial<F>> for &NttPolynomial<F>
394where
395    F: Field + MultiplyNtt,
396{
397    type Output = NttPolynomial<F>;
398
399    fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
400        F::multiply_ntt(self, rhs)
401    }
402}
403
404/// Perform multiplication in the NTT domain.
405pub trait MultiplyNtt: Field {
406    /// Multiply two NTT polynomials.
407    fn multiply_ntt(lhs: &NttPolynomial<Self>, rhs: &NttPolynomial<Self>) -> NttPolynomial<Self>;
408}
409
410impl<F: Field> Neg for &NttPolynomial<F> {
411    type Output = NttPolynomial<F>;
412
413    fn neg(self) -> NttPolynomial<F> {
414        NttPolynomial(self.0.iter().map(|&x| -x).collect())
415    }
416}
417
418impl<F: Field> From<Array<Elem<F>, U256>> for NttPolynomial<F> {
419    fn from(f: Array<Elem<F>, U256>) -> NttPolynomial<F> {
420        NttPolynomial(f)
421    }
422}
423
424impl<F: Field> From<NttPolynomial<F>> for Array<Elem<F>, U256> {
425    fn from(f_hat: NttPolynomial<F>) -> Array<Elem<F>, U256> {
426        f_hat.0
427    }
428}
429
430#[cfg(feature = "ctutils")]
431impl<F: Field> CtEq for NttPolynomial<F>
432where
433    F::Int: CtEq,
434{
435    fn ct_eq(&self, other: &Self) -> Choice {
436        self.0.ct_eq(&other.0)
437    }
438}
439
440#[cfg(feature = "ctutils")]
441impl<F: Field<Int: CtEq>> CtEqSlice for NttPolynomial<F> {}
442
443#[cfg(feature = "zeroize")]
444impl<F: Field> Zeroize for NttPolynomial<F>
445where
446    F::Int: Zeroize,
447{
448    fn zeroize(&mut self) {
449        self.0.zeroize();
450    }
451}
452
453/// An [`NttVector`] is a vector of polynomials from `T_q` of length `K`.
454///
455/// NTT vectors can be added and subtracted.  If multiplication is defined for NTT polynomials, then
456/// NTT vectors can be multiplied by NTT polynomials, and "multiplied" with each other to produce a
457/// dot product.
458#[derive(Clone, Default, Debug, Eq, PartialEq)]
459pub struct NttVector<F: Field, K: ArraySize>(pub Array<NttPolynomial<F>, K>);
460
461impl<F: Field, K: ArraySize> NttVector<F, K> {
462    /// Create a new NTT vector.
463    pub const fn new(x: Array<NttPolynomial<F>, K>) -> Self {
464        Self(x)
465    }
466}
467
468#[cfg(feature = "ctutils")]
469impl<F: Field, K: ArraySize> CtEq for NttVector<F, K>
470where
471    F::Int: CtEq,
472{
473    fn ct_eq(&self, other: &Self) -> Choice {
474        self.0.ct_eq(&other.0)
475    }
476}
477
478#[cfg(feature = "ctutils")]
479impl<F: Field<Int: CtEq>, K: ArraySize> CtEqSlice for NttVector<F, K> {}
480
481#[cfg(feature = "zeroize")]
482impl<F: Field, K: ArraySize> Zeroize for NttVector<F, K>
483where
484    F::Int: Zeroize,
485{
486    fn zeroize(&mut self) {
487        self.0.zeroize();
488    }
489}
490
491impl<F: Field, K: ArraySize> Add<&NttVector<F, K>> for &NttVector<F, K> {
492    type Output = NttVector<F, K>;
493
494    fn add(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
495        NttVector(
496            self.0
497                .iter()
498                .zip(rhs.0.iter())
499                .map(|(x, y)| x + y)
500                .collect(),
501        )
502    }
503}
504
505impl<F: Field, K: ArraySize> Sub<&NttVector<F, K>> for &NttVector<F, K> {
506    type Output = NttVector<F, K>;
507
508    fn sub(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
509        NttVector(
510            self.0
511                .iter()
512                .zip(rhs.0.iter())
513                .map(|(x, y)| x - y)
514                .collect(),
515        )
516    }
517}
518
519impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttPolynomial<F>
520where
521    for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
522{
523    type Output = NttVector<F, K>;
524
525    fn mul(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
526        NttVector(rhs.0.iter().map(|x| self * x).collect())
527    }
528}
529
530impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttVector<F, K>
531where
532    for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
533{
534    type Output = NttPolynomial<F>;
535
536    fn mul(self, rhs: &NttVector<F, K>) -> NttPolynomial<F> {
537        self.0
538            .iter()
539            .zip(rhs.0.iter())
540            .map(|(x, y)| x * y)
541            .fold(NttPolynomial::default(), |x, y| &x + &y)
542    }
543}
544
545/// A `K x L` matrix of NTT-domain polynomials.
546///
547/// Each vector represents a row of the matrix, so that multiplying on the right just requires
548/// iteration.
549///
550/// Multiplication on the right by vectors is the only defined operation, and is only defined when
551/// multiplication of NTT polynomials is defined.
552#[derive(Clone, Default, Debug, PartialEq)]
553pub struct NttMatrix<F: Field, K: ArraySize, L: ArraySize>(pub Array<NttVector<F, L>, K>);
554
555impl<F: Field, K: ArraySize, L: ArraySize> NttMatrix<F, K, L> {
556    /// Create a new NTT matrix.
557    pub const fn new(x: Array<NttVector<F, L>, K>) -> Self {
558        Self(x)
559    }
560}
561
562impl<F: Field, K: ArraySize, L: ArraySize> Mul<&NttVector<F, L>> for &NttMatrix<F, K, L>
563where
564    for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
565{
566    type Output = NttVector<F, K>;
567
568    fn mul(self, rhs: &NttVector<F, L>) -> NttVector<F, K> {
569        NttVector(self.0.iter().map(|x| x * rhs).collect())
570    }
571}
572
573#[cfg(feature = "ctutils")]
574impl<F: Field, K: ArraySize, L: ArraySize> CtEq for NttMatrix<F, K, L>
575where
576    F::Int: CtEq,
577{
578    fn ct_eq(&self, other: &Self) -> Choice {
579        self.0.ct_eq(&other.0)
580    }
581}
582
583#[cfg(feature = "ctutils")]
584impl<F: Field<Int: CtEq>, K: ArraySize, L: ArraySize> CtEqSlice for NttMatrix<F, K, L> {}