fhe_math/rq/
mod.rs

1#![warn(missing_docs, unused_imports)]
2
3//! Polynomials in R_q\[x\] = (ZZ_q1 x ... x ZZ_qn)\[x\] where the qi's are
4//! prime moduli in zq.
5
6mod context;
7mod convert;
8mod ops;
9mod serialize;
10
11pub mod scaler;
12pub mod switcher;
13pub mod traits;
14use self::{scaler::Scaler, switcher::Switcher, traits::TryConvertFrom};
15use crate::{zq::Modulus, Error, Result};
16pub use context::Context;
17use fhe_util::sample_vec_cbd;
18use itertools::{izip, Itertools};
19use ndarray::{s, Array2, ArrayView2, Axis};
20pub use ops::dot_product;
21use rand::{CryptoRng, RngCore, SeedableRng};
22use rand_chacha::ChaCha8Rng;
23use sha2::{Digest, Sha256};
24use std::sync::Arc;
25use zeroize::{Zeroize, Zeroizing};
26
27/// Possible representations of the underlying polynomial.
28#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
29#[non_exhaustive]
30pub enum Representation {
31    /// This is the list of coefficients ci, such that the polynomial is c0 + c1
32    /// * x + ... + c_(degree - 1) * x^(degree - 1)
33    #[default]
34    PowerBasis,
35    /// This is the NTT representation of the PowerBasis representation.
36    Ntt,
37    /// This is a "Shoup" representation of the Ntt representation used for
38    /// faster multiplication.
39    NttShoup,
40}
41
42/// An exponent for a substitution.
43#[derive(Debug, PartialEq, Eq)]
44pub struct SubstitutionExponent {
45    /// The value of the exponent.
46    pub exponent: usize,
47
48    ctx: Arc<Context>,
49    power_bitrev: Vec<usize>,
50}
51
52impl SubstitutionExponent {
53    /// Creates a substitution element from an exponent.
54    /// Returns an error if the exponent is even modulo 2 * degree.
55    pub fn new(ctx: &Arc<Context>, exponent: usize) -> Result<Self> {
56        let exponent = exponent % (2 * ctx.degree);
57        if exponent & 1 == 0 {
58            return Err(Error::Default(
59                "The exponent should be odd modulo 2 * degree".to_string(),
60            ));
61        }
62        let mut power = (exponent - 1) / 2;
63        let mask = ctx.degree - 1;
64        let power_bitrev = (0..ctx.degree)
65            .map(|_| {
66                let r = (power & mask).reverse_bits() >> (ctx.degree.leading_zeros() + 1);
67                power += exponent;
68                r
69            })
70            .collect_vec();
71        Ok(Self {
72            ctx: ctx.clone(),
73            exponent,
74            power_bitrev,
75        })
76    }
77}
78
79/// Struct that holds a polynomial for a specific context.
80#[derive(Default, Debug, Clone, PartialEq, Eq)]
81pub struct Poly {
82    ctx: Arc<Context>,
83    representation: Representation,
84    has_lazy_coefficients: bool,
85    allow_variable_time_computations: bool,
86    coefficients: Array2<u64>,
87    coefficients_shoup: Option<Array2<u64>>,
88}
89
90// Implements zeroization of polynomials
91impl Zeroize for Poly {
92    fn zeroize(&mut self) {
93        if let Some(coeffs) = self.coefficients.as_slice_mut() {
94            coeffs.zeroize()
95        }
96        self.zeroize_shoup()
97    }
98}
99
100impl AsRef<Poly> for Poly {
101    fn as_ref(&self) -> &Poly {
102        self
103    }
104}
105
106impl AsMut<Poly> for Poly {
107    fn as_mut(&mut self) -> &mut Poly {
108        self
109    }
110}
111
112impl Poly {
113    /// Creates a polynomial holding the constant 0.
114    pub fn zero(ctx: &Arc<Context>, representation: Representation) -> Self {
115        Self {
116            ctx: ctx.clone(),
117            representation,
118            allow_variable_time_computations: false,
119            has_lazy_coefficients: false,
120            coefficients: Array2::zeros((ctx.q.len(), ctx.degree)),
121            coefficients_shoup: if representation == Representation::NttShoup {
122                Some(Array2::zeros((ctx.q.len(), ctx.degree)))
123            } else {
124                None
125            },
126        }
127    }
128
129    /// Enable variable time computations when this polynomial is involved.
130    ///
131    /// # Safety
132    ///
133    /// By default, this is marked as unsafe, but is usually safe when only
134    /// public data is processed.
135    pub unsafe fn allow_variable_time_computations(&mut self) {
136        self.allow_variable_time_computations = true
137    }
138
139    /// Disable variable time computations when this polynomial is involved.
140    pub fn disallow_variable_time_computations(&mut self) {
141        self.allow_variable_time_computations = false
142    }
143
144    /// Current representation of the polynomial.
145    pub const fn representation(&self) -> &Representation {
146        &self.representation
147    }
148
149    /// Zeroize the shoup coefficients
150    fn zeroize_shoup(&mut self) {
151        if let Some(coeffs_shoup) = self
152            .coefficients_shoup
153            .as_mut()
154            .and_then(|f| f.as_slice_mut())
155        {
156            coeffs_shoup.zeroize()
157        }
158    }
159
160    /// Change the representation of the underlying polynomial.
161    pub fn change_representation(&mut self, to: Representation) {
162        if self.representation == to {
163            return;
164        }
165
166        match (&self.representation, &to) {
167            (Representation::PowerBasis, Representation::Ntt) => self.ntt_forward(),
168            (Representation::PowerBasis, Representation::NttShoup) => {
169                self.ntt_forward();
170                self.compute_coefficients_shoup()
171            }
172            (Representation::Ntt, Representation::PowerBasis) => self.ntt_backward(),
173            (Representation::Ntt, Representation::NttShoup) => self.compute_coefficients_shoup(),
174            (Representation::NttShoup, Representation::PowerBasis) => {
175                self.zeroize_shoup();
176                self.coefficients_shoup = None;
177                self.ntt_backward()
178            }
179            (Representation::NttShoup, Representation::Ntt) => {
180                self.zeroize_shoup();
181                self.coefficients_shoup = None;
182            }
183            _ => unreachable!(),
184        }
185
186        self.representation = to;
187    }
188
189    /// Compute the Shoup representation of the coefficients.
190    fn compute_coefficients_shoup(&mut self) {
191        let mut coefficients_shoup = Array2::zeros((self.ctx.q.len(), self.ctx.degree));
192        izip!(
193            coefficients_shoup.outer_iter_mut(),
194            self.coefficients.outer_iter(),
195            self.ctx.q.iter()
196        )
197        .for_each(|(mut v_shoup, v, qi)| {
198            v_shoup
199                .as_slice_mut()
200                .unwrap()
201                .copy_from_slice(&qi.shoup_vec(v.as_slice().unwrap()))
202        });
203        self.coefficients_shoup = Some(coefficients_shoup)
204    }
205
206    /// Override the internal representation to a given representation.
207    ///
208    /// # Safety
209    ///
210    /// Prefer the `change_representation` function to safely modify the
211    /// polynomial representation. If the `to` representation is NttShoup, the
212    /// coefficients are still computed correctly to avoid being in an unstable
213    /// state. If we override a polynomial with Shoup coefficients, we zeroize
214    /// them.
215    pub unsafe fn override_representation(&mut self, to: Representation) {
216        if self.coefficients_shoup.is_some() {
217            self.zeroize_shoup();
218            self.coefficients_shoup = None
219        }
220        if to == Representation::NttShoup {
221            self.compute_coefficients_shoup()
222        }
223        self.representation = to;
224    }
225
226    /// Generate a random polynomial.
227    pub fn random<R: RngCore + CryptoRng>(
228        ctx: &Arc<Context>,
229        representation: Representation,
230        rng: &mut R,
231    ) -> Self {
232        let mut p = Poly::zero(ctx, representation);
233        izip!(p.coefficients.outer_iter_mut(), ctx.q.iter()).for_each(|(mut v, qi)| {
234            v.as_slice_mut()
235                .unwrap()
236                .copy_from_slice(&qi.random_vec(ctx.degree, rng))
237        });
238        if p.representation == Representation::NttShoup {
239            p.compute_coefficients_shoup()
240        }
241        p
242    }
243
244    /// Generate a random polynomial deterministically from a seed.
245    pub fn random_from_seed(
246        ctx: &Arc<Context>,
247        representation: Representation,
248        seed: <ChaCha8Rng as SeedableRng>::Seed,
249    ) -> Self {
250        // Let's hash the seed into a ChaCha8Rng seed.
251        let mut hasher = Sha256::new();
252        hasher.update(seed);
253        let mut prng =
254            ChaCha8Rng::from_seed(<ChaCha8Rng as SeedableRng>::Seed::from(hasher.finalize()));
255        let mut p = Poly::zero(ctx, representation);
256        izip!(p.coefficients.outer_iter_mut(), ctx.q.iter()).for_each(|(mut v, qi)| {
257            v.as_slice_mut()
258                .unwrap()
259                .copy_from_slice(&qi.random_vec(ctx.degree, &mut prng))
260        });
261        if p.representation == Representation::NttShoup {
262            p.compute_coefficients_shoup()
263        }
264        p
265    }
266
267    /// Generate a small polynomial and convert into the specified
268    /// representation.
269    ///
270    /// Returns an error if the variance does not belong to [1, ..., 16].
271    pub fn small<T: RngCore + CryptoRng>(
272        ctx: &Arc<Context>,
273        representation: Representation,
274        variance: usize,
275        rng: &mut T,
276    ) -> Result<Self> {
277        if !(1..=16).contains(&variance) {
278            return Err(Error::Default(
279                "The variance should be an integer between 1 and 16".to_string(),
280            ));
281        }
282
283        let coeffs = Zeroizing::new(
284            sample_vec_cbd(ctx.degree, variance, rng).map_err(|e| Error::Default(e.to_string()))?,
285        );
286        let mut p = Poly::try_convert_from(
287            coeffs.as_ref() as &[i64],
288            ctx,
289            false,
290            Representation::PowerBasis,
291        )?;
292        if representation != Representation::PowerBasis {
293            p.change_representation(representation);
294        }
295        Ok(p)
296    }
297
298    /// Access the polynomial coefficients in RNS representation.
299    pub fn coefficients(&self) -> ArrayView2<'_, u64> {
300        self.coefficients.view()
301    }
302
303    /// Computes the forward Ntt on the coefficients
304    fn ntt_forward(&mut self) {
305        if self.allow_variable_time_computations {
306            izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter())
307                .for_each(|(mut v, op)| unsafe { op.forward_vt(v.as_mut_ptr()) });
308        } else {
309            izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter())
310                .for_each(|(mut v, op)| op.forward(v.as_slice_mut().unwrap()));
311        }
312    }
313
314    /// Computes the backward Ntt on the coefficients
315    fn ntt_backward(&mut self) {
316        if self.allow_variable_time_computations {
317            izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter())
318                .for_each(|(mut v, op)| unsafe { op.backward_vt(v.as_mut_ptr()) });
319        } else {
320            izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter())
321                .for_each(|(mut v, op)| op.backward(v.as_slice_mut().unwrap()));
322        }
323    }
324
325    /// Substitute x by x^i in a polynomial.
326    /// In PowerBasis representation, i can be any integer that is not a
327    /// multiple of 2 * degree. In Ntt and NttShoup representation, i can be any
328    /// odd integer that is not a multiple of 2 * degree.
329    pub fn substitute(&self, i: &SubstitutionExponent) -> Result<Poly> {
330        let mut q = Poly::zero(&self.ctx, self.representation);
331        if self.allow_variable_time_computations {
332            unsafe { q.allow_variable_time_computations() }
333        }
334        match self.representation {
335            Representation::Ntt | Representation::NttShoup => {
336                izip!(
337                    q.coefficients.outer_iter_mut(),
338                    self.coefficients.outer_iter()
339                )
340                .for_each(|(mut q_row, p_row)| {
341                    for (j, k) in izip!(self.ctx.bitrev.iter(), i.power_bitrev.iter()) {
342                        q_row[*j] = p_row[*k]
343                    }
344                });
345                if self.representation == Representation::NttShoup {
346                    izip!(
347                        q.coefficients_shoup.as_mut().unwrap().outer_iter_mut(),
348                        self.coefficients_shoup.as_ref().unwrap().outer_iter()
349                    )
350                    .for_each(|(mut q_row, p_row)| {
351                        for (j, k) in izip!(self.ctx.bitrev.iter(), i.power_bitrev.iter()) {
352                            q_row[*j] = p_row[*k]
353                        }
354                    });
355                }
356            }
357            Representation::PowerBasis => {
358                let mut power = 0usize;
359                let mask = self.ctx.degree - 1;
360                for j in 0..self.ctx.degree {
361                    izip!(
362                        self.ctx.q.iter(),
363                        q.coefficients.slice_mut(s![.., power & mask]),
364                        self.coefficients.slice(s![.., j])
365                    )
366                    .for_each(|(qi, qij, pij)| {
367                        if power & self.ctx.degree != 0 {
368                            *qij = qi.sub(*qij, *pij)
369                        } else {
370                            *qij = qi.add(*qij, *pij)
371                        }
372                    });
373                    power += i.exponent
374                }
375            }
376        }
377
378        Ok(q)
379    }
380
381    /// Create a polynomial which can only be multiplied by a polynomial in
382    /// NttShoup representation. All other operations may panic.
383    ///
384    /// # Safety
385    /// This operation also creates a polynomial that allows variable time
386    /// operations.
387    pub unsafe fn create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time(
388        power_basis_coefficients: &[u64],
389        ctx: &Arc<Context>,
390    ) -> Self {
391        let mut coefficients = Array2::zeros((ctx.q.len(), ctx.degree));
392        izip!(coefficients.outer_iter_mut(), ctx.q.iter(), ctx.ops.iter()).for_each(
393            |(mut p, qi, op)| {
394                p.as_slice_mut()
395                    .unwrap()
396                    .clone_from_slice(power_basis_coefficients);
397                qi.lazy_reduce_vec(p.as_slice_mut().unwrap());
398                op.forward_vt_lazy(p.as_mut_ptr());
399            },
400        );
401        Self {
402            ctx: ctx.clone(),
403            representation: Representation::Ntt,
404            allow_variable_time_computations: true,
405            coefficients,
406            coefficients_shoup: None,
407            has_lazy_coefficients: true,
408        }
409    }
410
411    /// Modulus switch down the polynomial by dividing and rounding each
412    /// coefficient by the last modulus in the chain, then drops the last
413    /// modulus, as described in Algorithm 2 of <https://eprint.iacr.org/2018/931.pdf>.
414    ///
415    /// Returns an error if there is no next context or if the representation
416    /// is not PowerBasis.
417    pub fn switch_down(&mut self) -> Result<()> {
418        if self.ctx.next_context.is_none() {
419            return Err(Error::NoMoreContext);
420        }
421
422        if self.representation != Representation::PowerBasis {
423            return Err(Error::IncorrectRepresentation(
424                self.representation,
425                Representation::PowerBasis,
426            ));
427        }
428
429        // Unwrap the next_context.
430        let next_context = self.ctx.next_context.as_ref().unwrap();
431
432        let q_len = self.ctx.q.len();
433        let q_last = self.ctx.q.last().unwrap();
434        let q_last_div_2 = (**q_last) / 2;
435
436        // Add (q_last - 1) / 2 to change from flooring to rounding
437        let (mut q_new_polys, mut q_last_poly) =
438            self.coefficients.view_mut().split_at(Axis(0), q_len - 1);
439
440        let add: fn(&Modulus, u64, u64) -> u64 = if self.allow_variable_time_computations {
441            |qi, a, b| unsafe { qi.add_vt(a, b) }
442        } else {
443            |qi, a, b| qi.add(a, b)
444        };
445        let reduce: unsafe fn(&Modulus, u64) -> u64 = if self.allow_variable_time_computations {
446            |qi, a| unsafe { qi.reduce_vt(a) }
447        } else {
448            |qi, a| qi.reduce(a)
449        };
450
451        q_last_poly
452            .iter_mut()
453            .for_each(|coeff| *coeff = add(q_last, *coeff, q_last_div_2));
454        izip!(
455            q_new_polys.outer_iter_mut(),
456            self.ctx.q.iter(),
457            self.ctx.inv_last_qi_mod_qj.iter(),
458            self.ctx.inv_last_qi_mod_qj_shoup.iter(),
459        )
460        .for_each(|(coeffs, qi, inv, inv_shoup)| {
461            let q_last_div_2_mod_qi = **qi - unsafe { reduce(qi, q_last_div_2) }; // Up to qi.modulus()
462            for (coeff, q_last_coeff) in izip!(coeffs, q_last_poly.iter()) {
463                // (x mod q_last - q_L/2) mod q_i
464                let tmp = qi.lazy_reduce(*q_last_coeff) + q_last_div_2_mod_qi; // Up to 3 * qi.modulus()
465
466                // ((x mod q_i) - (x mod q_last) + (q_L/2 mod q_i)) mod q_i
467                // = (x - x mod q_last + q_L/2) mod q_i
468                *coeff += 3 * (**qi) - tmp; // Up to 4 * qi.modulus()
469
470                // q_last^{-1} * (x - x mod q_last) mod q_i
471                *coeff = qi.mul_shoup(*coeff, *inv, *inv_shoup);
472            }
473        });
474
475        // Remove the last row, and update the context.
476        if !self.allow_variable_time_computations {
477            q_last_poly.as_slice_mut().unwrap().zeroize();
478        }
479        self.coefficients.remove_index(Axis(0), q_len - 1);
480        self.ctx = next_context.clone();
481
482        Ok(())
483    }
484
485    /// Modulo switch down to a smaller context.
486    ///
487    /// Returns an error if there is the provided context is not a child of the
488    /// current context, or if the polynomial is not in PowerBasis
489    /// representation.
490    pub fn switch_down_to(&mut self, context: &Arc<Context>) -> Result<()> {
491        let niterations = self.ctx.niterations_to(context)?;
492        for _ in 0..niterations {
493            self.switch_down()?;
494        }
495        assert_eq!(&self.ctx, context);
496        Ok(())
497    }
498
499    /// Modulo switch to another context. The target context needs not to be
500    /// related to the current context.
501    pub fn switch(&self, switcher: &Switcher) -> Result<Poly> {
502        switcher.switch(self)
503    }
504
505    /// Scale a polynomial using a scaler.
506    pub fn scale(&self, scaler: &Scaler) -> Result<Poly> {
507        scaler.scale(self)
508    }
509
510    /// Returns the context of the underlying polynomial
511    pub fn ctx(&self) -> &Arc<Context> {
512        &self.ctx
513    }
514
515    /// Multiplies a polynomial in PowerBasis representation by x^(-power).
516    pub fn multiply_inverse_power_of_x(&mut self, power: usize) -> Result<()> {
517        if self.representation != Representation::PowerBasis {
518            return Err(Error::IncorrectRepresentation(
519                self.representation,
520                Representation::PowerBasis,
521            ));
522        }
523
524        let shift = ((self.ctx.degree << 1) - power) % (self.ctx.degree << 1);
525        let mask = self.ctx.degree - 1;
526        let mut new_coefficients = Array2::zeros((self.ctx.q.len(), self.ctx.degree));
527        izip!(
528            new_coefficients.outer_iter_mut(),
529            self.coefficients.outer_iter(),
530            self.ctx.q.iter()
531        )
532        .for_each(|(mut new_coeffs, orig_coeffs, qi)| {
533            for k in 0..self.ctx.degree {
534                let index = shift + k;
535                if index & self.ctx.degree == 0 {
536                    new_coeffs[index & mask] = orig_coeffs[k];
537                } else {
538                    new_coeffs[index & mask] = qi.neg(orig_coeffs[k]);
539                }
540            }
541        });
542        self.coefficients = new_coefficients;
543        Ok(())
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::{switcher::Switcher, Context, Poly, Representation};
550    use crate::{rq::SubstitutionExponent, zq::Modulus};
551    use fhe_util::variance;
552    use itertools::Itertools;
553    use num_bigint::BigUint;
554    use num_traits::{One, Zero};
555    use rand::{Rng, SeedableRng};
556    use rand_chacha::ChaCha8Rng;
557    use std::{error::Error, sync::Arc};
558
559    // Moduli to be used in tests.
560    const MODULI: &[u64; 5] = &[
561        1153,
562        4611686018326724609,
563        4611686018309947393,
564        4611686018232352769,
565        4611686018171535361,
566    ];
567
568    #[test]
569    fn poly_zero() -> Result<(), Box<dyn Error>> {
570        let reference = &[
571            BigUint::zero(),
572            BigUint::zero(),
573            BigUint::zero(),
574            BigUint::zero(),
575            BigUint::zero(),
576            BigUint::zero(),
577            BigUint::zero(),
578            BigUint::zero(),
579            BigUint::zero(),
580            BigUint::zero(),
581            BigUint::zero(),
582            BigUint::zero(),
583            BigUint::zero(),
584            BigUint::zero(),
585            BigUint::zero(),
586            BigUint::zero(),
587        ];
588
589        for modulus in MODULI {
590            let ctx = Arc::new(Context::new(&[*modulus], 16)?);
591            let p = Poly::zero(&ctx, Representation::PowerBasis);
592            let q = Poly::zero(&ctx, Representation::Ntt);
593            assert_ne!(p, q);
594            assert_eq!(Vec::<u64>::from(&p), &[0; 16]);
595            assert_eq!(Vec::<u64>::from(&q), &[0; 16]);
596        }
597
598        let ctx = Arc::new(Context::new(MODULI, 16)?);
599        let p = Poly::zero(&ctx, Representation::PowerBasis);
600        let q = Poly::zero(&ctx, Representation::Ntt);
601        assert_ne!(p, q);
602        assert_eq!(Vec::<u64>::from(&p), [0; 16 * MODULI.len()]);
603        assert_eq!(Vec::<u64>::from(&q), [0; 16 * MODULI.len()]);
604        assert_eq!(Vec::<BigUint>::from(&p), reference);
605        assert_eq!(Vec::<BigUint>::from(&q), reference);
606
607        Ok(())
608    }
609
610    #[test]
611    fn ctx() -> Result<(), Box<dyn Error>> {
612        for modulus in MODULI {
613            let ctx = Arc::new(Context::new(&[*modulus], 16)?);
614            let p = Poly::zero(&ctx, Representation::PowerBasis);
615            assert_eq!(p.ctx(), &ctx);
616        }
617
618        let ctx = Arc::new(Context::new(MODULI, 16)?);
619        let p = Poly::zero(&ctx, Representation::PowerBasis);
620        assert_eq!(p.ctx(), &ctx);
621
622        Ok(())
623    }
624
625    #[test]
626    fn random() -> Result<(), Box<dyn Error>> {
627        let mut rng = rand::rng();
628        for _ in 0..100 {
629            let mut seed = <ChaCha8Rng as SeedableRng>::Seed::default();
630            rand::rng().fill(&mut seed);
631
632            for modulus in MODULI {
633                let ctx = Arc::new(Context::new(&[*modulus], 16)?);
634                let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed);
635                let q = Poly::random_from_seed(&ctx, Representation::Ntt, seed);
636                assert_eq!(p, q);
637            }
638
639            let ctx = Arc::new(Context::new(MODULI, 16)?);
640            let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed);
641            let q = Poly::random_from_seed(&ctx, Representation::Ntt, seed);
642            assert_eq!(p, q);
643
644            rand::rng().fill(&mut seed);
645            let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed);
646            assert_ne!(p, q);
647
648            let r = Poly::random(&ctx, Representation::Ntt, &mut rng);
649            assert_ne!(p, r);
650            assert_ne!(q, r);
651        }
652        Ok(())
653    }
654
655    #[test]
656    fn coefficients() -> Result<(), Box<dyn Error>> {
657        let mut rng = rand::rng();
658        for _ in 0..50 {
659            for modulus in MODULI {
660                let ctx = Arc::new(Context::new(&[*modulus], 16)?);
661                let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
662                let p_coefficients = Vec::<u64>::from(&p);
663                assert_eq!(p_coefficients, p.coefficients().as_slice().unwrap())
664            }
665
666            let ctx = Arc::new(Context::new(MODULI, 16)?);
667            let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
668            let p_coefficients = Vec::<u64>::from(&p);
669            assert_eq!(p_coefficients, p.coefficients().as_slice().unwrap())
670        }
671        Ok(())
672    }
673
674    #[test]
675    fn modulus() -> Result<(), Box<dyn Error>> {
676        for modulus in MODULI {
677            let modulus_biguint = BigUint::from(*modulus);
678            let ctx = Arc::new(Context::new(&[*modulus], 16)?);
679            assert_eq!(ctx.modulus(), &modulus_biguint)
680        }
681
682        let mut modulus_biguint = BigUint::one();
683        MODULI.iter().for_each(|m| modulus_biguint *= *m);
684        let ctx = Arc::new(Context::new(MODULI, 16)?);
685        assert_eq!(ctx.modulus(), &modulus_biguint);
686
687        Ok(())
688    }
689
690    #[test]
691    fn allow_variable_time_computations() -> Result<(), Box<dyn Error>> {
692        let mut rng = rand::rng();
693        for modulus in MODULI {
694            let ctx = Arc::new(Context::new(&[*modulus], 16)?);
695            let mut p = Poly::random(&ctx, Representation::default(), &mut rng);
696            assert!(!p.allow_variable_time_computations);
697
698            unsafe { p.allow_variable_time_computations() }
699            assert!(p.allow_variable_time_computations);
700
701            let q = p.clone();
702            assert!(q.allow_variable_time_computations);
703
704            p.disallow_variable_time_computations();
705            assert!(!p.allow_variable_time_computations);
706        }
707
708        let ctx = Arc::new(Context::new(MODULI, 16)?);
709        let mut p = Poly::random(&ctx, Representation::default(), &mut rng);
710        assert!(!p.allow_variable_time_computations);
711
712        unsafe { p.allow_variable_time_computations() }
713        assert!(p.allow_variable_time_computations);
714
715        let q = p.clone();
716        assert!(q.allow_variable_time_computations);
717
718        // Allowing variable time propagates.
719        let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng);
720        unsafe { p.allow_variable_time_computations() }
721        let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
722
723        assert!(!q.allow_variable_time_computations);
724        q *= &p;
725        assert!(q.allow_variable_time_computations);
726
727        q.disallow_variable_time_computations();
728        q += &p;
729        assert!(q.allow_variable_time_computations);
730
731        q.disallow_variable_time_computations();
732        q -= &p;
733        assert!(q.allow_variable_time_computations);
734
735        q = -&p;
736        assert!(q.allow_variable_time_computations);
737
738        Ok(())
739    }
740
741    #[test]
742    fn change_representation() -> Result<(), Box<dyn Error>> {
743        let mut rng = rand::rng();
744        let ctx = Arc::new(Context::new(MODULI, 16)?);
745
746        let mut p = Poly::random(&ctx, Representation::default(), &mut rng);
747        assert_eq!(p.representation, Representation::default());
748        assert_eq!(p.representation(), &Representation::default());
749
750        p.change_representation(Representation::PowerBasis);
751        assert_eq!(p.representation, Representation::PowerBasis);
752        assert_eq!(p.representation(), &Representation::PowerBasis);
753        assert!(p.coefficients_shoup.is_none());
754        let q = p.clone();
755
756        p.change_representation(Representation::Ntt);
757        assert_eq!(p.representation, Representation::Ntt);
758        assert_eq!(p.representation(), &Representation::Ntt);
759        assert_ne!(p.coefficients, q.coefficients);
760        assert!(p.coefficients_shoup.is_none());
761        let q_ntt = p.clone();
762
763        p.change_representation(Representation::NttShoup);
764        assert_eq!(p.representation, Representation::NttShoup);
765        assert_eq!(p.representation(), &Representation::NttShoup);
766        assert_ne!(p.coefficients, q.coefficients);
767        assert!(p.coefficients_shoup.is_some());
768        let q_ntt_shoup = p.clone();
769
770        p.change_representation(Representation::PowerBasis);
771        assert_eq!(p, q);
772
773        p.change_representation(Representation::NttShoup);
774        assert_eq!(p, q_ntt_shoup);
775
776        p.change_representation(Representation::Ntt);
777        assert_eq!(p, q_ntt);
778
779        p.change_representation(Representation::PowerBasis);
780        assert_eq!(p, q);
781
782        Ok(())
783    }
784
785    #[test]
786    fn override_representation() -> Result<(), Box<dyn Error>> {
787        let mut rng = rand::rng();
788        let ctx = Arc::new(Context::new(MODULI, 16)?);
789
790        let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
791        assert_eq!(p.representation(), &p.representation);
792        let q = p.clone();
793
794        unsafe { p.override_representation(Representation::Ntt) }
795        assert_eq!(p.representation, Representation::Ntt);
796        assert_eq!(p.representation(), &p.representation);
797        assert_eq!(p.coefficients, q.coefficients);
798        assert!(p.coefficients_shoup.is_none());
799
800        unsafe { p.override_representation(Representation::NttShoup) }
801        assert_eq!(p.representation, Representation::NttShoup);
802        assert_eq!(p.representation(), &p.representation);
803        assert_eq!(p.coefficients, q.coefficients);
804        assert!(p.coefficients_shoup.is_some());
805
806        unsafe { p.override_representation(Representation::PowerBasis) }
807        assert_eq!(p, q);
808
809        unsafe { p.override_representation(Representation::NttShoup) }
810        assert!(p.coefficients_shoup.is_some());
811
812        unsafe { p.override_representation(Representation::Ntt) }
813        assert!(p.coefficients_shoup.is_none());
814
815        Ok(())
816    }
817
818    #[test]
819    fn small() -> Result<(), Box<dyn Error>> {
820        let mut rng = rand::rng();
821        for modulus in MODULI {
822            let ctx = Arc::new(Context::new(&[*modulus], 16)?);
823            let q = Modulus::new(*modulus).unwrap();
824
825            let e = Poly::small(&ctx, Representation::PowerBasis, 0, &mut rng);
826            assert!(e.is_err());
827            assert_eq!(
828                e.unwrap_err().to_string(),
829                "The variance should be an integer between 1 and 16"
830            );
831            let e = Poly::small(&ctx, Representation::PowerBasis, 17, &mut rng);
832            assert!(e.is_err());
833            assert_eq!(
834                e.unwrap_err().to_string(),
835                "The variance should be an integer between 1 and 16"
836            );
837
838            for i in 1..=16 {
839                let p = Poly::small(&ctx, Representation::PowerBasis, i, &mut rng)?;
840                let coefficients = p.coefficients().to_slice().unwrap();
841                let v = unsafe { q.center_vec_vt(coefficients) };
842
843                assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 2 * i as i64);
844            }
845        }
846
847        // Generate a very large polynomial to check the variance (here equal to 8).
848        let ctx = Arc::new(Context::new(&[4611686018326724609], 1 << 18)?);
849        let q = Modulus::new(4611686018326724609).unwrap();
850        let mut rng = rand::rng();
851        let p = Poly::small(&ctx, Representation::PowerBasis, 16, &mut rng)?;
852        let coefficients = p.coefficients().to_slice().unwrap();
853        let v = unsafe { q.center_vec_vt(coefficients) };
854        assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 32);
855        assert_eq!(variance(&v).round(), 16.0);
856
857        Ok(())
858    }
859
860    #[test]
861    fn substitute() -> Result<(), Box<dyn Error>> {
862        let mut rng = rand::rng();
863        for modulus in MODULI {
864            let ctx = Arc::new(Context::new(&[*modulus], 16)?);
865            let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
866            let mut p_ntt = p.clone();
867            p_ntt.change_representation(Representation::Ntt);
868            let mut p_ntt_shoup = p.clone();
869            p_ntt_shoup.change_representation(Representation::NttShoup);
870            let p_coeffs = Vec::<u64>::from(&p);
871
872            // Substitution by a multiple of 2 * degree, or even numbers, should fail
873            assert!(SubstitutionExponent::new(&ctx, 0).is_err());
874            assert!(SubstitutionExponent::new(&ctx, 2).is_err());
875            assert!(SubstitutionExponent::new(&ctx, 16).is_err());
876
877            // Substitution by 1 leaves the polynomials unchanged
878            assert_eq!(p, p.substitute(&SubstitutionExponent::new(&ctx, 1)?)?);
879            assert_eq!(
880                p_ntt,
881                p_ntt.substitute(&SubstitutionExponent::new(&ctx, 1)?)?
882            );
883            assert_eq!(
884                p_ntt_shoup,
885                p_ntt_shoup.substitute(&SubstitutionExponent::new(&ctx, 1)?)?
886            );
887
888            // Substitution by 3
889            let mut q = p.substitute(&SubstitutionExponent::new(&ctx, 3)?)?;
890            let mut v = vec![0u64; 16];
891            for i in 0..16 {
892                v[(3 * i) % 16] = if ((3 * i) / 16) & 1 == 1 && p_coeffs[i] > 0 {
893                    *modulus - p_coeffs[i]
894                } else {
895                    p_coeffs[i]
896                };
897            }
898            assert_eq!(&Vec::<u64>::from(&q), &v);
899
900            let q_ntt = p_ntt.substitute(&SubstitutionExponent::new(&ctx, 3)?)?;
901            q.change_representation(Representation::Ntt);
902            assert_eq!(q, q_ntt);
903
904            let q_ntt_shoup = p_ntt_shoup.substitute(&SubstitutionExponent::new(&ctx, 3)?)?;
905            q.change_representation(Representation::NttShoup);
906            assert_eq!(q, q_ntt_shoup);
907
908            // 11 = 3^(-1) % 16
909            assert_eq!(
910                p,
911                p.substitute(&SubstitutionExponent::new(&ctx, 3)?)?
912                    .substitute(&SubstitutionExponent::new(&ctx, 11)?)?
913            );
914            assert_eq!(
915                p_ntt,
916                p_ntt
917                    .substitute(&SubstitutionExponent::new(&ctx, 3)?)?
918                    .substitute(&SubstitutionExponent::new(&ctx, 11)?)?
919            );
920            assert_eq!(
921                p_ntt_shoup,
922                p_ntt_shoup
923                    .substitute(&SubstitutionExponent::new(&ctx, 3)?)?
924                    .substitute(&SubstitutionExponent::new(&ctx, 11)?)?
925            );
926        }
927
928        let ctx = Arc::new(Context::new(MODULI, 16)?);
929        let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
930        let mut p_ntt = p.clone();
931        p_ntt.change_representation(Representation::Ntt);
932        let mut p_ntt_shoup = p.clone();
933        p_ntt_shoup.change_representation(Representation::NttShoup);
934
935        assert_eq!(
936            p,
937            p.substitute(&SubstitutionExponent::new(&ctx, 3)?)?
938                .substitute(&SubstitutionExponent::new(&ctx, 11)?)?
939        );
940        assert_eq!(
941            p_ntt,
942            p_ntt
943                .substitute(&SubstitutionExponent::new(&ctx, 3)?)?
944                .substitute(&SubstitutionExponent::new(&ctx, 11)?)?
945        );
946        assert_eq!(
947            p_ntt_shoup,
948            p_ntt_shoup
949                .substitute(&SubstitutionExponent::new(&ctx, 3)?)?
950                .substitute(&SubstitutionExponent::new(&ctx, 11)?)?
951        );
952
953        Ok(())
954    }
955
956    #[test]
957    fn switch_down() -> Result<(), Box<dyn Error>> {
958        let mut rng = rand::rng();
959        let ntests = 100;
960        let ctx = Arc::new(Context::new(MODULI, 16)?);
961
962        for _ in 0..ntests {
963            // If the polynomial has incorrect representation, an error is returned
964            let e = Poly::random(&ctx, Representation::Ntt, &mut rng).switch_down();
965            assert!(e.is_err());
966            assert_eq!(
967                e.unwrap_err(),
968                crate::Error::IncorrectRepresentation(
969                    Representation::Ntt,
970                    Representation::PowerBasis
971                )
972            );
973
974            // Otherwise, no error happens and the coefficients evolve as expected.
975            let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
976            let mut reference = Vec::<BigUint>::from(&p);
977            let mut current_ctx = ctx.clone();
978            assert_eq!(p.ctx, current_ctx);
979            while current_ctx.next_context.is_some() {
980                let denominator = current_ctx.modulus().clone();
981                current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
982                let numerator = current_ctx.modulus().clone();
983                assert!(p.switch_down().is_ok());
984                assert_eq!(p.ctx, current_ctx);
985                let p_biguint = Vec::<BigUint>::from(&p);
986                assert_eq!(
987                    p_biguint,
988                    reference
989                        .iter()
990                        .map(
991                            |b| (((b * &numerator) + (&denominator >> 1)) / &denominator)
992                                % current_ctx.modulus()
993                        )
994                        .collect_vec()
995                );
996                reference.clone_from(&p_biguint);
997            }
998        }
999        Ok(())
1000    }
1001
1002    #[test]
1003    fn switch_down_to() -> Result<(), Box<dyn Error>> {
1004        let mut rng = rand::rng();
1005        let ntests = 100;
1006        let ctx1 = Arc::new(Context::new(MODULI, 16)?);
1007        let ctx2 = Arc::new(Context::new(&MODULI[..2], 16)?);
1008
1009        for _ in 0..ntests {
1010            let mut p = Poly::random(&ctx1, Representation::PowerBasis, &mut rng);
1011            let reference = Vec::<BigUint>::from(&p);
1012
1013            p.switch_down_to(&ctx2)?;
1014
1015            assert_eq!(p.ctx, ctx2);
1016            assert_eq!(
1017                Vec::<BigUint>::from(&p),
1018                reference
1019                    .iter()
1020                    .map(|b| ((b * ctx2.modulus()) + (ctx1.modulus() >> 1)) / ctx1.modulus())
1021                    .collect_vec()
1022            );
1023        }
1024
1025        Ok(())
1026    }
1027
1028    #[test]
1029    fn switch() -> Result<(), Box<dyn Error>> {
1030        let mut rng = rand::rng();
1031        let ntests = 100;
1032        let ctx1 = Arc::new(Context::new(&MODULI[..2], 16)?);
1033        let ctx2 = Arc::new(Context::new(&MODULI[3..], 16)?);
1034        let switcher = Switcher::new(&ctx1, &ctx2)?;
1035        for _ in 0..ntests {
1036            let p = Poly::random(&ctx1, Representation::PowerBasis, &mut rng);
1037            let reference = Vec::<BigUint>::from(&p);
1038
1039            let q = p.switch(&switcher)?;
1040
1041            assert_eq!(q.ctx, ctx2);
1042            assert_eq!(
1043                Vec::<BigUint>::from(&q),
1044                reference
1045                    .iter()
1046                    .map(|b| ((b * ctx2.modulus()) + (ctx1.modulus() >> 1)) / ctx1.modulus())
1047                    .collect_vec()
1048            );
1049        }
1050        Ok(())
1051    }
1052
1053    #[test]
1054    fn mul_x_power() -> Result<(), Box<dyn Error>> {
1055        let mut rng = rand::rng();
1056        let ctx = Arc::new(Context::new(MODULI, 16)?);
1057        let e = Poly::random(&ctx, Representation::Ntt, &mut rng).multiply_inverse_power_of_x(1);
1058        assert!(e.is_err());
1059        assert_eq!(
1060            e.unwrap_err(),
1061            crate::Error::IncorrectRepresentation(Representation::Ntt, Representation::PowerBasis)
1062        );
1063
1064        let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
1065        let q = p.clone();
1066
1067        p.multiply_inverse_power_of_x(0)?;
1068        assert_eq!(p, q);
1069
1070        p.multiply_inverse_power_of_x(1)?;
1071        assert_ne!(p, q);
1072
1073        p.multiply_inverse_power_of_x(2 * ctx.degree - 1)?;
1074        assert_eq!(p, q);
1075
1076        p.multiply_inverse_power_of_x(ctx.degree)?;
1077        assert_eq!(
1078            Vec::<BigUint>::from(&p)
1079                .iter()
1080                .map(|c| ctx.modulus() - c)
1081                .collect_vec(),
1082            Vec::<BigUint>::from(&q)
1083        );
1084
1085        Ok(())
1086    }
1087}