competitive-programming-rs 41.0.0

Competitive Programming Library in Rust
Documentation
pub struct FastFourierTransform {
    modulo: i64,
    sum_e: [i64; 30],
    sum_ie: [i64; 30],
}

impl FastFourierTransform {
    pub fn new(modulo: i64) -> Self {
        let primitive_root = primitive_root(modulo);

        let mut es = [0; 30];
        let mut ies = [0; 30];
        let count2 = (modulo - 1).trailing_zeros();
        let mut e = mod_pow(primitive_root, (modulo - 1) >> count2, modulo);
        let mut ie = mod_inv(e, modulo);
        let count2 = count2 as usize;
        for i in (2..=count2).rev() {
            es[i - 2] = e;
            ies[i - 2] = ie;
            e = (e * e) % modulo;
            ie = (ie * ie) % modulo;
        }

        let mut sum_e = [0; 30];
        let mut now = 1;
        for i in 0..=(count2 - 2) {
            sum_e[i] = (es[i] * now) % modulo;
            now = (now * ies[i]) % modulo;
        }

        let mut es = [0; 30];
        let mut ies = [0; 30];
        let count2 = (modulo - 1).trailing_zeros();
        let mut e = mod_pow(primitive_root, (modulo - 1) >> count2, modulo);
        let mut ie = mod_inv(e, modulo);
        let count2 = count2 as usize;
        for i in (2..=count2).rev() {
            es[i - 2] = e;
            ies[i - 2] = ie;
            e = (e * e) % modulo;
            ie = (ie * ie) % modulo;
        }

        let mut sum_ie = [0; 30];
        let mut now = 1;
        for i in 0..=(count2 - 2) {
            sum_ie[i] = (ies[i] * now) % modulo;
            now = (now * es[i]) % modulo;
        }

        Self {
            sum_e,
            modulo,
            sum_ie,
        }
    }
    fn butterfly(&self, a: &mut [i64]) {
        let h = a.len().next_power_of_two().trailing_zeros();
        for ph in 1..=h {
            let w = 1 << (ph - 1);
            let p = 1 << (h - ph);
            let mut now = 1;
            for s in 0..w {
                let offset = s << (h - ph + 1);
                for i in 0..p {
                    let l = a[i + offset];
                    let r = (a[i + offset + p] * now) % self.modulo;

                    a[i + offset] = l + r;
                    if a[i + offset] >= self.modulo {
                        a[i + offset] -= self.modulo;
                    }

                    a[i + offset + p] = l + self.modulo - r;
                    if a[i + offset + p] >= self.modulo {
                        a[i + offset + p] -= self.modulo;
                    }
                }

                now = (self.sum_e[(!s).trailing_zeros() as usize] * now) % self.modulo;
            }
        }
    }

    fn butterfly_inv(&self, a: &mut [i64]) {
        let h = a.len().next_power_of_two().trailing_zeros();
        for ph in (1..=h).rev() {
            let w = 1 << (ph - 1);
            let p = 1 << (h - ph);
            let mut inv_now = 1;
            for s in 0..w {
                let offset = s << (h - ph + 1);
                for i in 0..p {
                    let l = a[i + offset];
                    let r = a[i + offset + p];

                    a[i + offset] = l + r;
                    if a[i + offset] >= self.modulo {
                        a[i + offset] -= self.modulo;
                    }

                    a[i + offset + p] = ((l + self.modulo - r) * inv_now) % self.modulo;
                }

                inv_now = (self.sum_ie[(!s).trailing_zeros() as usize] * inv_now) % self.modulo;
            }
        }
    }

    pub fn convolution(&self, a: &[i64], b: &[i64]) -> Vec<i64> {
        if a.is_empty() || b.is_empty() {
            return Vec::new();
        }

        let n = a.len();
        let m = b.len();

        let z = (n + m - 1).next_power_of_two();
        let mut a = a.iter().map(|&v| v % self.modulo).collect::<Vec<_>>();
        a.resize(z, 0);
        self.butterfly(&mut a);

        let mut b = b.iter().map(|&v| v % self.modulo).collect::<Vec<_>>();
        b.resize(z, 0);
        self.butterfly(&mut b);

        for i in 0..z {
            a[i] = (a[i] * b[i]) % self.modulo;
        }

        self.butterfly_inv(&mut a);
        a.resize(n + m - 1, 0);
        let iz = mod_inv(z as i64, self.modulo);
        for i in 0..a.len() {
            a[i] = (a[i] * iz) % self.modulo;
        }
        a
    }
}

