1use flint_sys::{self, deps, fmpz::*, fmpz_mod::*, fmpz_mod_poly::*};
2use rug::Integer;
3use rug_fft;
4use serde::de::Deserializer;
5use serde::ser::Serializer;
6use serde::{Deserialize, Serialize};
7
8use std::cmp::*;
9use std::fmt::{self, Debug, Display, Formatter};
10use std::mem::MaybeUninit;
11use std::ops::*;
12
13mod flint_rug_bridge;
14
15pub struct ModPoly {
27 raw: fmpz_mod_poly_struct,
28 ctx: fmpz_mod_ctx,
29 modulus: Integer,
30}
31
32impl ModPoly {
33 pub fn new(modulus: Integer) -> Self {
35 unsafe {
36 let mut raw = MaybeUninit::uninit();
37 let mut ctx = MaybeUninit::uninit();
38 let mut flint_modulus = flint_rug_bridge::int_to_fmpz(&modulus);
39 fmpz_mod_ctx_init(ctx.as_mut_ptr(), &flint_modulus);
40 fmpz_clear(&mut flint_modulus);
41 let ctx = ctx.assume_init();
42 fmpz_mod_poly_init(raw.as_mut_ptr(), &ctx as *const _ as *mut _);
43 ModPoly {
44 raw: raw.assume_init(),
45 ctx,
46 modulus,
47 }
48 }
49 }
50
51 pub fn from_int(modulus: Integer, mut constant: Integer) -> Self {
53 constant %= &modulus;
54 let mut this = ModPoly::new(modulus);
55 this.set_coefficient(0, &constant);
56 this
57 }
58
59 pub fn with_capacity(modulus: Integer, n: usize) -> Self {
61 unsafe {
62 let mut raw = MaybeUninit::uninit();
63 let mut flint_modulus = flint_rug_bridge::int_to_fmpz(&modulus);
64 let mut ctx = MaybeUninit::uninit();
65 fmpz_mod_ctx_init(ctx.as_mut_ptr(), &flint_modulus);
66 fmpz_clear(&mut flint_modulus);
67 let ctx = ctx.assume_init();
68 fmpz_mod_poly_init2(
69 raw.as_mut_ptr(),
70 n as deps::slong,
71 &ctx as *const _ as *mut _,
72 );
73 ModPoly {
74 raw: raw.assume_init(),
75 modulus,
76 ctx,
77 }
78 }
79 }
80
81 pub fn interpolate_from_mul_subgroup(mut ys: Vec<Integer>, m: Integer, w: &Integer) -> Self {
108 rug_fft::cooley_tukey_radix_2_intt(&mut ys, &m, w);
109 let mut p = ModPoly::with_capacity(m, ys.len());
110 for (i, c) in ys.iter().enumerate() {
111 p.set_coefficient(i, c);
112 }
113 p
114 }
115
116 pub fn evaluate_over_mul_subgroup(&self, w: &Integer, n: usize) -> Vec<Integer> {
141 let mut cs: Vec<Integer> = (0..n)
142 .into_iter()
143 .map(|i| self.get_coefficient(i))
144 .collect();
145 rug_fft::cooley_tukey_radix_2_ntt(&mut cs, &self.modulus, w);
146 cs
147 }
148
149 pub fn with_roots(xs: impl IntoIterator<Item = Integer>, m: &Integer) -> Self {
164 let mut ps = xs
165 .into_iter()
166 .map(|x| {
167 let mut p = ModPoly::new(m.clone());
168 p.set_coefficient_ui(1, 1);
169 p.set_coefficient(0, &-x);
170 p
171 })
172 .collect::<Vec<_>>();
173 while ps.len() > 1 {
174 for i in 0..(ps.len() / 2) {
175 let back = ps.pop().unwrap();
176 ps[i] *= &back;
177 }
178 }
179 ps.pop().unwrap_or_else(|| {
180 let mut p = ModPoly::new(m.clone());
181 p.set_coefficient_ui(0, 1);
182 p
183 })
184 }
185
186 pub fn reserve(&mut self, n: usize) {
189 unsafe {
190 fmpz_mod_poly_realloc(&mut self.raw, n as deps::slong, &mut self.ctx);
191 }
192 }
193
194 pub fn evaluate(&self, i: &Integer) -> Integer {
207 unsafe {
208 let mut in_ = flint_rug_bridge::int_to_fmpz(i);
209
210 let mut out = fmpz::default();
211 fmpz_init(&mut out);
212 fmpz_mod_poly_evaluate_fmpz(
213 &mut out,
214 &self.raw as *const _ as *mut _,
215 &mut in_,
216 &self.ctx as *const _ as *mut _,
217 );
218
219 let out_rug = flint_rug_bridge::fmpz_to_int(&out);
220 fmpz_clear(&mut in_);
221 fmpz_clear(&mut out);
222 out_rug
223 }
224 }
225
226 pub fn modulus(&self) -> &Integer {
228 &self.modulus
229 }
230
231 pub fn get_coefficient(&self, i: usize) -> Integer {
233 unsafe {
234 let mut c = fmpz::default();
235 fmpz_init(&mut c);
236 fmpz_mod_poly_get_coeff_fmpz(
237 &mut c,
238 &self.raw as *const _ as *mut _,
239 i as deps::slong,
240 &self.ctx as *const _ as *mut _,
241 );
242 let c_gmp = flint_rug_bridge::fmpz_to_int(&c);
243 fmpz_clear(&mut c);
244 c_gmp % &self.modulus
245 }
246 }
247
248 pub fn set_coefficient(&mut self, i: usize, c: &Integer) {
250 unsafe {
251 let mut c_flint = flint_rug_bridge::int_to_fmpz(c);
252 fmpz_mod_poly_set_coeff_fmpz(
253 &mut self.raw,
254 i as deps::slong,
255 &mut c_flint,
256 &mut self.ctx,
257 );
258 fmpz_clear(&mut c_flint);
259 }
260 }
261
262 pub fn set_coefficient_ui(&mut self, i: usize, c: usize) {
264 unsafe {
265 fmpz_mod_poly_set_coeff_ui(
266 &mut self.raw,
267 i as deps::slong,
268 c as deps::ulong,
269 &mut self.ctx,
270 );
271 }
272 }
273
274 pub fn len(&self) -> usize {
276 unsafe {
277 fmpz_mod_poly_length(
278 &self.raw as *const _ as *mut _,
279 &self.ctx as *const _ as *mut _,
280 ) as usize
281 }
282 }
283
284 pub fn neg(&mut self) {
286 unsafe {
287 fmpz_mod_poly_neg(&mut self.raw, &mut self.raw, &mut self.ctx);
288 }
289 }
290
291 pub fn add(&mut self, other: &Self) {
293 assert_eq!(self.modulus, other.modulus);
294 unsafe {
295 fmpz_mod_poly_add(
296 &mut self.raw,
297 &mut self.raw,
298 &other.raw as *const _ as *mut _,
299 &mut self.ctx,
300 );
301 }
302 }
303
304 pub fn sub(&mut self, other: &Self) {
306 assert_eq!(self.modulus, other.modulus);
307 unsafe {
308 fmpz_mod_poly_sub(
309 &mut self.raw,
310 &mut self.raw,
311 &other.raw as *const _ as *mut _,
312 &mut self.ctx,
313 );
314 }
315 }
316
317 pub fn sub_from(&mut self, other: &Self) {
319 assert_eq!(self.modulus, other.modulus);
320 unsafe {
321 fmpz_mod_poly_sub(
322 &mut self.raw,
323 &other.raw as *const _ as *mut _,
324 &mut self.raw,
325 &mut self.ctx,
326 );
327 }
328 }
329
330 pub fn mul(&mut self, other: &Self) {
332 assert_eq!(self.modulus, other.modulus);
333 unsafe {
334 fmpz_mod_poly_mul(
335 &mut self.raw,
336 &mut self.raw,
337 &other.raw as *const _ as *mut _,
338 &mut self.ctx,
339 );
340 }
341 }
342
343 pub fn divrem(&self, other: &Self) -> (ModPoly, ModPoly) {
349 assert_eq!(self.modulus, other.modulus);
350 let mut q = ModPoly::new(self.modulus.clone());
351 let mut r = ModPoly::new(self.modulus.clone());
352 unsafe {
353 fmpz_mod_poly_divrem(
354 &mut q.raw,
355 &mut r.raw,
356 &self.raw as *const _ as *mut _,
357 &other.raw as *const _ as *mut _,
358 &self.ctx as *const _ as *mut _,
359 );
360 }
361 (q, r)
362 }
363
364 pub fn div(&mut self, other: &Self) {
366 assert_eq!(self.modulus, other.modulus);
367 let mut r = ModPoly::new(self.modulus.clone());
368 unsafe {
369 fmpz_mod_poly_divrem(
370 &mut self.raw,
371 &mut r.raw,
372 &mut self.raw,
373 &other.raw as *const _ as *mut _,
374 &mut self.ctx,
375 );
376 }
377 }
378
379 pub fn div_from(&mut self, other: &Self) {
381 assert_eq!(self.modulus, other.modulus);
382 let mut r = ModPoly::new(self.modulus.clone());
383 unsafe {
384 fmpz_mod_poly_divrem(
385 &mut self.raw,
386 &mut r.raw,
387 &other.raw as *const _ as *mut _,
388 &mut self.raw,
389 &mut self.ctx,
390 );
391 }
392 }
393
394 pub fn rem(&mut self, other: &Self) {
396 assert_eq!(self.modulus, other.modulus);
397 let mut q = ModPoly::new(self.modulus.clone());
398 unsafe {
399 fmpz_mod_poly_divrem(
400 &mut q.raw,
401 &mut self.raw,
402 &mut self.raw,
403 &other.raw as *const _ as *mut _,
404 &mut self.ctx,
405 );
406 }
407 }
408
409 pub fn rem_from(&mut self, other: &Self) {
411 assert_eq!(self.modulus, other.modulus);
412 let mut q = ModPoly::new(self.modulus.clone());
413 unsafe {
414 fmpz_mod_poly_divrem(
415 &mut q.raw,
416 &mut self.raw,
417 &other.raw as *const _ as *mut _,
418 &mut self.raw,
419 &mut self.ctx,
420 );
421 }
422 }
423
424 pub fn sqr(&mut self) {
426 unsafe {
427 fmpz_mod_poly_sqr(&mut self.raw, &mut self.raw, &mut self.ctx);
428 }
429 }
430
431 pub fn xgcd(&self, other: &Self) -> (Self, Self, Self) {
433 assert_eq!(self.modulus, other.modulus);
434 let mut g = ModPoly::new(self.modulus.clone());
435 let mut s = ModPoly::new(self.modulus.clone());
436 let mut t = ModPoly::new(self.modulus.clone());
437 unsafe {
438 fmpz_mod_poly_xgcd(
439 &mut g.raw,
440 &mut s.raw,
441 &mut t.raw,
442 &self.raw as *const _ as *mut _,
443 &other.raw as *const _ as *mut _,
444 &self.ctx as *const _ as *mut _,
445 );
446 }
447 (g, s, t)
448 }
449
450 pub fn derivative(&self) -> Self {
452 let mut d_self = ModPoly::new(self.modulus.clone());
453 unsafe {
454 fmpz_mod_poly_derivative(
455 &mut d_self.raw,
456 &self.raw as *const _ as *mut _,
457 &self.ctx as *const _ as *mut _,
458 );
459 }
460 d_self
461 }
462}
463
464impl Clone for ModPoly {
465 fn clone(&self) -> Self {
466 let mut this = ModPoly::new(self.modulus.clone());
467 unsafe {
468 fmpz_mod_poly_set(
469 &mut this.raw,
470 &self.raw as *const _ as *mut _,
471 &self.ctx as *const _ as *mut _,
472 );
473 }
474 this
475 }
476}
477
478impl Drop for ModPoly {
479 fn drop(&mut self) {
480 unsafe {
481 fmpz_mod_poly_clear(&mut self.raw, &mut self.ctx);
482 fmpz_mod_ctx_clear(&mut self.ctx);
483 }
484 }
485}
486
487impl PartialEq<ModPoly> for ModPoly {
488 fn eq(&self, other: &ModPoly) -> bool {
489 unsafe {
490 fmpz_mod_poly_equal(
491 &self.raw as *const _ as *mut _,
492 &other.raw as *const _ as *mut _,
493 &self.ctx as *const _ as *mut _,
494 ) != 0
495 }
496 }
497}
498impl Eq for ModPoly {}
499
500impl Debug for ModPoly {
501 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
502 f.debug_struct("ModPoly")
503 .field("modulus", &self.modulus)
504 .field(
505 "coefficients",
506 &(0..self.len())
507 .map(|i| self.get_coefficient(i))
508 .collect::<Vec<_>>(),
509 )
510 .finish()
511 }
512}
513
514impl Display for ModPoly {
515 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
516 let n = self.len();
517 let mut first = true;
518 for i in 0..n {
519 let j = n - i - 1;
520 let c = self.get_coefficient(j);
521 if c != 0 {
522 if !first {
523 write!(f, " + ")?;
524 }
525 write!(f, "{}", c)?;
526 if j != 0 {
527 write!(f, "x^{}", j)?;
528 }
529 first = false;
530 }
531 }
532 if first {
533 write!(f, "0")?;
534 }
535 Ok(())
536 }
537}
538
539macro_rules! impl_self_binary {
540 ($Big:ty,
541 $func:ident,
542 $from_func:ident,
543 $Trait:ident { $method:ident },
544 $TraitAssign:ident { $method_assign:ident }
545 ) => {
546 impl $Trait<$Big> for $Big {
548 type Output = $Big;
549 #[inline]
550 fn $method(mut self, rhs: $Big) -> $Big {
551 self.$method_assign(rhs);
552 self
553 }
554 }
555 impl $Trait<&$Big> for $Big {
557 type Output = $Big;
558 #[inline]
559 fn $method(mut self, rhs: &$Big) -> $Big {
560 self.$method_assign(rhs);
561 self
562 }
563 }
564 impl $Trait<$Big> for &$Big {
566 type Output = $Big;
567 #[inline]
568 fn $method(self, mut rhs: $Big) -> $Big {
569 <$Big>::$from_func(&mut rhs, self);
570 rhs
571 }
572 }
573 impl $TraitAssign<$Big> for $Big {
575 #[inline]
576 fn $method_assign(&mut self, rhs: $Big) {
577 <$Big>::$func(self, &rhs)
578 }
579 }
580 impl $TraitAssign<&$Big> for $Big {
582 #[inline]
583 fn $method_assign(&mut self, rhs: &$Big) {
584 <$Big>::$func(self, rhs)
585 }
586 }
587 };
588}
589
590macro_rules! impl_int_binary {
591 ($Big:ty,
592 $Base:ty,
593 $func:ident,
594 $from_func:ident,
595 $lift_func:ident,
596 $Trait:ident { $method:ident },
597 $TraitAssign:ident { $method_assign:ident }
598 ) => {
599 impl $Trait<$Base> for $Big {
601 type Output = $Big;
602 #[inline]
603 fn $method(mut self, rhs: $Base) -> $Big {
604 let rhs = <$Big>::$lift_func(self.modulus.clone(), rhs);
605 <$Big>::$func(&mut self, &rhs);
606 self
607 }
608 }
609 impl $Trait<$Big> for $Base {
611 type Output = $Big;
612 #[inline]
613 fn $method(self, mut rhs: $Big) -> $Big {
614 let lhs = <$Big>::$lift_func(rhs.modulus.clone(), self);
615 <$Big>::$from_func(&mut rhs, &lhs);
616 rhs
617 }
618 }
619 impl $TraitAssign<$Base> for $Big {
621 #[inline]
622 fn $method_assign(&mut self, rhs: $Base) {
623 let rhs = <$Big>::$lift_func(self.modulus.clone(), rhs);
624 <$Big>::$func(self, &rhs)
625 }
626 }
627 };
628}
629
630impl_self_binary!(ModPoly, add, add, Add { add }, AddAssign { add_assign });
631impl_int_binary!(
632 ModPoly,
633 Integer,
634 add,
635 add,
636 from_int,
637 Add { add },
638 AddAssign { add_assign }
639);
640impl_self_binary!(
641 ModPoly,
642 sub,
643 sub_from,
644 Sub { sub },
645 SubAssign { sub_assign }
646);
647impl_int_binary!(
648 ModPoly,
649 Integer,
650 sub,
651 sub_from,
652 from_int,
653 Sub { sub },
654 SubAssign { sub_assign }
655);
656impl_self_binary!(ModPoly, mul, mul, Mul { mul }, MulAssign { mul_assign });
657impl_int_binary!(
658 ModPoly,
659 Integer,
660 mul,
661 mul,
662 from_int,
663 Mul { mul },
664 MulAssign { mul_assign }
665);
666impl_self_binary!(
667 ModPoly,
668 div,
669 div_from,
670 Div { div },
671 DivAssign { div_assign }
672);
673impl_int_binary!(
674 ModPoly,
675 Integer,
676 div,
677 div_from,
678 from_int,
679 Div { div },
680 DivAssign { div_assign }
681);
682impl_self_binary!(
683 ModPoly,
684 rem,
685 rem_from,
686 Rem { rem },
687 RemAssign { rem_assign }
688);
689impl_int_binary!(
690 ModPoly,
691 Integer,
692 rem,
693 rem_from,
694 from_int,
695 Rem { rem },
696 RemAssign { rem_assign }
697);
698
699use std::convert::From;
700
701#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
703pub struct ModPolySer {
704 pub modulus: Integer,
705 pub coefficients: Vec<Integer>,
706}
707
708impl From<ModPolySer> for ModPoly {
709 fn from(other: ModPolySer) -> ModPoly {
710 let mut inner = ModPoly::new(other.modulus.clone());
711 for (i, c) in other.coefficients.into_iter().enumerate() {
712 inner.set_coefficient(i, &c);
713 }
714 inner
715 }
716}
717
718impl From<&ModPoly> for ModPolySer {
719 fn from(other: &ModPoly) -> ModPolySer {
720 let modulus = other.modulus().clone();
721 let coefficients = (0..(other.len()))
722 .into_iter()
723 .map(|i| other.get_coefficient(i).clone())
724 .collect();
725 ModPolySer {
726 modulus,
727 coefficients,
728 }
729 }
730}
731
732impl Serialize for ModPoly {
733 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
734 where
735 S: Serializer,
736 {
737 ModPolySer::from(self).serialize(serializer)
738 }
739}
740
741impl<'de> Deserialize<'de> for ModPoly {
742 fn deserialize<D>(deserializer: D) -> Result<ModPoly, D::Error>
743 where
744 D: Deserializer<'de>,
745 {
746 ModPolySer::deserialize(deserializer).map(ModPoly::from)
747 }
748}
749
750#[cfg(test)]
751mod test {
752 use super::*;
753 use quickcheck;
754 use quickcheck_macros;
755
756 #[test]
757 fn init() {
758 let p = Integer::from(17);
759 let f = ModPoly::new(p);
760 assert_eq!(f.len(), 0);
761 }
762
763 #[test]
764 fn from_const() {
765 let p = Integer::from(17);
766 let f = ModPoly::from_int(p, Integer::from(0));
767 assert_eq!(f.len(), 0);
768 assert_eq!(f.evaluate(&Integer::from(0)), Integer::from(0));
769 }
770
771 #[test]
772 fn just_set() {
773 let p = Integer::from(17);
774 let mut f = ModPoly::new(p);
775 f.set_coefficient_ui(0, 1);
776 assert_eq!(f.len(), 1);
777 f.set_coefficient_ui(5, 1);
778 assert_eq!(f.len(), 6);
779 f.set_coefficient_ui(5, 0);
780 assert_eq!(f.len(), 1);
781 }
782
783 #[test]
784 fn set_get() {
785 let p = Integer::from(17);
786 let mut f = ModPoly::new(p);
787 f.set_coefficient_ui(0, 1);
788 assert_eq!(f.get_coefficient(0), Integer::from(1));
789 f.set_coefficient(5, &Integer::from(5));
790 for i in 1..5 {
791 assert_eq!(f.get_coefficient(i), Integer::from(0));
792 }
793 assert_eq!(f.get_coefficient(5), Integer::from(5));
794 }
795
796 #[test]
797 fn add() {
798 let p = Integer::from(17);
799 let mut f = ModPoly::new(p.clone());
800 f.set_coefficient_ui(0, 1);
801 let mut g = ModPoly::new(p);
802 g.set_coefficient_ui(3, 1);
803 let h = f.clone() + g.clone();
804 assert_eq!(h.get_coefficient(0), Integer::from(1));
805 assert_eq!(h.get_coefficient(1), Integer::from(0));
806 assert_eq!(h.get_coefficient(2), Integer::from(0));
807 assert_eq!(h.get_coefficient(3), Integer::from(1));
808 assert_eq!(h.len(), 4);
809 assert_eq!(h, f.clone() + &g);
810 assert_eq!(h, &f + g.clone());
811 assert_eq!(h, g.clone() + Integer::from(1));
812 assert_eq!(h, Integer::from(1) + g.clone());
813 }
814
815 #[test]
816 fn sub() {
817 let p = Integer::from(17);
818 let mut f = ModPoly::new(p.clone());
819 f.set_coefficient_ui(0, 1);
820 let mut g = ModPoly::new(p);
821 g.set_coefficient_ui(3, 1);
822 let h = f.clone() - g.clone();
823 assert_eq!(h.get_coefficient(0), Integer::from(1));
824 assert_eq!(h.get_coefficient(1), Integer::from(0));
825 assert_eq!(h.get_coefficient(2), Integer::from(0));
826 assert_eq!(h.get_coefficient(3), Integer::from(16));
827 assert_eq!(h.len(), 4);
828 assert_eq!(h, f.clone() - &g);
829 assert_eq!(h, &f - g.clone());
830 assert_eq!(h, Integer::from(1) - g.clone());
831 }
832
833 #[test]
834 fn mul() {
835 let p = Integer::from(17);
836 let mut f = ModPoly::new(p.clone());
837 f.set_coefficient_ui(1, 2);
838 let mut g = ModPoly::new(p);
839 g.set_coefficient_ui(3, 1);
840 let h = f.clone() * g.clone();
841 assert_eq!(h.get_coefficient(0), Integer::from(0));
842 assert_eq!(h.get_coefficient(1), Integer::from(0));
843 assert_eq!(h.get_coefficient(2), Integer::from(0));
844 assert_eq!(h.get_coefficient(3), Integer::from(0));
845 assert_eq!(h.get_coefficient(4), Integer::from(2));
846 assert_eq!(h.len(), 5);
847 assert_eq!(h, f.clone() * &g);
848 assert_eq!(h, &f * g.clone());
849 assert_eq!(h, h.clone() * Integer::from(1));
850 assert_eq!(h, Integer::from(1) * h.clone());
851 }
852 #[test]
853 fn mul_wrap() {
854 let p = Integer::from(17);
855 let mut g = ModPoly::new(p);
856 g.set_coefficient_ui(3, 1);
857 g.set_coefficient_ui(0, 5);
858 let h = g.clone() * Integer::from(4);
859 assert_eq!(h.get_coefficient(0), Integer::from(3));
860 assert_eq!(h.get_coefficient(1), Integer::from(0));
861 assert_eq!(h.get_coefficient(2), Integer::from(0));
862 assert_eq!(h.get_coefficient(3), Integer::from(4));
863 assert_eq!(h.len(), 4);
864 }
865
866 #[test]
867 fn div() {
868 let p = Integer::from(17);
869 let mut f = ModPoly::new(p.clone());
870 f.set_coefficient_ui(1, 1);
871 let mut g = ModPoly::new(p);
872 g.set_coefficient_ui(3, 1);
873 let h = g.clone() / f.clone();
874 assert_eq!(h.get_coefficient(0), Integer::from(0));
875 assert_eq!(h.get_coefficient(1), Integer::from(0));
876 assert_eq!(h.get_coefficient(2), Integer::from(1));
877 assert_eq!(h.len(), 3);
878 assert_eq!(h, g.clone() / &f);
879 assert_eq!(h, &g / f.clone());
880 assert_eq!(h, h.clone() / Integer::from(1));
881 }
882
883 fn test_interpolate_from_mul_subgroup(
884 ys: Vec<isize>,
885 m: usize,
886 w: usize,
887 expected_cs: Vec<isize>,
888 ) {
889 let n = ys.len();
890 let p = ModPoly::interpolate_from_mul_subgroup(
891 ys.into_iter().map(Integer::from).collect(),
892 Integer::from(m),
893 &Integer::from(w),
894 );
895 for i in 0..n {
896 assert_eq!(
897 p.get_coefficient(i),
898 expected_cs[i],
899 "Difference in coefficient {}: expected {} but got {}",
900 i,
901 expected_cs[i],
902 p.get_coefficient(i)
903 );
904 }
905 }
906
907 #[test]
908 fn interpolate_zero_mod_5() {
909 test_interpolate_from_mul_subgroup(vec![0, 0, 0, 0], 5, 2, vec![0, 0, 0, 0]);
910 }
911 #[test]
912 fn interpolate_const_mod_5() {
913 test_interpolate_from_mul_subgroup(vec![3, 3, 3, 3], 5, 2, vec![3, 0, 0, 0]);
914 }
915 #[test]
916 fn interpolate_line_mod_5() {
917 test_interpolate_from_mul_subgroup(vec![1, 0, 3, 4], 5, 2, vec![2, 4, 0, 0]);
918 }
919 #[test]
920 fn interpolate_poly_mod_5() {
921 test_interpolate_from_mul_subgroup(vec![4, 0, 0, 0], 5, 2, vec![1, 1, 1, 1]);
922 }
923
924 #[derive(Debug, Clone)]
925 struct Usize16([u32; 16]);
926
927 impl quickcheck::Arbitrary for Usize16 {
928 fn arbitrary<G: quickcheck::Gen>(g: &mut G) -> Self {
929 let mut a = [0u32; 16];
930 for i in &mut a {
931 *i = g.next_u32();
932 }
933 Usize16(a)
934 }
935 }
936
937 #[quickcheck_macros::quickcheck]
938 fn test_interpolate_rountrip(ys: Usize16) -> bool {
939 let m = Integer::from(17);
940 let w = Integer::from(3);
941 let Usize16(mut ys) = ys;
942 for i in &mut ys {
943 *i %= 17;
944 }
945 let ys: Vec<Integer> = ys.iter().cloned().map(Integer::from).collect();
946 let p = ModPoly::interpolate_from_mul_subgroup(ys.clone(), m.clone(), &w);
947 let ys2 = p.evaluate_over_mul_subgroup(&w, 16);
948 ys == ys2
949 }
950
951 fn test_derivative_xgcd(roots: Vec<isize>, m: Integer) {
952 let p = ModPoly::with_roots(roots.into_iter().map(Integer::from), &m);
953 let dp = p.derivative();
954 let (g, s, t) = p.xgcd(&dp);
955 assert_eq!(g.len(), 1);
956 assert_eq!(g, p * s + dp * t);
957 }
958
959 #[test]
960 fn test_xgcd() {
961 test_derivative_xgcd(vec![0], Integer::from(17));
962 test_derivative_xgcd(vec![0, 1], Integer::from(17));
963 test_derivative_xgcd(vec![0, 1, 2], Integer::from(17));
964 test_derivative_xgcd(vec![0, 4, 5], Integer::from(17));
965 }
966
967 #[test]
968 #[ignore]
969 fn bench_xgcd() {
970 let bls_12_381_r = Integer::from_str_radix(
971 "52435875175126190479447740508185965837690552500527637822603658699938581184513",
972 10,
973 )
974 .unwrap();
975 for log_n in 4..16 {
976 let n = 1 << log_n;
977 let roots: Vec<usize> = (0..n).collect();
978 let p = ModPoly::with_roots(roots.into_iter().map(Integer::from), &bls_12_381_r);
979 let dp = p.derivative();
980 let start = std::time::Instant::now();
981 let (g, _s, _t) = p.xgcd(&dp);
982 let duration = start.elapsed();
983 let nanos_per = duration.as_nanos() / n as u128;
984 println!("{log_n:>2}: {n:>8}: {duration:>8.1?} {nanos_per}ns/deg");
985 assert_eq!(g.len(), 1);
986 }
987 }
988}