1use super::util::Truncate;
2
3use core::fmt::Debug;
4use core::ops::{Add, Mul, Neg, Sub};
5use hybrid_array::{Array, ArraySize, typenum::U256};
6use num_traits::PrimInt;
7
8#[cfg(feature = "zeroize")]
9use zeroize::Zeroize;
10
11pub trait Field: Copy + Default + Debug + PartialEq {
12 type Int: PrimInt + Default + Debug + From<u8> + Into<u128> + Into<Self::Long> + Truncate<u128>;
13 type Long: PrimInt + From<Self::Int>;
14 type LongLong: PrimInt;
15
16 const Q: Self::Int;
17 const QL: Self::Long;
18 const QLL: Self::LongLong;
19
20 const BARRETT_SHIFT: usize;
21 const BARRETT_MULTIPLIER: Self::LongLong;
22
23 fn small_reduce(x: Self::Int) -> Self::Int;
24 fn barrett_reduce(x: Self::Long) -> Self::Int;
25}
26
27#[macro_export]
38macro_rules! define_field {
39 ($field:ident, $int:ty, $long:ty, $longlong:ty, $q:literal) => {
40 #[derive(Copy, Clone, Default, Debug, PartialEq)]
41 pub struct $field;
42
43 impl Field for $field {
44 type Int = $int;
45 type Long = $long;
46 type LongLong = $longlong;
47
48 const Q: Self::Int = $q;
49 const QL: Self::Long = $q;
50 const QLL: Self::LongLong = $q;
51
52 #[allow(clippy::as_conversions)]
53 const BARRETT_SHIFT: usize = 2 * (Self::Q.ilog2() + 1) as usize;
54 #[allow(clippy::integer_division_remainder_used)]
55 const BARRETT_MULTIPLIER: Self::LongLong = (1 << Self::BARRETT_SHIFT) / Self::QLL;
56
57 fn small_reduce(x: Self::Int) -> Self::Int {
58 if x < Self::Q { x } else { x - Self::Q }
59 }
60
61 fn barrett_reduce(x: Self::Long) -> Self::Int {
62 let x: Self::LongLong = x.into();
63 let product = x * Self::BARRETT_MULTIPLIER;
64 let quotient = product >> Self::BARRETT_SHIFT;
65 let remainder = x - quotient * Self::QLL;
66 Self::small_reduce(Truncate::truncate(remainder))
67 }
68 }
69 };
70}
71
72#[derive(Copy, Clone, Default, Debug, PartialEq)]
78pub struct Elem<F: Field>(pub F::Int);
79
80impl<F: Field> Elem<F> {
81 pub const fn new(x: F::Int) -> Self {
82 Self(x)
83 }
84}
85
86#[cfg(feature = "zeroize")]
87impl<F: Field> Zeroize for Elem<F>
88where
89 F::Int: Zeroize,
90{
91 fn zeroize(&mut self) {
92 self.0.zeroize();
93 }
94}
95
96impl<F: Field> Neg for Elem<F> {
97 type Output = Elem<F>;
98
99 fn neg(self) -> Elem<F> {
100 Elem(F::small_reduce(F::Q - self.0))
101 }
102}
103
104impl<F: Field> Add<Elem<F>> for Elem<F> {
105 type Output = Elem<F>;
106
107 fn add(self, rhs: Elem<F>) -> Elem<F> {
108 Elem(F::small_reduce(self.0 + rhs.0))
109 }
110}
111
112impl<F: Field> Sub<Elem<F>> for Elem<F> {
113 type Output = Elem<F>;
114
115 fn sub(self, rhs: Elem<F>) -> Elem<F> {
116 Elem(F::small_reduce(self.0 + F::Q - rhs.0))
117 }
118}
119
120impl<F: Field> Mul<Elem<F>> for Elem<F> {
121 type Output = Elem<F>;
122
123 fn mul(self, rhs: Elem<F>) -> Elem<F> {
124 let lhs: F::Long = self.0.into();
125 let rhs: F::Long = rhs.0.into();
126 let prod = lhs * rhs;
127 Elem(F::barrett_reduce(prod))
128 }
129}
130
131#[derive(Clone, Default, Debug, PartialEq)]
135pub struct Polynomial<F: Field>(pub Array<Elem<F>, U256>);
136
137impl<F: Field> Polynomial<F> {
138 pub const fn new(x: Array<Elem<F>, U256>) -> Self {
139 Self(x)
140 }
141}
142
143#[cfg(feature = "zeroize")]
144impl<F: Field> Zeroize for Polynomial<F>
145where
146 F::Int: Zeroize,
147{
148 fn zeroize(&mut self) {
149 self.0.zeroize();
150 }
151}
152
153impl<F: Field> Add<&Polynomial<F>> for &Polynomial<F> {
154 type Output = Polynomial<F>;
155
156 fn add(self, rhs: &Polynomial<F>) -> Polynomial<F> {
157 Polynomial(
158 self.0
159 .iter()
160 .zip(rhs.0.iter())
161 .map(|(&x, &y)| x + y)
162 .collect(),
163 )
164 }
165}
166
167impl<F: Field> Sub<&Polynomial<F>> for &Polynomial<F> {
168 type Output = Polynomial<F>;
169
170 fn sub(self, rhs: &Polynomial<F>) -> Polynomial<F> {
171 Polynomial(
172 self.0
173 .iter()
174 .zip(rhs.0.iter())
175 .map(|(&x, &y)| x - y)
176 .collect(),
177 )
178 }
179}
180
181impl<F: Field> Mul<&Polynomial<F>> for Elem<F> {
182 type Output = Polynomial<F>;
183
184 fn mul(self, rhs: &Polynomial<F>) -> Polynomial<F> {
185 Polynomial(rhs.0.iter().map(|&x| self * x).collect())
186 }
187}
188
189impl<F: Field> Neg for &Polynomial<F> {
190 type Output = Polynomial<F>;
191
192 fn neg(self) -> Polynomial<F> {
193 Polynomial(self.0.iter().map(|&x| -x).collect())
194 }
195}
196
197#[derive(Clone, Default, Debug, PartialEq)]
200pub struct Vector<F: Field, K: ArraySize>(pub Array<Polynomial<F>, K>);
201
202impl<F: Field, K: ArraySize> Vector<F, K> {
203 pub const fn new(x: Array<Polynomial<F>, K>) -> Self {
204 Self(x)
205 }
206}
207
208#[cfg(feature = "zeroize")]
209impl<F: Field, K: ArraySize> Zeroize for Vector<F, K>
210where
211 F::Int: Zeroize,
212{
213 fn zeroize(&mut self) {
214 self.0.zeroize();
215 }
216}
217
218impl<F: Field, K: ArraySize> Add<&Vector<F, K>> for &Vector<F, K> {
219 type Output = Vector<F, K>;
220
221 fn add(self, rhs: &Vector<F, K>) -> Vector<F, K> {
222 Vector(
223 self.0
224 .iter()
225 .zip(rhs.0.iter())
226 .map(|(x, y)| x + y)
227 .collect(),
228 )
229 }
230}
231
232impl<F: Field, K: ArraySize> Sub<&Vector<F, K>> for &Vector<F, K> {
233 type Output = Vector<F, K>;
234
235 fn sub(self, rhs: &Vector<F, K>) -> Vector<F, K> {
236 Vector(
237 self.0
238 .iter()
239 .zip(rhs.0.iter())
240 .map(|(x, y)| x - y)
241 .collect(),
242 )
243 }
244}
245
246impl<F: Field, K: ArraySize> Mul<&Vector<F, K>> for Elem<F> {
247 type Output = Vector<F, K>;
248
249 fn mul(self, rhs: &Vector<F, K>) -> Vector<F, K> {
250 Vector(rhs.0.iter().map(|x| self * x).collect())
251 }
252}
253
254impl<F: Field, K: ArraySize> Neg for &Vector<F, K> {
255 type Output = Vector<F, K>;
256
257 fn neg(self) -> Vector<F, K> {
258 Vector(self.0.iter().map(|x| -x).collect())
259 }
260}
261
262#[derive(Clone, Default, Debug, PartialEq)]
268pub struct NttPolynomial<F: Field>(pub Array<Elem<F>, U256>);
269
270impl<F: Field> NttPolynomial<F> {
271 pub const fn new(x: Array<Elem<F>, U256>) -> Self {
272 Self(x)
273 }
274}
275
276#[cfg(feature = "zeroize")]
277impl<F: Field> Zeroize for NttPolynomial<F>
278where
279 F::Int: Zeroize,
280{
281 fn zeroize(&mut self) {
282 self.0.zeroize();
283 }
284}
285
286impl<F: Field> Add<&NttPolynomial<F>> for &NttPolynomial<F> {
287 type Output = NttPolynomial<F>;
288
289 fn add(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
290 NttPolynomial(
291 self.0
292 .iter()
293 .zip(rhs.0.iter())
294 .map(|(&x, &y)| x + y)
295 .collect(),
296 )
297 }
298}
299
300impl<F: Field> Sub<&NttPolynomial<F>> for &NttPolynomial<F> {
301 type Output = NttPolynomial<F>;
302
303 fn sub(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
304 NttPolynomial(
305 self.0
306 .iter()
307 .zip(rhs.0.iter())
308 .map(|(&x, &y)| x - y)
309 .collect(),
310 )
311 }
312}
313
314impl<F: Field> Mul<&NttPolynomial<F>> for Elem<F> {
315 type Output = NttPolynomial<F>;
316
317 fn mul(self, rhs: &NttPolynomial<F>) -> NttPolynomial<F> {
318 NttPolynomial(rhs.0.iter().map(|&x| self * x).collect())
319 }
320}
321
322impl<F: Field> Neg for &NttPolynomial<F> {
323 type Output = NttPolynomial<F>;
324
325 fn neg(self) -> NttPolynomial<F> {
326 NttPolynomial(self.0.iter().map(|&x| -x).collect())
327 }
328}
329
330#[derive(Clone, Default, Debug, PartialEq)]
335pub struct NttVector<F: Field, K: ArraySize>(pub Array<NttPolynomial<F>, K>);
336
337impl<F: Field, K: ArraySize> NttVector<F, K> {
338 pub const fn new(x: Array<NttPolynomial<F>, K>) -> Self {
339 Self(x)
340 }
341}
342
343#[cfg(feature = "zeroize")]
344impl<F: Field, K: ArraySize> Zeroize for NttVector<F, K>
345where
346 F::Int: Zeroize,
347{
348 fn zeroize(&mut self) {
349 self.0.zeroize();
350 }
351}
352
353impl<F: Field, K: ArraySize> Add<&NttVector<F, K>> for &NttVector<F, K> {
354 type Output = NttVector<F, K>;
355
356 fn add(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
357 NttVector(
358 self.0
359 .iter()
360 .zip(rhs.0.iter())
361 .map(|(x, y)| x + y)
362 .collect(),
363 )
364 }
365}
366
367impl<F: Field, K: ArraySize> Sub<&NttVector<F, K>> for &NttVector<F, K> {
368 type Output = NttVector<F, K>;
369
370 fn sub(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
371 NttVector(
372 self.0
373 .iter()
374 .zip(rhs.0.iter())
375 .map(|(x, y)| x - y)
376 .collect(),
377 )
378 }
379}
380
381impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttPolynomial<F>
382where
383 for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
384{
385 type Output = NttVector<F, K>;
386
387 fn mul(self, rhs: &NttVector<F, K>) -> NttVector<F, K> {
388 NttVector(rhs.0.iter().map(|x| self * x).collect())
389 }
390}
391
392impl<F: Field, K: ArraySize> Mul<&NttVector<F, K>> for &NttVector<F, K>
393where
394 for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
395{
396 type Output = NttPolynomial<F>;
397
398 fn mul(self, rhs: &NttVector<F, K>) -> NttPolynomial<F> {
399 self.0
400 .iter()
401 .zip(rhs.0.iter())
402 .map(|(x, y)| x * y)
403 .fold(NttPolynomial::default(), |x, y| &x + &y)
404 }
405}
406
407#[derive(Clone, Default, Debug, PartialEq)]
412pub struct NttMatrix<F: Field, K: ArraySize, L: ArraySize>(pub Array<NttVector<F, L>, K>);
413
414impl<F: Field, K: ArraySize, L: ArraySize> NttMatrix<F, K, L> {
415 pub const fn new(x: Array<NttVector<F, L>, K>) -> Self {
416 Self(x)
417 }
418}
419
420impl<F: Field, K: ArraySize, L: ArraySize> Mul<&NttVector<F, L>> for &NttMatrix<F, K, L>
421where
422 for<'a> &'a NttPolynomial<F>: Mul<&'a NttPolynomial<F>, Output = NttPolynomial<F>>,
423{
424 type Output = NttVector<F, K>;
425
426 fn mul(self, rhs: &NttVector<F, L>) -> NttVector<F, K> {
427 NttVector(self.0.iter().map(|x| x * rhs).collect())
428 }
429}