competitive_hpp/modulo/
mod_fact.rs

1use num::traits::{int::PrimInt, Unsigned};
2
3#[derive(Eq, PartialEq, Clone, Debug)]
4pub struct ModFact<T> {
5    fact: Vec<T>,
6    inv: Vec<T>,
7    factinv: Vec<T>,
8    modulo: T,
9}
10
11impl<T> ModFact<T>
12where
13    T: PrimInt + Unsigned,
14{
15    pub fn new(n: T, modulo: T) -> Self {
16        let zero = T::zero();
17        let one = T::one();
18        let vec_size = n.to_usize().unwrap() + 1usize;
19
20        let mut fact: Vec<T> = vec![zero; vec_size];
21        fact[0] = one;
22        for i in 1..vec_size {
23            fact[i] = T::from(i).unwrap() * fact[i - 1] % modulo;
24        }
25
26        let mut inv: Vec<T> = vec![zero; vec_size];
27        inv[0] = zero;
28        inv[1] = one;
29        for i in 2..vec_size {
30            inv[i] = inv[modulo.to_usize().unwrap() % i] * (modulo - modulo / T::from(i).unwrap())
31                % modulo;
32        }
33
34        let mut factinv: Vec<T> = vec![zero; vec_size];
35        factinv[0] = one;
36        for i in 1..vec_size {
37            factinv[i] = factinv[i - 1] * inv[i] % modulo;
38        }
39
40        ModFact {
41            fact,
42            inv,
43            factinv,
44            modulo,
45        }
46    }
47
48    pub fn perm(&self, n: T, r: T) -> T {
49        let n = n.to_usize().unwrap();
50        let r = r.to_usize().unwrap();
51        if n < r {
52            T::from(0).unwrap()
53        } else {
54            self.fact[n] * self.factinv[n - r] % self.modulo
55        }
56    }
57
58    pub fn comb(&self, n: T, r: T) -> T {
59        let n = n.to_usize().unwrap();
60        let r = r.to_usize().unwrap();
61        if n < r {
62            T::from(0).unwrap()
63        } else {
64            (self.fact[n] * self.factinv[r] % self.modulo) * self.factinv[n - r] % self.modulo
65        }
66    }
67}
68
69#[cfg(test)]
70mod test {
71
72    use super::*;
73
74    #[test]
75    fn test_u16_factinv() {
76        let _table = ModFact::new(10u16, 91u16);
77    }
78    #[test]
79    fn test_u32_factinv() {
80        let _table = ModFact::new(10u32, 91u32);
81    }
82    #[test]
83    fn test_u64_factinv() {
84        let _table = ModFact::new(10u64, 91u64);
85    }
86    #[test]
87    fn test_usize_factinv() {
88        let _table = ModFact::new(10usize, 91usize);
89    }
90}