1use super::truncate::Truncate;
2
3use array::{Array, ArraySize, typenum::U256};
4use core::fmt::Debug;
5use core::ops::{Add, Mul, Neg, Sub};
6use num_traits::PrimInt;
7
8#[cfg(feature = "ctutils")]
9use ctutils::{Choice, CtEq, CtEqSlice};
10#[cfg(feature = "zeroize")]
11use zeroize::Zeroize;
12
13pub trait Field: Copy + Default + Debug + PartialEq {
15 type Int: PrimInt + Default + Debug + From<u8> + Into<u128> + Into<Self::Long> + Truncate<u128>;
17 type Long: PrimInt + From<Self::Int>;
19 type LongLong: PrimInt;
21
22 const Q: Self::Int;
24 const QL: Self::Long;
26 const QLL: Self::LongLong;
28
29 const BARRETT_SHIFT: usize;
31 const BARRETT_MULTIPLIER: Self::LongLong;
33
34 fn small_reduce(x: Self::Int) -> Self::Int;
36 fn barrett_reduce(x: Self::Long) -> Self::Int;
38}
39
40#[macro_export]
51macro_rules! define_field {
52 ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal) => {
53 $crate::define_field!($field, $int, $long, $longlong, $q, "Finite field");
54 };
55 ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal, $doc:expr) => {
56 #[doc = $doc]
57 #[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
58 pub struct $field;
59
60 impl $crate::Field for $field {
61 type Int = $int;
62 type Long = $long;
63 type LongLong = $longlong;
64
65 const Q: Self::Int = $q;
66 const QL: Self::Long = $q;
67 const QLL: Self::LongLong = $q;
68
69 #[allow(clippy::as_conversions)]
70 const BARRETT_SHIFT: usize = 2 * (Self::Q.ilog2() + 1) as usize;
71 #[allow(clippy::integer_division_remainder_used)]
72 const BARRETT_MULTIPLIER: Self::LongLong = (1 << Self::BARRETT_SHIFT) / Self::QLL;
73
74 fn small_reduce(x: Self::Int) -> Self::Int {
75 let mask = ((x >= Self::Q) as $int).wrapping_neg();
81 x - (Self::Q & mask)
82 }
83
84 fn barrett_reduce(x: Self::Long) -> Self::Int {
85 let x: Self::LongLong = x.into();
86 let product = x * Self::BARRETT_MULTIPLIER;
87 let quotient = product >> Self::BARRETT_SHIFT;
88 let remainder = x - quotient * Self::QLL;
89 Self::small_reduce($crate::Truncate::truncate(remainder))
90 }
91 }
92 };
93}
94
95#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)]
104pub struct Elem<F: Field>(pub F::Int);
105
106impl<F: Field> Elem<F> {
107 pub const fn new(x: F::Int) -> Self {
109 Self(x)
110 }
111}
112
113#[cfg(feature = "ctutils")]
114impl<F: Field> CtEq for Elem<F>
115where
116 F::Int: CtEq,
117{
118 fn ct_eq(&self, other: &Self) -> Choice {
119 self.0.ct_eq(&other.0)
120 }
121}
122
123#[cfg(feature = "ctutils")]
124impl<F: Field<Int: CtEq>> CtEqSlice for Elem<F> {}
125
126#[cfg(feature = "zeroize")]
127impl<F: Field> Zeroize for Elem<F>
128where
129 F::Int: Zeroize,
130{
131 fn zeroize(&mut self) {
132 self.0.zeroize();
133 }
134}
135
136impl<F: Field> Neg for Elem<F> {
137 type Output = Elem<F>;
138
139 fn neg(self) -> Elem<F> {
140 Elem(F::small_reduce(F::Q - self.0))
141 }
142}
143
144impl<F: Field> Add<Elem<F>> for Elem<F> {
145 type Output = Elem<F>;
146
147 fn add(self, rhs: Elem<F>) -> Elem<F> {
148 Elem(F::small_reduce(self.0 + rhs.0))
149 }
150}
151
152impl<F: Field> Sub<Elem<F>> for Elem<F> {
153 type Output = Elem<F>;
154
155 fn sub(self, rhs: Elem<F>) -> Elem<F> {
156 Elem(F::small_reduce(self.0 + F::Q - rhs.0))
157 }
158}
159
160impl<F: Field> Mul<Elem<F>> for Elem<F> {
161 type Output = Elem<F>;
162
163 fn mul(self, rhs: Elem<F>) -> Elem<F> {
164 let lhs: F::Long = self.0.into();
165 let rhs: F::Long = rhs.0.into();
166 let prod = lhs * rhs;
167 Elem(F::barrett_reduce(prod))
168 }
169}
170
171#[derive(Clone, Copy, Default, Debug, PartialEq)]
176pub struct Polynomial<F: Field>(pub Array<Elem<F>, U256>);
177
178impl<F: Field> Polynomial<F> {
179 pub const fn new(x: Array<Elem<F>, U256>) -> Self {
181 Self(x)
182 }
183}
184
185#[cfg(feature = "zeroize")]
186impl<F: Field> Zeroize for Polynomial<F>
187where
188 F::Int: Zeroize,
189{
190 fn zeroize(&mut self) {
191 self.0.zeroize();
192 }
193}
194
195impl<F: Field> Add<&Polynomial<F>> for &Polynomial<F> {
196 type Output = Polynomial<F>;
197
198 fn add(self, rhs: &Polynomial<F>) -> Polynomial<F> {
199 Polynomial(
200 self.0
201 .iter()
202 .zip(rhs.0.iter())
203 .map(|(&x, &y)| x + y)
204 .collect(),
205 )
206 }
207}
208
209impl<F: Field> Sub<&Polynomial<F>> for &Polynomial<F> {
210 type Output = Polynomial<F>;
211
212 fn sub(self, rhs: &Polynomial<F>) -> Polynomial<F> {
213 Polynomial(
214 self.0
215 .iter()
216 .zip(rhs.0.iter())
217 .map(|(&x, &y)| x - y)
218 .collect(),
219 )
220 }
221}
222
223impl<F: Field> Mul<&Polynomial<F>> for Elem<F> {
224 type Output = Polynomial<F>;
225
226 fn mul(self, rhs: &Polynomial<F>) -> Polynomial<F> {
227 Polynomial(rhs.0.iter().map(|&x| self * x).collect())
228 }
229}
230
231impl<F: Field> Neg for &Polynomial<F> {
232 type Output = Polynomial<F>;
233
234 fn neg(self) -> Polynomial<F> {
235 Polynomial(self.0.iter().map(|&x| -x).collect())
236 }
237}
238
239#[cfg(feature = "ctutils")]
240impl<F: Field> CtEq for Polynomial<F>
241where
242 F::Int: CtEq,
243{
244 fn ct_eq(&self, other: &Self) -> Choice {
245 self.0.ct_eq(&other.0)
246 }
247}
248
249#[cfg(feature = "ctutils")]
250impl<F: Field<Int: CtEq>> CtEqSlice for Polynomial<F> {}
251
252#[derive(Clone, Default, Debug, PartialEq)]
256pub struct Vector<F: Field, K: ArraySize>(pub Array<Polynomial<F>, K>);
257
258impl<F: Field, K: ArraySize> Vector<F, K> {
259 pub const fn new(x: Array<Polynomial<F>, K>) -> Self {
261 Self(x)
262 }
263}
264
265#[cfg(feature = "zeroize")]
266impl<F: Field, K: ArraySize> Zeroize for Vector<F, K>
267where
268 F::Int: Zeroize,
269{
270 fn zeroize(&mut self) {
271 self.0.zeroize();
272 }
273}
274
275impl<F: Field, K: ArraySize> Add<Vector<F, K>> for Vector<F, K> {
276 type Output = Vector<F, K>;
277 fn add(self, rhs: Vector<F, K>) -> Vector<F, K> {
278 Add::add(&self, &rhs)
279 }
280}
281impl<F: Field, K: ArraySize> Add<&Vector<F, K>> for &Vector<F, K> {
282 type Output = Vector<F, K>;
283
284 fn add(self, rhs: &Vector<F, K>) -> Vector<F, K> {
285 Vector(
286 self.0
287 .iter()
288 .zip(rhs.0.iter())
289 .map(|(x, y)| x + y)
290 .collect(),
291 )
292 }
293}
294
295impl<F: Field, K: ArraySize> Sub<&Vector<F, K>> for &Vector<F, K> {
296 type Output = Vector<F, K>;
297
298 fn sub(self, rhs: &Vector<F, K>) -> Vector<F, K> {
299 Vector(
300 self.0
301 .iter()
302 .zip(rhs.0.iter())
303 .map(|(x, y)| x - y)
304 .collect(),
305 )
306 }
307}
308
309impl<F: Field, K: ArraySize> Mul<&Vector<F, K>> for Elem<F> {
310 type Output = Vector<F, K>;
311
312 fn mul(self, rhs: &Vector<F, K>) -> Vector<F, K> {
313 Vector(rhs.0.iter().map(|x| self * x).collect())
314 }
315}
316
317impl<F: Field, K: ArraySize> Neg for &Vector<F, K> {
318 type Output = Vector<F, K>;
319
320 fn neg(self) -> Vector<F, K> {
321 Vector(self.0.iter().map(|x| -x).collect())
322 }
323}
324
325#[cfg(feature = "ctutils")]
326impl<F: Field, K: ArraySize> CtEq for Vector<F, K>
327where
328 F::Int: CtEq,
329{
330 fn ct_eq(&self, other: &Self) -> Choice {
331 self.0.ct_eq(&other.0)
332 }
333}
334
335#[cfg(feature = "ctutils")]
336impl<F: Field<Int: CtEq>, K: ArraySize> CtEqSlice for Vector<F, K> {}
337
338#[derive(Clone, Default, Debug, Eq, PartialEq)]
348pub struct NttPolynomial<F: Field>(pub Array<Elem<F>, U256>);
349
350impl<F: Field> NttPolynomial<F> {
351 pub const fn new(x: Array<Elem<F>, U256>) -> Self {
353 Self(x)
354 }
355}
356
357impl<F: Field> Add<&NttPolynomial<F>> for &NttPolynomial<F> {
358 type Output = NttPolynomial<F>;
359
360 fn add(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
361 NttPolynomial(
362 self.0
363 .iter()
364 .zip(rhs.0.iter())
365 .map(|(&x, &y)| x + y)
366 .collect(),
367 )
368 }
369}
370
371impl<F: Field> Sub<&NttPolynomial<F>> for &NttPolynomial<F> {
372 type Output = NttPolynomial<F>;
373
374 fn sub(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
375 NttPolynomial(
376 self.0
377 .iter()
378 .zip(rhs.0.iter())
379 .map(|(&x, &y)| x - y)
380 .collect(),
381 )
382 }
383}
384
385impl<F: Field> Mul<&NttPolynomial<F>> for Elem<F> {
386 type Output = NttPolynomial<F>;
387
388 fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
389 NttPolynomial(rhs.0.iter().map(|&x| self * x).collect())
390 }
391}
392
393impl<F> Mul<&NttPolynomial<F>> for &NttPolynomial<F>
394where
395 F: Field + MultiplyNtt,
396{
397 type Output = NttPolynomial<F>;
398
399 fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
400 F::multiply_ntt(self, rhs)
401 }
402}
403
404pub trait MultiplyNtt: Field {
406 fn multiply_ntt(lhs: &NttPolynomial<Self>, rhs: &NttPolynomial<Self>) -> NttPolynomial<Self>;
408}
409
410impl<F: Field> Neg for &NttPolynomial<F> {
411 type Output = NttPolynomial<F>;
412
413 fn neg(self) -> NttPolynomial<F> {
414 NttPolynomial(self.0.iter().map(|&x| -x).collect())
415 }
416}
417
418impl<F: Field> From<Array<Elem<F>, U256>> for NttPolynomial<F> {
419 fn from(f: Array<Elem<F>, U256>) -> NttPolynomial<F> {
420 NttPolynomial(f)
421 }
422}
423
424impl<F: Field> From<NttPolynomial<F>> for Array<Elem<F>, U256> {
425 fn from(f_hat: NttPolynomial<F>) -> Array<Elem<F>, U256> {
426 f_hat.0
427 }
428}
429
430#[cfg(feature = "ctutils")]
431impl<F: Field> CtEq for NttPolynomial<F>
432where
433 F::Int: CtEq,
434{
435 fn ct_eq(&self, other: &Self) -> Choice {
436 self.0.ct_eq(&other.0)
437 }
438}
439
440#[cfg(feature = "ctutils")]
441impl<F: Field<Int: CtEq>> CtEqSlice for NttPolynomial<F> {}
442
443#[cfg(feature = "zeroize")]
444impl<F: Field> Zeroize for NttPolynomial<F>
445where
446 F::Int: Zeroize,
447{
448 fn zeroize(&mut self) {
449 self.0.zeroize();
450 }
451}
452
453#[derive(Clone, Default, Debug, Eq, PartialEq)]
459pub struct NttVector<F: Field, K: ArraySize>(pub Array<NttPolynomial<F>, K>);
460
461impl<F: Field, K: ArraySize> NttVector<F, K> {
462 pub const fn new(x: Array<NttPolynomial<F>, K>) -> Self {
464 Self(x)
465 }
466}
467
468#[cfg(feature = "ctutils")]
469impl<F: Field, K: ArraySize> CtEq for NttVector<F, K>
470where
471 F::Int: CtEq,
472{
473 fn ct_eq(&self, other: &Self) -> Choice {
474 self.0.ct_eq(&other.0)
475 }
476}
477
478#[cfg(feature = "ctutils")]
479impl<F: Field<Int: CtEq>, K: ArraySize> CtEqSlice for NttVector<F, K> {}
480
481#[cfg(feature = "zeroize")]
482impl<F: Field, K: ArraySize> Zeroize for NttVector<F, K>
483where
484 F::Int: Zeroize,
485{
486 fn zeroize(&mut self) {
487 self.0.zeroize();
488 }
489}
490
491impl<F: Field, K: ArraySize> Add<&NttVector<F, K>> for &NttVector<F, K> {
492 type Output = NttVector<F, K>;
493
494 fn add(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
495 NttVector(
496 self.0
497 .iter()
498 .zip(rhs.0.iter())
499 .map(|(x, y)| x + y)
500 .collect(),
501 )
502 }
503}
504
505impl<F: Field, K: ArraySize> Sub<&NttVector<F, K>> for &NttVector<F, K> {
506 type Output = NttVector<F, K>;
507
508 fn sub(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
509 NttVector(
510 self.0
511 .iter()
512 .zip(rhs.0.iter())
513 .map(|(x, y)| x - y)
514 .collect(),
515 )
516 }
517}
518
519impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttPolynomial<F>
520where
521 for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
522{
523 type Output = NttVector<F, K>;
524
525 fn mul(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
526 NttVector(rhs.0.iter().map(|x| self * x).collect())
527 }
528}
529
530impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttVector<F, K>
531where
532 for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
533{
534 type Output = NttPolynomial<F>;
535
536 fn mul(self, rhs: &NttVector<F, K>) -> NttPolynomial<F> {
537 self.0
538 .iter()
539 .zip(rhs.0.iter())
540 .map(|(x, y)| x * y)
541 .fold(NttPolynomial::default(), |x, y| &x + &y)
542 }
543}
544
545#[derive(Clone, Default, Debug, PartialEq)]
553pub struct NttMatrix<F: Field, K: ArraySize, L: ArraySize>(pub Array<NttVector<F, L>, K>);
554
555impl<F: Field, K: ArraySize, L: ArraySize> NttMatrix<F, K, L> {
556 pub const fn new(x: Array<NttVector<F, L>, K>) -> Self {
558 Self(x)
559 }
560}
561
562impl<F: Field, K: ArraySize, L: ArraySize> Mul<&NttVector<F, L>> for &NttMatrix<F, K, L>
563where
564 for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
565{
566 type Output = NttVector<F, K>;
567
568 fn mul(self, rhs: &NttVector<F, L>) -> NttVector<F, K> {
569 NttVector(self.0.iter().map(|x| x * rhs).collect())
570 }
571}
572
573#[cfg(feature = "ctutils")]
574impl<F: Field, K: ArraySize, L: ArraySize> CtEq for NttMatrix<F, K, L>
575where
576 F::Int: CtEq,
577{
578 fn ct_eq(&self, other: &Self) -> Choice {
579 self.0.ct_eq(&other.0)
580 }
581}
582
583#[cfg(feature = "ctutils")]
584impl<F: Field<Int: CtEq>, K: ArraySize, L: ArraySize> CtEqSlice for NttMatrix<F, K, L> {}