competitive_programming_rs/math/
combination.rs

1pub struct Combination {
2    fact: Vec<usize>,
3    inv_fact: Vec<usize>,
4    modulo: usize,
5}
6
7impl Combination {
8    pub fn new(max: usize, modulo: usize) -> Self {
9        let mut inv = vec![0; max + 1];
10        let mut fact = vec![0; max + 1];
11        let mut inv_fact = vec![0; max + 1];
12        inv[1] = 1;
13        for i in 2..(max + 1) {
14            inv[i] = inv[modulo % i] * (modulo - modulo / i) % modulo;
15        }
16        fact[0] = 1;
17        inv_fact[0] = 1;
18        for i in 0..max {
19            fact[i + 1] = fact[i] * (i + 1) % modulo;
20        }
21        for i in 0..max {
22            inv_fact[i + 1] = inv_fact[i] * inv[i + 1] % modulo;
23        }
24        Self {
25            fact,
26            inv_fact,
27            modulo,
28        }
29    }
30
31    pub fn get(&self, x: usize, y: usize) -> usize {
32        assert!(x >= y);
33        self.fact[x] * self.inv_fact[y] % self.modulo * self.inv_fact[x - y] % self.modulo
34    }
35
36    pub fn h(&self, n: usize, r: usize) -> usize {
37        self.get(n + r - 1, r)
38    }
39}
40
41#[cfg(test)]
42mod test {
43    use super::*;
44
45    fn gcd(a: usize, b: usize) -> usize {
46        if b == 0 {
47            a
48        } else {
49            gcd(b, a % b)
50        }
51    }
52
53    #[test]
54    fn random_combination() {
55        let modulo = 1_000_000_007;
56
57        for n in 100..200 {
58            let comb = Combination::new(n, modulo);
59            for m in 0..(n + 1) {
60                let mut upper = (0..m).map(|i| n - i).collect::<Vec<_>>();
61                for i in 0..m {
62                    let mut divisor = i + 1;
63                    for j in 0..(i + 1) {
64                        if divisor == 1 {
65                            break;
66                        }
67
68                        let g = gcd(divisor, upper[j]);
69                        upper[j] /= g;
70                        divisor /= g;
71                    }
72                }
73
74                let mut check = 1;
75                for u in &upper {
76                    check = (check * u) % modulo;
77                }
78
79                assert_eq!(comb.get(n, m), check);
80            }
81        }
82    }
83}