ml_dsa/module_lattice/
algebra.rs

1use super::util::Truncate;
2
3use core::fmt::Debug;
4use core::ops::{Add, Mul, Neg, Sub};
5use hybrid_array::{Array, ArraySize, typenum::U256};
6use num_traits::PrimInt;
7
8#[cfg(feature = "zeroize")]
9use zeroize::Zeroize;
10
11pub trait Field: Copy + Default + Debug + PartialEq {
12    type Int: PrimInt + Default + Debug + From<u8> + Into<u128> + Into<Self::Long> + Truncate<u128>;
13    type Long: PrimInt + From<Self::Int>;
14    type LongLong: PrimInt;
15
16    const Q: Self::Int;
17    const QL: Self::Long;
18    const QLL: Self::LongLong;
19
20    const BARRETT_SHIFT: usize;
21    const BARRETT_MULTIPLIER: Self::LongLong;
22
23    fn small_reduce(x: Self::Int) -> Self::Int;
24    fn barrett_reduce(x: Self::Long) -> Self::Int;
25}
26
27/// The `define_field` macro creates a zero-sized struct and an implementation of the Field trait
28/// for that struct.  The caller must specify:
29///
30/// * `$field`: The name of the zero-sized struct to be created
31/// * `$q`: The prime number that defines the field.
32/// * `$int`: The primitive integer type to be used to represent members of the field
33/// * `$long`: The primitive integer type to be used to represent products of two field members.
34///   This type should have roughly twice the bits of `$int`.
35/// * `$longlong`: The primitive integer type to be used to represent products of three field
36///   members. This type should have roughly four times the bits of `$int`.
37#[macro_export]
38macro_rules! define_field {
39    ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal) => {
40        #[derive(Copy, Clone, Default, Debug, PartialEq)]
41        pub struct $field;
42
43        impl Field for $field {
44            type Int = $int;
45            type Long = $long;
46            type LongLong = $longlong;
47
48            const Q: Self::Int = $q;
49            const QL: Self::Long = $q;
50            const QLL: Self::LongLong = $q;
51
52            #[allow(clippy::as_conversions)]
53            const BARRETT_SHIFT: usize = 2 * (Self::Q.ilog2() + 1) as usize;
54            #[allow(clippy::integer_division_remainder_used)]
55            const BARRETT_MULTIPLIER: Self::LongLong = (1 << Self::BARRETT_SHIFT) / Self::QLL;
56
57            fn small_reduce(x: Self::Int) -> Self::Int {
58                if x < Self::Q { x } else { x - Self::Q }
59            }
60
61            fn barrett_reduce(x: Self::Long) -> Self::Int {
62                let x: Self::LongLong = x.into();
63                let product = x * Self::BARRETT_MULTIPLIER;
64                let quotient = product >> Self::BARRETT_SHIFT;
65                let remainder = x - quotient * Self::QLL;
66                Self::small_reduce(Truncate::truncate(remainder))
67            }
68        }
69    };
70}
71
72/// An `Elem` is a member of the specified prime-order field.  Elements can be added,
73/// subtracted, multiplied, and negated, and the overloaded operators will ensure both that the
74/// integer values remain in the field, and that the reductions are done efficiently.  For
75/// addition and subtraction, a simple conditional subtraction is used; for multiplication,
76/// Barrett reduction.
77#[derive(Copy, Clone, Default, Debug, PartialEq)]
78pub struct Elem<F: Field>(pub F::Int);
79
80impl<F: Field> Elem<F> {
81    pub const fn new(x: F::Int) -> Self {
82        Self(x)
83    }
84}
85
86#[cfg(feature = "zeroize")]
87impl<F: Field> Zeroize for Elem<F>
88where
89    F::Int: Zeroize,
90{
91    fn zeroize(&mut self) {
92        self.0.zeroize();
93    }
94}
95
96impl<F: Field> Neg for Elem<F> {
97    type Output = Elem<F>;
98
99    fn neg(self) -> Elem<F> {
100        Elem(F::small_reduce(F::Q - self.0))
101    }
102}
103
104impl<F: Field> Add<Elem<F>> for Elem<F> {
105    type Output = Elem<F>;
106
107    fn add(self, rhs: Elem<F>) -> Elem<F> {
108        Elem(F::small_reduce(self.0 + rhs.0))
109    }
110}
111
112impl<F: Field> Sub<Elem<F>> for Elem<F> {
113    type Output = Elem<F>;
114
115    fn sub(self, rhs: Elem<F>) -> Elem<F> {
116        Elem(F::small_reduce(self.0 + F::Q - rhs.0))
117    }
118}
119
120impl<F: Field> Mul<Elem<F>> for Elem<F> {
121    type Output = Elem<F>;
122
123    fn mul(self, rhs: Elem<F>) -> Elem<F> {
124        let lhs: F::Long = self.0.into();
125        let rhs: F::Long = rhs.0.into();
126        let prod = lhs * rhs;
127        Elem(F::barrett_reduce(prod))
128    }
129}
130
131/// A `Polynomial` is a member of the ring `R_q = Z_q[X] / (X^256)` of degree-256 polynomials
132/// over the finite field with prime order `q`.  Polynomials can be added, subtracted, negated,
133/// and multiplied by field elements.  We do not define multiplication of polynomials here.
134#[derive(Clone, Default, Debug, PartialEq)]
135pub struct Polynomial<F: Field>(pub Array<Elem<F>, U256>);
136
137impl<F: Field> Polynomial<F> {
138    pub const fn new(x: Array<Elem<F>, U256>) -> Self {
139        Self(x)
140    }
141}
142
143#[cfg(feature = "zeroize")]
144impl<F: Field> Zeroize for Polynomial<F>
145where
146    F::Int: Zeroize,
147{
148    fn zeroize(&mut self) {
149        self.0.zeroize();
150    }
151}
152
153impl<F: Field> Add<&Polynomial<F>> for &Polynomial<F> {
154    type Output = Polynomial<F>;
155
156    fn add(self, rhs: &Polynomial<F>) -> Polynomial<F> {
157        Polynomial(
158            self.0
159                .iter()
160                .zip(rhs.0.iter())
161                .map(|(&x, &y)| x + y)
162                .collect(),
163        )
164    }
165}
166
167impl<F: Field> Sub<&Polynomial<F>> for &Polynomial<F> {
168    type Output = Polynomial<F>;
169
170    fn sub(self, rhs: &Polynomial<F>) -> Polynomial<F> {
171        Polynomial(
172            self.0
173                .iter()
174                .zip(rhs.0.iter())
175                .map(|(&x, &y)| x - y)
176                .collect(),
177        )
178    }
179}
180
181impl<F: Field> Mul<&Polynomial<F>> for Elem<F> {
182    type Output = Polynomial<F>;
183
184    fn mul(self, rhs: &Polynomial<F>) -> Polynomial<F> {
185        Polynomial(rhs.0.iter().map(|&x| self * x).collect())
186    }
187}
188
189impl<F: Field> Neg for &Polynomial<F> {
190    type Output = Polynomial<F>;
191
192    fn neg(self) -> Polynomial<F> {
193        Polynomial(self.0.iter().map(|&x| -x).collect())
194    }
195}
196
197/// A `Vector` is a vector of polynomials from `R_q` of length `K`.  Vectors can be
198/// added, subtracted, negated, and multiplied by field elements.
199#[derive(Clone, Default, Debug, PartialEq)]
200pub struct Vector<F: Field, K: ArraySize>(pub Array<Polynomial<F>, K>);
201
202impl<F: Field, K: ArraySize> Vector<F, K> {
203    pub const fn new(x: Array<Polynomial<F>, K>) -> Self {
204        Self(x)
205    }
206}
207
208#[cfg(feature = "zeroize")]
209impl<F: Field, K: ArraySize> Zeroize for Vector<F, K>
210where
211    F::Int: Zeroize,
212{
213    fn zeroize(&mut self) {
214        self.0.zeroize();
215    }
216}
217
218impl<F: Field, K: ArraySize> Add<&Vector<F, K>> for &Vector<F, K> {
219    type Output = Vector<F, K>;
220
221    fn add(self, rhs: &Vector<F, K>) -> Vector<F, K> {
222        Vector(
223            self.0
224                .iter()
225                .zip(rhs.0.iter())
226                .map(|(x, y)| x + y)
227                .collect(),
228        )
229    }
230}
231
232impl<F: Field, K: ArraySize> Sub<&Vector<F, K>> for &Vector<F, K> {
233    type Output = Vector<F, K>;
234
235    fn sub(self, rhs: &Vector<F, K>) -> Vector<F, K> {
236        Vector(
237            self.0
238                .iter()
239                .zip(rhs.0.iter())
240                .map(|(x, y)| x - y)
241                .collect(),
242        )
243    }
244}
245
246impl<F: Field, K: ArraySize> Mul<&Vector<F, K>> for Elem<F> {
247    type Output = Vector<F, K>;
248
249    fn mul(self, rhs: &Vector<F, K>) -> Vector<F, K> {
250        Vector(rhs.0.iter().map(|x| self * x).collect())
251    }
252}
253
254impl<F: Field, K: ArraySize> Neg for &Vector<F, K> {
255    type Output = Vector<F, K>;
256
257    fn neg(self) -> Vector<F, K> {
258        Vector(self.0.iter().map(|x| -x).collect())
259    }
260}
261
262/// An `NttPolynomial` is a member of the NTT algebra `T_q = Z_q[X]^256` of 256-tuples of field
263/// elements.  NTT polynomials can be added and
264/// subtracted, negated, and multiplied by scalars.
265/// We do not define multiplication of NTT polynomials here.  We also do not define the
266/// mappings between normal polynomials and NTT polynomials (i.e., between `R_q` and `T_q`).
267#[derive(Clone, Default, Debug, PartialEq)]
268pub struct NttPolynomial<F: Field>(pub Array<Elem<F>, U256>);
269
270impl<F: Field> NttPolynomial<F> {
271    pub const fn new(x: Array<Elem<F>, U256>) -> Self {
272        Self(x)
273    }
274}
275
276#[cfg(feature = "zeroize")]
277impl<F: Field> Zeroize for NttPolynomial<F>
278where
279    F::Int: Zeroize,
280{
281    fn zeroize(&mut self) {
282        self.0.zeroize();
283    }
284}
285
286impl<F: Field> Add<&NttPolynomial<F>> for &NttPolynomial<F> {
287    type Output = NttPolynomial<F>;
288
289    fn add(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
290        NttPolynomial(
291            self.0
292                .iter()
293                .zip(rhs.0.iter())
294                .map(|(&x, &y)| x + y)
295                .collect(),
296        )
297    }
298}
299
300impl<F: Field> Sub<&NttPolynomial<F>> for &NttPolynomial<F> {
301    type Output = NttPolynomial<F>;
302
303    fn sub(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
304        NttPolynomial(
305            self.0
306                .iter()
307                .zip(rhs.0.iter())
308                .map(|(&x, &y)| x - y)
309                .collect(),
310        )
311    }
312}
313
314impl<F: Field> Mul<&NttPolynomial<F>> for Elem<F> {
315    type Output = NttPolynomial<F>;
316
317    fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
318        NttPolynomial(rhs.0.iter().map(|&x| self * x).collect())
319    }
320}
321
322impl<F: Field> Neg for &NttPolynomial<F> {
323    type Output = NttPolynomial<F>;
324
325    fn neg(self) -> NttPolynomial<F> {
326        NttPolynomial(self.0.iter().map(|&x| -x).collect())
327    }
328}
329
330/// An `NttVector` is a vector of polynomials from `T_q` of length `K`.  NTT vectors can be
331/// added and subtracted.  If multiplication is defined for NTT polynomials, then NTT vectors
332/// can be multiplied by NTT polynomials, and "multipled" with each other to produce a dot
333/// product.
334#[derive(Clone, Default, Debug, PartialEq)]
335pub struct NttVector<F: Field, K: ArraySize>(pub Array<NttPolynomial<F>, K>);
336
337impl<F: Field, K: ArraySize> NttVector<F, K> {
338    pub const fn new(x: Array<NttPolynomial<F>, K>) -> Self {
339        Self(x)
340    }
341}
342
343#[cfg(feature = "zeroize")]
344impl<F: Field, K: ArraySize> Zeroize for NttVector<F, K>
345where
346    F::Int: Zeroize,
347{
348    fn zeroize(&mut self) {
349        self.0.zeroize();
350    }
351}
352
353impl<F: Field, K: ArraySize> Add<&NttVector<F, K>> for &NttVector<F, K> {
354    type Output = NttVector<F, K>;
355
356    fn add(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
357        NttVector(
358            self.0
359                .iter()
360                .zip(rhs.0.iter())
361                .map(|(x, y)| x + y)
362                .collect(),
363        )
364    }
365}
366
367impl<F: Field, K: ArraySize> Sub<&NttVector<F, K>> for &NttVector<F, K> {
368    type Output = NttVector<F, K>;
369
370    fn sub(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
371        NttVector(
372            self.0
373                .iter()
374                .zip(rhs.0.iter())
375                .map(|(x, y)| x - y)
376                .collect(),
377        )
378    }
379}
380
381impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttPolynomial<F>
382where
383    for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
384{
385    type Output = NttVector<F, K>;
386
387    fn mul(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
388        NttVector(rhs.0.iter().map(|x| self * x).collect())
389    }
390}
391
392impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttVector<F, K>
393where
394    for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
395{
396    type Output = NttPolynomial<F>;
397
398    fn mul(self, rhs: &NttVector<F, K>) -> NttPolynomial<F> {
399        self.0
400            .iter()
401            .zip(rhs.0.iter())
402            .map(|(x, y)| x * y)
403            .fold(NttPolynomial::default(), |x, y| &x + &y)
404    }
405}
406
407/// A K x L matrix of NTT-domain polynomials.  Each vector represents a row of the matrix, so that
408/// multiplying on the right just requires iteration.  Multiplication on the right by vectors
409/// is the only defined operation, and is only defined when multiplication of NTT polynomials
410/// is defined.
411#[derive(Clone, Default, Debug, PartialEq)]
412pub struct NttMatrix<F: Field, K: ArraySize, L: ArraySize>(pub Array<NttVector<F, L>, K>);
413
414impl<F: Field, K: ArraySize, L: ArraySize> NttMatrix<F, K, L> {
415    pub const fn new(x: Array<NttVector<F, L>, K>) -> Self {
416        Self(x)
417    }
418}
419
420impl<F: Field, K: ArraySize, L: ArraySize> Mul<&NttVector<F, L>> for &NttMatrix<F, K, L>
421where
422    for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
423{
424    type Output = NttVector<F, K>;
425
426    fn mul(self, rhs: &NttVector<F, L>) -> NttVector<F, K> {
427        NttVector(self.0.iter().map(|x| x * rhs).collect())
428    }
429}