1use crate::distribution::ln_gamma;
7use cyanea_core::{CyaneaError, Result};
8
9pub fn factorial(n: u64) -> Option<u64> {
11 if n > 20 {
12 return None;
13 }
14 let mut result = 1u64;
15 for i in 2..=n {
16 result = result.checked_mul(i)?;
17 }
18 Some(result)
19}
20
21pub fn ln_factorial(n: u64) -> f64 {
23 ln_gamma(n as f64 + 1.0)
24}
25
26pub fn binomial(n: u64, k: u64) -> Option<u64> {
28 if k > n {
29 return Some(0);
30 }
31 let k = k.min(n - k);
33 let mut result = 1u64;
34 for i in 0..k {
35 result = result.checked_mul(n - i)?;
36 result /= i + 1;
37 }
38 Some(result)
39}
40
41pub fn ln_binomial(n: u64, k: u64) -> Result<f64> {
47 if k > n {
48 return Err(CyaneaError::InvalidInput(
49 "ln_binomial: k must be <= n".into(),
50 ));
51 }
52 Ok(ln_gamma(n as f64 + 1.0) - ln_gamma(k as f64 + 1.0) - ln_gamma((n - k) as f64 + 1.0))
53}
54
55pub fn permutations(n: u64, k: u64) -> Option<u64> {
57 if k > n {
58 return Some(0);
59 }
60 let mut result = 1u64;
61 for i in 0..k {
62 result = result.checked_mul(n - i)?;
63 }
64 Some(result)
65}
66
67pub fn ln_permutations(n: u64, k: u64) -> Result<f64> {
73 if k > n {
74 return Err(CyaneaError::InvalidInput(
75 "ln_permutations: k must be <= n".into(),
76 ));
77 }
78 Ok(ln_gamma(n as f64 + 1.0) - ln_gamma((n - k) as f64 + 1.0))
79}
80
81pub fn multinomial(n: u64, counts: &[u64]) -> Option<u64> {
85 let sum: u64 = counts.iter().sum();
86 if sum != n {
87 return None;
88 }
89 let mut result = 1u64;
90 let mut remaining = n;
91 for &c in counts {
92 result = result.checked_mul(binomial(remaining, c)?)?;
93 remaining -= c;
94 }
95 Some(result)
96}
97
98pub fn ln_multinomial(n: u64, counts: &[u64]) -> Result<f64> {
104 let sum: u64 = counts.iter().sum();
105 if sum != n {
106 return Err(CyaneaError::InvalidInput(
107 "ln_multinomial: counts must sum to n".into(),
108 ));
109 }
110 let mut result = ln_gamma(n as f64 + 1.0);
111 for &c in counts {
112 result -= ln_gamma(c as f64 + 1.0);
113 }
114 Ok(result)
115}
116
117pub fn combinations(n: usize, k: usize) -> Combinations {
133 let first = if k == 0 || k > n {
134 None
135 } else {
136 Some((0..k).collect())
137 };
138 Combinations { n, k, current: first }
139}
140
141#[derive(Debug, Clone)]
143pub struct Combinations {
144 n: usize,
145 k: usize,
146 current: Option<Vec<usize>>,
147}
148
149impl Iterator for Combinations {
150 type Item = Vec<usize>;
151
152 fn next(&mut self) -> Option<Self::Item> {
153 let result = self.current.clone()?;
154
155 let mut next = result.clone();
157 let mut i = self.k;
158 while i > 0 {
159 i -= 1;
160 next[i] += 1;
161 if next[i] <= self.n - self.k + i {
162 for j in (i + 1)..self.k {
164 next[j] = next[j - 1] + 1;
165 }
166 self.current = Some(next);
167 return Some(result);
168 }
169 }
170
171 self.current = None;
172 Some(result)
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn factorial_small() {
182 assert_eq!(factorial(0), Some(1));
183 assert_eq!(factorial(1), Some(1));
184 assert_eq!(factorial(5), Some(120));
185 assert_eq!(factorial(10), Some(3_628_800));
186 assert_eq!(factorial(20), Some(2_432_902_008_176_640_000));
187 }
188
189 #[test]
190 fn factorial_overflow() {
191 assert_eq!(factorial(21), None);
192 }
193
194 #[test]
195 fn ln_factorial_matches_exact() {
196 for n in 0..=20 {
197 let exact = factorial(n).unwrap() as f64;
198 let ln_val = ln_factorial(n);
199 assert!((ln_val - exact.ln()).abs() < 1e-8 || (n == 0 && ln_val.abs() < 1e-10),
200 "ln_factorial({}) = {} but expected {}", n, ln_val, exact.ln());
201 }
202 }
203
204 #[test]
205 fn binomial_known_values() {
206 assert_eq!(binomial(5, 0), Some(1));
207 assert_eq!(binomial(5, 1), Some(5));
208 assert_eq!(binomial(5, 2), Some(10));
209 assert_eq!(binomial(5, 3), Some(10));
210 assert_eq!(binomial(5, 5), Some(1));
211 assert_eq!(binomial(10, 3), Some(120));
212 assert_eq!(binomial(20, 10), Some(184_756));
213 }
214
215 #[test]
216 fn binomial_k_greater_than_n() {
217 assert_eq!(binomial(3, 5), Some(0));
218 }
219
220 #[test]
221 fn ln_binomial_accuracy() {
222 let ln_val = ln_binomial(10, 3).unwrap();
223 let expected = (120.0_f64).ln();
224 assert!((ln_val - expected).abs() < 1e-8);
225 }
226
227 #[test]
228 fn ln_binomial_invalid() {
229 assert!(ln_binomial(3, 5).is_err());
230 }
231
232 #[test]
233 fn permutations_known() {
234 assert_eq!(permutations(5, 3), Some(60)); assert_eq!(permutations(5, 0), Some(1));
236 assert_eq!(permutations(5, 5), Some(120));
237 }
238
239 #[test]
240 fn permutations_k_greater_than_n() {
241 assert_eq!(permutations(3, 5), Some(0));
242 }
243
244 #[test]
245 fn ln_permutations_accuracy() {
246 let ln_val = ln_permutations(5, 3).unwrap();
247 let expected = (60.0_f64).ln();
248 assert!((ln_val - expected).abs() < 1e-8);
249 }
250
251 #[test]
252 fn multinomial_known() {
253 assert_eq!(multinomial(4, &[2, 1, 1]), Some(12));
255 assert_eq!(multinomial(6, &[3, 2, 1]), Some(60));
257 }
258
259 #[test]
260 fn multinomial_bad_sum() {
261 assert_eq!(multinomial(5, &[2, 1]), None);
262 }
263
264 #[test]
265 fn ln_multinomial_accuracy() {
266 let ln_val = ln_multinomial(4, &[2, 1, 1]).unwrap();
267 let expected = (12.0_f64).ln();
268 assert!((ln_val - expected).abs() < 1e-8);
269 }
270
271 #[test]
272 fn ln_multinomial_invalid() {
273 assert!(ln_multinomial(5, &[2, 1]).is_err());
274 }
275
276 #[test]
277 fn combinations_count() {
278 let combos: Vec<Vec<usize>> = combinations(5, 2).collect();
279 assert_eq!(combos.len(), 10); }
281
282 #[test]
283 fn combinations_values() {
284 let combos: Vec<Vec<usize>> = combinations(4, 2).collect();
285 assert_eq!(combos[0], vec![0, 1]);
286 assert_eq!(combos[1], vec![0, 2]);
287 assert_eq!(combos[2], vec![0, 3]);
288 assert_eq!(combos[3], vec![1, 2]);
289 assert_eq!(combos[4], vec![1, 3]);
290 assert_eq!(combos[5], vec![2, 3]);
291 }
292
293 #[test]
294 fn combinations_k_zero() {
295 let combos: Vec<Vec<usize>> = combinations(5, 0).collect();
296 assert!(combos.is_empty());
297 }
298
299 #[test]
300 fn combinations_k_equals_n() {
301 let combos: Vec<Vec<usize>> = combinations(3, 3).collect();
302 assert_eq!(combos.len(), 1);
303 assert_eq!(combos[0], vec![0, 1, 2]);
304 }
305
306 #[test]
307 fn combinations_k_greater_than_n() {
308 let combos: Vec<Vec<usize>> = combinations(2, 5).collect();
309 assert!(combos.is_empty());
310 }
311}