1use super::truncate::Truncate;
2
3use array::{Array, ArraySize, typenum::U256};
4use core::fmt::Debug;
5use core::ops::{Add, Mul, Neg, Sub};
6use num_traits::PrimInt;
7
8#[cfg(feature = "subtle")]
9use subtle::{Choice, ConstantTimeEq};
10#[cfg(feature = "zeroize")]
11use zeroize::Zeroize;
12
13pub trait Field: Copy + Default + Debug + PartialEq {
15 type Int: PrimInt + Default + Debug + From<u8> + Into<u128> + Into<Self::Long> + Truncate<u128>;
17 type Long: PrimInt + From<Self::Int>;
19 type LongLong: PrimInt;
21
22 const Q: Self::Int;
24 const QL: Self::Long;
26 const QLL: Self::LongLong;
28
29 const BARRETT_SHIFT: usize;
31 const BARRETT_MULTIPLIER: Self::LongLong;
33
34 fn small_reduce(x: Self::Int) -> Self::Int;
36 fn barrett_reduce(x: Self::Long) -> Self::Int;
38}
39
40#[macro_export]
51macro_rules! define_field {
52 ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal) => {
53 $crate::define_field!($field, $int, $long, $longlong, $q, "Finite field");
54 };
55 ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal, $doc:expr) => {
56 #[doc = $doc]
57 #[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
58 pub struct $field;
59
60 impl $crate::Field for $field {
61 type Int = $int;
62 type Long = $long;
63 type LongLong = $longlong;
64
65 const Q: Self::Int = $q;
66 const QL: Self::Long = $q;
67 const QLL: Self::LongLong = $q;
68
69 #[allow(clippy::as_conversions)]
70 const BARRETT_SHIFT: usize = 2 * (Self::Q.ilog2() + 1) as usize;
71 #[allow(clippy::integer_division_remainder_used)]
72 const BARRETT_MULTIPLIER: Self::LongLong = (1 << Self::BARRETT_SHIFT) / Self::QLL;
73
74 fn small_reduce(x: Self::Int) -> Self::Int {
75 if x < Self::Q { x } else { x - Self::Q }
76 }
77
78 fn barrett_reduce(x: Self::Long) -> Self::Int {
79 let x: Self::LongLong = x.into();
80 let product = x * Self::BARRETT_MULTIPLIER;
81 let quotient = product >> Self::BARRETT_SHIFT;
82 let remainder = x - quotient * Self::QLL;
83 Self::small_reduce($crate::Truncate::truncate(remainder))
84 }
85 }
86 };
87}
88
89#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
98pub struct Elem<F: Field>(pub F::Int);
99
100impl<F: Field> Elem<F> {
101 pub const fn new(x: F::Int) -> Self {
103 Self(x)
104 }
105}
106
107#[cfg(feature = "subtle")]
108impl<F: Field> ConstantTimeEq for Elem<F>
109where
110 F::Int: ConstantTimeEq,
111{
112 fn ct_eq(&self, other: &Self) -> Choice {
113 self.0.ct_eq(&other.0)
114 }
115}
116
117#[cfg(feature = "zeroize")]
118impl<F: Field> Zeroize for Elem<F>
119where
120 F::Int: Zeroize,
121{
122 fn zeroize(&mut self) {
123 self.0.zeroize();
124 }
125}
126
127impl<F: Field> Neg for Elem<F> {
128 type Output = Elem<F>;
129
130 fn neg(self) -> Elem<F> {
131 Elem(F::small_reduce(F::Q - self.0))
132 }
133}
134
135impl<F: Field> Add<Elem<F>> for Elem<F> {
136 type Output = Elem<F>;
137
138 fn add(self, rhs: Elem<F>) -> Elem<F> {
139 Elem(F::small_reduce(self.0 + rhs.0))
140 }
141}
142
143impl<F: Field> Sub<Elem<F>> for Elem<F> {
144 type Output = Elem<F>;
145
146 fn sub(self, rhs: Elem<F>) -> Elem<F> {
147 Elem(F::small_reduce(self.0 + F::Q - rhs.0))
148 }
149}
150
151impl<F: Field> Mul<Elem<F>> for Elem<F> {
152 type Output = Elem<F>;
153
154 fn mul(self, rhs: Elem<F>) -> Elem<F> {
155 let lhs: F::Long = self.0.into();
156 let rhs: F::Long = rhs.0.into();
157 let prod = lhs * rhs;
158 Elem(F::barrett_reduce(prod))
159 }
160}
161
162#[derive(Clone, Copy, Default, Debug, PartialEq)]
167pub struct Polynomial<F: Field>(pub Array<Elem<F>, U256>);
168
169impl<F: Field> Polynomial<F> {
170 pub const fn new(x: Array<Elem<F>, U256>) -> Self {
172 Self(x)
173 }
174}
175
176#[cfg(feature = "zeroize")]
177impl<F: Field> Zeroize for Polynomial<F>
178where
179 F::Int: Zeroize,
180{
181 fn zeroize(&mut self) {
182 self.0.zeroize();
183 }
184}
185
186impl<F: Field> Add<&Polynomial<F>> for &Polynomial<F> {
187 type Output = Polynomial<F>;
188
189 fn add(self, rhs: &Polynomial<F>) -> Polynomial<F> {
190 Polynomial(
191 self.0
192 .iter()
193 .zip(rhs.0.iter())
194 .map(|(&x, &y)| x + y)
195 .collect(),
196 )
197 }
198}
199
200impl<F: Field> Sub<&Polynomial<F>> for &Polynomial<F> {
201 type Output = Polynomial<F>;
202
203 fn sub(self, rhs: &Polynomial<F>) -> Polynomial<F> {
204 Polynomial(
205 self.0
206 .iter()
207 .zip(rhs.0.iter())
208 .map(|(&x, &y)| x - y)
209 .collect(),
210 )
211 }
212}
213
214impl<F: Field> Mul<&Polynomial<F>> for Elem<F> {
215 type Output = Polynomial<F>;
216
217 fn mul(self, rhs: &Polynomial<F>) -> Polynomial<F> {
218 Polynomial(rhs.0.iter().map(|&x| self * x).collect())
219 }
220}
221
222impl<F: Field> Neg for &Polynomial<F> {
223 type Output = Polynomial<F>;
224
225 fn neg(self) -> Polynomial<F> {
226 Polynomial(self.0.iter().map(|&x| -x).collect())
227 }
228}
229
230#[derive(Clone, Default, Debug, PartialEq)]
234pub struct Vector<F: Field, K: ArraySize>(pub Array<Polynomial<F>, K>);
235
236impl<F: Field, K: ArraySize> Vector<F, K> {
237 pub const fn new(x: Array<Polynomial<F>, K>) -> Self {
239 Self(x)
240 }
241}
242
243#[cfg(feature = "zeroize")]
244impl<F: Field, K: ArraySize> Zeroize for Vector<F, K>
245where
246 F::Int: Zeroize,
247{
248 fn zeroize(&mut self) {
249 self.0.zeroize();
250 }
251}
252
253impl<F: Field, K: ArraySize> Add<Vector<F, K>> for Vector<F, K> {
254 type Output = Vector<F, K>;
255 fn add(self, rhs: Vector<F, K>) -> Vector<F, K> {
256 Add::add(&self, &rhs)
257 }
258}
259impl<F: Field, K: ArraySize> Add<&Vector<F, K>> for &Vector<F, K> {
260 type Output = Vector<F, K>;
261
262 fn add(self, rhs: &Vector<F, K>) -> Vector<F, K> {
263 Vector(
264 self.0
265 .iter()
266 .zip(rhs.0.iter())
267 .map(|(x, y)| x + y)
268 .collect(),
269 )
270 }
271}
272
273impl<F: Field, K: ArraySize> Sub<&Vector<F, K>> for &Vector<F, K> {
274 type Output = Vector<F, K>;
275
276 fn sub(self, rhs: &Vector<F, K>) -> Vector<F, K> {
277 Vector(
278 self.0
279 .iter()
280 .zip(rhs.0.iter())
281 .map(|(x, y)| x - y)
282 .collect(),
283 )
284 }
285}
286
287impl<F: Field, K: ArraySize> Mul<&Vector<F, K>> for Elem<F> {
288 type Output = Vector<F, K>;
289
290 fn mul(self, rhs: &Vector<F, K>) -> Vector<F, K> {
291 Vector(rhs.0.iter().map(|x| self * x).collect())
292 }
293}
294
295impl<F: Field, K: ArraySize> Neg for &Vector<F, K> {
296 type Output = Vector<F, K>;
297
298 fn neg(self) -> Vector<F, K> {
299 Vector(self.0.iter().map(|x| -x).collect())
300 }
301}
302
303#[derive(Clone, Default, Debug, Eq, PartialEq)]
313pub struct NttPolynomial<F: Field>(pub Array<Elem<F>, U256>);
314
315impl<F: Field> NttPolynomial<F> {
316 pub const fn new(x: Array<Elem<F>, U256>) -> Self {
318 Self(x)
319 }
320}
321
322impl<F: Field> Add<&NttPolynomial<F>> for &NttPolynomial<F> {
323 type Output = NttPolynomial<F>;
324
325 fn add(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
326 NttPolynomial(
327 self.0
328 .iter()
329 .zip(rhs.0.iter())
330 .map(|(&x, &y)| x + y)
331 .collect(),
332 )
333 }
334}
335
336impl<F: Field> Sub<&NttPolynomial<F>> for &NttPolynomial<F> {
337 type Output = NttPolynomial<F>;
338
339 fn sub(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
340 NttPolynomial(
341 self.0
342 .iter()
343 .zip(rhs.0.iter())
344 .map(|(&x, &y)| x - y)
345 .collect(),
346 )
347 }
348}
349
350impl<F: Field> Mul<&NttPolynomial<F>> for Elem<F> {
351 type Output = NttPolynomial<F>;
352
353 fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
354 NttPolynomial(rhs.0.iter().map(|&x| self * x).collect())
355 }
356}
357
358impl<F> Mul<&NttPolynomial<F>> for &NttPolynomial<F>
359where
360 F: Field + MultiplyNtt,
361{
362 type Output = NttPolynomial<F>;
363
364 fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
365 F::multiply_ntt(self, rhs)
366 }
367}
368
369pub trait MultiplyNtt: Field {
371 fn multiply_ntt(lhs: &NttPolynomial<Self>, rhs: &NttPolynomial<Self>) -> NttPolynomial<Self>;
373}
374
375impl<F: Field> Neg for &NttPolynomial<F> {
376 type Output = NttPolynomial<F>;
377
378 fn neg(self) -> NttPolynomial<F> {
379 NttPolynomial(self.0.iter().map(|&x| -x).collect())
380 }
381}
382
383impl<F: Field> From<Array<Elem<F>, U256>> for NttPolynomial<F> {
384 fn from(f: Array<Elem<F>, U256>) -> NttPolynomial<F> {
385 NttPolynomial(f)
386 }
387}
388
389impl<F: Field> From<NttPolynomial<F>> for Array<Elem<F>, U256> {
390 fn from(f_hat: NttPolynomial<F>) -> Array<Elem<F>, U256> {
391 f_hat.0
392 }
393}
394
395#[cfg(feature = "subtle")]
396impl<F: Field> ConstantTimeEq for NttPolynomial<F>
397where
398 F::Int: ConstantTimeEq,
399{
400 fn ct_eq(&self, other: &Self) -> Choice {
401 self.0.ct_eq(&other.0)
402 }
403}
404
405#[cfg(feature = "zeroize")]
406impl<F: Field> Zeroize for NttPolynomial<F>
407where
408 F::Int: Zeroize,
409{
410 fn zeroize(&mut self) {
411 self.0.zeroize();
412 }
413}
414
415#[derive(Clone, Default, Debug, Eq, PartialEq)]
421pub struct NttVector<F: Field, K: ArraySize>(pub Array<NttPolynomial<F>, K>);
422
423impl<F: Field, K: ArraySize> NttVector<F, K> {
424 pub const fn new(x: Array<NttPolynomial<F>, K>) -> Self {
426 Self(x)
427 }
428}
429
430#[cfg(feature = "subtle")]
431impl<F: Field, K: ArraySize> ConstantTimeEq for NttVector<F, K>
432where
433 F::Int: ConstantTimeEq,
434{
435 fn ct_eq(&self, other: &Self) -> Choice {
436 self.0.ct_eq(&other.0)
437 }
438}
439
440#[cfg(feature = "zeroize")]
441impl<F: Field, K: ArraySize> Zeroize for NttVector<F, K>
442where
443 F::Int: Zeroize,
444{
445 fn zeroize(&mut self) {
446 self.0.zeroize();
447 }
448}
449
450impl<F: Field, K: ArraySize> Add<&NttVector<F, K>> for &NttVector<F, K> {
451 type Output = NttVector<F, K>;
452
453 fn add(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
454 NttVector(
455 self.0
456 .iter()
457 .zip(rhs.0.iter())
458 .map(|(x, y)| x + y)
459 .collect(),
460 )
461 }
462}
463
464impl<F: Field, K: ArraySize> Sub<&NttVector<F, K>> for &NttVector<F, K> {
465 type Output = NttVector<F, K>;
466
467 fn sub(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
468 NttVector(
469 self.0
470 .iter()
471 .zip(rhs.0.iter())
472 .map(|(x, y)| x - y)
473 .collect(),
474 )
475 }
476}
477
478impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttPolynomial<F>
479where
480 for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
481{
482 type Output = NttVector<F, K>;
483
484 fn mul(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
485 NttVector(rhs.0.iter().map(|x| self * x).collect())
486 }
487}
488
489impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttVector<F, K>
490where
491 for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
492{
493 type Output = NttPolynomial<F>;
494
495 fn mul(self, rhs: &NttVector<F, K>) -> NttPolynomial<F> {
496 self.0
497 .iter()
498 .zip(rhs.0.iter())
499 .map(|(x, y)| x * y)
500 .fold(NttPolynomial::default(), |x, y| &x + &y)
501 }
502}
503
504#[derive(Clone, Default, Debug, PartialEq)]
512pub struct NttMatrix<F: Field, K: ArraySize, L: ArraySize>(pub Array<NttVector<F, L>, K>);
513
514impl<F: Field, K: ArraySize, L: ArraySize> NttMatrix<F, K, L> {
515 pub const fn new(x: Array<NttVector<F, L>, K>) -> Self {
517 Self(x)
518 }
519}
520
521impl<F: Field, K: ArraySize, L: ArraySize> Mul<&NttVector<F, L>> for &NttMatrix<F, K, L>
522where
523 for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
524{
525 type Output = NttVector<F, K>;
526
527 fn mul(self, rhs: &NttVector<F, L>) -> NttVector<F, K> {
528 NttVector(self.0.iter().map(|x| x * rhs).collect())
529 }
530}