1use subtle::{
4 Choice,
5 ConditionallySelectable,
6 ConstantTimeGreater,
7};
8use zeroize::{
9 Zeroize,
10 ZeroizeOnDrop,
11};
12
13use crate::coeff::{
14 COEFFICIENTS_IN_SIMD_UNIT,
15 Coefficients,
16 FieldElement,
17 SIMD_UNITS_IN_RING_ELEMENT,
18};
19use crate::constants::{
20 COEFFICIENTS_IN_RING_ELEMENT,
21 FIELD_MODULUS,
22};
23use crate::field::{
24 add_coeffs,
25 reduce_element,
26 reduce_poly_simd,
27 subtract_coeffs,
28};
29use crate::ntt::{
30 intt_montgomery,
31 ntt_forward_simd,
32 ntt_multiply_montgomery,
33};
34
35#[inline]
36fn ct_gt_i32(a: i32, b: i32) -> Choice {
37 let flip = 1u32 << 31;
38 let a_u = (a as u32) ^ flip;
39 let b_u = (b as u32) ^ flip;
40 a_u.ct_gt(&b_u)
41}
42
43#[inline]
44fn centered_abs_i32(coefficient: i32) -> i32 {
45 let sign = coefficient >> 31;
46 coefficient - (sign & (coefficient << 1))
47}
48
49#[derive(Clone, Debug, Eq, PartialEq, Hash, Zeroize, ZeroizeOnDrop)]
51pub struct Poly {
52 pub coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT],
54}
55
56impl Poly {
57 #[must_use]
59 pub const fn zero() -> Self {
60 Self {
61 coeffs: [0; COEFFICIENTS_IN_RING_ELEMENT],
62 }
63 }
64
65 #[must_use]
67 pub const fn from_coeffs(coeffs: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT]) -> Self {
68 Self { coeffs }
69 }
70
71 pub fn add_assign(&mut self, rhs: &Self) {
73 for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
74 self.coeffs[i] = reduce_element(self.coeffs[i] + rhs.coeffs[i]);
75 }
76 }
77
78 pub fn sub_assign(&mut self, rhs: &Self) {
80 for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
81 self.coeffs[i] = reduce_element(self.coeffs[i] - rhs.coeffs[i]);
82 }
83 }
84
85 pub fn scalar_mul_assign(&mut self, k: i32) {
87 for c in &mut self.coeffs {
88 *c = reduce_element((*c as i64 * k as i64) as i32);
89 }
90 }
91
92 #[must_use]
94 pub fn mul_negacyclic(&self, rhs: &Self) -> Self {
95 let mut acc = [0i64; COEFFICIENTS_IN_RING_ELEMENT];
96 let q = FIELD_MODULUS as i64;
97 for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
98 for j in 0..COEFFICIENTS_IN_RING_ELEMENT {
99 let k = i + j;
100 let prod = (self.coeffs[i] as i64).wrapping_mul(rhs.coeffs[j] as i64);
101 if k < COEFFICIENTS_IN_RING_ELEMENT {
102 acc[k] += prod;
103 } else {
104 let idx = k - COEFFICIENTS_IN_RING_ELEMENT;
105 acc[idx] -= prod;
106 }
107 }
108 }
109 let mut out = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
110 for (o, a) in out.iter_mut().zip(acc) {
111 let mut r = a % q;
112 if r < 0 {
113 r += q;
114 }
115 *o = reduce_element(r as i32);
116 }
117 Self { coeffs: out }
118 }
119
120 #[must_use]
126 pub fn infinity_norm(&self) -> i32 {
127 let half = FIELD_MODULUS / 2;
128 let mut m = 0i32;
129 for &c in &self.coeffs {
130 let gt_half = ct_gt_i32(c, half);
131 let centered = i32::conditional_select(&c, &c.wrapping_sub(FIELD_MODULUS), gt_half);
132 let abs = centered_abs_i32(centered);
133 let gt_max = ct_gt_i32(abs, m);
134 m = i32::conditional_select(&m, &abs, gt_max);
135 }
136 m
137 }
138
139 #[must_use]
141 pub fn norm_within_bound(&self, bound: i32) -> Choice {
142 let exceeds = ct_gt_i32(self.infinity_norm(), bound);
143 exceeds ^ Choice::from(1u8)
144 }
145
146 pub fn normalize_mod_q_assign(&mut self) {
149 let q = FIELD_MODULUS;
150 for c in &mut self.coeffs {
151 *c = reduce_element(*c);
152 let sign = *c >> 31;
153 *c += sign & q;
154 }
155 }
156
157 #[must_use]
159 pub fn scalar_mul_by_u32_mod_q(&self, scalar: u32) -> Poly {
160 let q = FIELD_MODULUS as i64;
161 let r = (scalar % FIELD_MODULUS as u32) as i64;
162 let mut out = self.clone();
163 for c in &mut out.coeffs {
164 let v = (*c as i64 * r).rem_euclid(q) as i32;
165 *c = reduce_element(v);
166 }
167 out
168 }
169
170 #[must_use]
172 pub fn to_simd(&self) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
173 let mut s = [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT];
174 for (i, lane) in s.iter_mut().enumerate() {
175 let base = i * COEFFICIENTS_IN_SIMD_UNIT;
176 lane.values
177 .copy_from_slice(&self.coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT]);
178 }
179 s
180 }
181
182 fn from_simd(simd: &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT]) -> Self {
183 let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
184 for (i, lane) in simd.iter().enumerate() {
185 let base = i * COEFFICIENTS_IN_SIMD_UNIT;
186 coeffs[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&lane.values);
187 }
188 Self { coeffs }
189 }
190
191 #[must_use]
193 pub fn to_ntt(&self) -> NttPoly {
194 let mut simd = self.to_simd();
195 ntt_forward_simd(&mut simd);
196 NttPoly { simd }
197 }
198}
199
200#[derive(Clone, Debug, PartialEq, Eq, Hash)]
202pub struct NttPoly {
203 pub(crate) simd: [Coefficients; SIMD_UNITS_IN_RING_ELEMENT],
204}
205
206impl NttPoly {
207 #[must_use]
209 pub fn zero() -> Self {
210 Self {
211 simd: [Coefficients::default(); SIMD_UNITS_IN_RING_ELEMENT],
212 }
213 }
214
215 #[must_use]
217 pub fn packed_ntt_coefficients(&self) -> [FieldElement; COEFFICIENTS_IN_RING_ELEMENT] {
218 let mut c = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
219 for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
220 let base = i * COEFFICIENTS_IN_SIMD_UNIT;
221 c[base..base + COEFFICIENTS_IN_SIMD_UNIT].copy_from_slice(&self.simd[i].values);
222 }
223 c
224 }
225
226 #[must_use]
228 pub fn as_simd(&self) -> &[Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
229 &self.simd
230 }
231
232 pub fn as_simd_mut(&mut self) -> &mut [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
234 &mut self.simd
235 }
236
237 pub fn pointwise_mul_assign(&mut self, rhs: &Self) {
239 for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
240 ntt_multiply_montgomery(&mut self.simd[i], &rhs.simd[i]);
241 }
242 }
243
244 #[must_use]
246 pub fn to_poly(mut self) -> Poly {
247 intt_montgomery(&mut self.simd);
248 reduce_poly_simd(&mut self.simd);
249 Poly::from_simd(&self.simd)
250 }
251
252 pub fn add_assign(&mut self, rhs: &Self) {
254 for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
255 add_coeffs(&mut self.simd[i], &rhs.simd[i]);
256 }
257 }
258
259 pub fn sub_assign(&mut self, rhs: &Self) {
261 for i in 0..SIMD_UNITS_IN_RING_ELEMENT {
262 subtract_coeffs(&mut self.simd[i], &rhs.simd[i]);
263 }
264 }
265}
266
267#[must_use]
269pub fn simd_from_i256(
270 buf: &[i32; COEFFICIENTS_IN_RING_ELEMENT],
271) -> [Coefficients; SIMD_UNITS_IN_RING_ELEMENT] {
272 Poly::from_coeffs(*buf).to_simd()
273}
274
275#[must_use]
277pub fn polys_norm_within_bound(polys: &[Poly], bound: i32) -> Choice {
278 let mut acc = Choice::from(1u8);
279 for p in polys {
280 acc &= p.norm_within_bound(bound);
281 }
282 acc
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 fn lcg_step(state: &mut u64) -> u32 {
290 *state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
291 (*state >> 32) as u32
292 }
293
294 fn small_poly(state: &mut u64, bound: i32) -> Poly {
295 let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
296 let width = (2 * bound + 1) as u32;
297 for c in &mut coeffs {
298 let v = (lcg_step(state) % width) as i32;
299 *c = v - bound;
300 }
301 Poly::from_coeffs(coeffs)
302 }
303
304 #[test]
305 fn ntt_inverse_has_expected_linear_scale() {
306 let mut one = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
307 one[0] = 1;
308 let scale = Poly::from_coeffs(one).to_ntt().to_poly().coeffs[0];
309
310 let mut st = 0xC0DEC0DE_u64;
311 for _ in 0..16 {
312 let p = small_poly(&mut st, 8);
313 let back = p.clone().to_ntt().to_poly();
314 for (orig, got) in p.coeffs.iter().zip(back.coeffs.iter()) {
315 let expected = reduce_element((*orig as i64 * scale as i64) as i32);
316 assert_eq!(expected, *got);
317 }
318 }
319 }
320
321 #[test]
322 fn ntt_pointwise_matches_schoolbook_for_small_coeffs() {
323 let mut st = 0xDEADBEEF_u64;
324 for _ in 0..4 {
325 let a = small_poly(&mut st, 8);
326 let b = small_poly(&mut st, 8);
327 let schoolbook = a.mul_negacyclic(&b);
328
329 let mut ntt = a.to_ntt();
330 let b_ntt = b.to_ntt();
331 ntt.pointwise_mul_assign(&b_ntt);
332 let back = ntt.to_poly();
333
334 assert_eq!(schoolbook, back);
335 }
336 }
337
338 fn infinity_norm_branchy_reference(p: &Poly) -> i32 {
339 let half = FIELD_MODULUS / 2;
340 let mut m = 0i32;
341 for &c in &p.coeffs {
342 let v = if c > half { c - FIELD_MODULUS } else { c };
343 m = m.max(v.abs());
344 }
345 m
346 }
347
348 #[test]
349 fn infinity_norm_matches_branchy_reference() {
350 let q = FIELD_MODULUS;
351 let mut st = 0xA11CE_u64;
352 for _ in 0..256 {
353 let mut coeffs = [0i32; COEFFICIENTS_IN_RING_ELEMENT];
354 for c in &mut coeffs {
355 *c = (lcg_step(&mut st) as i32) % q;
356 }
357 let p = Poly::from_coeffs(coeffs);
358 assert_eq!(p.infinity_norm(), infinity_norm_branchy_reference(&p));
359 }
360 for &edge in &[0, 1, q / 2, q / 2 + 1, q - 1] {
361 let mut p = Poly::zero();
362 p.coeffs[0] = edge;
363 p.coeffs[1] = -edge;
364 assert_eq!(p.infinity_norm(), infinity_norm_branchy_reference(&p));
365 }
366 }
367
368 #[test]
369 fn normalize_mod_q_and_scalar_mul_smoke() {
370 let mut p = Poly::zero();
371 p.coeffs[0] = FIELD_MODULUS + 5;
372 p.normalize_mod_q_assign();
373 assert!((0..FIELD_MODULUS).contains(&p.coeffs[0]));
374 p.coeffs[1] = -3;
375 p.normalize_mod_q_assign();
376 assert!((0..FIELD_MODULUS).contains(&p.coeffs[1]));
377 let scaled = p.scalar_mul_by_u32_mod_q(3);
378 assert_eq!(scaled.coeffs[0], reduce_element(p.coeffs[0] * 3));
379 }
380}