fhe_math/rq/
ops.rs

1//! Implementation of operations over polynomials.
2
3use super::{Poly, Representation};
4use crate::{Error, Result};
5use itertools::{izip, Itertools};
6use ndarray::Array2;
7use num_bigint::BigUint;
8use std::{
9    cmp::min,
10    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
11};
12use zeroize::Zeroize;
13
14impl AddAssign<&Poly> for Poly {
15    fn add_assign(&mut self, p: &Poly) {
16        assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients);
17        assert_ne!(
18            self.representation,
19            Representation::NttShoup,
20            "Cannot add to a polynomial in NttShoup representation"
21        );
22        assert_eq!(
23            self.representation, p.representation,
24            "Incompatible representations"
25        );
26        debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts");
27        self.allow_variable_time_computations |= p.allow_variable_time_computations;
28        if self.allow_variable_time_computations {
29            izip!(
30                self.coefficients.outer_iter_mut(),
31                p.coefficients.outer_iter(),
32                self.ctx.q.iter()
33            )
34            .for_each(|(mut v1, v2, qi)| unsafe {
35                qi.add_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap())
36            });
37        } else {
38            izip!(
39                self.coefficients.outer_iter_mut(),
40                p.coefficients.outer_iter(),
41                self.ctx.q.iter()
42            )
43            .for_each(|(mut v1, v2, qi)| {
44                qi.add_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap())
45            });
46        }
47    }
48}
49
50impl Add<&Poly> for &Poly {
51    type Output = Poly;
52    fn add(self, p: &Poly) -> Poly {
53        let mut q = self.clone();
54        q += p;
55        q
56    }
57}
58
59impl Add for Poly {
60    type Output = Poly;
61    fn add(self, mut p: Poly) -> Poly {
62        p += &self;
63        p
64    }
65}
66
67impl SubAssign<&Poly> for Poly {
68    fn sub_assign(&mut self, p: &Poly) {
69        assert!(!self.has_lazy_coefficients && !p.has_lazy_coefficients);
70        assert_ne!(
71            self.representation,
72            Representation::NttShoup,
73            "Cannot subtract from a polynomial in NttShoup representation"
74        );
75        assert_eq!(
76            self.representation, p.representation,
77            "Incompatible representations"
78        );
79        debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts");
80        self.allow_variable_time_computations |= p.allow_variable_time_computations;
81        if self.allow_variable_time_computations {
82            izip!(
83                self.coefficients.outer_iter_mut(),
84                p.coefficients.outer_iter(),
85                self.ctx.q.iter()
86            )
87            .for_each(|(mut v1, v2, qi)| unsafe {
88                qi.sub_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap())
89            });
90        } else {
91            izip!(
92                self.coefficients.outer_iter_mut(),
93                p.coefficients.outer_iter(),
94                self.ctx.q.iter()
95            )
96            .for_each(|(mut v1, v2, qi)| {
97                qi.sub_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap())
98            });
99        }
100    }
101}
102
103impl Sub<&Poly> for &Poly {
104    type Output = Poly;
105    fn sub(self, p: &Poly) -> Poly {
106        let mut q = self.clone();
107        q -= p;
108        q
109    }
110}
111
112impl MulAssign<&Poly> for Poly {
113    fn mul_assign(&mut self, p: &Poly) {
114        assert!(!p.has_lazy_coefficients);
115        assert_ne!(
116            self.representation,
117            Representation::NttShoup,
118            "Cannot multiply to a polynomial in NttShoup representation"
119        );
120        if self.has_lazy_coefficients && self.representation == Representation::Ntt {
121            assert!(
122				p.representation == Representation::NttShoup,
123				"Can only multiply a polynomial with lazy coefficients by an NttShoup representation."
124			);
125        } else {
126            assert_eq!(
127                self.representation,
128                Representation::Ntt,
129                "Multiplication requires an Ntt representation."
130            );
131        }
132        debug_assert_eq!(self.ctx, p.ctx, "Incompatible contexts");
133        self.allow_variable_time_computations |= p.allow_variable_time_computations;
134
135        match p.representation {
136            Representation::Ntt => {
137                if self.allow_variable_time_computations {
138                    unsafe {
139                        izip!(
140                            self.coefficients.outer_iter_mut(),
141                            p.coefficients.outer_iter(),
142                            self.ctx.q.iter()
143                        )
144                        .for_each(|(mut v1, v2, qi)| {
145                            qi.mul_vec_vt(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap());
146                        });
147                    }
148                } else {
149                    izip!(
150                        self.coefficients.outer_iter_mut(),
151                        p.coefficients.outer_iter(),
152                        self.ctx.q.iter()
153                    )
154                    .for_each(|(mut v1, v2, qi)| {
155                        qi.mul_vec(v1.as_slice_mut().unwrap(), v2.as_slice().unwrap())
156                    });
157                }
158            }
159            Representation::NttShoup => {
160                if self.allow_variable_time_computations {
161                    izip!(
162                        self.coefficients.outer_iter_mut(),
163                        p.coefficients.outer_iter(),
164                        p.coefficients_shoup.as_ref().unwrap().outer_iter(),
165                        self.ctx.q.iter()
166                    )
167                    .for_each(|(mut v1, v2, v2_shoup, qi)| unsafe {
168                        qi.mul_shoup_vec_vt(
169                            v1.as_slice_mut().unwrap(),
170                            v2.as_slice().unwrap(),
171                            v2_shoup.as_slice().unwrap(),
172                        )
173                    });
174                } else {
175                    izip!(
176                        self.coefficients.outer_iter_mut(),
177                        p.coefficients.outer_iter(),
178                        p.coefficients_shoup.as_ref().unwrap().outer_iter(),
179                        self.ctx.q.iter()
180                    )
181                    .for_each(|(mut v1, v2, v2_shoup, qi)| {
182                        qi.mul_shoup_vec(
183                            v1.as_slice_mut().unwrap(),
184                            v2.as_slice().unwrap(),
185                            v2_shoup.as_slice().unwrap(),
186                        )
187                    });
188                }
189                self.has_lazy_coefficients = false
190            }
191            _ => {
192                panic!("Multiplication requires a multipliand in Ntt or NttShoup representation.")
193            }
194        }
195    }
196}
197
198impl MulAssign<&BigUint> for Poly {
199    fn mul_assign(&mut self, p: &BigUint) {
200        assert_ne!(
201            self.representation,
202            Representation::NttShoup,
203            "Cannot multiply a polynomial in NttShoup representation by a scalar"
204        );
205
206        // Project the scalar into its CRT representation (reduced modulo each prime)
207        let scalar_crt = self.ctx.rns.project(p);
208
209        if self.allow_variable_time_computations {
210            unsafe {
211                izip!(
212                    self.coefficients.outer_iter_mut(),
213                    scalar_crt.iter(),
214                    self.ctx.q.iter()
215                )
216                .for_each(|(mut v1, scalar_qi, qi)| {
217                    qi.scalar_mul_vec_vt(v1.as_slice_mut().unwrap(), *scalar_qi)
218                });
219            }
220        } else {
221            izip!(
222                self.coefficients.outer_iter_mut(),
223                scalar_crt.iter(),
224                self.ctx.q.iter()
225            )
226            .for_each(|(mut v1, scalar_qi, qi)| {
227                qi.scalar_mul_vec(v1.as_slice_mut().unwrap(), *scalar_qi)
228            });
229        }
230    }
231}
232
233impl Mul<&Poly> for &Poly {
234    type Output = Poly;
235    fn mul(self, p: &Poly) -> Poly {
236        match self.representation {
237            Representation::NttShoup => {
238                // TODO: To test, and do the same thing for add, sub, and neg
239                let mut q = p.clone();
240                if q.representation == Representation::NttShoup {
241                    q.coefficients_shoup
242                        .as_mut()
243                        .unwrap()
244                        .as_slice_mut()
245                        .unwrap()
246                        .zeroize();
247                    unsafe { q.override_representation(Representation::Ntt) }
248                }
249                q *= self;
250                q
251            }
252            _ => {
253                let mut q = self.clone();
254                q *= p;
255                q
256            }
257        }
258    }
259}
260
261impl Mul<&BigUint> for &Poly {
262    type Output = Poly;
263    fn mul(self, p: &BigUint) -> Poly {
264        let mut q = self.clone();
265        q *= p;
266        q
267    }
268}
269
270impl Mul<&Poly> for &BigUint {
271    type Output = Poly;
272    fn mul(self, p: &Poly) -> Poly {
273        p * self
274    }
275}
276
277impl Neg for &Poly {
278    type Output = Poly;
279
280    fn neg(self) -> Poly {
281        assert!(!self.has_lazy_coefficients);
282        let mut out = self.clone();
283        if self.allow_variable_time_computations {
284            izip!(out.coefficients.outer_iter_mut(), out.ctx.q.iter())
285                .for_each(|(mut v1, qi)| unsafe { qi.neg_vec_vt(v1.as_slice_mut().unwrap()) });
286        } else {
287            izip!(out.coefficients.outer_iter_mut(), out.ctx.q.iter())
288                .for_each(|(mut v1, qi)| qi.neg_vec(v1.as_slice_mut().unwrap()));
289        }
290        out
291    }
292}
293
294impl Neg for Poly {
295    type Output = Poly;
296
297    fn neg(mut self) -> Poly {
298        assert!(!self.has_lazy_coefficients);
299        if self.allow_variable_time_computations {
300            izip!(self.coefficients.outer_iter_mut(), self.ctx.q.iter())
301                .for_each(|(mut v1, qi)| unsafe { qi.neg_vec_vt(v1.as_slice_mut().unwrap()) });
302        } else {
303            izip!(self.coefficients.outer_iter_mut(), self.ctx.q.iter())
304                .for_each(|(mut v1, qi)| qi.neg_vec(v1.as_slice_mut().unwrap()));
305        }
306        self
307    }
308}
309
310/// Computes the Fused-Mul-Add operation `out[i] += x[i] * y[i]`
311unsafe fn fma(out: &mut [u128], x: &[u64], y: &[u64]) {
312    let n = out.len();
313    assert_eq!(x.len(), n);
314    assert_eq!(y.len(), n);
315
316    macro_rules! fma_at {
317        ($idx:expr) => {
318            *out.get_unchecked_mut($idx) +=
319                (*x.get_unchecked($idx) as u128) * (*y.get_unchecked($idx) as u128);
320        };
321    }
322
323    let r = n / 16;
324    for i in 0..r {
325        fma_at!(16 * i);
326        fma_at!(16 * i + 1);
327        fma_at!(16 * i + 2);
328        fma_at!(16 * i + 3);
329        fma_at!(16 * i + 4);
330        fma_at!(16 * i + 5);
331        fma_at!(16 * i + 6);
332        fma_at!(16 * i + 7);
333        fma_at!(16 * i + 8);
334        fma_at!(16 * i + 9);
335        fma_at!(16 * i + 10);
336        fma_at!(16 * i + 11);
337        fma_at!(16 * i + 12);
338        fma_at!(16 * i + 13);
339        fma_at!(16 * i + 14);
340        fma_at!(16 * i + 15);
341    }
342
343    for i in 0..n % 16 {
344        fma_at!(16 * r + i);
345    }
346}
347
348/// Compute the dot product between two iterators of polynomials.
349/// Returna an error if the iterator counts are 0, or if any of the polynomial
350/// is not in Ntt or NttShoup representation.
351pub fn dot_product<'a, 'b, I, J>(p: I, q: J) -> Result<Poly>
352where
353    I: Iterator<Item = &'a Poly> + Clone,
354    J: Iterator<Item = &'b Poly> + Clone,
355{
356    debug_assert!(!p
357        .clone()
358        .any(|pi| pi.representation == Representation::PowerBasis));
359    debug_assert!(!q
360        .clone()
361        .any(|qi| qi.representation == Representation::PowerBasis));
362
363    let count = min(p.clone().count(), q.clone().count());
364    if count == 0 {
365        return Err(Error::Default("At least one iterator is empty".to_string()));
366    }
367
368    let p_first = p.clone().next().unwrap();
369
370    // Initialize the accumulator
371    let mut acc: Array2<u128> = Array2::zeros((p_first.ctx.q.len(), p_first.ctx.degree));
372    let acc_ptr = acc.as_mut_ptr();
373
374    // Current number of products accumulated
375    let mut num_acc = vec![1u128; p_first.ctx.q.len()];
376    let num_acc_ptr = num_acc.as_mut_ptr();
377
378    // Maximum number of products that can be accumulated
379    let max_acc = p_first
380        .ctx
381        .q
382        .iter()
383        .map(|qi| 1u128 << (2 * (*qi).leading_zeros()))
384        .collect_vec();
385    let max_acc_ptr = max_acc.as_ptr();
386
387    let q_ptr = p_first.ctx.q.as_ptr();
388    let degree = p_first.ctx.degree as isize;
389
390    let min_of_max = max_acc.iter().min().unwrap();
391
392    let out_slice = acc.as_slice_mut().unwrap();
393    if count as u128 > *min_of_max {
394        for (pi, qi) in izip!(p, q) {
395            let pij = pi.coefficients();
396            let qij = qi.coefficients();
397            let pi_slice = pij.as_slice().unwrap();
398            let qi_slice = qij.as_slice().unwrap();
399            unsafe {
400                fma(out_slice, pi_slice, qi_slice);
401
402                for j in 0..p_first.ctx.q.len() as isize {
403                    let qj = &*q_ptr.offset(j);
404                    *num_acc_ptr.offset(j) += 1;
405                    if *num_acc_ptr.offset(j) == *max_acc_ptr.offset(j) {
406                        if p_first.allow_variable_time_computations {
407                            for i in j * degree..(j + 1) * degree {
408                                *acc_ptr.offset(i) = qj.reduce_u128_vt(*acc_ptr.offset(i)) as u128;
409                            }
410                        } else {
411                            for i in j * degree..(j + 1) * degree {
412                                *acc_ptr.offset(i) = qj.reduce_u128(*acc_ptr.offset(i)) as u128;
413                            }
414                        }
415                        *num_acc_ptr.offset(j) = 1;
416                    }
417                }
418            }
419        }
420    } else {
421        // We don't need to check the condition on the max, it should shave off a few
422        // cycles.
423        for (pi, qi) in izip!(p, q) {
424            let pij = pi.coefficients();
425            let qij = qi.coefficients();
426            let pi_slice = pij.as_slice().unwrap();
427            let qi_slice = qij.as_slice().unwrap();
428            unsafe { fma(out_slice, pi_slice, qi_slice) }
429        }
430    }
431    // Last reduction to create the coefficients
432    let mut coeffs: Array2<u64> = Array2::zeros((p_first.ctx.q.len(), p_first.ctx.degree));
433    izip!(
434        coeffs.outer_iter_mut(),
435        acc.outer_iter(),
436        p_first.ctx.q.iter()
437    )
438    .for_each(|(mut coeffsj, accj, m)| {
439        if p_first.allow_variable_time_computations {
440            izip!(coeffsj.iter_mut(), accj.iter())
441                .for_each(|(cj, accjk)| *cj = unsafe { m.reduce_u128_vt(*accjk) });
442        } else {
443            izip!(coeffsj.iter_mut(), accj.iter())
444                .for_each(|(cj, accjk)| *cj = m.reduce_u128(*accjk));
445        }
446    });
447
448    Ok(Poly {
449        ctx: p_first.ctx.clone(),
450        representation: Representation::Ntt,
451        allow_variable_time_computations: p_first.allow_variable_time_computations,
452        coefficients: coeffs,
453        coefficients_shoup: None,
454        has_lazy_coefficients: false,
455    })
456}
457
458#[cfg(test)]
459mod tests {
460    use itertools::{izip, Itertools};
461    use num_bigint::BigUint;
462    use rand::rng;
463
464    use super::dot_product;
465    use crate::{
466        rq::{Context, Poly, Representation},
467        zq::Modulus,
468    };
469    use std::{error::Error, sync::Arc};
470
471    static MODULI: &[u64; 3] = &[1153, 4611686018326724609, 4611686018309947393];
472
473    #[test]
474    fn add() -> Result<(), Box<dyn Error>> {
475        let mut rng = rng();
476        let n = 16;
477        for _ in 0..100 {
478            for modulus in MODULI {
479                let ctx = Arc::new(Context::new(&[*modulus], n)?);
480                let m = Modulus::new(*modulus).unwrap();
481
482                let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
483                let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
484                let r = &p + &q;
485                assert_eq!(r.representation, Representation::PowerBasis);
486                let mut a = Vec::<u64>::from(&p);
487                m.add_vec(&mut a, &Vec::<u64>::from(&q));
488                assert_eq!(Vec::<u64>::from(&r), a);
489
490                let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
491                let q = Poly::random(&ctx, Representation::Ntt, &mut rng);
492                let r = &p + &q;
493                assert_eq!(r.representation, Representation::Ntt);
494                let mut a = Vec::<u64>::from(&p);
495                m.add_vec(&mut a, &Vec::<u64>::from(&q));
496                assert_eq!(Vec::<u64>::from(&r), a);
497            }
498
499            let ctx = Arc::new(Context::new(MODULI, 16)?);
500            let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
501            let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
502            let mut a = Vec::<u64>::from(&p);
503            let b = Vec::<u64>::from(&q);
504            for i in 0..MODULI.len() {
505                let m = Modulus::new(MODULI[i]).unwrap();
506                m.add_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16])
507            }
508            let r = &p + &q;
509            assert_eq!(r.representation, Representation::PowerBasis);
510            assert_eq!(Vec::<u64>::from(&r), a);
511        }
512        Ok(())
513    }
514
515    #[test]
516    fn sub() -> Result<(), Box<dyn Error>> {
517        let mut rng = rng();
518        for _ in 0..100 {
519            for modulus in MODULI {
520                let ctx = Arc::new(Context::new(&[*modulus], 16)?);
521                let m = Modulus::new(*modulus).unwrap();
522
523                let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
524                let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
525                let r = &p - &q;
526                assert_eq!(r.representation, Representation::PowerBasis);
527                let mut a = Vec::<u64>::from(&p);
528                m.sub_vec(&mut a, &Vec::<u64>::from(&q));
529                assert_eq!(Vec::<u64>::from(&r), a);
530
531                let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
532                let q = Poly::random(&ctx, Representation::Ntt, &mut rng);
533                let r = &p - &q;
534                assert_eq!(r.representation, Representation::Ntt);
535                let mut a = Vec::<u64>::from(&p);
536                m.sub_vec(&mut a, &Vec::<u64>::from(&q));
537                assert_eq!(Vec::<u64>::from(&r), a);
538            }
539
540            let ctx = Arc::new(Context::new(MODULI, 16)?);
541            let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
542            let q = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
543            let mut a = Vec::<u64>::from(&p);
544            let b = Vec::<u64>::from(&q);
545            for i in 0..MODULI.len() {
546                let m = Modulus::new(MODULI[i]).unwrap();
547                m.sub_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16])
548            }
549            let r = &p - &q;
550            assert_eq!(r.representation, Representation::PowerBasis);
551            assert_eq!(Vec::<u64>::from(&r), a);
552        }
553        Ok(())
554    }
555
556    #[test]
557    fn mul() -> Result<(), Box<dyn Error>> {
558        let mut rng = rng();
559        for _ in 0..100 {
560            for modulus in MODULI {
561                let ctx = Arc::new(Context::new(&[*modulus], 16)?);
562                let m = Modulus::new(*modulus).unwrap();
563
564                let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
565                let q = Poly::random(&ctx, Representation::Ntt, &mut rng);
566                let r = &p * &q;
567                assert_eq!(r.representation, Representation::Ntt);
568                let mut a = Vec::<u64>::from(&p);
569                m.mul_vec(&mut a, &Vec::<u64>::from(&q));
570                assert_eq!(Vec::<u64>::from(&r), a);
571            }
572
573            let ctx = Arc::new(Context::new(MODULI, 16)?);
574            let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
575            let q = Poly::random(&ctx, Representation::Ntt, &mut rng);
576            let mut a = Vec::<u64>::from(&p);
577            let b = Vec::<u64>::from(&q);
578            for i in 0..MODULI.len() {
579                let m = Modulus::new(MODULI[i]).unwrap();
580                m.mul_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16])
581            }
582            let r = &p * &q;
583            assert_eq!(r.representation, Representation::Ntt);
584            assert_eq!(Vec::<u64>::from(&r), a);
585        }
586        Ok(())
587    }
588
589    #[test]
590    fn mul_shoup() -> Result<(), Box<dyn Error>> {
591        let mut rng = rng();
592        for _ in 0..100 {
593            for modulus in MODULI {
594                let ctx = Arc::new(Context::new(&[*modulus], 16)?);
595                let m = Modulus::new(*modulus).unwrap();
596
597                let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
598                let q = Poly::random(&ctx, Representation::NttShoup, &mut rng);
599                let r = &p * &q;
600                assert_eq!(r.representation, Representation::Ntt);
601                let mut a = Vec::<u64>::from(&p);
602                m.mul_vec(&mut a, &Vec::<u64>::from(&q));
603                assert_eq!(Vec::<u64>::from(&r), a);
604            }
605
606            let ctx = Arc::new(Context::new(MODULI, 16)?);
607            let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
608            let q = Poly::random(&ctx, Representation::NttShoup, &mut rng);
609            let mut a = Vec::<u64>::from(&p);
610            let b = Vec::<u64>::from(&q);
611            for i in 0..MODULI.len() {
612                let m = Modulus::new(MODULI[i]).unwrap();
613                m.mul_vec(&mut a[i * 16..(i + 1) * 16], &b[i * 16..(i + 1) * 16])
614            }
615            let r = &p * &q;
616            assert_eq!(r.representation, Representation::Ntt);
617            assert_eq!(Vec::<u64>::from(&r), a);
618        }
619        Ok(())
620    }
621
622    #[test]
623    fn neg() -> Result<(), Box<dyn Error>> {
624        let mut rng = rng();
625        for _ in 0..100 {
626            for modulus in MODULI {
627                let ctx = Arc::new(Context::new(&[*modulus], 16)?);
628                let m = Modulus::new(*modulus).unwrap();
629
630                let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
631                let r = -&p;
632                assert_eq!(r.representation, Representation::PowerBasis);
633                let mut a = Vec::<u64>::from(&p);
634                m.neg_vec(&mut a);
635                assert_eq!(Vec::<u64>::from(&r), a);
636
637                let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
638                let r = -&p;
639                assert_eq!(r.representation, Representation::Ntt);
640                let mut a = Vec::<u64>::from(&p);
641                m.neg_vec(&mut a);
642                assert_eq!(Vec::<u64>::from(&r), a);
643            }
644
645            let ctx = Arc::new(Context::new(MODULI, 16)?);
646            let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
647            let mut a = Vec::<u64>::from(&p);
648            for i in 0..MODULI.len() {
649                let m = Modulus::new(MODULI[i]).unwrap();
650                m.neg_vec(&mut a[i * 16..(i + 1) * 16])
651            }
652            let r = -&p;
653            assert_eq!(r.representation, Representation::PowerBasis);
654            assert_eq!(Vec::<u64>::from(&r), a);
655
656            let r = -p;
657            assert_eq!(r.representation, Representation::PowerBasis);
658            assert_eq!(Vec::<u64>::from(&r), a);
659        }
660        Ok(())
661    }
662
663    #[test]
664    fn test_dot_product() -> Result<(), Box<dyn Error>> {
665        let mut rng = rng();
666        for _ in 0..20 {
667            for modulus in MODULI {
668                let ctx = Arc::new(Context::new(&[*modulus], 16)?);
669
670                for len in 1..50 {
671                    let p = (0..len)
672                        .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
673                        .collect_vec();
674                    let q = (0..len)
675                        .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
676                        .collect_vec();
677                    let r = dot_product(p.iter(), q.iter())?;
678
679                    let mut expected = Poly::zero(&ctx, Representation::Ntt);
680                    izip!(&p, &q).for_each(|(pi, qi)| expected += &(pi * qi));
681                    assert_eq!(r, expected);
682                }
683            }
684
685            let ctx = Arc::new(Context::new(MODULI, 16)?);
686            for len in 1..50 {
687                let p = (0..len)
688                    .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
689                    .collect_vec();
690                let q = (0..len)
691                    .map(|_| Poly::random(&ctx, Representation::Ntt, &mut rng))
692                    .collect_vec();
693                let r = dot_product(p.iter(), q.iter())?;
694
695                let mut expected = Poly::zero(&ctx, Representation::Ntt);
696                izip!(&p, &q).for_each(|(pi, qi)| expected += &(pi * qi));
697                assert_eq!(r, expected);
698            }
699        }
700        Ok(())
701    }
702
703    #[test]
704    fn mul_scalar() -> Result<(), Box<dyn Error>> {
705        let mut rng = rng();
706        for _ in 0..100 {
707            for modulus in MODULI {
708                let ctx = Arc::new(Context::new(&[*modulus], 16)?);
709                let m = Modulus::new(*modulus).unwrap();
710
711                // Test with PowerBasis representation
712                let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
713                let scalar = BigUint::from(42u64);
714                let r = &p * &scalar;
715                assert_eq!(r.representation, Representation::PowerBasis);
716                let mut expected = Vec::<u64>::from(&p);
717                m.scalar_mul_vec(&mut expected, 42u64);
718                assert_eq!(Vec::<u64>::from(&r), expected);
719
720                // Test with NTT representation
721                let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
722                let scalar = BigUint::from(123u64);
723                let r = &p * &scalar;
724                assert_eq!(r.representation, Representation::Ntt);
725                let mut expected = Vec::<u64>::from(&p);
726                m.scalar_mul_vec(&mut expected, 123u64);
727                assert_eq!(Vec::<u64>::from(&r), expected);
728            }
729
730            let ctx = Arc::new(Context::new(MODULI, 16)?);
731
732            // Test with PowerBasis representation
733            let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
734            let scalar = BigUint::from(99u64);
735            let r = &p * &scalar;
736            assert_eq!(r.representation, Representation::PowerBasis);
737            let mut expected = Vec::<u64>::from(&p);
738            for i in 0..MODULI.len() {
739                let m = Modulus::new(MODULI[i]).unwrap();
740                m.scalar_mul_vec(&mut expected[i * 16..(i + 1) * 16], 99u64)
741            }
742            assert_eq!(Vec::<u64>::from(&r), expected);
743
744            // Test with NTT representation
745            let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
746            let scalar = BigUint::from(77u64);
747            let r = &p * &scalar;
748            assert_eq!(r.representation, Representation::Ntt);
749            let mut expected = Vec::<u64>::from(&p);
750            for i in 0..MODULI.len() {
751                let m = Modulus::new(MODULI[i]).unwrap();
752                m.scalar_mul_vec(&mut expected[i * 16..(i + 1) * 16], 77u64)
753            }
754            assert_eq!(Vec::<u64>::from(&r), expected);
755        }
756        Ok(())
757    }
758
759    #[test]
760    fn mul_scalar_large_crt() -> Result<(), Box<dyn Error>> {
761        let ctx = Arc::new(Context::new(MODULI, 16)?);
762
763        // Create a large scalar that exceeds the max modulus
764        let q_prod = MODULI.iter().fold(BigUint::from(1u64), |acc, &m| acc * m);
765        let large_scalar = &q_prod + BigUint::from(12345u64);
766
767        let p = Poly::random(&ctx, Representation::Ntt, &mut rng());
768        let r = &p * &large_scalar;
769        assert_eq!(r.representation, Representation::Ntt);
770
771        // Verify by computing the expected result manually for each modulus
772        let mut expected = Vec::<u64>::from(&p);
773        for i in 0..MODULI.len() {
774            let m = Modulus::new(MODULI[i]).unwrap();
775            // Reduce the large scalar modulo this prime
776            let scalar_mod_qi = (&large_scalar % MODULI[i]).to_u64_digits()[0];
777            m.scalar_mul_vec(&mut expected[i * 16..(i + 1) * 16], scalar_mod_qi)
778        }
779        assert_eq!(Vec::<u64>::from(&r), expected);
780
781        Ok(())
782    }
783
784    #[test]
785    #[should_panic(
786        expected = "Cannot multiply a polynomial in NttShoup representation by a scalar"
787    )]
788    fn mul_scalar_ntt_shoup_panic() {
789        use num_bigint::BigUint;
790
791        let ctx = Arc::new(Context::new(MODULI, 16).unwrap());
792        let mut p = Poly::random(&ctx, Representation::NttShoup, &mut rng());
793        let scalar = BigUint::from(42u64);
794
795        // This should panic with the assertion message
796        p *= &scalar;
797    }
798}