1use subtle::{
4 Choice,
5 ConditionallySelectable,
6 ConstantTimeGreater,
7};
8use zeroize::{
9 Zeroize,
10 ZeroizeOnDrop,
11};
12
13use crate::coeff::{
14 COEFFICIENTS_IN_SIMD_UNIT,
15 Coefficients,
16 FieldElement,
17 SIMD_UNITS_IN_RING_ELEMENT,
18};
19use crate::constants::{
20 COEFFICIENTS_IN_RING_ELEMENT,
21 FIELD_MODULUS,
22};
23use crate::field::{
24 add_coeffs,
25 reduce_element,
26 reduce_poly_simd,
27 subtract_coeffs,
28};
29use crate::ntt::{
30 intt_montgomery,
31 ntt_forward_simd,
32 ntt_multiply_montgomery,
33};
34
35#[inline]
36fn ct_gt_i32(a: i32, b: i32) -> Choice {
37 let flip = 1u32 << 31;
38 let a_u = (a as u32) ^ flip;
39 let b_u = (b as u32) ^ flip;
40 a_u.ct_gt(&b_u)
41}
42
43#[inline]
44fn centered_abs_i32(coefficient: i32) -> i32 {
45 let sign = coefficient >> 31;
46 coefficient - (sign & (coefficient << 1))
47}
48
49#[derive(Clone, Debug, Eq, PartialEq, Hash, Zeroize, ZeroizeOnDrop)]
51pub struct Poly {
52 pub coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT],
54}
55
56impl Poly {
57 #[must_use]
59 pub const fn zero() -> Self {
60 Self {
61 coeffs: [0; COEFFICIENTS_IN_RING_ELEMENT],
62 }
63 }
64
65 #[must_use]
67 pub const fn from_coeffs(coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT]) -> Self {
68 Self { coeffs }
69 }
70
71 pub fn add_assign(&mut self, rhs: &Self) {
73 for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
74 self.coeffs[i] = reduce_element(self.coeffs[i] + rhs.coeffs[i]);
75 }
76 }
77
78 pub fn sub_assign(&mut self, rhs: &Self) {
80 for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
81 self.coeffs[i] = reduce_element(self.coeffs[i] - rhs.coeffs[i]);
82 }
83 }
84
85 pub fn scalar_mul_assign(&mut self, k: i32) {
87 let q = FIELD_MODULUS as i64;
88 for c in &mut self.coeffs {
89 let wide = *c as i64 * k as i64;
90 *c = reduce_element(wide.rem_euclid(q) as i32);
91 }
92 }
93
94 #[must_use]
96 pub fn mul_negacyclic(&self, rhs: &Self) -> Self {
97 let mut acc = [0i64; COEFFICIENTS_IN_RING_ELEMENT];
98 let q = FIELD_MODULUS as i64;
99 for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
100 for j in 0..COEFFICIENTS_IN_RING_ELEMENT {
101 let k = i + j;
102 let prod = (self.coeffs[i] as i64).wrapping_mul(rhs.coeffs[j] as i64);
103 if k < COEFFICIENTS_IN_RING_ELEMENT {
104 acc[k] += prod;
105 } else {
106 let idx = k - COEFFICIENTS_IN_RING_ELEMENT;
107 acc[idx] -= prod;
108 }
109 }
110 }
111 let mut out = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
112 for (o, a) in out.iter_mut().zip(acc) {
113 let mut r = a % q;
114 if r < 0 {
115 r += q;
116 }
117 *o = reduce_element(r as i32);
118 }
119 Self { coeffs: out }
120 }
121
122 #[must_use]
128 pub fn infinity_norm(&self) -> i32 {
129 let half = FIELD_MODULUS / 2;
130 let mut m = 0i32;
131 for &c in &self.coeffs {
132 let gt_half = ct_gt_i32(c, half);
133 let centered = i32::conditional_select(&c, &c.wrapping_sub(FIELD_MODULUS), gt_half);
134 let abs = centered_abs_i32(centered);
135 let gt_max = ct_gt_i32(abs, m);
136 m = i32::conditional_select(&m, &abs, gt_max);
137 }
138 m
139 }
140
141 #[must_use]
143 pub fn norm_within_bound(&self, bound: i32) -> Choice {
144 let exceeds = ct_gt_i32(self.infinity_norm(), bound);
145 exceeds ^ Choice::from(1u8)
146 }
147
148 pub fn normalize_mod_q_assign(&mut self) {
151 let q = FIELD_MODULUS;
152 for c in &mut self.coeffs {
153 *c = reduce_element(*c);
154 let sign = *c >> 31;
155 *c += sign & q;
156 }
157 }
158
159 #[must_use]
161 pub fn scalar_mul_by_u32_mod_q(&self, scalar: u32) -> Poly {
162 let q = FIELD_MODULUS as i64;
163 let r = (scalar % FIELD_MODULUS as u32) as i64;
164 let mut out = self.clone();
165 for c in &mut out.coeffs {
166 let v = (*c as i64 * r).rem_euclid(q) as i32;
167 *c = reduce_element(v);
168 }
169 out
170 }
171
172 #[must_use]
174 pub fn to_simd(&self) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
175 let mut s = [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT];
176 for (i, lane) in s.iter_mut().enumerate() {
177 let base = i * COEFFICIENTS_IN_SIMD_UNIT;
178 lane.values
179 .copy_from_slice(&self.coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT]);
180 }
181 s
182 }
183
184 fn from_simd(simd: &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT]) -> Self {
185 let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
186 for (i, lane) in simd.iter().enumerate() {
187 let base = i * COEFFICIENTS_IN_SIMD_UNIT;
188 coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&lane.values);
189 }
190 Self { coeffs }
191 }
192
193 #[must_use]
195 pub fn to_ntt(&self) -> NttPoly {
196 let mut simd = self.to_simd();
197 ntt_forward_simd(&mut simd);
198 NttPoly { simd }
199 }
200}
201
202#[derive(Clone, Debug, PartialEq, Eq, Hash)]
204pub struct NttPoly {
205 pub(crate) simd: [Coefficients; SIMD_UNITS_IN_RING_ELEMENT],
206}
207
208impl NttPoly {
209 #[must_use]
211 pub fn zero() -> Self {
212 Self {
213 simd: [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT],
214 }
215 }
216
217 #[must_use]
219 pub fn packed_ntt_coefficients(&self) -> [FieldElement; COEFFICIENTS_IN_RING_ELEMENT] {
220 let mut c = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
221 for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
222 let base = i * COEFFICIENTS_IN_SIMD_UNIT;
223 c[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&self.simd[i].values);
224 }
225 c
226 }
227
228 #[must_use]
230 pub fn as_simd(&self) -> &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
231 &self.simd
232 }
233
234 pub fn as_simd_mut(&mut self) -> &mut [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
236 &mut self.simd
237 }
238
239 pub fn pointwise_mul_assign(&mut self, rhs: &Self) {
241 for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
242 ntt_multiply_montgomery(&mut self.simd[i], &rhs.simd[i]);
243 }
244 }
245
246 #[must_use]
248 pub fn to_poly(mut self) -> Poly {
249 intt_montgomery(&mut self.simd);
250 reduce_poly_simd(&mut self.simd);
251 Poly::from_simd(&self.simd)
252 }
253
254 pub fn add_assign(&mut self, rhs: &Self) {
256 for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
257 add_coeffs(&mut self.simd[i], &rhs.simd[i]);
258 }
259 }
260
261 pub fn sub_assign(&mut self, rhs: &Self) {
263 for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
264 subtract_coeffs(&mut self.simd[i], &rhs.simd[i]);
265 }
266 }
267}
268
269#[must_use]
271pub fn simd_from_i256(
272 buf: &[i32; COEFFICIENTS_IN_RING_ELEMENT],
273) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
274 Poly::from_coeffs(*buf).to_simd()
275}
276
277#[must_use]
279pub fn polys_norm_within_bound(polys: &[Poly], bound: i32) -> Choice {
280 let mut acc = Choice::from(1u8);
281 for p in polys {
282 acc &= p.norm_within_bound(bound);
283 }
284 acc
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 fn lcg_step(state: &mut u64) -> u32 {
292 *state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
293 (*state >> 32) as u32
294 }
295
296 fn small_poly(state: &mut u64, bound: i32) -> Poly {
297 let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
298 let width = (2 * bound + 1) as u32;
299 for c in &mut coeffs {
300 let v = (lcg_step(state) % width) as i32;
301 *c = v - bound;
302 }
303 Poly::from_coeffs(coeffs)
304 }
305
306 #[test]
307 fn ntt_inverse_has_expected_linear_scale() {
308 let mut one = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
309 one[0] = 1;
310 let scale = Poly::from_coeffs(one).to_ntt().to_poly().coeffs[0];
311
312 let mut st = 0xC0DEC0DE_u64;
313 for _ in 0..16 {
314 let p = small_poly(&mut st, 8);
315 let back = p.clone().to_ntt().to_poly();
316 for (orig, got) in p.coeffs.iter().zip(back.coeffs.iter()) {
317 let expected = reduce_element((*orig as i64 * scale as i64) as i32);
318 assert_eq!(expected, *got);
319 }
320 }
321 }
322
323 #[test]
324 fn ntt_pointwise_matches_schoolbook_for_small_coeffs() {
325 let mut st = 0xDEADBEEF_u64;
326 for _ in 0..4 {
327 let a = small_poly(&mut st, 8);
328 let b = small_poly(&mut st, 8);
329 let schoolbook = a.mul_negacyclic(&b);
330
331 let mut ntt = a.to_ntt();
332 let b_ntt = b.to_ntt();
333 ntt.pointwise_mul_assign(&b_ntt);
334 let back = ntt.to_poly();
335
336 assert_eq!(schoolbook, back);
337 }
338 }
339
340 fn infinity_norm_branchy_reference(p: &Poly) -> i32 {
341 let half = FIELD_MODULUS / 2;
342 let mut m = 0i32;
343 for &c in &p.coeffs {
344 let v = if c > half { c - FIELD_MODULUS } else { c };
345 m = m.max(v.abs());
346 }
347 m
348 }
349
350 #[test]
351 fn infinity_norm_matches_branchy_reference() {
352 let q = FIELD_MODULUS;
353 let mut st = 0xA11CE_u64;
354 for _ in 0..256 {
355 let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
356 for c in &mut coeffs {
357 *c = (lcg_step(&mut st) as i32) % q;
358 }
359 let p = Poly::from_coeffs(coeffs);
360 assert_eq!(p.infinity_norm(), infinity_norm_branchy_reference(&p));
361 }
362 for &edge in &[0, 1, q / 2, q / 2 + 1, q - 1] {
363 let mut p = Poly::zero();
364 p.coeffs[0] = edge;
365 p.coeffs[1] = -edge;
366 assert_eq!(p.infinity_norm(), infinity_norm_branchy_reference(&p));
367 }
368 }
369
370 #[test]
371 fn normalize_mod_q_and_scalar_mul_smoke() {
372 let mut p = Poly::zero();
373 p.coeffs[0] = FIELD_MODULUS + 5;
374 p.normalize_mod_q_assign();
375 assert!((0..FIELD_MODULUS).contains(&p.coeffs[0]));
376 p.coeffs[1] = -3;
377 p.normalize_mod_q_assign();
378 assert!((0..FIELD_MODULUS).contains(&p.coeffs[1]));
379 let scaled = p.scalar_mul_by_u32_mod_q(3);
380 assert_eq!(scaled.coeffs[0], reduce_element(p.coeffs[0] * 3));
381 }
382}