1use super::{Poly, Representation};
4use crate::{Error, Result};
5use itertools::{izip, Itertools};
6use ndarray::Array2;
7use num_bigint::BigUint;
8use std::{
9 cmp::min,
10 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
11};
12use zeroize::Zeroize;
13
14impl AddAssign<&Poly> for Poly {
15 fn add_assign(&mut self, p: &Poly) {
16 assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients);
17 assert_ne!(
18 self.representation,
19 Representation::NttShoup,
20 "Cannot add to a polynomial in NttShoup representation"
21 );
22 assert_eq!(
23 self.representation, p.representation,
24 "Incompatible representations"
25 );
26 debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts");
27 self.allow_variable_time_computations |= p.allow_variable_time_computations;
28 if self.allow_variable_time_computations {
29 izip!(
30 self.coefficients.outer_iter_mut(),
31 p.coefficients.outer_iter(),
32 self.ctx.q.iter()
33 )
34 .for_each(|(mut v1, v2, qi)| unsafe {
35 qi.add_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap())
36 });
37 } else {
38 izip!(
39 self.coefficients.outer_iter_mut(),
40 p.coefficients.outer_iter(),
41 self.ctx.q.iter()
42 )
43 .for_each(|(mut v1, v2, qi)| {
44 qi.add_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap())
45 });
46 }
47 }
48}
49
50impl Add<&Poly> for &Poly {
51 type Output = Poly;
52 fn add(self, p: &Poly) -> Poly {
53 let mut q = self.clone();
54 q += p;
55 q
56 }
57}
58
59impl Add for Poly {
60 type Output = Poly;
61 fn add(self, mut p: Poly) -> Poly {
62 p += &self;
63 p
64 }
65}
66
67impl SubAssign<&Poly> for Poly {
68 fn sub_assign(&mut self, p: &Poly) {
69 assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients);
70 assert_ne!(
71 self.representation,
72 Representation::NttShoup,
73 "Cannot subtract from a polynomial in NttShoup representation"
74 );
75 assert_eq!(
76 self.representation, p.representation,
77 "Incompatible representations"
78 );
79 debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts");
80 self.allow_variable_time_computations |= p.allow_variable_time_computations;
81 if self.allow_variable_time_computations {
82 izip!(
83 self.coefficients.outer_iter_mut(),
84 p.coefficients.outer_iter(),
85 self.ctx.q.iter()
86 )
87 .for_each(|(mut v1, v2, qi)| unsafe {
88 qi.sub_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap())
89 });
90 } else {
91 izip!(
92 self.coefficients.outer_iter_mut(),
93 p.coefficients.outer_iter(),
94 self.ctx.q.iter()
95 )
96 .for_each(|(mut v1, v2, qi)| {
97 qi.sub_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap())
98 });
99 }
100 }
101}
102
103impl Sub<&Poly> for &Poly {
104 type Output = Poly;
105 fn sub(self, p: &Poly) -> Poly {
106 let mut q = self.clone();
107 q -= p;
108 q
109 }
110}
111
112impl MulAssign<&Poly> for Poly {
113 fn mul_assign(&mut self, p: &Poly) {
114 assert!(!p.has_lazy_coefficients);
115 assert_ne!(
116 self.representation,
117 Representation::NttShoup,
118 "Cannot multiply to a polynomial in NttShoup representation"
119 );
120 if self.has_lazy_coefficients && self.representation == Representation::Ntt {
121 assert!(
122 p.representation == Representation::NttShoup,
123 "Can only multiply a polynomial with lazy coefficients by an NttShoup representation."
124 );
125 } else {
126 assert_eq!(
127 self.representation,
128 Representation::Ntt,
129 "Multiplication requires an Ntt representation."
130 );
131 }
132 debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts");
133 self.allow_variable_time_computations |= p.allow_variable_time_computations;
134
135 match p.representation {
136 Representation::Ntt => {
137 if self.allow_variable_time_computations {
138 unsafe {
139 izip!(
140 self.coefficients.outer_iter_mut(),
141 p.coefficients.outer_iter(),
142 self.ctx.q.iter()
143 )
144 .for_each(|(mut v1, v2, qi)| {
145 qi.mul_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap());
146 });
147 }
148 } else {
149 izip!(
150 self.coefficients.outer_iter_mut(),
151 p.coefficients.outer_iter(),
152 self.ctx.q.iter()
153 )
154 .for_each(|(mut v1, v2, qi)| {
155 qi.mul_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap())
156 });
157 }
158 }
159 Representation::NttShoup => {
160 if self.allow_variable_time_computations {
161 izip!(
162 self.coefficients.outer_iter_mut(),
163 p.coefficients.outer_iter(),
164 p.coefficients_shoup.as_ref().unwrap().outer_iter(),
165 self.ctx.q.iter()
166 )
167 .for_each(|(mut v1, v2, v2_shoup, qi)| unsafe {
168 qi.mul_shoup_vec_vt(
169 v1.as_slice_mut().unwrap(),
170 v2.as_slice().unwrap(),
171 v2_shoup.as_slice().unwrap(),
172 )
173 });
174 } else {
175 izip!(
176 self.coefficients.outer_iter_mut(),
177 p.coefficients.outer_iter(),
178 p.coefficients_shoup.as_ref().unwrap().outer_iter(),
179 self.ctx.q.iter()
180 )
181 .for_each(|(mut v1, v2, v2_shoup, qi)| {
182 qi.mul_shoup_vec(
183 v1.as_slice_mut().unwrap(),
184 v2.as_slice().unwrap(),
185 v2_shoup.as_slice().unwrap(),
186 )
187 });
188 }
189 self.has_lazy_coefficients = false
190 }
191 _ => {
192 panic!("Multiplication requires a multipliand in Ntt or NttShoup representation.")
193 }
194 }
195 }
196}
197
198impl MulAssign<&BigUint> for Poly {
199 fn mul_assign(&mut self, p: &BigUint) {
200 assert_ne!(
201 self.representation,
202 Representation::NttShoup,
203 "Cannot multiply a polynomial in NttShoup representation by a scalar"
204 );
205
206 let scalar_crt = self.ctx.rns.project(p);
208
209 if self.allow_variable_time_computations {
210 unsafe {
211 izip!(
212 self.coefficients.outer_iter_mut(),
213 scalar_crt.iter(),
214 self.ctx.q.iter()
215 )
216 .for_each(|(mut v1, scalar_qi, qi)| {
217 qi.scalar_mul_vec_vt(v1.as_slice_mut().unwrap(), *scalar_qi)
218 });
219 }
220 } else {
221 izip!(
222 self.coefficients.outer_iter_mut(),
223 scalar_crt.iter(),
224 self.ctx.q.iter()
225 )
226 .for_each(|(mut v1, scalar_qi, qi)| {
227 qi.scalar_mul_vec(v1.as_slice_mut().unwrap(), *scalar_qi)
228 });
229 }
230 }
231}
232
233impl Mul<&Poly> for &Poly {
234 type Output = Poly;
235 fn mul(self, p: &Poly) -> Poly {
236 match self.representation {
237 Representation::NttShoup => {
238 let mut q = p.clone();
240 if q.representation == Representation::NttShoup {
241 q.coefficients_shoup
242 .as_mut()
243 .unwrap()
244 .as_slice_mut()
245 .unwrap()
246 .zeroize();
247 unsafe { q.override_representation(Representation::Ntt) }
248 }
249 q *= self;
250 q
251 }
252 _ => {
253 let mut q = self.clone();
254 q *= p;
255 q
256 }
257 }
258 }
259}
260
261impl Mul<&BigUint> for &Poly {
262 type Output = Poly;
263 fn mul(self, p: &BigUint) -> Poly {
264 let mut q = self.clone();
265 q *= p;
266 q
267 }
268}
269
270impl Mul<&Poly> for &BigUint {
271 type Output = Poly;
272 fn mul(self, p: &Poly) -> Poly {
273 p * self
274 }
275}
276
277impl Neg for &Poly {
278 type Output = Poly;
279
280 fn neg(self) -> Poly {
281 assert!(!self.has_lazy_coefficients);
282 let mut out = self.clone();
283 if self.allow_variable_time_computations {
284 izip!(out.coefficients.outer_iter_mut(), out.ctx.q.iter())
285 .for_each(|(mut v1, qi)| unsafe { qi.neg_vec_vt(v1.as_slice_mut().unwrap()) });
286 } else {
287 izip!(out.coefficients.outer_iter_mut(), out.ctx.q.iter())
288 .for_each(|(mut v1, qi)| qi.neg_vec(v1.as_slice_mut().unwrap()));
289 }
290 out
291 }
292}
293
294impl Neg for Poly {
295 type Output = Poly;
296
297 fn neg(mut self) -> Poly {
298 assert!(!self.has_lazy_coefficients);
299 if self.allow_variable_time_computations {
300 izip!(self.coefficients.outer_iter_mut(), self.ctx.q.iter())
301 .for_each(|(mut v1, qi)| unsafe { qi.neg_vec_vt(v1.as_slice_mut().unwrap()) });
302 } else {
303 izip!(self.coefficients.outer_iter_mut(), self.ctx.q.iter())
304 .for_each(|(mut v1, qi)| qi.neg_vec(v1.as_slice_mut().unwrap()));
305 }
306 self
307 }
308}
309
310unsafe fn fma(out: &mut [u128], x: &[u64], y: &[u64]) {
312 let n = out.len();
313 assert_eq!(x.len(), n);
314 assert_eq!(y.len(), n);
315
316 macro_rules! fma_at {
317 ($idx:expr) => {
318 *out.get_unchecked_mut($idx) +=
319 (*x.get_unchecked($idx) as u128) * (*y.get_unchecked($idx) as u128);
320 };
321 }
322
323 let r = n / 16;
324 for i in 0..r {
325 fma_at!(16 * i);
326 fma_at!(16 * i + 1);
327 fma_at!(16 * i + 2);
328 fma_at!(16 * i + 3);
329 fma_at!(16 * i + 4);
330 fma_at!(16 * i + 5);
331 fma_at!(16 * i + 6);
332 fma_at!(16 * i + 7);
333 fma_at!(16 * i + 8);
334 fma_at!(16 * i + 9);
335 fma_at!(16 * i + 10);
336 fma_at!(16 * i + 11);
337 fma_at!(16 * i + 12);
338 fma_at!(16 * i + 13);
339 fma_at!(16 * i + 14);
340 fma_at!(16 * i + 15);
341 }
342
343 for i in 0..n % 16 {
344 fma_at!(16 * r + i);
345 }
346}
347
348pub fn dot_product<'a, 'b, I, J>(p: I, q: J) -> Result<Poly>
352where
353 I: Iterator<Item = &'a Poly> + Clone,
354 J: Iterator<Item = &'b Poly> + Clone,
355{
356 debug_assert!(!p
357 .clone()
358 .any(|pi| pi.representation == Representation::PowerBasis));
359 debug_assert!(!q
360 .clone()
361 .any(|qi| qi.representation == Representation::PowerBasis));
362
363 let count = min(p.clone().count(), q.clone().count());
364 if count == 0 {
365 return Err(Error::Default("At least one iterator is empty".to_string()));
366 }
367
368 let p_first = p.clone().next().unwrap();
369
370 let mut acc: Array2<u128> = Array2::zeros((p_first.ctx.q.len(), p_first.ctx.degree));
372 let acc_ptr = acc.as_mut_ptr();
373
374 let mut num_acc = vec![1u128; p_first.ctx.q.len()];
376 let num_acc_ptr = num_acc.as_mut_ptr();
377
378 let max_acc = p_first
380 .ctx
381 .q
382 .iter()
383 .map(|qi| 1u128 << (2 * (*qi).leading_zeros()))
384 .collect_vec();
385 let max_acc_ptr = max_acc.as_ptr();
386
387 let q_ptr = p_first.ctx.q.as_ptr();
388 let degree = p_first.ctx.degree as isize;
389
390 let min_of_max = max_acc.iter().min().unwrap();
391
392 let out_slice = acc.as_slice_mut().unwrap();
393 if count as u128 > *min_of_max {
394 for (pi, qi) in izip!(p, q) {
395 let pij = pi.coefficients();
396 let qij = qi.coefficients();
397 let pi_slice = pij.as_slice().unwrap();
398 let qi_slice = qij.as_slice().unwrap();
399 unsafe {
400 fma(out_slice, pi_slice, qi_slice);
401
402 for j in 0..p_first.ctx.q.len() as isize {
403 let qj = &*q_ptr.offset(j);
404 *num_acc_ptr.offset(j) += 1;
405 if *num_acc_ptr.offset(j) == *max_acc_ptr.offset(j) {
406 if p_first.allow_variable_time_computations {
407 for i in j * degree..(j + 1) * degree {
408 *acc_ptr.offset(i) = qj.reduce_u128_vt(*acc_ptr.offset(i)) as u128;
409 }
410 } else {
411 for i in j * degree..(j + 1) * degree {
412 *acc_ptr.offset(i) = qj.reduce_u128(*acc_ptr.offset(i)) as u128;
413 }
414 }
415 *num_acc_ptr.offset(j) = 1;
416 }
417 }
418 }
419 }
420 } else {
421 for (pi, qi) in izip!(p, q) {
424 let pij = pi.coefficients();
425 let qij = qi.coefficients();
426 let pi_slice = pij.as_slice().unwrap();
427 let qi_slice = qij.as_slice().unwrap();
428 unsafe { fma(out_slice, pi_slice, qi_slice) }
429 }
430 }
431 let mut coeffs: Array2<u64> = Array2::zeros((p_first.ctx.q.len(), p_first.ctx.degree));
433 izip!(
434 coeffs.outer_iter_mut(),
435 acc.outer_iter(),
436 p_first.ctx.q.iter()
437 )
438 .for_each(|(mut coeffsj, accj, m)| {
439 if p_first.allow_variable_time_computations {
440 izip!(coeffsj.iter_mut(), accj.iter())
441 .for_each(|(cj, accjk)| *cj = unsafe { m.reduce_u128_vt(*accjk) });
442 } else {
443 izip!(coeffsj.iter_mut(), accj.iter())
444 .for_each(|(cj, accjk)| *cj = m.reduce_u128(*accjk));
445 }
446 });
447
448 Ok(Poly {
449 ctx: p_first.ctx.clone(),
450 representation: Representation::Ntt,
451 allow_variable_time_computations: p_first.allow_variable_time_computations,
452 coefficients: coeffs,
453 coefficients_shoup: None,
454 has_lazy_coefficients: false,
455 })
456}
457
458#[cfg(test)]
459mod tests {
460 use itertools::{izip, Itertools};
461 use num_bigint::BigUint;
462 use rand::rng;
463
464 use super::dot_product;
465 use crate::{
466 rq::{Context, Poly, Representation},
467 zq::Modulus,
468 };
469 use std::{error::Error, sync::Arc};
470
471 static MODULI: &[u64; 3] = &[1153, 4611686018326724609, 4611686018309947393];
472
473 #[test]
474 fn add() -> Result<(), Box<dyn Error>> {
475 let mut rng = rng();
476 let n = 16;
477 for _ in 0..100 {
478 for modulus in MODULI {
479 let ctx = Arc::new(Context::new(&[*modulus], n)?);
480 let m = Modulus::new(*modulus).unwrap();
481
482 let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
483 let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
484 let r = &p + &q;
485 assert_eq!(r.representation, Representation::PowerBasis);
486 let mut a = Vec::<u64>::from(&p);
487 m.add_vec(&mut a, &Vec::<u64>::from(&q));
488 assert_eq!(Vec::<u64>::from(&r), a);
489
490 let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
491 let q = Poly::random(&ctx, Representation::Ntt, &mut rng);
492 let r = &p + &q;
493 assert_eq!(r.representation, Representation::Ntt);
494 let mut a = Vec::<u64>::from(&p);
495 m.add_vec(&mut a, &Vec::<u64>::from(&q));
496 assert_eq!(Vec::<u64>::from(&r), a);
497 }
498
499 let ctx = Arc::new(Context::new(MODULI, 16)?);
500 let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
501 let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
502 let mut a = Vec::<u64>::from(&p);
503 let b = Vec::<u64>::from(&q);
504 for i in 0..MODULI.len() {
505 let m = Modulus::new(MODULI[i]).unwrap();
506 m.add_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16])
507 }
508 let r = &p + &q;
509 assert_eq!(r.representation, Representation::PowerBasis);
510 assert_eq!(Vec::<u64>::from(&r), a);
511 }
512 Ok(())
513 }
514
515 #[test]
516 fn sub() -> Result<(), Box<dyn Error>> {
517 let mut rng = rng();
518 for _ in 0..100 {
519 for modulus in MODULI {
520 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
521 let m = Modulus::new(*modulus).unwrap();
522
523 let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
524 let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
525 let r = &p - &q;
526 assert_eq!(r.representation, Representation::PowerBasis);
527 let mut a = Vec::<u64>::from(&p);
528 m.sub_vec(&mut a, &Vec::<u64>::from(&q));
529 assert_eq!(Vec::<u64>::from(&r), a);
530
531 let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
532 let q = Poly::random(&ctx, Representation::Ntt, &mut rng);
533 let r = &p - &q;
534 assert_eq!(r.representation, Representation::Ntt);
535 let mut a = Vec::<u64>::from(&p);
536 m.sub_vec(&mut a, &Vec::<u64>::from(&q));
537 assert_eq!(Vec::<u64>::from(&r), a);
538 }
539
540 let ctx = Arc::new(Context::new(MODULI, 16)?);
541 let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
542 let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
543 let mut a = Vec::<u64>::from(&p);
544 let b = Vec::<u64>::from(&q);
545 for i in 0..MODULI.len() {
546 let m = Modulus::new(MODULI[i]).unwrap();
547 m.sub_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16])
548 }
549 let r = &p - &q;
550 assert_eq!(r.representation, Representation::PowerBasis);
551 assert_eq!(Vec::<u64>::from(&r), a);
552 }
553 Ok(())
554 }
555
556 #[test]
557 fn mul() -> Result<(), Box<dyn Error>> {
558 let mut rng = rng();
559 for _ in 0..100 {
560 for modulus in MODULI {
561 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
562 let m = Modulus::new(*modulus).unwrap();
563
564 let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
565 let q = Poly::random(&ctx, Representation::Ntt, &mut rng);
566 let r = &p * &q;
567 assert_eq!(r.representation, Representation::Ntt);
568 let mut a = Vec::<u64>::from(&p);
569 m.mul_vec(&mut a, &Vec::<u64>::from(&q));
570 assert_eq!(Vec::<u64>::from(&r), a);
571 }
572
573 let ctx = Arc::new(Context::new(MODULI, 16)?);
574 let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
575 let q = Poly::random(&ctx, Representation::Ntt, &mut rng);
576 let mut a = Vec::<u64>::from(&p);
577 let b = Vec::<u64>::from(&q);
578 for i in 0..MODULI.len() {
579 let m = Modulus::new(MODULI[i]).unwrap();
580 m.mul_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16])
581 }
582 let r = &p * &q;
583 assert_eq!(r.representation, Representation::Ntt);
584 assert_eq!(Vec::<u64>::from(&r), a);
585 }
586 Ok(())
587 }
588
589 #[test]
590 fn mul_shoup() -> Result<(), Box<dyn Error>> {
591 let mut rng = rng();
592 for _ in 0..100 {
593 for modulus in MODULI {
594 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
595 let m = Modulus::new(*modulus).unwrap();
596
597 let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
598 let q = Poly::random(&ctx, Representation::NttShoup, &mut rng);
599 let r = &p * &q;
600 assert_eq!(r.representation, Representation::Ntt);
601 let mut a = Vec::<u64>::from(&p);
602 m.mul_vec(&mut a, &Vec::<u64>::from(&q));
603 assert_eq!(Vec::<u64>::from(&r), a);
604 }
605
606 let ctx = Arc::new(Context::new(MODULI, 16)?);
607 let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
608 let q = Poly::random(&ctx, Representation::NttShoup, &mut rng);
609 let mut a = Vec::<u64>::from(&p);
610 let b = Vec::<u64>::from(&q);
611 for i in 0..MODULI.len() {
612 let m = Modulus::new(MODULI[i]).unwrap();
613 m.mul_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16])
614 }
615 let r = &p * &q;
616 assert_eq!(r.representation, Representation::Ntt);
617 assert_eq!(Vec::<u64>::from(&r), a);
618 }
619 Ok(())
620 }
621
622 #[test]
623 fn neg() -> Result<(), Box<dyn Error>> {
624 let mut rng = rng();
625 for _ in 0..100 {
626 for modulus in MODULI {
627 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
628 let m = Modulus::new(*modulus).unwrap();
629
630 let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
631 let r = -&p;
632 assert_eq!(r.representation, Representation::PowerBasis);
633 let mut a = Vec::<u64>::from(&p);
634 m.neg_vec(&mut a);
635 assert_eq!(Vec::<u64>::from(&r), a);
636
637 let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
638 let r = -&p;
639 assert_eq!(r.representation, Representation::Ntt);
640 let mut a = Vec::<u64>::from(&p);
641 m.neg_vec(&mut a);
642 assert_eq!(Vec::<u64>::from(&r), a);
643 }
644
645 let ctx = Arc::new(Context::new(MODULI, 16)?);
646 let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
647 let mut a = Vec::<u64>::from(&p);
648 for i in 0..MODULI.len() {
649 let m = Modulus::new(MODULI[i]).unwrap();
650 m.neg_vec(&mut a[i * 16..(i + 1) * 16])
651 }
652 let r = -&p;
653 assert_eq!(r.representation, Representation::PowerBasis);
654 assert_eq!(Vec::<u64>::from(&r), a);
655
656 let r = -p;
657 assert_eq!(r.representation, Representation::PowerBasis);
658 assert_eq!(Vec::<u64>::from(&r), a);
659 }
660 Ok(())
661 }
662
663 #[test]
664 fn test_dot_product() -> Result<(), Box<dyn Error>> {
665 let mut rng = rng();
666 for _ in 0..20 {
667 for modulus in MODULI {
668 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
669
670 for len in 1..50 {
671 let p = (0..len)
672 .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
673 .collect_vec();
674 let q = (0..len)
675 .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
676 .collect_vec();
677 let r = dot_product(p.iter(), q.iter())?;
678
679 let mut expected = Poly::zero(&ctx, Representation::Ntt);
680 izip!(&p, &q).for_each(|(pi, qi)| expected += &(pi * qi));
681 assert_eq!(r, expected);
682 }
683 }
684
685 let ctx = Arc::new(Context::new(MODULI, 16)?);
686 for len in 1..50 {
687 let p = (0..len)
688 .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
689 .collect_vec();
690 let q = (0..len)
691 .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
692 .collect_vec();
693 let r = dot_product(p.iter(), q.iter())?;
694
695 let mut expected = Poly::zero(&ctx, Representation::Ntt);
696 izip!(&p, &q).for_each(|(pi, qi)| expected += &(pi * qi));
697 assert_eq!(r, expected);
698 }
699 }
700 Ok(())
701 }
702
703 #[test]
704 fn mul_scalar() -> Result<(), Box<dyn Error>> {
705 let mut rng = rng();
706 for _ in 0..100 {
707 for modulus in MODULI {
708 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
709 let m = Modulus::new(*modulus).unwrap();
710
711 let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
713 let scalar = BigUint::from(42u64);
714 let r = &p * &scalar;
715 assert_eq!(r.representation, Representation::PowerBasis);
716 let mut expected = Vec::<u64>::from(&p);
717 m.scalar_mul_vec(&mut expected, 42u64);
718 assert_eq!(Vec::<u64>::from(&r), expected);
719
720 let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
722 let scalar = BigUint::from(123u64);
723 let r = &p * &scalar;
724 assert_eq!(r.representation, Representation::Ntt);
725 let mut expected = Vec::<u64>::from(&p);
726 m.scalar_mul_vec(&mut expected, 123u64);
727 assert_eq!(Vec::<u64>::from(&r), expected);
728 }
729
730 let ctx = Arc::new(Context::new(MODULI, 16)?);
731
732 let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
734 let scalar = BigUint::from(99u64);
735 let r = &p * &scalar;
736 assert_eq!(r.representation, Representation::PowerBasis);
737 let mut expected = Vec::<u64>::from(&p);
738 for i in 0..MODULI.len() {
739 let m = Modulus::new(MODULI[i]).unwrap();
740 m.scalar_mul_vec(&mut expected[i * 16..(i + 1) * 16], 99u64)
741 }
742 assert_eq!(Vec::<u64>::from(&r), expected);
743
744 let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
746 let scalar = BigUint::from(77u64);
747 let r = &p * &scalar;
748 assert_eq!(r.representation, Representation::Ntt);
749 let mut expected = Vec::<u64>::from(&p);
750 for i in 0..MODULI.len() {
751 let m = Modulus::new(MODULI[i]).unwrap();
752 m.scalar_mul_vec(&mut expected[i * 16..(i + 1) * 16], 77u64)
753 }
754 assert_eq!(Vec::<u64>::from(&r), expected);
755 }
756 Ok(())
757 }
758
759 #[test]
760 fn mul_scalar_large_crt() -> Result<(), Box<dyn Error>> {
761 let ctx = Arc::new(Context::new(MODULI, 16)?);
762
763 let q_prod = MODULI.iter().fold(BigUint::from(1u64), |acc, &m| acc * m);
765 let large_scalar = &q_prod + BigUint::from(12345u64);
766
767 let p = Poly::random(&ctx, Representation::Ntt, &mut rng());
768 let r = &p * &large_scalar;
769 assert_eq!(r.representation, Representation::Ntt);
770
771 let mut expected = Vec::<u64>::from(&p);
773 for i in 0..MODULI.len() {
774 let m = Modulus::new(MODULI[i]).unwrap();
775 let scalar_mod_qi = (&large_scalar % MODULI[i]).to_u64_digits()[0];
777 m.scalar_mul_vec(&mut expected[i * 16..(i + 1) * 16], scalar_mod_qi)
778 }
779 assert_eq!(Vec::<u64>::from(&r), expected);
780
781 Ok(())
782 }
783
784 #[test]
785 #[should_panic(
786 expected = "Cannot multiply a polynomial in NttShoup representation by a scalar"
787 )]
788 fn mul_scalar_ntt_shoup_panic() {
789 use num_bigint::BigUint;
790
791 let ctx = Arc::new(Context::new(MODULI, 16).unwrap());
792 let mut p = Poly::random(&ctx, Representation::NttShoup, &mut rng());
793 let scalar = BigUint::from(42u64);
794
795 p *= &scalar;
797 }
798}