1#![warn(missing_docs, unused_imports)]
2
3mod context;
7mod convert;
8mod ops;
9mod serialize;
10
11pub mod scaler;
12pub mod switcher;
13pub mod traits;
14use self::{scaler::Scaler, switcher::Switcher, traits::TryConvertFrom};
15use crate::{zq::Modulus, Error, Result};
16pub use context::Context;
17use fhe_util::sample_vec_cbd;
18use itertools::{izip, Itertools};
19use ndarray::{s, Array2, ArrayView2, Axis};
20pub use ops::dot_product;
21use rand::{CryptoRng, RngCore, SeedableRng};
22use rand_chacha::ChaCha8Rng;
23use sha2::{Digest, Sha256};
24use std::sync::Arc;
25use zeroize::{Zeroize, Zeroizing};
26
27#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
29#[non_exhaustive]
30pub enum Representation {
31 #[default]
34 PowerBasis,
35 Ntt,
37 NttShoup,
40}
41
42#[derive(Debug, PartialEq, Eq)]
44pub struct SubstitutionExponent {
45 pub exponent: usize,
47
48 ctx: Arc<Context>,
49 power_bitrev: Vec<usize>,
50}
51
52impl SubstitutionExponent {
53 pub fn new(ctx: &Arc<Context>, exponent: usize) -> Result<Self> {
56 let exponent = exponent % (2 * ctx.degree);
57 if exponent & 1 == 0 {
58 return Err(Error::Default(
59 "The exponent should be odd modulo 2 * degree".to_string(),
60 ));
61 }
62 let mut power = (exponent - 1) / 2;
63 let mask = ctx.degree - 1;
64 let power_bitrev = (0..ctx.degree)
65 .map(|_| {
66 let r = (power & mask).reverse_bits() >> (ctx.degree.leading_zeros() + 1);
67 power += exponent;
68 r
69 })
70 .collect_vec();
71 Ok(Self {
72 ctx: ctx.clone(),
73 exponent,
74 power_bitrev,
75 })
76 }
77}
78
79#[derive(Default, Debug, Clone, PartialEq, Eq)]
81pub struct Poly {
82 ctx: Arc<Context>,
83 representation: Representation,
84 has_lazy_coefficients: bool,
85 allow_variable_time_computations: bool,
86 coefficients: Array2<u64>,
87 coefficients_shoup: Option<Array2<u64>>,
88}
89
90impl Zeroize for Poly {
92 fn zeroize(&mut self) {
93 if let Some(coeffs) = self.coefficients.as_slice_mut() {
94 coeffs.zeroize()
95 }
96 self.zeroize_shoup()
97 }
98}
99
100impl AsRef<Poly> for Poly {
101 fn as_ref(&self) -> &Poly {
102 self
103 }
104}
105
106impl AsMut<Poly> for Poly {
107 fn as_mut(&mut self) -> &mut Poly {
108 self
109 }
110}
111
112impl Poly {
113 pub fn zero(ctx: &Arc<Context>, representation: Representation) -> Self {
115 Self {
116 ctx: ctx.clone(),
117 representation,
118 allow_variable_time_computations: false,
119 has_lazy_coefficients: false,
120 coefficients: Array2::zeros((ctx.q.len(), ctx.degree)),
121 coefficients_shoup: if representation == Representation::NttShoup {
122 Some(Array2::zeros((ctx.q.len(), ctx.degree)))
123 } else {
124 None
125 },
126 }
127 }
128
129 pub unsafe fn allow_variable_time_computations(&mut self) {
136 self.allow_variable_time_computations = true
137 }
138
139 pub fn disallow_variable_time_computations(&mut self) {
141 self.allow_variable_time_computations = false
142 }
143
144 pub const fn representation(&self) -> &Representation {
146 &self.representation
147 }
148
149 fn zeroize_shoup(&mut self) {
151 if let Some(coeffs_shoup) = self
152 .coefficients_shoup
153 .as_mut()
154 .and_then(|f| f.as_slice_mut())
155 {
156 coeffs_shoup.zeroize()
157 }
158 }
159
160 pub fn change_representation(&mut self, to: Representation) {
162 if self.representation == to {
163 return;
164 }
165
166 match (&self.representation, &to) {
167 (Representation::PowerBasis, Representation::Ntt) => self.ntt_forward(),
168 (Representation::PowerBasis, Representation::NttShoup) => {
169 self.ntt_forward();
170 self.compute_coefficients_shoup()
171 }
172 (Representation::Ntt, Representation::PowerBasis) => self.ntt_backward(),
173 (Representation::Ntt, Representation::NttShoup) => self.compute_coefficients_shoup(),
174 (Representation::NttShoup, Representation::PowerBasis) => {
175 self.zeroize_shoup();
176 self.coefficients_shoup = None;
177 self.ntt_backward()
178 }
179 (Representation::NttShoup, Representation::Ntt) => {
180 self.zeroize_shoup();
181 self.coefficients_shoup = None;
182 }
183 _ => unreachable!(),
184 }
185
186 self.representation = to;
187 }
188
189 fn compute_coefficients_shoup(&mut self) {
191 let mut coefficients_shoup = Array2::zeros((self.ctx.q.len(), self.ctx.degree));
192 izip!(
193 coefficients_shoup.outer_iter_mut(),
194 self.coefficients.outer_iter(),
195 self.ctx.q.iter()
196 )
197 .for_each(|(mut v_shoup, v, qi)| {
198 v_shoup
199 .as_slice_mut()
200 .unwrap()
201 .copy_from_slice(&qi.shoup_vec(v.as_slice().unwrap()))
202 });
203 self.coefficients_shoup = Some(coefficients_shoup)
204 }
205
206 pub unsafe fn override_representation(&mut self, to: Representation) {
216 if self.coefficients_shoup.is_some() {
217 self.zeroize_shoup();
218 self.coefficients_shoup = None
219 }
220 if to == Representation::NttShoup {
221 self.compute_coefficients_shoup()
222 }
223 self.representation = to;
224 }
225
226 pub fn random<R: RngCore + CryptoRng>(
228 ctx: &Arc<Context>,
229 representation: Representation,
230 rng: &mut R,
231 ) -> Self {
232 let mut p = Poly::zero(ctx, representation);
233 izip!(p.coefficients.outer_iter_mut(), ctx.q.iter()).for_each(|(mut v, qi)| {
234 v.as_slice_mut()
235 .unwrap()
236 .copy_from_slice(&qi.random_vec(ctx.degree, rng))
237 });
238 if p.representation == Representation::NttShoup {
239 p.compute_coefficients_shoup()
240 }
241 p
242 }
243
244 pub fn random_from_seed(
246 ctx: &Arc<Context>,
247 representation: Representation,
248 seed: <ChaCha8Rng as SeedableRng>::Seed,
249 ) -> Self {
250 let mut hasher = Sha256::new();
252 hasher.update(seed);
253 let mut prng =
254 ChaCha8Rng::from_seed(<ChaCha8Rng as SeedableRng>::Seed::from(hasher.finalize()));
255 let mut p = Poly::zero(ctx, representation);
256 izip!(p.coefficients.outer_iter_mut(), ctx.q.iter()).for_each(|(mut v, qi)| {
257 v.as_slice_mut()
258 .unwrap()
259 .copy_from_slice(&qi.random_vec(ctx.degree, &mut prng))
260 });
261 if p.representation == Representation::NttShoup {
262 p.compute_coefficients_shoup()
263 }
264 p
265 }
266
267 pub fn small<T: RngCore + CryptoRng>(
272 ctx: &Arc<Context>,
273 representation: Representation,
274 variance: usize,
275 rng: &mut T,
276 ) -> Result<Self> {
277 if !(1..=16).contains(&variance) {
278 return Err(Error::Default(
279 "The variance should be an integer between 1 and 16".to_string(),
280 ));
281 }
282
283 let coeffs = Zeroizing::new(
284 sample_vec_cbd(ctx.degree, variance, rng).map_err(|e| Error::Default(e.to_string()))?,
285 );
286 let mut p = Poly::try_convert_from(
287 coeffs.as_ref() as &[i64],
288 ctx,
289 false,
290 Representation::PowerBasis,
291 )?;
292 if representation != Representation::PowerBasis {
293 p.change_representation(representation);
294 }
295 Ok(p)
296 }
297
298 pub fn coefficients(&self) -> ArrayView2<'_, u64> {
300 self.coefficients.view()
301 }
302
303 fn ntt_forward(&mut self) {
305 if self.allow_variable_time_computations {
306 izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter())
307 .for_each(|(mut v, op)| unsafe { op.forward_vt(v.as_mut_ptr()) });
308 } else {
309 izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter())
310 .for_each(|(mut v, op)| op.forward(v.as_slice_mut().unwrap()));
311 }
312 }
313
314 fn ntt_backward(&mut self) {
316 if self.allow_variable_time_computations {
317 izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter())
318 .for_each(|(mut v, op)| unsafe { op.backward_vt(v.as_mut_ptr()) });
319 } else {
320 izip!(self.coefficients.outer_iter_mut(), self.ctx.ops.iter())
321 .for_each(|(mut v, op)| op.backward(v.as_slice_mut().unwrap()));
322 }
323 }
324
325 pub fn substitute(&self, i: &SubstitutionExponent) -> Result<Poly> {
330 let mut q = Poly::zero(&self.ctx, self.representation);
331 if self.allow_variable_time_computations {
332 unsafe { q.allow_variable_time_computations() }
333 }
334 match self.representation {
335 Representation::Ntt | Representation::NttShoup => {
336 izip!(
337 q.coefficients.outer_iter_mut(),
338 self.coefficients.outer_iter()
339 )
340 .for_each(|(mut q_row, p_row)| {
341 for (j, k) in izip!(self.ctx.bitrev.iter(), i.power_bitrev.iter()) {
342 q_row[*j] = p_row[*k]
343 }
344 });
345 if self.representation == Representation::NttShoup {
346 izip!(
347 q.coefficients_shoup.as_mut().unwrap().outer_iter_mut(),
348 self.coefficients_shoup.as_ref().unwrap().outer_iter()
349 )
350 .for_each(|(mut q_row, p_row)| {
351 for (j, k) in izip!(self.ctx.bitrev.iter(), i.power_bitrev.iter()) {
352 q_row[*j] = p_row[*k]
353 }
354 });
355 }
356 }
357 Representation::PowerBasis => {
358 let mut power = 0usize;
359 let mask = self.ctx.degree - 1;
360 for j in 0..self.ctx.degree {
361 izip!(
362 self.ctx.q.iter(),
363 q.coefficients.slice_mut(s![.., power & mask]),
364 self.coefficients.slice(s![.., j])
365 )
366 .for_each(|(qi, qij, pij)| {
367 if power & self.ctx.degree != 0 {
368 *qij = qi.sub(*qij, *pij)
369 } else {
370 *qij = qi.add(*qij, *pij)
371 }
372 });
373 power += i.exponent
374 }
375 }
376 }
377
378 Ok(q)
379 }
380
381 pub unsafe fn create_constant_ntt_polynomial_with_lazy_coefficients_and_variable_time(
388 power_basis_coefficients: &[u64],
389 ctx: &Arc<Context>,
390 ) -> Self {
391 let mut coefficients = Array2::zeros((ctx.q.len(), ctx.degree));
392 izip!(coefficients.outer_iter_mut(), ctx.q.iter(), ctx.ops.iter()).for_each(
393 |(mut p, qi, op)| {
394 p.as_slice_mut()
395 .unwrap()
396 .clone_from_slice(power_basis_coefficients);
397 qi.lazy_reduce_vec(p.as_slice_mut().unwrap());
398 op.forward_vt_lazy(p.as_mut_ptr());
399 },
400 );
401 Self {
402 ctx: ctx.clone(),
403 representation: Representation::Ntt,
404 allow_variable_time_computations: true,
405 coefficients,
406 coefficients_shoup: None,
407 has_lazy_coefficients: true,
408 }
409 }
410
411 pub fn switch_down(&mut self) -> Result<()> {
418 if self.ctx.next_context.is_none() {
419 return Err(Error::NoMoreContext);
420 }
421
422 if self.representation != Representation::PowerBasis {
423 return Err(Error::IncorrectRepresentation(
424 self.representation,
425 Representation::PowerBasis,
426 ));
427 }
428
429 let next_context = self.ctx.next_context.as_ref().unwrap();
431
432 let q_len = self.ctx.q.len();
433 let q_last = self.ctx.q.last().unwrap();
434 let q_last_div_2 = (**q_last) / 2;
435
436 let (mut q_new_polys, mut q_last_poly) =
438 self.coefficients.view_mut().split_at(Axis(0), q_len - 1);
439
440 let add: fn(&Modulus, u64, u64) -> u64 = if self.allow_variable_time_computations {
441 |qi, a, b| unsafe { qi.add_vt(a, b) }
442 } else {
443 |qi, a, b| qi.add(a, b)
444 };
445 let reduce: unsafe fn(&Modulus, u64) -> u64 = if self.allow_variable_time_computations {
446 |qi, a| unsafe { qi.reduce_vt(a) }
447 } else {
448 |qi, a| qi.reduce(a)
449 };
450
451 q_last_poly
452 .iter_mut()
453 .for_each(|coeff| *coeff = add(q_last, *coeff, q_last_div_2));
454 izip!(
455 q_new_polys.outer_iter_mut(),
456 self.ctx.q.iter(),
457 self.ctx.inv_last_qi_mod_qj.iter(),
458 self.ctx.inv_last_qi_mod_qj_shoup.iter(),
459 )
460 .for_each(|(coeffs, qi, inv, inv_shoup)| {
461 let q_last_div_2_mod_qi = **qi - unsafe { reduce(qi, q_last_div_2) }; for (coeff, q_last_coeff) in izip!(coeffs, q_last_poly.iter()) {
463 let tmp = qi.lazy_reduce(*q_last_coeff) + q_last_div_2_mod_qi; *coeff += 3 * (**qi) - tmp; *coeff = qi.mul_shoup(*coeff, *inv, *inv_shoup);
472 }
473 });
474
475 if !self.allow_variable_time_computations {
477 q_last_poly.as_slice_mut().unwrap().zeroize();
478 }
479 self.coefficients.remove_index(Axis(0), q_len - 1);
480 self.ctx = next_context.clone();
481
482 Ok(())
483 }
484
485 pub fn switch_down_to(&mut self, context: &Arc<Context>) -> Result<()> {
491 let niterations = self.ctx.niterations_to(context)?;
492 for _ in 0..niterations {
493 self.switch_down()?;
494 }
495 assert_eq!(&self.ctx, context);
496 Ok(())
497 }
498
499 pub fn switch(&self, switcher: &Switcher) -> Result<Poly> {
502 switcher.switch(self)
503 }
504
505 pub fn scale(&self, scaler: &Scaler) -> Result<Poly> {
507 scaler.scale(self)
508 }
509
510 pub fn ctx(&self) -> &Arc<Context> {
512 &self.ctx
513 }
514
515 pub fn multiply_inverse_power_of_x(&mut self, power: usize) -> Result<()> {
517 if self.representation != Representation::PowerBasis {
518 return Err(Error::IncorrectRepresentation(
519 self.representation,
520 Representation::PowerBasis,
521 ));
522 }
523
524 let shift = ((self.ctx.degree << 1) - power) % (self.ctx.degree << 1);
525 let mask = self.ctx.degree - 1;
526 let mut new_coefficients = Array2::zeros((self.ctx.q.len(), self.ctx.degree));
527 izip!(
528 new_coefficients.outer_iter_mut(),
529 self.coefficients.outer_iter(),
530 self.ctx.q.iter()
531 )
532 .for_each(|(mut new_coeffs, orig_coeffs, qi)| {
533 for k in 0..self.ctx.degree {
534 let index = shift + k;
535 if index & self.ctx.degree == 0 {
536 new_coeffs[index & mask] = orig_coeffs[k];
537 } else {
538 new_coeffs[index & mask] = qi.neg(orig_coeffs[k]);
539 }
540 }
541 });
542 self.coefficients = new_coefficients;
543 Ok(())
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::{switcher::Switcher, Context, Poly, Representation};
550 use crate::{rq::SubstitutionExponent, zq::Modulus};
551 use fhe_util::variance;
552 use itertools::Itertools;
553 use num_bigint::BigUint;
554 use num_traits::{One, Zero};
555 use rand::{Rng, SeedableRng};
556 use rand_chacha::ChaCha8Rng;
557 use std::{error::Error, sync::Arc};
558
559 const MODULI: &[u64; 5] = &[
561 1153,
562 4611686018326724609,
563 4611686018309947393,
564 4611686018232352769,
565 4611686018171535361,
566 ];
567
568 #[test]
569 fn poly_zero() -> Result<(), Box<dyn Error>> {
570 let reference = &[
571 BigUint::zero(),
572 BigUint::zero(),
573 BigUint::zero(),
574 BigUint::zero(),
575 BigUint::zero(),
576 BigUint::zero(),
577 BigUint::zero(),
578 BigUint::zero(),
579 BigUint::zero(),
580 BigUint::zero(),
581 BigUint::zero(),
582 BigUint::zero(),
583 BigUint::zero(),
584 BigUint::zero(),
585 BigUint::zero(),
586 BigUint::zero(),
587 ];
588
589 for modulus in MODULI {
590 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
591 let p = Poly::zero(&ctx, Representation::PowerBasis);
592 let q = Poly::zero(&ctx, Representation::Ntt);
593 assert_ne!(p, q);
594 assert_eq!(Vec::<u64>::from(&p), &[0; 16]);
595 assert_eq!(Vec::<u64>::from(&q), &[0; 16]);
596 }
597
598 let ctx = Arc::new(Context::new(MODULI, 16)?);
599 let p = Poly::zero(&ctx, Representation::PowerBasis);
600 let q = Poly::zero(&ctx, Representation::Ntt);
601 assert_ne!(p, q);
602 assert_eq!(Vec::<u64>::from(&p), [0; 16 * MODULI.len()]);
603 assert_eq!(Vec::<u64>::from(&q), [0; 16 * MODULI.len()]);
604 assert_eq!(Vec::<BigUint>::from(&p), reference);
605 assert_eq!(Vec::<BigUint>::from(&q), reference);
606
607 Ok(())
608 }
609
610 #[test]
611 fn ctx() -> Result<(), Box<dyn Error>> {
612 for modulus in MODULI {
613 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
614 let p = Poly::zero(&ctx, Representation::PowerBasis);
615 assert_eq!(p.ctx(), &ctx);
616 }
617
618 let ctx = Arc::new(Context::new(MODULI, 16)?);
619 let p = Poly::zero(&ctx, Representation::PowerBasis);
620 assert_eq!(p.ctx(), &ctx);
621
622 Ok(())
623 }
624
625 #[test]
626 fn random() -> Result<(), Box<dyn Error>> {
627 let mut rng = rand::rng();
628 for _ in 0..100 {
629 let mut seed = <ChaCha8Rng as SeedableRng>::Seed::default();
630 rand::rng().fill(&mut seed);
631
632 for modulus in MODULI {
633 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
634 let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed);
635 let q = Poly::random_from_seed(&ctx, Representation::Ntt, seed);
636 assert_eq!(p, q);
637 }
638
639 let ctx = Arc::new(Context::new(MODULI, 16)?);
640 let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed);
641 let q = Poly::random_from_seed(&ctx, Representation::Ntt, seed);
642 assert_eq!(p, q);
643
644 rand::rng().fill(&mut seed);
645 let p = Poly::random_from_seed(&ctx, Representation::Ntt, seed);
646 assert_ne!(p, q);
647
648 let r = Poly::random(&ctx, Representation::Ntt, &mut rng);
649 assert_ne!(p, r);
650 assert_ne!(q, r);
651 }
652 Ok(())
653 }
654
655 #[test]
656 fn coefficients() -> Result<(), Box<dyn Error>> {
657 let mut rng = rand::rng();
658 for _ in 0..50 {
659 for modulus in MODULI {
660 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
661 let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
662 let p_coefficients = Vec::<u64>::from(&p);
663 assert_eq!(p_coefficients, p.coefficients().as_slice().unwrap())
664 }
665
666 let ctx = Arc::new(Context::new(MODULI, 16)?);
667 let p = Poly::random(&ctx, Representation::Ntt, &mut rng);
668 let p_coefficients = Vec::<u64>::from(&p);
669 assert_eq!(p_coefficients, p.coefficients().as_slice().unwrap())
670 }
671 Ok(())
672 }
673
674 #[test]
675 fn modulus() -> Result<(), Box<dyn Error>> {
676 for modulus in MODULI {
677 let modulus_biguint = BigUint::from(*modulus);
678 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
679 assert_eq!(ctx.modulus(), &modulus_biguint)
680 }
681
682 let mut modulus_biguint = BigUint::one();
683 MODULI.iter().for_each(|m| modulus_biguint *= *m);
684 let ctx = Arc::new(Context::new(MODULI, 16)?);
685 assert_eq!(ctx.modulus(), &modulus_biguint);
686
687 Ok(())
688 }
689
690 #[test]
691 fn allow_variable_time_computations() -> Result<(), Box<dyn Error>> {
692 let mut rng = rand::rng();
693 for modulus in MODULI {
694 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
695 let mut p = Poly::random(&ctx, Representation::default(), &mut rng);
696 assert!(!p.allow_variable_time_computations);
697
698 unsafe { p.allow_variable_time_computations() }
699 assert!(p.allow_variable_time_computations);
700
701 let q = p.clone();
702 assert!(q.allow_variable_time_computations);
703
704 p.disallow_variable_time_computations();
705 assert!(!p.allow_variable_time_computations);
706 }
707
708 let ctx = Arc::new(Context::new(MODULI, 16)?);
709 let mut p = Poly::random(&ctx, Representation::default(), &mut rng);
710 assert!(!p.allow_variable_time_computations);
711
712 unsafe { p.allow_variable_time_computations() }
713 assert!(p.allow_variable_time_computations);
714
715 let q = p.clone();
716 assert!(q.allow_variable_time_computations);
717
718 let mut p = Poly::random(&ctx, Representation::Ntt, &mut rng);
720 unsafe { p.allow_variable_time_computations() }
721 let mut q = Poly::random(&ctx, Representation::Ntt, &mut rng);
722
723 assert!(!q.allow_variable_time_computations);
724 q *= &p;
725 assert!(q.allow_variable_time_computations);
726
727 q.disallow_variable_time_computations();
728 q += &p;
729 assert!(q.allow_variable_time_computations);
730
731 q.disallow_variable_time_computations();
732 q -= &p;
733 assert!(q.allow_variable_time_computations);
734
735 q = -&p;
736 assert!(q.allow_variable_time_computations);
737
738 Ok(())
739 }
740
741 #[test]
742 fn change_representation() -> Result<(), Box<dyn Error>> {
743 let mut rng = rand::rng();
744 let ctx = Arc::new(Context::new(MODULI, 16)?);
745
746 let mut p = Poly::random(&ctx, Representation::default(), &mut rng);
747 assert_eq!(p.representation, Representation::default());
748 assert_eq!(p.representation(), &Representation::default());
749
750 p.change_representation(Representation::PowerBasis);
751 assert_eq!(p.representation, Representation::PowerBasis);
752 assert_eq!(p.representation(), &Representation::PowerBasis);
753 assert!(p.coefficients_shoup.is_none());
754 let q = p.clone();
755
756 p.change_representation(Representation::Ntt);
757 assert_eq!(p.representation, Representation::Ntt);
758 assert_eq!(p.representation(), &Representation::Ntt);
759 assert_ne!(p.coefficients, q.coefficients);
760 assert!(p.coefficients_shoup.is_none());
761 let q_ntt = p.clone();
762
763 p.change_representation(Representation::NttShoup);
764 assert_eq!(p.representation, Representation::NttShoup);
765 assert_eq!(p.representation(), &Representation::NttShoup);
766 assert_ne!(p.coefficients, q.coefficients);
767 assert!(p.coefficients_shoup.is_some());
768 let q_ntt_shoup = p.clone();
769
770 p.change_representation(Representation::PowerBasis);
771 assert_eq!(p, q);
772
773 p.change_representation(Representation::NttShoup);
774 assert_eq!(p, q_ntt_shoup);
775
776 p.change_representation(Representation::Ntt);
777 assert_eq!(p, q_ntt);
778
779 p.change_representation(Representation::PowerBasis);
780 assert_eq!(p, q);
781
782 Ok(())
783 }
784
785 #[test]
786 fn override_representation() -> Result<(), Box<dyn Error>> {
787 let mut rng = rand::rng();
788 let ctx = Arc::new(Context::new(MODULI, 16)?);
789
790 let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
791 assert_eq!(p.representation(), &p.representation);
792 let q = p.clone();
793
794 unsafe { p.override_representation(Representation::Ntt) }
795 assert_eq!(p.representation, Representation::Ntt);
796 assert_eq!(p.representation(), &p.representation);
797 assert_eq!(p.coefficients, q.coefficients);
798 assert!(p.coefficients_shoup.is_none());
799
800 unsafe { p.override_representation(Representation::NttShoup) }
801 assert_eq!(p.representation, Representation::NttShoup);
802 assert_eq!(p.representation(), &p.representation);
803 assert_eq!(p.coefficients, q.coefficients);
804 assert!(p.coefficients_shoup.is_some());
805
806 unsafe { p.override_representation(Representation::PowerBasis) }
807 assert_eq!(p, q);
808
809 unsafe { p.override_representation(Representation::NttShoup) }
810 assert!(p.coefficients_shoup.is_some());
811
812 unsafe { p.override_representation(Representation::Ntt) }
813 assert!(p.coefficients_shoup.is_none());
814
815 Ok(())
816 }
817
818 #[test]
819 fn small() -> Result<(), Box<dyn Error>> {
820 let mut rng = rand::rng();
821 for modulus in MODULI {
822 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
823 let q = Modulus::new(*modulus).unwrap();
824
825 let e = Poly::small(&ctx, Representation::PowerBasis, 0, &mut rng);
826 assert!(e.is_err());
827 assert_eq!(
828 e.unwrap_err().to_string(),
829 "The variance should be an integer between 1 and 16"
830 );
831 let e = Poly::small(&ctx, Representation::PowerBasis, 17, &mut rng);
832 assert!(e.is_err());
833 assert_eq!(
834 e.unwrap_err().to_string(),
835 "The variance should be an integer between 1 and 16"
836 );
837
838 for i in 1..=16 {
839 let p = Poly::small(&ctx, Representation::PowerBasis, i, &mut rng)?;
840 let coefficients = p.coefficients().to_slice().unwrap();
841 let v = unsafe { q.center_vec_vt(coefficients) };
842
843 assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 2 * i as i64);
844 }
845 }
846
847 let ctx = Arc::new(Context::new(&[4611686018326724609], 1 << 18)?);
849 let q = Modulus::new(4611686018326724609).unwrap();
850 let mut rng = rand::rng();
851 let p = Poly::small(&ctx, Representation::PowerBasis, 16, &mut rng)?;
852 let coefficients = p.coefficients().to_slice().unwrap();
853 let v = unsafe { q.center_vec_vt(coefficients) };
854 assert!(v.iter().map(|vi| vi.abs()).max().unwrap() <= 32);
855 assert_eq!(variance(&v).round(), 16.0);
856
857 Ok(())
858 }
859
860 #[test]
861 fn substitute() -> Result<(), Box<dyn Error>> {
862 let mut rng = rand::rng();
863 for modulus in MODULI {
864 let ctx = Arc::new(Context::new(&[*modulus], 16)?);
865 let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
866 let mut p_ntt = p.clone();
867 p_ntt.change_representation(Representation::Ntt);
868 let mut p_ntt_shoup = p.clone();
869 p_ntt_shoup.change_representation(Representation::NttShoup);
870 let p_coeffs = Vec::<u64>::from(&p);
871
872 assert!(SubstitutionExponent::new(&ctx, 0).is_err());
874 assert!(SubstitutionExponent::new(&ctx, 2).is_err());
875 assert!(SubstitutionExponent::new(&ctx, 16).is_err());
876
877 assert_eq!(p, p.substitute(&SubstitutionExponent::new(&ctx, 1)?)?);
879 assert_eq!(
880 p_ntt,
881 p_ntt.substitute(&SubstitutionExponent::new(&ctx, 1)?)?
882 );
883 assert_eq!(
884 p_ntt_shoup,
885 p_ntt_shoup.substitute(&SubstitutionExponent::new(&ctx, 1)?)?
886 );
887
888 let mut q = p.substitute(&SubstitutionExponent::new(&ctx, 3)?)?;
890 let mut v = vec![0u64; 16];
891 for i in 0..16 {
892 v[(3 * i) % 16] = if ((3 * i) / 16) & 1 == 1 && p_coeffs[i] > 0 {
893 *modulus - p_coeffs[i]
894 } else {
895 p_coeffs[i]
896 };
897 }
898 assert_eq!(&Vec::<u64>::from(&q), &v);
899
900 let q_ntt = p_ntt.substitute(&SubstitutionExponent::new(&ctx, 3)?)?;
901 q.change_representation(Representation::Ntt);
902 assert_eq!(q, q_ntt);
903
904 let q_ntt_shoup = p_ntt_shoup.substitute(&SubstitutionExponent::new(&ctx, 3)?)?;
905 q.change_representation(Representation::NttShoup);
906 assert_eq!(q, q_ntt_shoup);
907
908 assert_eq!(
910 p,
911 p.substitute(&SubstitutionExponent::new(&ctx, 3)?)?
912 .substitute(&SubstitutionExponent::new(&ctx, 11)?)?
913 );
914 assert_eq!(
915 p_ntt,
916 p_ntt
917 .substitute(&SubstitutionExponent::new(&ctx, 3)?)?
918 .substitute(&SubstitutionExponent::new(&ctx, 11)?)?
919 );
920 assert_eq!(
921 p_ntt_shoup,
922 p_ntt_shoup
923 .substitute(&SubstitutionExponent::new(&ctx, 3)?)?
924 .substitute(&SubstitutionExponent::new(&ctx, 11)?)?
925 );
926 }
927
928 let ctx = Arc::new(Context::new(MODULI, 16)?);
929 let p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
930 let mut p_ntt = p.clone();
931 p_ntt.change_representation(Representation::Ntt);
932 let mut p_ntt_shoup = p.clone();
933 p_ntt_shoup.change_representation(Representation::NttShoup);
934
935 assert_eq!(
936 p,
937 p.substitute(&SubstitutionExponent::new(&ctx, 3)?)?
938 .substitute(&SubstitutionExponent::new(&ctx, 11)?)?
939 );
940 assert_eq!(
941 p_ntt,
942 p_ntt
943 .substitute(&SubstitutionExponent::new(&ctx, 3)?)?
944 .substitute(&SubstitutionExponent::new(&ctx, 11)?)?
945 );
946 assert_eq!(
947 p_ntt_shoup,
948 p_ntt_shoup
949 .substitute(&SubstitutionExponent::new(&ctx, 3)?)?
950 .substitute(&SubstitutionExponent::new(&ctx, 11)?)?
951 );
952
953 Ok(())
954 }
955
956 #[test]
957 fn switch_down() -> Result<(), Box<dyn Error>> {
958 let mut rng = rand::rng();
959 let ntests = 100;
960 let ctx = Arc::new(Context::new(MODULI, 16)?);
961
962 for _ in 0..ntests {
963 let e = Poly::random(&ctx, Representation::Ntt, &mut rng).switch_down();
965 assert!(e.is_err());
966 assert_eq!(
967 e.unwrap_err(),
968 crate::Error::IncorrectRepresentation(
969 Representation::Ntt,
970 Representation::PowerBasis
971 )
972 );
973
974 let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
976 let mut reference = Vec::<BigUint>::from(&p);
977 let mut current_ctx = ctx.clone();
978 assert_eq!(p.ctx, current_ctx);
979 while current_ctx.next_context.is_some() {
980 let denominator = current_ctx.modulus().clone();
981 current_ctx = current_ctx.next_context.as_ref().unwrap().clone();
982 let numerator = current_ctx.modulus().clone();
983 assert!(p.switch_down().is_ok());
984 assert_eq!(p.ctx, current_ctx);
985 let p_biguint = Vec::<BigUint>::from(&p);
986 assert_eq!(
987 p_biguint,
988 reference
989 .iter()
990 .map(
991 |b| (((b * &numerator) + (&denominator >> 1)) / &denominator)
992 % current_ctx.modulus()
993 )
994 .collect_vec()
995 );
996 reference.clone_from(&p_biguint);
997 }
998 }
999 Ok(())
1000 }
1001
1002 #[test]
1003 fn switch_down_to() -> Result<(), Box<dyn Error>> {
1004 let mut rng = rand::rng();
1005 let ntests = 100;
1006 let ctx1 = Arc::new(Context::new(MODULI, 16)?);
1007 let ctx2 = Arc::new(Context::new(&MODULI[..2], 16)?);
1008
1009 for _ in 0..ntests {
1010 let mut p = Poly::random(&ctx1, Representation::PowerBasis, &mut rng);
1011 let reference = Vec::<BigUint>::from(&p);
1012
1013 p.switch_down_to(&ctx2)?;
1014
1015 assert_eq!(p.ctx, ctx2);
1016 assert_eq!(
1017 Vec::<BigUint>::from(&p),
1018 reference
1019 .iter()
1020 .map(|b| ((b * ctx2.modulus()) + (ctx1.modulus() >> 1)) / ctx1.modulus())
1021 .collect_vec()
1022 );
1023 }
1024
1025 Ok(())
1026 }
1027
1028 #[test]
1029 fn switch() -> Result<(), Box<dyn Error>> {
1030 let mut rng = rand::rng();
1031 let ntests = 100;
1032 let ctx1 = Arc::new(Context::new(&MODULI[..2], 16)?);
1033 let ctx2 = Arc::new(Context::new(&MODULI[3..], 16)?);
1034 let switcher = Switcher::new(&ctx1, &ctx2)?;
1035 for _ in 0..ntests {
1036 let p = Poly::random(&ctx1, Representation::PowerBasis, &mut rng);
1037 let reference = Vec::<BigUint>::from(&p);
1038
1039 let q = p.switch(&switcher)?;
1040
1041 assert_eq!(q.ctx, ctx2);
1042 assert_eq!(
1043 Vec::<BigUint>::from(&q),
1044 reference
1045 .iter()
1046 .map(|b| ((b * ctx2.modulus()) + (ctx1.modulus() >> 1)) / ctx1.modulus())
1047 .collect_vec()
1048 );
1049 }
1050 Ok(())
1051 }
1052
1053 #[test]
1054 fn mul_x_power() -> Result<(), Box<dyn Error>> {
1055 let mut rng = rand::rng();
1056 let ctx = Arc::new(Context::new(MODULI, 16)?);
1057 let e = Poly::random(&ctx, Representation::Ntt, &mut rng).multiply_inverse_power_of_x(1);
1058 assert!(e.is_err());
1059 assert_eq!(
1060 e.unwrap_err(),
1061 crate::Error::IncorrectRepresentation(Representation::Ntt, Representation::PowerBasis)
1062 );
1063
1064 let mut p = Poly::random(&ctx, Representation::PowerBasis, &mut rng);
1065 let q = p.clone();
1066
1067 p.multiply_inverse_power_of_x(0)?;
1068 assert_eq!(p, q);
1069
1070 p.multiply_inverse_power_of_x(1)?;
1071 assert_ne!(p, q);
1072
1073 p.multiply_inverse_power_of_x(2 * ctx.degree - 1)?;
1074 assert_eq!(p, q);
1075
1076 p.multiply_inverse_power_of_x(ctx.degree)?;
1077 assert_eq!(
1078 Vec::<BigUint>::from(&p)
1079 .iter()
1080 .map(|c| ctx.modulus() - c)
1081 .collect_vec(),
1082 Vec::<BigUint>::from(&q)
1083 );
1084
1085 Ok(())
1086 }
1087}