1use super::truncate::Truncate;
2
3use array::{Array, ArraySize, typenum::U256};
4use core::fmt::Debug;
5use core::ops::{Add, Mul, Neg, Sub};
6use num_traits::PrimInt;
7
8#[cfg(feature = "ctutils")]
9use ctutils::{Choice, CtEq, CtEqSlice};
10#[cfg(feature = "zeroize")]
11use zeroize::Zeroize;
12
13pub trait Field: Copy + Default + Debug + PartialEq {
15 type Int: PrimInt + Default + Debug + From<u8> + Into<u128> + Into<Self::Long> + Truncate<u128>;
17 type Long: PrimInt + From<Self::Int>;
19 type LongLong: PrimInt;
21
22 const Q: Self::Int;
24 const QL: Self::Long;
26 const QLL: Self::LongLong;
28
29 const BARRETT_SHIFT: usize;
31 const BARRETT_MULTIPLIER: Self::LongLong;
33
34 fn small_reduce(x: Self::Int) -> Self::Int;
36 fn barrett_reduce(x: Self::Long) -> Self::Int;
38}
39
40#[macro_export]
51macro_rules! define_field {
52 ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal) => {
53 $crate::define_field!($field, $int, $long, $longlong, $q, "Finite field");
54 };
55 ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal, $doc:expr) => {
56 #[doc = $doc]
57 #[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
58 pub struct $field;
59
60 impl $crate::Field for $field {
61 type Int = $int;
62 type Long = $long;
63 type LongLong = $longlong;
64
65 const Q: Self::Int = $q;
66 const QL: Self::Long = $q;
67 const QLL: Self::LongLong = $q;
68
69 #[allow(clippy::as_conversions)]
70 const BARRETT_SHIFT: usize = 2 * (Self::Q.ilog2() + 1) as usize;
71 #[allow(clippy::integer_division_remainder_used)]
72 const BARRETT_MULTIPLIER: Self::LongLong = (1 << Self::BARRETT_SHIFT) / Self::QLL;
73
74 fn small_reduce(x: Self::Int) -> Self::Int {
75 if x < Self::Q { x } else { x - Self::Q }
76 }
77
78 fn barrett_reduce(x: Self::Long) -> Self::Int {
79 let x: Self::LongLong = x.into();
80 let product = x * Self::BARRETT_MULTIPLIER;
81 let quotient = product >> Self::BARRETT_SHIFT;
82 let remainder = x - quotient * Self::QLL;
83 Self::small_reduce($crate::Truncate::truncate(remainder))
84 }
85 }
86 };
87}
88
89#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
98pub struct Elem<F: Field>(pub F::Int);
99
100impl<F: Field> Elem<F> {
101 pub const fn new(x: F::Int) -> Self {
103 Self(x)
104 }
105}
106
107#[cfg(feature = "ctutils")]
108impl<F: Field> CtEq for Elem<F>
109where
110 F::Int: CtEq,
111{
112 fn ct_eq(&self, other: &Self) -> Choice {
113 self.0.ct_eq(&other.0)
114 }
115}
116
117#[cfg(feature = "ctutils")]
118impl<F: Field<Int: CtEq>> CtEqSlice for Elem<F> {}
119
120#[cfg(feature = "zeroize")]
121impl<F: Field> Zeroize for Elem<F>
122where
123 F::Int: Zeroize,
124{
125 fn zeroize(&mut self) {
126 self.0.zeroize();
127 }
128}
129
130impl<F: Field> Neg for Elem<F> {
131 type Output = Elem<F>;
132
133 fn neg(self) -> Elem<F> {
134 Elem(F::small_reduce(F::Q - self.0))
135 }
136}
137
138impl<F: Field> Add<Elem<F>> for Elem<F> {
139 type Output = Elem<F>;
140
141 fn add(self, rhs: Elem<F>) -> Elem<F> {
142 Elem(F::small_reduce(self.0 + rhs.0))
143 }
144}
145
146impl<F: Field> Sub<Elem<F>> for Elem<F> {
147 type Output = Elem<F>;
148
149 fn sub(self, rhs: Elem<F>) -> Elem<F> {
150 Elem(F::small_reduce(self.0 + F::Q - rhs.0))
151 }
152}
153
154impl<F: Field> Mul<Elem<F>> for Elem<F> {
155 type Output = Elem<F>;
156
157 fn mul(self, rhs: Elem<F>) -> Elem<F> {
158 let lhs: F::Long = self.0.into();
159 let rhs: F::Long = rhs.0.into();
160 let prod = lhs * rhs;
161 Elem(F::barrett_reduce(prod))
162 }
163}
164
165#[derive(Clone, Copy, Default, Debug, PartialEq)]
170pub struct Polynomial<F: Field>(pub Array<Elem<F>, U256>);
171
172impl<F: Field> Polynomial<F> {
173 pub const fn new(x: Array<Elem<F>, U256>) -> Self {
175 Self(x)
176 }
177}
178
179#[cfg(feature = "zeroize")]
180impl<F: Field> Zeroize for Polynomial<F>
181where
182 F::Int: Zeroize,
183{
184 fn zeroize(&mut self) {
185 self.0.zeroize();
186 }
187}
188
189impl<F: Field> Add<&Polynomial<F>> for &Polynomial<F> {
190 type Output = Polynomial<F>;
191
192 fn add(self, rhs: &Polynomial<F>) -> Polynomial<F> {
193 Polynomial(
194 self.0
195 .iter()
196 .zip(rhs.0.iter())
197 .map(|(&x, &y)| x + y)
198 .collect(),
199 )
200 }
201}
202
203impl<F: Field> Sub<&Polynomial<F>> for &Polynomial<F> {
204 type Output = Polynomial<F>;
205
206 fn sub(self, rhs: &Polynomial<F>) -> Polynomial<F> {
207 Polynomial(
208 self.0
209 .iter()
210 .zip(rhs.0.iter())
211 .map(|(&x, &y)| x - y)
212 .collect(),
213 )
214 }
215}
216
217impl<F: Field> Mul<&Polynomial<F>> for Elem<F> {
218 type Output = Polynomial<F>;
219
220 fn mul(self, rhs: &Polynomial<F>) -> Polynomial<F> {
221 Polynomial(rhs.0.iter().map(|&x| self * x).collect())
222 }
223}
224
225impl<F: Field> Neg for &Polynomial<F> {
226 type Output = Polynomial<F>;
227
228 fn neg(self) -> Polynomial<F> {
229 Polynomial(self.0.iter().map(|&x| -x).collect())
230 }
231}
232
233#[cfg(feature = "ctutils")]
234impl<F: Field> CtEq for Polynomial<F>
235where
236 F::Int: CtEq,
237{
238 fn ct_eq(&self, other: &Self) -> Choice {
239 self.0.ct_eq(&other.0)
240 }
241}
242
243#[cfg(feature = "ctutils")]
244impl<F: Field<Int: CtEq>> CtEqSlice for Polynomial<F> {}
245
246#[derive(Clone, Default, Debug, PartialEq)]
250pub struct Vector<F: Field, K: ArraySize>(pub Array<Polynomial<F>, K>);
251
252impl<F: Field, K: ArraySize> Vector<F, K> {
253 pub const fn new(x: Array<Polynomial<F>, K>) -> Self {
255 Self(x)
256 }
257}
258
259#[cfg(feature = "zeroize")]
260impl<F: Field, K: ArraySize> Zeroize for Vector<F, K>
261where
262 F::Int: Zeroize,
263{
264 fn zeroize(&mut self) {
265 self.0.zeroize();
266 }
267}
268
269impl<F: Field, K: ArraySize> Add<Vector<F, K>> for Vector<F, K> {
270 type Output = Vector<F, K>;
271 fn add(self, rhs: Vector<F, K>) -> Vector<F, K> {
272 Add::add(&self, &rhs)
273 }
274}
275impl<F: Field, K: ArraySize> Add<&Vector<F, K>> for &Vector<F, K> {
276 type Output = Vector<F, K>;
277
278 fn add(self, rhs: &Vector<F, K>) -> Vector<F, K> {
279 Vector(
280 self.0
281 .iter()
282 .zip(rhs.0.iter())
283 .map(|(x, y)| x + y)
284 .collect(),
285 )
286 }
287}
288
289impl<F: Field, K: ArraySize> Sub<&Vector<F, K>> for &Vector<F, K> {
290 type Output = Vector<F, K>;
291
292 fn sub(self, rhs: &Vector<F, K>) -> Vector<F, K> {
293 Vector(
294 self.0
295 .iter()
296 .zip(rhs.0.iter())
297 .map(|(x, y)| x - y)
298 .collect(),
299 )
300 }
301}
302
303impl<F: Field, K: ArraySize> Mul<&Vector<F, K>> for Elem<F> {
304 type Output = Vector<F, K>;
305
306 fn mul(self, rhs: &Vector<F, K>) -> Vector<F, K> {
307 Vector(rhs.0.iter().map(|x| self * x).collect())
308 }
309}
310
311impl<F: Field, K: ArraySize> Neg for &Vector<F, K> {
312 type Output = Vector<F, K>;
313
314 fn neg(self) -> Vector<F, K> {
315 Vector(self.0.iter().map(|x| -x).collect())
316 }
317}
318
319#[cfg(feature = "ctutils")]
320impl<F: Field, K: ArraySize> CtEq for Vector<F, K>
321where
322 F::Int: CtEq,
323{
324 fn ct_eq(&self, other: &Self) -> Choice {
325 self.0.ct_eq(&other.0)
326 }
327}
328
329#[cfg(feature = "ctutils")]
330impl<F: Field<Int: CtEq>, K: ArraySize> CtEqSlice for Vector<F, K> {}
331
332#[derive(Clone, Default, Debug, Eq, PartialEq)]
342pub struct NttPolynomial<F: Field>(pub Array<Elem<F>, U256>);
343
344impl<F: Field> NttPolynomial<F> {
345 pub const fn new(x: Array<Elem<F>, U256>) -> Self {
347 Self(x)
348 }
349}
350
351impl<F: Field> Add<&NttPolynomial<F>> for &NttPolynomial<F> {
352 type Output = NttPolynomial<F>;
353
354 fn add(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
355 NttPolynomial(
356 self.0
357 .iter()
358 .zip(rhs.0.iter())
359 .map(|(&x, &y)| x + y)
360 .collect(),
361 )
362 }
363}
364
365impl<F: Field> Sub<&NttPolynomial<F>> for &NttPolynomial<F> {
366 type Output = NttPolynomial<F>;
367
368 fn sub(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
369 NttPolynomial(
370 self.0
371 .iter()
372 .zip(rhs.0.iter())
373 .map(|(&x, &y)| x - y)
374 .collect(),
375 )
376 }
377}
378
379impl<F: Field> Mul<&NttPolynomial<F>> for Elem<F> {
380 type Output = NttPolynomial<F>;
381
382 fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
383 NttPolynomial(rhs.0.iter().map(|&x| self * x).collect())
384 }
385}
386
387impl<F> Mul<&NttPolynomial<F>> for &NttPolynomial<F>
388where
389 F: Field + MultiplyNtt,
390{
391 type Output = NttPolynomial<F>;
392
393 fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
394 F::multiply_ntt(self, rhs)
395 }
396}
397
398pub trait MultiplyNtt: Field {
400 fn multiply_ntt(lhs: &NttPolynomial<Self>, rhs: &NttPolynomial<Self>) -> NttPolynomial<Self>;
402}
403
404impl<F: Field> Neg for &NttPolynomial<F> {
405 type Output = NttPolynomial<F>;
406
407 fn neg(self) -> NttPolynomial<F> {
408 NttPolynomial(self.0.iter().map(|&x| -x).collect())
409 }
410}
411
412impl<F: Field> From<Array<Elem<F>, U256>> for NttPolynomial<F> {
413 fn from(f: Array<Elem<F>, U256>) -> NttPolynomial<F> {
414 NttPolynomial(f)
415 }
416}
417
418impl<F: Field> From<NttPolynomial<F>> for Array<Elem<F>, U256> {
419 fn from(f_hat: NttPolynomial<F>) -> Array<Elem<F>, U256> {
420 f_hat.0
421 }
422}
423
424#[cfg(feature = "ctutils")]
425impl<F: Field> CtEq for NttPolynomial<F>
426where
427 F::Int: CtEq,
428{
429 fn ct_eq(&self, other: &Self) -> Choice {
430 self.0.ct_eq(&other.0)
431 }
432}
433
434#[cfg(feature = "ctutils")]
435impl<F: Field<Int: CtEq>> CtEqSlice for NttPolynomial<F> {}
436
437#[cfg(feature = "zeroize")]
438impl<F: Field> Zeroize for NttPolynomial<F>
439where
440 F::Int: Zeroize,
441{
442 fn zeroize(&mut self) {
443 self.0.zeroize();
444 }
445}
446
447#[derive(Clone, Default, Debug, Eq, PartialEq)]
453pub struct NttVector<F: Field, K: ArraySize>(pub Array<NttPolynomial<F>, K>);
454
455impl<F: Field, K: ArraySize> NttVector<F, K> {
456 pub const fn new(x: Array<NttPolynomial<F>, K>) -> Self {
458 Self(x)
459 }
460}
461
462#[cfg(feature = "ctutils")]
463impl<F: Field, K: ArraySize> CtEq for NttVector<F, K>
464where
465 F::Int: CtEq,
466{
467 fn ct_eq(&self, other: &Self) -> Choice {
468 self.0.ct_eq(&other.0)
469 }
470}
471
472#[cfg(feature = "ctutils")]
473impl<F: Field<Int: CtEq>, K: ArraySize> CtEqSlice for NttVector<F, K> {}
474
475#[cfg(feature = "zeroize")]
476impl<F: Field, K: ArraySize> Zeroize for NttVector<F, K>
477where
478 F::Int: Zeroize,
479{
480 fn zeroize(&mut self) {
481 self.0.zeroize();
482 }
483}
484
485impl<F: Field, K: ArraySize> Add<&NttVector<F, K>> for &NttVector<F, K> {
486 type Output = NttVector<F, K>;
487
488 fn add(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
489 NttVector(
490 self.0
491 .iter()
492 .zip(rhs.0.iter())
493 .map(|(x, y)| x + y)
494 .collect(),
495 )
496 }
497}
498
499impl<F: Field, K: ArraySize> Sub<&NttVector<F, K>> for &NttVector<F, K> {
500 type Output = NttVector<F, K>;
501
502 fn sub(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
503 NttVector(
504 self.0
505 .iter()
506 .zip(rhs.0.iter())
507 .map(|(x, y)| x - y)
508 .collect(),
509 )
510 }
511}
512
513impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttPolynomial<F>
514where
515 for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
516{
517 type Output = NttVector<F, K>;
518
519 fn mul(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
520 NttVector(rhs.0.iter().map(|x| self * x).collect())
521 }
522}
523
524impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttVector<F, K>
525where
526 for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
527{
528 type Output = NttPolynomial<F>;
529
530 fn mul(self, rhs: &NttVector<F, K>) -> NttPolynomial<F> {
531 self.0
532 .iter()
533 .zip(rhs.0.iter())
534 .map(|(x, y)| x * y)
535 .fold(NttPolynomial::default(), |x, y| &x + &y)
536 }
537}
538
539#[derive(Clone, Default, Debug, PartialEq)]
547pub struct NttMatrix<F: Field, K: ArraySize, L: ArraySize>(pub Array<NttVector<F, L>, K>);
548
549impl<F: Field, K: ArraySize, L: ArraySize> NttMatrix<F, K, L> {
550 pub const fn new(x: Array<NttVector<F, L>, K>) -> Self {
552 Self(x)
553 }
554}
555
556impl<F: Field, K: ArraySize, L: ArraySize> Mul<&NttVector<F, L>> for &NttMatrix<F, K, L>
557where
558 for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
559{
560 type Output = NttVector<F, K>;
561
562 fn mul(self, rhs: &NttVector<F, L>) -> NttVector<F, K> {
563 NttVector(self.0.iter().map(|x| x * rhs).collect())
564 }
565}
566
567#[cfg(feature = "ctutils")]
568impl<F: Field, K: ArraySize, L: ArraySize> CtEq for NttMatrix<F, K, L>
569where
570 F::Int: CtEq,
571{
572 fn ct_eq(&self, other: &Self) -> Choice {
573 self.0.ct_eq(&other.0)
574 }
575}
576
577#[cfg(feature = "ctutils")]
578impl<F: Field<Int: CtEq>, K: ArraySize, L: ArraySize> CtEqSlice for NttMatrix<F, K, L> {}