rug_polynomial/
lib.rs

1use flint_sys::{self, deps, fmpz::*, fmpz_mod::*, fmpz_mod_poly::*};
2use rug::Integer;
3use rug_fft;
4use serde::de::Deserializer;
5use serde::ser::Serializer;
6use serde::{Deserialize, Serialize};
7
8use std::cmp::*;
9use std::fmt::{self, Debug, Display, Formatter};
10use std::mem::MaybeUninit;
11use std::ops::*;
12
13mod flint_rug_bridge;
14
15/// A polynomial with modular integer coefficients.
16///
17/// That is, a member of `Z/nZ[X]`.
18///
19/// # Examples
20///
21/// See constructors:
22///
23///    * [`new`](fn.ModPoly.new)
24///    * [`interpolate_from_mul_subgroup`](fn.ModPoly.interpolate_from_mul_subgroup)
25///
26pub struct ModPoly {
27    raw: fmpz_mod_poly_struct,
28    ctx: fmpz_mod_ctx,
29    modulus: Integer,
30}
31
32impl ModPoly {
33    /// A new polynomial, equal to zero.
34    pub fn new(modulus: Integer) -> Self {
35        unsafe {
36            let mut raw = MaybeUninit::uninit();
37            let mut ctx = MaybeUninit::uninit();
38            let mut flint_modulus = flint_rug_bridge::int_to_fmpz(&modulus);
39            fmpz_mod_ctx_init(ctx.as_mut_ptr(), &flint_modulus);
40            fmpz_clear(&mut flint_modulus);
41            let ctx = ctx.assume_init();
42            fmpz_mod_poly_init(raw.as_mut_ptr(), &ctx as *const _ as *mut _);
43            ModPoly {
44                raw: raw.assume_init(),
45                ctx,
46                modulus,
47            }
48        }
49    }
50
51    /// A new polynomial, equal to `constant`.
52    pub fn from_int(modulus: Integer, mut constant: Integer) -> Self {
53        constant %= &modulus;
54        let mut this = ModPoly::new(modulus);
55        this.set_coefficient(0, &constant);
56        this
57    }
58
59    /// A new polynomial, equal to zero, with room for `n` coefficients.
60    pub fn with_capacity(modulus: Integer, n: usize) -> Self {
61        unsafe {
62            let mut raw = MaybeUninit::uninit();
63            let mut flint_modulus = flint_rug_bridge::int_to_fmpz(&modulus);
64            let mut ctx = MaybeUninit::uninit();
65            fmpz_mod_ctx_init(ctx.as_mut_ptr(), &flint_modulus);
66            fmpz_clear(&mut flint_modulus);
67            let ctx = ctx.assume_init();
68            fmpz_mod_poly_init2(
69                raw.as_mut_ptr(),
70                n as deps::slong,
71                &ctx as *const _ as *mut _,
72            );
73            ModPoly {
74                raw: raw.assume_init(),
75                modulus,
76                ctx,
77            }
78        }
79    }
80
81    /// Interpolate a polynomial which agrees with the given values over a multiplicative subgroup
82    /// of the prime field with modulus `m`.
83    ///
84    /// Let `n` be a power of two and the order of the multiplicative subgroup generated by `w`
85    /// modulo `m`. Let `ys` be a vector of values at `1`, `w`, `w^2`, ...
86    ///
87    /// Returns a polynomial `f`, such that for `i` in `0..n`, `f(w^i) = ys[i] mod m`.
88    ///
89    /// # Panics
90    ///
91    /// If `n` is not a power of two, or if `w` does not generate a subgroup of order `n`.
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// use rug_polynomial::*;
97    /// use rug::Integer;
98    ///
99    /// let m = Integer::from(5);
100    /// let w = Integer::from(2);
101    /// let ys: Vec<Integer> = vec![2, 3, 0, 4].into_iter().map(Integer::from).collect();
102    /// let p = ModPoly::interpolate_from_mul_subgroup(ys, m, &w);
103    /// debug_assert_eq!(p.len(), 2);
104    /// debug_assert_eq!(p.get_coefficient(0), Integer::from(1));
105    /// debug_assert_eq!(p.get_coefficient(1), Integer::from(1));
106    /// ```
107    pub fn interpolate_from_mul_subgroup(mut ys: Vec<Integer>, m: Integer, w: &Integer) -> Self {
108        rug_fft::cooley_tukey_radix_2_intt(&mut ys, &m, w);
109        let mut p = ModPoly::with_capacity(m, ys.len());
110        for (i, c) in ys.iter().enumerate() {
111            p.set_coefficient(i, c);
112        }
113        p
114    }
115
116    /// Evaluate this polynomial over the multiplicative subgroup generated by `w`, of size `n`.
117    ///
118    /// Returns list of evaluations, over `{1, w, w^2, ... w^(2^n-1)}`.
119    ///
120    /// # Panics
121    ///
122    /// If `n` is not a power of two, or if `w` does not generate a subgroup of order `n`.
123    ///
124    /// # Examples
125    ///
126    /// ```
127    /// use rug_polynomial::*;
128    /// use rug::Integer;
129    ///
130    /// let m = Integer::from(5);
131    /// let w = Integer::from(2);
132    /// let ys: Vec<Integer> = vec![2, 3, 0, 4].into_iter().map(Integer::from).collect();
133    /// let mut p = ModPoly::new(m);
134    /// p.set_coefficient_ui(0, 1);
135    /// p.set_coefficient_ui(1, 1);
136    /// let vs = p.evaluate_over_mul_subgroup(&Integer::from(2), 4);
137    /// let vs: Vec<usize> = vs.into_iter().map(|i| i.to_usize().unwrap()).collect();
138    /// debug_assert_eq!(vs, vec![2, 3, 0, 4]);
139    /// ```
140    pub fn evaluate_over_mul_subgroup(&self, w: &Integer, n: usize) -> Vec<Integer> {
141        let mut cs: Vec<Integer> = (0..n)
142            .into_iter()
143            .map(|i| self.get_coefficient(i))
144            .collect();
145        rug_fft::cooley_tukey_radix_2_ntt(&mut cs, &self.modulus, w);
146        cs
147    }
148
149    /// Returns the minimal-degree monic polynomial with the given roots.
150    ///
151    /// # Example
152    ///
153    /// ```
154    /// use rug_polynomial::*;
155    /// use rug::Integer;
156    ///
157    /// let p = ModPoly::with_roots(vec![0, 1].into_iter().map(Integer::from), &Integer::from(5));
158    /// debug_assert_eq!(p.len(), 3);
159    /// debug_assert_eq!(p.get_coefficient(0), Integer::from(0));
160    /// debug_assert_eq!(p.get_coefficient(1), Integer::from(4));
161    /// debug_assert_eq!(p.get_coefficient(2), Integer::from(1));
162    /// ```
163    pub fn with_roots(xs: impl IntoIterator<Item = Integer>, m: &Integer) -> Self {
164        let mut ps = xs
165            .into_iter()
166            .map(|x| {
167                let mut p = ModPoly::new(m.clone());
168                p.set_coefficient_ui(1, 1);
169                p.set_coefficient(0, &-x);
170                p
171            })
172            .collect::<Vec<_>>();
173        while ps.len() > 1 {
174            for i in 0..(ps.len() / 2) {
175                let back = ps.pop().unwrap();
176                ps[i] *= &back;
177            }
178        }
179        ps.pop().unwrap_or_else(|| {
180            let mut p = ModPoly::new(m.clone());
181            p.set_coefficient_ui(0, 1);
182            p
183        })
184    }
185
186    /// Reallocates the polynomial to have room for `n` coefficients. Truncates the polynomial if
187    /// it has more than `n` coefficients.
188    pub fn reserve(&mut self, n: usize) {
189        unsafe {
190            fmpz_mod_poly_realloc(&mut self.raw, n as deps::slong, &mut self.ctx);
191        }
192    }
193
194    /// Evaluate the polynomial at the given input.
195    ///
196    /// # Example
197    ///
198    /// ```
199    /// use rug_polynomial::*;
200    /// use rug::Integer;
201    ///
202    /// let p = ModPoly::with_roots(vec![0, 1].into_iter().map(Integer::from), &Integer::from(5));
203    /// let y = p.evaluate(&Integer::from(3));
204    /// debug_assert_eq!(y, Integer::from(1));
205    /// ```
206    pub fn evaluate(&self, i: &Integer) -> Integer {
207        unsafe {
208            let mut in_ = flint_rug_bridge::int_to_fmpz(i);
209
210            let mut out = fmpz::default();
211            fmpz_init(&mut out);
212            fmpz_mod_poly_evaluate_fmpz(
213                &mut out,
214                &self.raw as *const _ as *mut _,
215                &mut in_,
216                &self.ctx as *const _ as *mut _,
217            );
218
219            let out_rug = flint_rug_bridge::fmpz_to_int(&out);
220            fmpz_clear(&mut in_);
221            fmpz_clear(&mut out);
222            out_rug
223        }
224    }
225
226    /// Get the modulus of this polynomial.
227    pub fn modulus(&self) -> &Integer {
228        &self.modulus
229    }
230
231    /// Get the `i`th coefficient
232    pub fn get_coefficient(&self, i: usize) -> Integer {
233        unsafe {
234            let mut c = fmpz::default();
235            fmpz_init(&mut c);
236            fmpz_mod_poly_get_coeff_fmpz(
237                &mut c,
238                &self.raw as *const _ as *mut _,
239                i as deps::slong,
240                &self.ctx as *const _ as *mut _,
241            );
242            let c_gmp = flint_rug_bridge::fmpz_to_int(&c);
243            fmpz_clear(&mut c);
244            c_gmp % &self.modulus
245        }
246    }
247
248    /// Set the `i`th coefficient to be `c`
249    pub fn set_coefficient(&mut self, i: usize, c: &Integer) {
250        unsafe {
251            let mut c_flint = flint_rug_bridge::int_to_fmpz(c);
252            fmpz_mod_poly_set_coeff_fmpz(
253                &mut self.raw,
254                i as deps::slong,
255                &mut c_flint,
256                &mut self.ctx,
257            );
258            fmpz_clear(&mut c_flint);
259        }
260    }
261
262    /// Set the `i`th coefficient to be `c`
263    pub fn set_coefficient_ui(&mut self, i: usize, c: usize) {
264        unsafe {
265            fmpz_mod_poly_set_coeff_ui(
266                &mut self.raw,
267                i as deps::slong,
268                c as deps::ulong,
269                &mut self.ctx,
270            );
271        }
272    }
273
274    /// The number of coefficients in the polynomial. One more than the degree.
275    pub fn len(&self) -> usize {
276        unsafe {
277            fmpz_mod_poly_length(
278                &self.raw as *const _ as *mut _,
279                &self.ctx as *const _ as *mut _,
280            ) as usize
281        }
282    }
283
284    /// `self = -self`
285    pub fn neg(&mut self) {
286        unsafe {
287            fmpz_mod_poly_neg(&mut self.raw, &mut self.raw, &mut self.ctx);
288        }
289    }
290
291    /// `self = self + other`
292    pub fn add(&mut self, other: &Self) {
293        assert_eq!(self.modulus, other.modulus);
294        unsafe {
295            fmpz_mod_poly_add(
296                &mut self.raw,
297                &mut self.raw,
298                &other.raw as *const _ as *mut _,
299                &mut self.ctx,
300            );
301        }
302    }
303
304    /// `self = self - other`
305    pub fn sub(&mut self, other: &Self) {
306        assert_eq!(self.modulus, other.modulus);
307        unsafe {
308            fmpz_mod_poly_sub(
309                &mut self.raw,
310                &mut self.raw,
311                &other.raw as *const _ as *mut _,
312                &mut self.ctx,
313            );
314        }
315    }
316
317    /// `self = other - self`
318    pub fn sub_from(&mut self, other: &Self) {
319        assert_eq!(self.modulus, other.modulus);
320        unsafe {
321            fmpz_mod_poly_sub(
322                &mut self.raw,
323                &other.raw as *const _ as *mut _,
324                &mut self.raw,
325                &mut self.ctx,
326            );
327        }
328    }
329
330    /// `self = self * other`
331    pub fn mul(&mut self, other: &Self) {
332        assert_eq!(self.modulus, other.modulus);
333        unsafe {
334            fmpz_mod_poly_mul(
335                &mut self.raw,
336                &mut self.raw,
337                &other.raw as *const _ as *mut _,
338                &mut self.ctx,
339            );
340        }
341    }
342
343    /// Find `q` and `r` such that `self = other * q + r` and `r` has degree less than `other`.
344    ///
345    /// ## Returns
346    ///
347    /// `(q, r)`
348    pub fn divrem(&self, other: &Self) -> (ModPoly, ModPoly) {
349        assert_eq!(self.modulus, other.modulus);
350        let mut q = ModPoly::new(self.modulus.clone());
351        let mut r = ModPoly::new(self.modulus.clone());
352        unsafe {
353            fmpz_mod_poly_divrem(
354                &mut q.raw,
355                &mut r.raw,
356                &self.raw as *const _ as *mut _,
357                &other.raw as *const _ as *mut _,
358                &self.ctx as *const _ as *mut _,
359            );
360        }
361        (q, r)
362    }
363
364    /// `self = self / other`
365    pub fn div(&mut self, other: &Self) {
366        assert_eq!(self.modulus, other.modulus);
367        let mut r = ModPoly::new(self.modulus.clone());
368        unsafe {
369            fmpz_mod_poly_divrem(
370                &mut self.raw,
371                &mut r.raw,
372                &mut self.raw,
373                &other.raw as *const _ as *mut _,
374                &mut self.ctx,
375            );
376        }
377    }
378
379    /// `self = other / self`
380    pub fn div_from(&mut self, other: &Self) {
381        assert_eq!(self.modulus, other.modulus);
382        let mut r = ModPoly::new(self.modulus.clone());
383        unsafe {
384            fmpz_mod_poly_divrem(
385                &mut self.raw,
386                &mut r.raw,
387                &other.raw as *const _ as *mut _,
388                &mut self.raw,
389                &mut self.ctx,
390            );
391        }
392    }
393
394    /// `self = self % other`
395    pub fn rem(&mut self, other: &Self) {
396        assert_eq!(self.modulus, other.modulus);
397        let mut q = ModPoly::new(self.modulus.clone());
398        unsafe {
399            fmpz_mod_poly_divrem(
400                &mut q.raw,
401                &mut self.raw,
402                &mut self.raw,
403                &other.raw as *const _ as *mut _,
404                &mut self.ctx,
405            );
406        }
407    }
408
409    /// `self = other % self`
410    pub fn rem_from(&mut self, other: &Self) {
411        assert_eq!(self.modulus, other.modulus);
412        let mut q = ModPoly::new(self.modulus.clone());
413        unsafe {
414            fmpz_mod_poly_divrem(
415                &mut q.raw,
416                &mut self.raw,
417                &other.raw as *const _ as *mut _,
418                &mut self.raw,
419                &mut self.ctx,
420            );
421        }
422    }
423
424    /// `self = self * self`
425    pub fn sqr(&mut self) {
426        unsafe {
427            fmpz_mod_poly_sqr(&mut self.raw, &mut self.raw, &mut self.ctx);
428        }
429    }
430
431    /// From `(a, b)`, returns `(g, s, t)` such that `g | a`, `g | b` and `g = a*s + b*t`.
432    pub fn xgcd(&self, other: &Self) -> (Self, Self, Self) {
433        assert_eq!(self.modulus, other.modulus);
434        let mut g = ModPoly::new(self.modulus.clone());
435        let mut s = ModPoly::new(self.modulus.clone());
436        let mut t = ModPoly::new(self.modulus.clone());
437        unsafe {
438            fmpz_mod_poly_xgcd(
439                &mut g.raw,
440                &mut s.raw,
441                &mut t.raw,
442                &self.raw as *const _ as *mut _,
443                &other.raw as *const _ as *mut _,
444                &self.ctx as *const _ as *mut _,
445            );
446        }
447        (g, s, t)
448    }
449
450    /// Give the formal derivative of `self`.
451    pub fn derivative(&self) -> Self {
452        let mut d_self = ModPoly::new(self.modulus.clone());
453        unsafe {
454            fmpz_mod_poly_derivative(
455                &mut d_self.raw,
456                &self.raw as *const _ as *mut _,
457                &self.ctx as *const _ as *mut _,
458            );
459        }
460        d_self
461    }
462}
463
464impl Clone for ModPoly {
465    fn clone(&self) -> Self {
466        let mut this = ModPoly::new(self.modulus.clone());
467        unsafe {
468            fmpz_mod_poly_set(
469                &mut this.raw,
470                &self.raw as *const _ as *mut _,
471                &self.ctx as *const _ as *mut _,
472            );
473        }
474        this
475    }
476}
477
478impl Drop for ModPoly {
479    fn drop(&mut self) {
480        unsafe {
481            fmpz_mod_poly_clear(&mut self.raw, &mut self.ctx);
482            fmpz_mod_ctx_clear(&mut self.ctx);
483        }
484    }
485}
486
487impl PartialEq<ModPoly> for ModPoly {
488    fn eq(&self, other: &ModPoly) -> bool {
489        unsafe {
490            fmpz_mod_poly_equal(
491                &self.raw as *const _ as *mut _,
492                &other.raw as *const _ as *mut _,
493                &self.ctx as *const _ as *mut _,
494            ) != 0
495        }
496    }
497}
498impl Eq for ModPoly {}
499
500impl Debug for ModPoly {
501    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
502        f.debug_struct("ModPoly")
503            .field("modulus", &self.modulus)
504            .field(
505                "coefficients",
506                &(0..self.len())
507                    .map(|i| self.get_coefficient(i))
508                    .collect::<Vec<_>>(),
509            )
510            .finish()
511    }
512}
513
514impl Display for ModPoly {
515    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
516        let n = self.len();
517        let mut first = true;
518        for i in 0..n {
519            let j = n - i - 1;
520            let c = self.get_coefficient(j);
521            if c != 0 {
522                if !first {
523                    write!(f, " + ")?;
524                }
525                write!(f, "{}", c)?;
526                if j != 0 {
527                    write!(f, "x^{}", j)?;
528                }
529                first = false;
530            }
531        }
532        if first {
533            write!(f, "0")?;
534        }
535        Ok(())
536    }
537}
538
539macro_rules! impl_self_binary {
540    ($Big:ty,
541     $func:ident,
542     $from_func:ident,
543     $Trait:ident { $method:ident },
544     $TraitAssign:ident { $method_assign:ident }
545    ) => {
546        // Big + Big
547        impl $Trait<$Big> for $Big {
548            type Output = $Big;
549            #[inline]
550            fn $method(mut self, rhs: $Big) -> $Big {
551                self.$method_assign(rhs);
552                self
553            }
554        }
555        // Big + &Big
556        impl $Trait<&$Big> for $Big {
557            type Output = $Big;
558            #[inline]
559            fn $method(mut self, rhs: &$Big) -> $Big {
560                self.$method_assign(rhs);
561                self
562            }
563        }
564        // &Big + Big
565        impl $Trait<$Big> for &$Big {
566            type Output = $Big;
567            #[inline]
568            fn $method(self, mut rhs: $Big) -> $Big {
569                <$Big>::$from_func(&mut rhs, self);
570                rhs
571            }
572        }
573        // Big += Big
574        impl $TraitAssign<$Big> for $Big {
575            #[inline]
576            fn $method_assign(&mut self, rhs: $Big) {
577                <$Big>::$func(self, &rhs)
578            }
579        }
580        // Big += &Big
581        impl $TraitAssign<&$Big> for $Big {
582            #[inline]
583            fn $method_assign(&mut self, rhs: &$Big) {
584                <$Big>::$func(self, rhs)
585            }
586        }
587    };
588}
589
590macro_rules! impl_int_binary {
591    ($Big:ty,
592     $Base:ty,
593     $func:ident,
594     $from_func:ident,
595     $lift_func:ident,
596     $Trait:ident { $method:ident },
597     $TraitAssign:ident { $method_assign:ident }
598    ) => {
599        // Big + &Base
600        impl $Trait<$Base> for $Big {
601            type Output = $Big;
602            #[inline]
603            fn $method(mut self, rhs: $Base) -> $Big {
604                let rhs = <$Big>::$lift_func(self.modulus.clone(), rhs);
605                <$Big>::$func(&mut self, &rhs);
606                self
607            }
608        }
609        // &Base + Big
610        impl $Trait<$Big> for $Base {
611            type Output = $Big;
612            #[inline]
613            fn $method(self, mut rhs: $Big) -> $Big {
614                let lhs = <$Big>::$lift_func(rhs.modulus.clone(), self);
615                <$Big>::$from_func(&mut rhs, &lhs);
616                rhs
617            }
618        }
619        // Big += &Base
620        impl $TraitAssign<$Base> for $Big {
621            #[inline]
622            fn $method_assign(&mut self, rhs: $Base) {
623                let rhs = <$Big>::$lift_func(self.modulus.clone(), rhs);
624                <$Big>::$func(self, &rhs)
625            }
626        }
627    };
628}
629
630impl_self_binary!(ModPoly, add, add, Add { add }, AddAssign { add_assign });
631impl_int_binary!(
632    ModPoly,
633    Integer,
634    add,
635    add,
636    from_int,
637    Add { add },
638    AddAssign { add_assign }
639);
640impl_self_binary!(
641    ModPoly,
642    sub,
643    sub_from,
644    Sub { sub },
645    SubAssign { sub_assign }
646);
647impl_int_binary!(
648    ModPoly,
649    Integer,
650    sub,
651    sub_from,
652    from_int,
653    Sub { sub },
654    SubAssign { sub_assign }
655);
656impl_self_binary!(ModPoly, mul, mul, Mul { mul }, MulAssign { mul_assign });
657impl_int_binary!(
658    ModPoly,
659    Integer,
660    mul,
661    mul,
662    from_int,
663    Mul { mul },
664    MulAssign { mul_assign }
665);
666impl_self_binary!(
667    ModPoly,
668    div,
669    div_from,
670    Div { div },
671    DivAssign { div_assign }
672);
673impl_int_binary!(
674    ModPoly,
675    Integer,
676    div,
677    div_from,
678    from_int,
679    Div { div },
680    DivAssign { div_assign }
681);
682impl_self_binary!(
683    ModPoly,
684    rem,
685    rem_from,
686    Rem { rem },
687    RemAssign { rem_assign }
688);
689impl_int_binary!(
690    ModPoly,
691    Integer,
692    rem,
693    rem_from,
694    from_int,
695    Rem { rem },
696    RemAssign { rem_assign }
697);
698
699use std::convert::From;
700
701/// Serializable modular polynomial
702#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
703pub struct ModPolySer {
704    pub modulus: Integer,
705    pub coefficients: Vec<Integer>,
706}
707
708impl From<ModPolySer> for ModPoly {
709    fn from(other: ModPolySer) -> ModPoly {
710        let mut inner = ModPoly::new(other.modulus.clone());
711        for (i, c) in other.coefficients.into_iter().enumerate() {
712            inner.set_coefficient(i, &c);
713        }
714        inner
715    }
716}
717
718impl From<&ModPoly> for ModPolySer {
719    fn from(other: &ModPoly) -> ModPolySer {
720        let modulus = other.modulus().clone();
721        let coefficients = (0..(other.len()))
722            .into_iter()
723            .map(|i| other.get_coefficient(i).clone())
724            .collect();
725        ModPolySer {
726            modulus,
727            coefficients,
728        }
729    }
730}
731
732impl Serialize for ModPoly {
733    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
734    where
735        S: Serializer,
736    {
737        ModPolySer::from(self).serialize(serializer)
738    }
739}
740
741impl<'de> Deserialize<'de> for ModPoly {
742    fn deserialize<D>(deserializer: D) -> Result<ModPoly, D::Error>
743    where
744        D: Deserializer<'de>,
745    {
746        ModPolySer::deserialize(deserializer).map(ModPoly::from)
747    }
748}
749
750#[cfg(test)]
751mod test {
752    use super::*;
753    use quickcheck;
754    use quickcheck_macros;
755
756    #[test]
757    fn init() {
758        let p = Integer::from(17);
759        let f = ModPoly::new(p);
760        assert_eq!(f.len(), 0);
761    }
762
763    #[test]
764    fn from_const() {
765        let p = Integer::from(17);
766        let f = ModPoly::from_int(p, Integer::from(0));
767        assert_eq!(f.len(), 0);
768        assert_eq!(f.evaluate(&Integer::from(0)), Integer::from(0));
769    }
770
771    #[test]
772    fn just_set() {
773        let p = Integer::from(17);
774        let mut f = ModPoly::new(p);
775        f.set_coefficient_ui(0, 1);
776        assert_eq!(f.len(), 1);
777        f.set_coefficient_ui(5, 1);
778        assert_eq!(f.len(), 6);
779        f.set_coefficient_ui(5, 0);
780        assert_eq!(f.len(), 1);
781    }
782
783    #[test]
784    fn set_get() {
785        let p = Integer::from(17);
786        let mut f = ModPoly::new(p);
787        f.set_coefficient_ui(0, 1);
788        assert_eq!(f.get_coefficient(0), Integer::from(1));
789        f.set_coefficient(5, &Integer::from(5));
790        for i in 1..5 {
791            assert_eq!(f.get_coefficient(i), Integer::from(0));
792        }
793        assert_eq!(f.get_coefficient(5), Integer::from(5));
794    }
795
796    #[test]
797    fn add() {
798        let p = Integer::from(17);
799        let mut f = ModPoly::new(p.clone());
800        f.set_coefficient_ui(0, 1);
801        let mut g = ModPoly::new(p);
802        g.set_coefficient_ui(3, 1);
803        let h = f.clone() + g.clone();
804        assert_eq!(h.get_coefficient(0), Integer::from(1));
805        assert_eq!(h.get_coefficient(1), Integer::from(0));
806        assert_eq!(h.get_coefficient(2), Integer::from(0));
807        assert_eq!(h.get_coefficient(3), Integer::from(1));
808        assert_eq!(h.len(), 4);
809        assert_eq!(h, f.clone() + &g);
810        assert_eq!(h, &f + g.clone());
811        assert_eq!(h, g.clone() + Integer::from(1));
812        assert_eq!(h, Integer::from(1) + g.clone());
813    }
814
815    #[test]
816    fn sub() {
817        let p = Integer::from(17);
818        let mut f = ModPoly::new(p.clone());
819        f.set_coefficient_ui(0, 1);
820        let mut g = ModPoly::new(p);
821        g.set_coefficient_ui(3, 1);
822        let h = f.clone() - g.clone();
823        assert_eq!(h.get_coefficient(0), Integer::from(1));
824        assert_eq!(h.get_coefficient(1), Integer::from(0));
825        assert_eq!(h.get_coefficient(2), Integer::from(0));
826        assert_eq!(h.get_coefficient(3), Integer::from(16));
827        assert_eq!(h.len(), 4);
828        assert_eq!(h, f.clone() - &g);
829        assert_eq!(h, &f - g.clone());
830        assert_eq!(h, Integer::from(1) - g.clone());
831    }
832
833    #[test]
834    fn mul() {
835        let p = Integer::from(17);
836        let mut f = ModPoly::new(p.clone());
837        f.set_coefficient_ui(1, 2);
838        let mut g = ModPoly::new(p);
839        g.set_coefficient_ui(3, 1);
840        let h = f.clone() * g.clone();
841        assert_eq!(h.get_coefficient(0), Integer::from(0));
842        assert_eq!(h.get_coefficient(1), Integer::from(0));
843        assert_eq!(h.get_coefficient(2), Integer::from(0));
844        assert_eq!(h.get_coefficient(3), Integer::from(0));
845        assert_eq!(h.get_coefficient(4), Integer::from(2));
846        assert_eq!(h.len(), 5);
847        assert_eq!(h, f.clone() * &g);
848        assert_eq!(h, &f * g.clone());
849        assert_eq!(h, h.clone() * Integer::from(1));
850        assert_eq!(h, Integer::from(1) * h.clone());
851    }
852    #[test]
853    fn mul_wrap() {
854        let p = Integer::from(17);
855        let mut g = ModPoly::new(p);
856        g.set_coefficient_ui(3, 1);
857        g.set_coefficient_ui(0, 5);
858        let h = g.clone() * Integer::from(4);
859        assert_eq!(h.get_coefficient(0), Integer::from(3));
860        assert_eq!(h.get_coefficient(1), Integer::from(0));
861        assert_eq!(h.get_coefficient(2), Integer::from(0));
862        assert_eq!(h.get_coefficient(3), Integer::from(4));
863        assert_eq!(h.len(), 4);
864    }
865
866    #[test]
867    fn div() {
868        let p = Integer::from(17);
869        let mut f = ModPoly::new(p.clone());
870        f.set_coefficient_ui(1, 1);
871        let mut g = ModPoly::new(p);
872        g.set_coefficient_ui(3, 1);
873        let h = g.clone() / f.clone();
874        assert_eq!(h.get_coefficient(0), Integer::from(0));
875        assert_eq!(h.get_coefficient(1), Integer::from(0));
876        assert_eq!(h.get_coefficient(2), Integer::from(1));
877        assert_eq!(h.len(), 3);
878        assert_eq!(h, g.clone() / &f);
879        assert_eq!(h, &g / f.clone());
880        assert_eq!(h, h.clone() / Integer::from(1));
881    }
882
883    fn test_interpolate_from_mul_subgroup(
884        ys: Vec<isize>,
885        m: usize,
886        w: usize,
887        expected_cs: Vec<isize>,
888    ) {
889        let n = ys.len();
890        let p = ModPoly::interpolate_from_mul_subgroup(
891            ys.into_iter().map(Integer::from).collect(),
892            Integer::from(m),
893            &Integer::from(w),
894        );
895        for i in 0..n {
896            assert_eq!(
897                p.get_coefficient(i),
898                expected_cs[i],
899                "Difference in coefficient {}: expected {} but got {}",
900                i,
901                expected_cs[i],
902                p.get_coefficient(i)
903            );
904        }
905    }
906
907    #[test]
908    fn interpolate_zero_mod_5() {
909        test_interpolate_from_mul_subgroup(vec![0, 0, 0, 0], 5, 2, vec![0, 0, 0, 0]);
910    }
911    #[test]
912    fn interpolate_const_mod_5() {
913        test_interpolate_from_mul_subgroup(vec![3, 3, 3, 3], 5, 2, vec![3, 0, 0, 0]);
914    }
915    #[test]
916    fn interpolate_line_mod_5() {
917        test_interpolate_from_mul_subgroup(vec![1, 0, 3, 4], 5, 2, vec![2, 4, 0, 0]);
918    }
919    #[test]
920    fn interpolate_poly_mod_5() {
921        test_interpolate_from_mul_subgroup(vec![4, 0, 0, 0], 5, 2, vec![1, 1, 1, 1]);
922    }
923
924    #[derive(Debug, Clone)]
925    struct Usize16([u32; 16]);
926
927    impl quickcheck::Arbitrary for Usize16 {
928        fn arbitrary<G: quickcheck::Gen>(g: &mut G) -> Self {
929            let mut a = [0u32; 16];
930            for i in &mut a {
931                *i = g.next_u32();
932            }
933            Usize16(a)
934        }
935    }
936
937    #[quickcheck_macros::quickcheck]
938    fn test_interpolate_rountrip(ys: Usize16) -> bool {
939        let m = Integer::from(17);
940        let w = Integer::from(3);
941        let Usize16(mut ys) = ys;
942        for i in &mut ys {
943            *i %= 17;
944        }
945        let ys: Vec<Integer> = ys.iter().cloned().map(Integer::from).collect();
946        let p = ModPoly::interpolate_from_mul_subgroup(ys.clone(), m.clone(), &w);
947        let ys2 = p.evaluate_over_mul_subgroup(&w, 16);
948        ys == ys2
949    }
950
951    fn test_derivative_xgcd(roots: Vec<isize>, m: Integer) {
952        let p = ModPoly::with_roots(roots.into_iter().map(Integer::from), &m);
953        let dp = p.derivative();
954        let (g, s, t) = p.xgcd(&dp);
955        assert_eq!(g.len(), 1);
956        assert_eq!(g, p * s + dp * t);
957    }
958
959    #[test]
960    fn test_xgcd() {
961        test_derivative_xgcd(vec![0], Integer::from(17));
962        test_derivative_xgcd(vec![0, 1], Integer::from(17));
963        test_derivative_xgcd(vec![0, 1, 2], Integer::from(17));
964        test_derivative_xgcd(vec![0, 4, 5], Integer::from(17));
965    }
966
967    #[test]
968    #[ignore]
969    fn bench_xgcd() {
970        let bls_12_381_r = Integer::from_str_radix(
971            "52435875175126190479447740508185965837690552500527637822603658699938581184513",
972            10,
973        )
974        .unwrap();
975        for log_n in 4..16 {
976            let n = 1 << log_n;
977            let roots: Vec<usize> = (0..n).collect();
978            let p = ModPoly::with_roots(roots.into_iter().map(Integer::from), &bls_12_381_r);
979            let dp = p.derivative();
980            let start = std::time::Instant::now();
981            let (g, _s, _t) = p.xgcd(&dp);
982            let duration = start.elapsed();
983            let nanos_per = duration.as_nanos() / n as u128;
984            println!("{log_n:>2}: {n:>8}: {duration:>8.1?} {nanos_per}ns/deg");
985            assert_eq!(g.len(), 1);
986        }
987    }
988}