competitive_programming_rs/math/
combination.rs1pub struct Combination {
2 fact: Vec<usize>,
3 inv_fact: Vec<usize>,
4 modulo: usize,
5}
6
7impl Combination {
8 pub fn new(max: usize, modulo: usize) -> Self {
9 let mut inv = vec![0; max + 1];
10 let mut fact = vec![0; max + 1];
11 let mut inv_fact = vec![0; max + 1];
12 inv[1] = 1;
13 for i in 2..(max + 1) {
14 inv[i] = inv[modulo % i] * (modulo - modulo / i) % modulo;
15 }
16 fact[0] = 1;
17 inv_fact[0] = 1;
18 for i in 0..max {
19 fact[i + 1] = fact[i] * (i + 1) % modulo;
20 }
21 for i in 0..max {
22 inv_fact[i + 1] = inv_fact[i] * inv[i + 1] % modulo;
23 }
24 Self {
25 fact,
26 inv_fact,
27 modulo,
28 }
29 }
30
31 pub fn get(&self, x: usize, y: usize) -> usize {
32 assert!(x >= y);
33 self.fact[x] * self.inv_fact[y] % self.modulo * self.inv_fact[x - y] % self.modulo
34 }
35
36 pub fn h(&self, n: usize, r: usize) -> usize {
37 self.get(n + r - 1, r)
38 }
39}
40
41#[cfg(test)]
42mod test {
43 use super::*;
44
45 fn gcd(a: usize, b: usize) -> usize {
46 if b == 0 {
47 a
48 } else {
49 gcd(b, a % b)
50 }
51 }
52
53 #[test]
54 fn random_combination() {
55 let modulo = 1_000_000_007;
56
57 for n in 100..200 {
58 let comb = Combination::new(n, modulo);
59 for m in 0..(n + 1) {
60 let mut upper = (0..m).map(|i| n - i).collect::<Vec<_>>();
61 for i in 0..m {
62 let mut divisor = i + 1;
63 for j in 0..(i + 1) {
64 if divisor == 1 {
65 break;
66 }
67
68 let g = gcd(divisor, upper[j]);
69 upper[j] /= g;
70 divisor /= g;
71 }
72 }
73
74 let mut check = 1;
75 for u in &upper {
76 check = (check * u) % modulo;
77 }
78
79 assert_eq!(comb.get(n, m), check);
80 }
81 }
82 }
83}