1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
use num::traits::{int::PrimInt, Unsigned};

#[derive(Eq, PartialEq, Clone, Debug)]
pub struct ModFact<T> {
    fact: Vec<T>,
    inv: Vec<T>,
    factinv: Vec<T>,
    modulo: T,
}

impl<T> ModFact<T>
where
    T: PrimInt + Unsigned,
{
    pub fn new(n: T, modulo: T) -> Self {
        let zero = T::zero();
        let one = T::one();
        let vec_size = n.to_usize().unwrap() + 1usize;

        let mut fact: Vec<T> = vec![zero; vec_size];
        fact[0] = one;
        for i in 1..vec_size {
            fact[i] = T::from(i).unwrap() * fact[i - 1] % modulo;
        }

        let mut inv: Vec<T> = vec![zero; vec_size];
        inv[0] = zero;
        inv[1] = one;
        for i in 2..vec_size {
            inv[i] = inv[modulo.to_usize().unwrap() % i] * (modulo - modulo / T::from(i).unwrap())
                % modulo;
        }

        let mut factinv: Vec<T> = vec![zero; vec_size];
        factinv[0] = one;
        for i in 1..vec_size {
            factinv[i] = factinv[i - 1] * inv[i] % modulo;
        }

        ModFact {
            fact,
            inv,
            factinv,
            modulo,
        }
    }

    pub fn perm(&self, n: T, r: T) -> T {
        let n = n.to_usize().unwrap();
        let r = r.to_usize().unwrap();
        if n < r {
            T::from(0).unwrap()
        } else {
            self.fact[n] * self.factinv[n - r] % self.modulo
        }
    }

    pub fn comb(&self, n: T, r: T) -> T {
        let n = n.to_usize().unwrap();
        let r = r.to_usize().unwrap();
        if n < r {
            T::from(0).unwrap()
        } else {
            (self.fact[n] * self.factinv[r] % self.modulo) * self.factinv[n - r] % self.modulo
        }
    }
}

#[cfg(test)]
mod test {

    use super::*;

    #[test]
    fn test_u16_factinv() -> () {
        let _table = ModFact::new(10u16, 91u16);
    }
    #[test]
    fn test_u32_factinv() -> () {
        let _table = ModFact::new(10u32, 91u32);
    }
    #[test]
    fn test_u64_factinv() -> () {
        let _table = ModFact::new(10u64, 91u64);
    }
    #[test]
    fn test_usize_factinv() -> () {
        let _table = ModFact::new(10usize, 91usize);
    }
}