fn mod_inv(x: i64, m: i64) -> i64 {
    mod_pow(x, m - 2, m)
}

fn mod_pow(x: i64, mut e: i64, m: i64) -> i64 {
    let mut cur = x;
    let mut result = 1;
    while e > 0 {
        if e & 1 == 1 {
            result = (result * cur) % m;
        }
        e >>= 1;
        cur = (cur * cur) % m;
    }
    result
}

fn primitive_root(m: i64) -> i64 {
    if m == 2 {
        return 1;
    };
    if m == 167772161 {
        return 3;
    };
    if m == 469762049 {
        return 3;
    };
    if m == 754974721 {
        return 11;
    };
    if m == 998244353 {
        return 3;
    };
    let mut divs = [0; 20];
    divs[0] = 2;
    let mut cnt = 1;
    let mut x = (m - 1) / 2;
    while x % 2 == 0 {
        x /= 2
    }

    let mut i = 3;
    while i * i <= x {
        if x % i == 0 {
            divs[cnt] = i;
            cnt += 1;
            while x % i == 0 {
                x /= i;
            }
        }
        i += 2;
    }
    if x > 1 {
        divs[cnt] = x;
        cnt += 1;
    }

    for g in 2.. {
        let mut ok = true;
        for i in 0..cnt {
            if mod_pow(g, (m - 1) / divs[i], m) == 1 {
                ok = false;
                break;
            }
        }
        if ok {
            return g;
        }
    }
    unreachable!()
}

#[cfg(test)]
mod tests {
    use super::*;
    use rand::{thread_rng, Rng};

    #[test]
    fn test_fft() {
        let a = vec![1, 2, 3, 4];
        let b = vec![5, 6, 7, 8, 9];
        let m = 998244353;
        let fft = FastFourierTransform::new(m);
        let c = fft.convolution(&a, &b);
        assert_eq!(vec![5, 16, 34, 60, 70, 70, 59, 36], c);
    }

    #[test]
    fn test_fft_rand() {
        let mut rng = thread_rng();
        let modulo = 998244353;
        let fft = FastFourierTransform::new(modulo);

        for _ in 0..10 {
            let n: usize = 5000 + rng.gen_range(0, 5000);
            let m: usize = 5000 + rng.gen_range(0, 5000);
            let a = (0..n)
                .map(|_| rng.gen_range(0, modulo * 2))
                .collect::<Vec<_>>();
            let b = (0..m)
                .map(|_| rng.gen_range(0, modulo * 2))
                .collect::<Vec<_>>();
            let c = fft.convolution(&a, &b);

            let mut check = vec![0; n + m - 1];
            for i in 0..n {
                for j in 0..m {
                    check[i + j] += a[i] * b[j];
                    check[i + j] %= modulo;
                }
            }

            assert_eq!(check, c);
        }
    }

    #[test]
    fn test_primitive_root() {
        assert!(is_primitive_root(2, primitive_root(2)));
        assert!(is_primitive_root(3, primitive_root(3)));
        assert!(is_primitive_root(5, primitive_root(5)));
        assert!(is_primitive_root(7, primitive_root(7)));
        assert!(is_primitive_root(11, primitive_root(11)));
        assert!(is_primitive_root(998244353, primitive_root(998244353)));
        assert!(is_primitive_root(1000000007, primitive_root(1000000007)));
        assert!(is_primitive_root(469762049, primitive_root(469762049)));
        assert!(is_primitive_root(167772161, primitive_root(167772161)));
        assert!(is_primitive_root(754974721, primitive_root(754974721)));
        assert!(is_primitive_root(324013369, primitive_root(324013369)));
        assert!(is_primitive_root(831143041, primitive_root(831143041)));
        assert!(is_primitive_root(1685283601, primitive_root(1685283601)));
    }

    fn is_primitive_root(m: i64, g: i64) -> bool {
        let mut factors = vec![];
        let mut cur = 2;
        let mut t = m - 1;
        while cur * cur <= t {
            if t % cur == 0 {
                factors.push(cur);
            }
            while t % cur == 0 {
                t /= cur;
            }
            cur += 1;
        }
        if t > 1 {
            factors.push(t);
        }

        for factor in factors {
            if mod_pow(g, (m - 1) / factor, m) == 1 {
                return false;
            }
        }
        true
    }
}