1use zeroize::{
4 Zeroize,
5 ZeroizeOnDrop,
6};
7
8use crate::coeff::{
9 COEFFICIENTS_IN_SIMD_UNIT,
10 Coefficients,
11 FieldElement,
12 SIMD_UNITS_IN_RING_ELEMENT,
13};
14use crate::constants::{
15 COEFFICIENTS_IN_RING_ELEMENT,
16 FIELD_MODULUS,
17};
18use crate::field::{
19 add_coeffs,
20 reduce_element,
21 reduce_poly_simd,
22 subtract_coeffs,
23};
24use crate::ntt::{
25 intt_montgomery,
26 ntt_forward_simd,
27 ntt_multiply_montgomery,
28};
29
30#[derive(Clone, Debug, Eq, PartialEq, Hash, Zeroize, ZeroizeOnDrop)]
32pub struct Poly {
33 pub coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT],
35}
36
37impl Poly {
38 #[must_use]
40 pub const fn zero() -> Self {
41 Self {
42 coeffs: [0; COEFFICIENTS_IN_RING_ELEMENT],
43 }
44 }
45
46 #[must_use]
48 pub const fn from_coeffs(coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT]) -> Self {
49 Self { coeffs }
50 }
51
52 pub fn add_assign(&mut self, rhs: &Self) {
54 for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
55 self.coeffs[i] = reduce_element(self.coeffs[i] + rhs.coeffs[i]);
56 }
57 }
58
59 pub fn sub_assign(&mut self, rhs: &Self) {
61 for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
62 self.coeffs[i] = reduce_element(self.coeffs[i] - rhs.coeffs[i]);
63 }
64 }
65
66 pub fn scalar_mul_assign(&mut self, k: i32) {
68 for c in &mut self.coeffs {
69 *c = reduce_element((*c as i64 * k as i64) as i32);
70 }
71 }
72
73 #[must_use]
75 pub fn mul_negacyclic(&self, rhs: &Self) -> Self {
76 let mut acc = [0i64; COEFFICIENTS_IN_RING_ELEMENT];
77 let q = FIELD_MODULUS as i64;
78 for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
79 for j in 0..COEFFICIENTS_IN_RING_ELEMENT {
80 let k = i + j;
81 let prod = (self.coeffs[i] as i64).wrapping_mul(rhs.coeffs[j] as i64);
82 if k < COEFFICIENTS_IN_RING_ELEMENT {
83 acc[k] += prod;
84 } else {
85 let idx = k - COEFFICIENTS_IN_RING_ELEMENT;
86 acc[idx] -= prod;
87 }
88 }
89 }
90 let mut out = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
91 for (o, a) in out.iter_mut().zip(acc) {
92 let mut r = a % q;
93 if r < 0 {
94 r += q;
95 }
96 *o = reduce_element(r as i32);
97 }
98 Self { coeffs: out }
99 }
100
101 #[must_use]
103 pub fn infinity_norm(&self) -> i32 {
104 let half = FIELD_MODULUS / 2;
105 let mut m = 0i32;
106 for &c in &self.coeffs {
107 let v = if c > half { c - FIELD_MODULUS } else { c };
108 m = m.max(v.abs());
109 }
110 m
111 }
112
113 #[must_use]
115 pub fn to_simd(&self) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
116 let mut s = [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT];
117 for (i, lane) in s.iter_mut().enumerate() {
118 let base = i * COEFFICIENTS_IN_SIMD_UNIT;
119 lane.values
120 .copy_from_slice(&self.coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT]);
121 }
122 s
123 }
124
125 fn from_simd(simd: &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT]) -> Self {
126 let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
127 for (i, lane) in simd.iter().enumerate() {
128 let base = i * COEFFICIENTS_IN_SIMD_UNIT;
129 coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&lane.values);
130 }
131 Self { coeffs }
132 }
133
134 #[must_use]
136 pub fn to_ntt(&self) -> NttPoly {
137 let mut simd = self.to_simd();
138 ntt_forward_simd(&mut simd);
139 NttPoly { simd }
140 }
141}
142
143#[derive(Clone, Debug, PartialEq, Eq, Hash)]
145pub struct NttPoly {
146 pub(crate) simd: [Coefficients; SIMD_UNITS_IN_RING_ELEMENT],
147}
148
149impl NttPoly {
150 #[must_use]
152 pub fn zero() -> Self {
153 Self {
154 simd: [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT],
155 }
156 }
157
158 #[must_use]
160 pub fn packed_ntt_coefficients(&self) -> [FieldElement; COEFFICIENTS_IN_RING_ELEMENT] {
161 let mut c = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
162 for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
163 let base = i * COEFFICIENTS_IN_SIMD_UNIT;
164 c[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&self.simd[i].values);
165 }
166 c
167 }
168
169 #[must_use]
171 pub fn as_simd(&self) -> &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
172 &self.simd
173 }
174
175 pub fn as_simd_mut(&mut self) -> &mut [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
177 &mut self.simd
178 }
179
180 pub fn pointwise_mul_assign(&mut self, rhs: &Self) {
182 for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
183 ntt_multiply_montgomery(&mut self.simd[i], &rhs.simd[i]);
184 }
185 }
186
187 #[must_use]
189 pub fn to_poly(mut self) -> Poly {
190 intt_montgomery(&mut self.simd);
191 reduce_poly_simd(&mut self.simd);
192 Poly::from_simd(&self.simd)
193 }
194
195 pub fn add_assign(&mut self, rhs: &Self) {
197 for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
198 add_coeffs(&mut self.simd[i], &rhs.simd[i]);
199 }
200 }
201
202 pub fn sub_assign(&mut self, rhs: &Self) {
204 for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
205 subtract_coeffs(&mut self.simd[i], &rhs.simd[i]);
206 }
207 }
208}
209
210#[must_use]
212pub fn simd_from_i256(
213 buf: &[i32; COEFFICIENTS_IN_RING_ELEMENT],
214) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
215 Poly::from_coeffs(*buf).to_simd()
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221
222 fn lcg_step(state: &mut u64) -> u32 {
223 *state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
224 (*state >> 32) as u32
225 }
226
227 fn small_poly(state: &mut u64, bound: i32) -> Poly {
228 let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
229 let width = (2 * bound + 1) as u32;
230 for c in &mut coeffs {
231 let v = (lcg_step(state) % width) as i32;
232 *c = v - bound;
233 }
234 Poly::from_coeffs(coeffs)
235 }
236
237 #[test]
238 fn ntt_inverse_has_expected_linear_scale() {
239 let mut one = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
240 one[0] = 1;
241 let scale = Poly::from_coeffs(one).to_ntt().to_poly().coeffs[0];
242
243 let mut st = 0xC0DEC0DE_u64;
244 for _ in 0..16 {
245 let p = small_poly(&mut st, 8);
246 let back = p.clone().to_ntt().to_poly();
247 for (orig, got) in p.coeffs.iter().zip(back.coeffs.iter()) {
248 let expected = reduce_element((*orig as i64 * scale as i64) as i32);
249 assert_eq!(expected, *got);
250 }
251 }
252 }
253
254 #[test]
255 fn ntt_pointwise_matches_schoolbook_for_small_coeffs() {
256 let mut st = 0xDEADBEEF_u64;
257 for _ in 0..4 {
258 let a = small_poly(&mut st, 8);
259 let b = small_poly(&mut st, 8);
260 let schoolbook = a.mul_negacyclic(&b);
261
262 let mut ntt = a.to_ntt();
263 let b_ntt = b.to_ntt();
264 ntt.pointwise_mul_assign(&b_ntt);
265 let back = ntt.to_poly();
266
267 assert_eq!(schoolbook, back);
268 }
269 }
270}