1use crate::ntt64::arith::{mod_mul_barrett, Ntt64Arith};
29use crate::ntt64::context::{ntt_forward, ntt_inverse, Ntt64Context};
30use alloc::vec;
31use alloc::vec::Vec;
32#[cfg(feature = "rand")]
33use rand::Rng;
34#[cfg(feature = "rand")]
35use rand_distr::{Distribution, Normal};
36
37#[derive(Clone, Debug)]
46pub struct Poly64 {
47 pub data: Vec<u64>,
49 pub is_ntt: bool,
51}
52
53impl Poly64 {
54 #[inline]
60 pub fn new_zero(n: usize) -> Self {
61 Self {
62 data: vec![0u64; n],
63 is_ntt: false,
64 }
65 }
66
67 #[cfg(feature = "rand")]
71 pub fn new_random(n: usize, arith: &Ntt64Arith) -> Self {
72 let mut rng = rand::thread_rng();
73 let q = arith.modulus;
74 let data: Vec<u64> = (0..n).map(|_| rng.gen_range(0..q)).collect();
75 Self {
76 data,
77 is_ntt: false,
78 }
79 }
80
81 #[cfg(feature = "rand")]
88 pub fn new_ternary(n: usize, arith: &Ntt64Arith) -> Self {
89 let mut rng = rand::thread_rng();
90 let q = arith.modulus;
91 let data: Vec<u64> = (0..n)
92 .map(|_| match rng.gen_range(0u32..3) {
93 0 => 0,
94 1 => 1,
95 _ => q - 1,
96 })
97 .collect();
98 Self {
99 data,
100 is_ntt: false,
101 }
102 }
103
104 #[cfg(feature = "rand")]
111 pub fn new_gaussian(n: usize, sigma: f64, arith: &Ntt64Arith) -> Self {
112 let mut rng = rand::thread_rng();
113 let q = arith.modulus;
114 let normal = Normal::new(0.0, sigma).expect("sigma must be > 0");
115 let data: Vec<u64> = (0..n)
116 .map(|_| {
117 let sample: f64 = normal.sample(&mut rng);
118 let rounded = sample.round() as i64;
119 if rounded >= 0 {
120 (rounded as u64) % q
121 } else {
122 let abs_val = (-rounded) as u64;
123 let r = abs_val % q;
124 if r == 0 {
125 0
126 } else {
127 q - r
128 }
129 }
130 })
131 .collect();
132 Self {
133 data,
134 is_ntt: false,
135 }
136 }
137
138 pub fn forward_ntt(&mut self, ntt_ctx: &Ntt64Context) {
147 assert!(!self.is_ntt, "polynomial is already in NTT domain");
148 ntt_forward(&mut self.data, ntt_ctx);
149 self.is_ntt = true;
150 }
151
152 pub fn inverse_ntt(&mut self, ntt_ctx: &Ntt64Context) {
157 assert!(self.is_ntt, "polynomial is not in NTT domain");
158 ntt_inverse(&mut self.data, ntt_ctx);
159 self.is_ntt = false;
160 }
161
162 pub fn add_assign(&mut self, other: &Poly64, arith: &Ntt64Arith) {
173 assert_eq!(
174 self.is_ntt, other.is_ntt,
175 "polynomials must be in the same domain"
176 );
177 assert_eq!(
178 self.data.len(),
179 other.data.len(),
180 "polynomials must have the same size"
181 );
182 let q = arith.modulus;
183 for (a, &b) in self.data.iter_mut().zip(other.data.iter()) {
184 let sum = *a + b;
185 let (sub, borrow) = sum.overflowing_sub(q);
187 *a = if borrow { sum } else { sub };
188 }
189 }
190
191 pub fn sub_assign(&mut self, other: &Poly64, arith: &Ntt64Arith) {
198 assert_eq!(
199 self.is_ntt, other.is_ntt,
200 "polynomials must be in the same domain"
201 );
202 assert_eq!(
203 self.data.len(),
204 other.data.len(),
205 "polynomials must have the same size"
206 );
207 let q = arith.modulus;
208 for (a, &b) in self.data.iter_mut().zip(other.data.iter()) {
209 let (sub, borrow) = (*a).overflowing_sub(b);
210 *a = if borrow { sub.wrapping_add(q) } else { sub };
211 }
212 }
213
214 pub fn mul_assign(&mut self, other: &Poly64, arith: &Ntt64Arith) {
222 assert!(
223 self.is_ntt && other.is_ntt,
224 "both polynomials must be in NTT domain for multiplication"
225 );
226 assert_eq!(
227 self.data.len(),
228 other.data.len(),
229 "polynomials must have the same size"
230 );
231 for (a, &b) in self.data.iter_mut().zip(other.data.iter()) {
232 *a = mod_mul_barrett(*a, b, arith);
233 }
234 }
235
236 pub fn scalar_mul(&mut self, scalar: u64, arith: &Ntt64Arith) {
238 for a in self.data.iter_mut() {
239 *a = mod_mul_barrett(*a, scalar, arith);
240 }
241 }
242
243 pub fn negate(&mut self, arith: &Ntt64Arith) {
245 let q = arith.modulus;
246 for a in self.data.iter_mut() {
247 *a = if *a == 0 { 0 } else { q - *a };
251 }
252 }
253
254 #[inline]
260 pub fn len(&self) -> usize {
261 self.data.len()
262 }
263
264 #[inline]
266 pub fn is_empty(&self) -> bool {
267 self.data.is_empty()
268 }
269}
270
271#[cfg(test)]
279fn naive_poly_mul(a: &[u64], b: &[u64], q: u64) -> Vec<u64> {
280 let n = a.len();
281 assert_eq!(n, b.len());
282 let mut result = vec![0u64; n];
283
284 for i in 0..n {
285 for j in 0..n {
286 let prod = (a[i] as u128) * (b[j] as u128);
287 let idx = i + j;
288 if idx < n {
289 let val = (result[idx] as u128 + prod) % (q as u128);
290 result[idx] = val as u64;
291 } else {
292 let wrapped_idx = idx - n;
293 let val = (result[wrapped_idx] as u128 + (q as u128) - (prod % (q as u128)))
294 % (q as u128);
295 result[wrapped_idx] = val as u64;
296 }
297 }
298 }
299 result
300}
301
302#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::ntt64::arith::Ntt64Arith;
310 use crate::ntt64::context::Ntt64Context;
311
312 const TEST_Q: u64 = 7681;
314 const TEST_N: usize = 256;
315
316 fn test_arith() -> Ntt64Arith {
317 Ntt64Arith::new(TEST_Q)
318 }
319
320 fn test_ntt_ctx() -> Ntt64Context {
321 Ntt64Context::new(TEST_N, test_arith())
322 }
323
324 #[test]
325 fn test_poly_add_sub() {
326 let arith = test_arith();
327 let a = Poly64::new_random(TEST_N, &arith);
328 let b = Poly64::new_random(TEST_N, &arith);
329
330 let mut c = a.clone();
331 c.add_assign(&b, &arith);
332 c.sub_assign(&b, &arith);
333
334 for i in 0..TEST_N {
335 assert_eq!(c.data[i], a.data[i], "add/sub roundtrip fails at index {i}");
336 }
337 }
338
339 #[test]
340 fn test_poly_add_commutative() {
341 let arith = test_arith();
342 let a = Poly64::new_random(TEST_N, &arith);
343 let b = Poly64::new_random(TEST_N, &arith);
344
345 let mut ab = a.clone();
346 ab.add_assign(&b, &arith);
347
348 let mut ba = b.clone();
349 ba.add_assign(&a, &arith);
350
351 for i in 0..TEST_N {
352 assert_eq!(ab.data[i], ba.data[i], "add not commutative at index {i}");
353 }
354 }
355
356 #[test]
357 fn test_poly_negate() {
358 let arith = test_arith();
359 let a = Poly64::new_random(TEST_N, &arith);
360
361 let mut neg_a = a.clone();
362 neg_a.negate(&arith);
363
364 let mut sum = a.clone();
365 sum.add_assign(&neg_a, &arith);
366
367 for i in 0..TEST_N {
368 assert_eq!(sum.data[i], 0, "a + (-a) != 0 at index {i}");
369 }
370 }
371
372 #[test]
373 fn test_poly_scalar_mul() {
374 let arith = test_arith();
375 let a = Poly64::new_random(TEST_N, &arith);
376
377 let mut doubled = a.clone();
378 doubled.scalar_mul(2, &arith);
379
380 let mut sum = a.clone();
381 sum.add_assign(&a, &arith);
382
383 for i in 0..TEST_N {
384 assert_eq!(doubled.data[i], sum.data[i], "2*a != a+a at index {i}");
385 }
386 }
387
388 #[test]
389 fn test_poly_mul_ntt() {
390 let arith = test_arith();
391 let ntt_ctx = test_ntt_ctx();
392
393 let mut a = Poly64::new_zero(TEST_N);
394 a.data[0] = 1;
395 a.data[1] = 1;
396
397 let mut b = Poly64::new_zero(TEST_N);
398 b.data[0] = 1;
399 b.data[2] = 1;
400
401 let expected = naive_poly_mul(&a.data, &b.data, TEST_Q);
402
403 a.forward_ntt(&ntt_ctx);
404 b.forward_ntt(&ntt_ctx);
405 a.mul_assign(&b, &arith);
406 a.inverse_ntt(&ntt_ctx);
407
408 for i in 0..TEST_N {
409 assert_eq!(a.data[i], expected[i], "NTT mul != naive at index {i}");
410 }
411 }
412
413 #[test]
414 fn test_poly_mul_random_ntt() {
415 let arith = test_arith();
416 let ntt_ctx = test_ntt_ctx();
417
418 let a_orig = Poly64::new_random(TEST_N, &arith);
419 let b_orig = Poly64::new_random(TEST_N, &arith);
420
421 let expected = naive_poly_mul(&a_orig.data, &b_orig.data, TEST_Q);
422
423 let mut a = a_orig.clone();
424 let mut b = b_orig.clone();
425 a.forward_ntt(&ntt_ctx);
426 b.forward_ntt(&ntt_ctx);
427 a.mul_assign(&b, &arith);
428 a.inverse_ntt(&ntt_ctx);
429
430 for i in 0..TEST_N {
431 assert_eq!(a.data[i], expected[i], "NTT mul != naive at index {i}");
432 }
433 }
434
435 #[test]
436 fn test_ternary_distribution() {
437 let arith = test_arith();
438 let poly = Poly64::new_ternary(1024, &arith);
439
440 for (i, &coeff) in poly.data.iter().enumerate() {
441 assert!(
442 coeff == 0 || coeff == 1 || coeff == TEST_Q - 1,
443 "invalid ternary coefficient at index {i}: {coeff}"
444 );
445 }
446
447 let count_zero = poly.data.iter().filter(|&&c| c == 0).count();
448 let count_one = poly.data.iter().filter(|&&c| c == 1).count();
449 let count_neg = poly.data.iter().filter(|&&c| c == TEST_Q - 1).count();
450
451 assert!(count_zero > 0);
452 assert!(count_one > 0);
453 assert!(count_neg > 0);
454 }
455
456 #[test]
457 fn test_gaussian_distribution() {
458 let arith = test_arith();
459 let sigma = 3.2;
460 let n = 8192;
461 let poly = Poly64::new_gaussian(n, sigma, &arith);
462
463 let q = TEST_Q as f64;
464 let half_q = q / 2.0;
465 let centered: Vec<f64> = poly
466 .data
467 .iter()
468 .map(|&c| {
469 let c = c as f64;
470 if c > half_q {
471 c - q
472 } else {
473 c
474 }
475 })
476 .collect();
477
478 let mean = centered.iter().sum::<f64>() / n as f64;
479 assert!(mean.abs() < 0.5, "mean too far from 0: {mean}");
480
481 let variance = centered.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
482 let std_dev = variance.sqrt();
483 assert!(
484 (std_dev - sigma).abs() < 1.0,
485 "stddev too far from {sigma}: {std_dev}"
486 );
487 }
488
489 #[test]
490 fn test_ntt_roundtrip() {
491 let arith = test_arith();
492 let ntt_ctx = test_ntt_ctx();
493 let original = Poly64::new_random(TEST_N, &arith);
494
495 let mut poly = original.clone();
496 poly.forward_ntt(&ntt_ctx);
497 assert!(poly.is_ntt);
498 poly.inverse_ntt(&ntt_ctx);
499 assert!(!poly.is_ntt);
500
501 for i in 0..TEST_N {
502 assert_eq!(
503 poly.data[i], original.data[i],
504 "NTT roundtrip fails at index {i}"
505 );
506 }
507 }
508
509 #[test]
510 fn test_new_zero() {
511 let poly = Poly64::new_zero(64);
512 assert_eq!(poly.len(), 64);
513 assert!(!poly.is_ntt);
514 for &c in &poly.data {
515 assert_eq!(c, 0);
516 }
517 }
518
519 #[test]
520 #[should_panic(expected = "already in NTT domain")]
521 fn test_double_forward_ntt_panics() {
522 let arith = test_arith();
523 let ntt_ctx = test_ntt_ctx();
524 let mut poly = Poly64::new_random(TEST_N, &arith);
525 poly.forward_ntt(&ntt_ctx);
526 poly.forward_ntt(&ntt_ctx);
527 }
528
529 #[test]
530 #[should_panic(expected = "not in NTT domain")]
531 fn test_inverse_ntt_without_forward_panics() {
532 let arith = test_arith();
533 let ntt_ctx = test_ntt_ctx();
534 let mut poly = Poly64::new_random(TEST_N, &arith);
535 poly.inverse_ntt(&ntt_ctx);
536 }
537}