Skip to main content

cyanea_stats/
combinatorics.rs

1//! Combinatorics utilities.
2//!
3//! Provides exact and log-space variants of factorials, binomial coefficients,
4//! permutations, and multinomial coefficients, plus a combinations iterator.
5
6use crate::distribution::ln_gamma;
7use cyanea_core::{CyaneaError, Result};
8
9/// Exact factorial. Returns `None` if `n > 20` (overflow for u64).
10pub fn factorial(n: u64) -> Option<u64> {
11    if n > 20 {
12        return None;
13    }
14    let mut result = 1u64;
15    for i in 2..=n {
16        result = result.checked_mul(i)?;
17    }
18    Some(result)
19}
20
21/// Log-factorial via `ln(Γ(n + 1))`.
22pub fn ln_factorial(n: u64) -> f64 {
23    ln_gamma(n as f64 + 1.0)
24}
25
26/// Exact binomial coefficient C(n, k). Returns `None` on overflow.
27pub fn binomial(n: u64, k: u64) -> Option<u64> {
28    if k > n {
29        return Some(0);
30    }
31    // Use the smaller of k and n-k for efficiency
32    let k = k.min(n - k);
33    let mut result = 1u64;
34    for i in 0..k {
35        result = result.checked_mul(n - i)?;
36        result /= i + 1;
37    }
38    Some(result)
39}
40
41/// Log-space binomial coefficient ln(C(n, k)).
42///
43/// # Errors
44///
45/// Returns an error if `k > n`.
46pub fn ln_binomial(n: u64, k: u64) -> Result<f64> {
47    if k > n {
48        return Err(CyaneaError::InvalidInput(
49            "ln_binomial: k must be <= n".into(),
50        ));
51    }
52    Ok(ln_gamma(n as f64 + 1.0) - ln_gamma(k as f64 + 1.0) - ln_gamma((n - k) as f64 + 1.0))
53}
54
55/// Exact permutations P(n, k) = n! / (n-k)!. Returns `None` on overflow.
56pub fn permutations(n: u64, k: u64) -> Option<u64> {
57    if k > n {
58        return Some(0);
59    }
60    let mut result = 1u64;
61    for i in 0..k {
62        result = result.checked_mul(n - i)?;
63    }
64    Some(result)
65}
66
67/// Log-space permutations ln(P(n, k)).
68///
69/// # Errors
70///
71/// Returns an error if `k > n`.
72pub fn ln_permutations(n: u64, k: u64) -> Result<f64> {
73    if k > n {
74        return Err(CyaneaError::InvalidInput(
75            "ln_permutations: k must be <= n".into(),
76        ));
77    }
78    Ok(ln_gamma(n as f64 + 1.0) - ln_gamma((n - k) as f64 + 1.0))
79}
80
81/// Exact multinomial coefficient n! / (c₁! · c₂! · ... · cₖ!).
82///
83/// Returns `None` on overflow. The counts must sum to `n`.
84pub fn multinomial(n: u64, counts: &[u64]) -> Option<u64> {
85    let sum: u64 = counts.iter().sum();
86    if sum != n {
87        return None;
88    }
89    let mut result = 1u64;
90    let mut remaining = n;
91    for &c in counts {
92        result = result.checked_mul(binomial(remaining, c)?)?;
93        remaining -= c;
94    }
95    Some(result)
96}
97
98/// Log-space multinomial coefficient.
99///
100/// # Errors
101///
102/// Returns an error if counts don't sum to `n`.
103pub fn ln_multinomial(n: u64, counts: &[u64]) -> Result<f64> {
104    let sum: u64 = counts.iter().sum();
105    if sum != n {
106        return Err(CyaneaError::InvalidInput(
107            "ln_multinomial: counts must sum to n".into(),
108        ));
109    }
110    let mut result = ln_gamma(n as f64 + 1.0);
111    for &c in counts {
112        result -= ln_gamma(c as f64 + 1.0);
113    }
114    Ok(result)
115}
116
117/// Iterator over all k-element combinations of indices `[0, n)`.
118///
119/// Yields combinations in lexicographic order. Each combination is a
120/// `Vec<usize>` of length `k` with strictly increasing indices.
121///
122/// # Example
123///
124/// ```
125/// use cyanea_stats::combinatorics::combinations;
126///
127/// let combos: Vec<Vec<usize>> = combinations(4, 2).collect();
128/// assert_eq!(combos.len(), 6); // C(4, 2) = 6
129/// assert_eq!(combos[0], vec![0, 1]);
130/// assert_eq!(combos[5], vec![2, 3]);
131/// ```
132pub fn combinations(n: usize, k: usize) -> Combinations {
133    let first = if k == 0 || k > n {
134        None
135    } else {
136        Some((0..k).collect())
137    };
138    Combinations { n, k, current: first }
139}
140
141/// Iterator over k-element combinations of `[0, n)`.
142#[derive(Debug, Clone)]
143pub struct Combinations {
144    n: usize,
145    k: usize,
146    current: Option<Vec<usize>>,
147}
148
149impl Iterator for Combinations {
150    type Item = Vec<usize>;
151
152    fn next(&mut self) -> Option<Self::Item> {
153        let result = self.current.clone()?;
154
155        // Advance to next combination
156        let mut next = result.clone();
157        let mut i = self.k;
158        while i > 0 {
159            i -= 1;
160            next[i] += 1;
161            if next[i] <= self.n - self.k + i {
162                // Fill remaining positions
163                for j in (i + 1)..self.k {
164                    next[j] = next[j - 1] + 1;
165                }
166                self.current = Some(next);
167                return Some(result);
168            }
169        }
170
171        self.current = None;
172        Some(result)
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn factorial_small() {
182        assert_eq!(factorial(0), Some(1));
183        assert_eq!(factorial(1), Some(1));
184        assert_eq!(factorial(5), Some(120));
185        assert_eq!(factorial(10), Some(3_628_800));
186        assert_eq!(factorial(20), Some(2_432_902_008_176_640_000));
187    }
188
189    #[test]
190    fn factorial_overflow() {
191        assert_eq!(factorial(21), None);
192    }
193
194    #[test]
195    fn ln_factorial_matches_exact() {
196        for n in 0..=20 {
197            let exact = factorial(n).unwrap() as f64;
198            let ln_val = ln_factorial(n);
199            assert!((ln_val - exact.ln()).abs() < 1e-8 || (n == 0 && ln_val.abs() < 1e-10),
200                "ln_factorial({}) = {} but expected {}", n, ln_val, exact.ln());
201        }
202    }
203
204    #[test]
205    fn binomial_known_values() {
206        assert_eq!(binomial(5, 0), Some(1));
207        assert_eq!(binomial(5, 1), Some(5));
208        assert_eq!(binomial(5, 2), Some(10));
209        assert_eq!(binomial(5, 3), Some(10));
210        assert_eq!(binomial(5, 5), Some(1));
211        assert_eq!(binomial(10, 3), Some(120));
212        assert_eq!(binomial(20, 10), Some(184_756));
213    }
214
215    #[test]
216    fn binomial_k_greater_than_n() {
217        assert_eq!(binomial(3, 5), Some(0));
218    }
219
220    #[test]
221    fn ln_binomial_accuracy() {
222        let ln_val = ln_binomial(10, 3).unwrap();
223        let expected = (120.0_f64).ln();
224        assert!((ln_val - expected).abs() < 1e-8);
225    }
226
227    #[test]
228    fn ln_binomial_invalid() {
229        assert!(ln_binomial(3, 5).is_err());
230    }
231
232    #[test]
233    fn permutations_known() {
234        assert_eq!(permutations(5, 3), Some(60)); // 5*4*3
235        assert_eq!(permutations(5, 0), Some(1));
236        assert_eq!(permutations(5, 5), Some(120));
237    }
238
239    #[test]
240    fn permutations_k_greater_than_n() {
241        assert_eq!(permutations(3, 5), Some(0));
242    }
243
244    #[test]
245    fn ln_permutations_accuracy() {
246        let ln_val = ln_permutations(5, 3).unwrap();
247        let expected = (60.0_f64).ln();
248        assert!((ln_val - expected).abs() < 1e-8);
249    }
250
251    #[test]
252    fn multinomial_known() {
253        // 4! / (2! * 1! * 1!) = 12
254        assert_eq!(multinomial(4, &[2, 1, 1]), Some(12));
255        // 6! / (3! * 2! * 1!) = 60
256        assert_eq!(multinomial(6, &[3, 2, 1]), Some(60));
257    }
258
259    #[test]
260    fn multinomial_bad_sum() {
261        assert_eq!(multinomial(5, &[2, 1]), None);
262    }
263
264    #[test]
265    fn ln_multinomial_accuracy() {
266        let ln_val = ln_multinomial(4, &[2, 1, 1]).unwrap();
267        let expected = (12.0_f64).ln();
268        assert!((ln_val - expected).abs() < 1e-8);
269    }
270
271    #[test]
272    fn ln_multinomial_invalid() {
273        assert!(ln_multinomial(5, &[2, 1]).is_err());
274    }
275
276    #[test]
277    fn combinations_count() {
278        let combos: Vec<Vec<usize>> = combinations(5, 2).collect();
279        assert_eq!(combos.len(), 10); // C(5,2) = 10
280    }
281
282    #[test]
283    fn combinations_values() {
284        let combos: Vec<Vec<usize>> = combinations(4, 2).collect();
285        assert_eq!(combos[0], vec![0, 1]);
286        assert_eq!(combos[1], vec![0, 2]);
287        assert_eq!(combos[2], vec![0, 3]);
288        assert_eq!(combos[3], vec![1, 2]);
289        assert_eq!(combos[4], vec![1, 3]);
290        assert_eq!(combos[5], vec![2, 3]);
291    }
292
293    #[test]
294    fn combinations_k_zero() {
295        let combos: Vec<Vec<usize>> = combinations(5, 0).collect();
296        assert!(combos.is_empty());
297    }
298
299    #[test]
300    fn combinations_k_equals_n() {
301        let combos: Vec<Vec<usize>> = combinations(3, 3).collect();
302        assert_eq!(combos.len(), 1);
303        assert_eq!(combos[0], vec![0, 1, 2]);
304    }
305
306    #[test]
307    fn combinations_k_greater_than_n() {
308        let combos: Vec<Vec<usize>> = combinations(2, 5).collect();
309        assert!(combos.is_empty());
310    }
311}