1use crate::distributions::{
62 beta::Beta,
63 binomial_distribution::Binomial,
64 chi_squared::ChiSquared,
65 f_distribution::FDistribution,
66 gamma_distribution::Gamma,
67 geometric::Geometric,
68 lognormal::LogNormal,
69 negative_binomial::NegativeBinomial,
70 normal_distribution::Normal,
71 poisson_distribution::Poisson,
72 student_t::StudentT,
73 traits::{DiscreteDistribution, Distribution},
74 uniform_distribution::Uniform,
75 weibull::Weibull,
76};
77use crate::error::{StatsError, StatsResult};
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum DataKind {
84 Discrete,
86 Continuous,
88}
89
90pub fn detect_data_type(data: &[f64]) -> DataKind {
100 if data
101 .iter()
102 .all(|&x| x >= 0.0 && x.fract() == 0.0 && x.is_finite())
103 {
104 DataKind::Discrete
105 } else {
106 DataKind::Continuous
107 }
108}
109
110#[derive(Debug, Clone, Copy)]
114pub struct KsResult {
115 pub statistic: f64,
117 pub p_value: f64,
119}
120
121pub fn ks_test(data: &[f64], cdf: impl Fn(f64) -> f64) -> KsResult {
125 let n = data.len();
126 if n == 0 {
127 return KsResult {
128 statistic: 0.0,
129 p_value: 1.0,
130 };
131 }
132 let mut sorted = data.to_vec();
133 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
134
135 let nf = n as f64;
136 let mut d = 0.0_f64;
137 for (i, &x) in sorted.iter().enumerate() {
138 let f = cdf(x);
139 let upper = (i + 1) as f64 / nf;
140 let lower = i as f64 / nf;
141 d = d.max((upper - f).abs()).max((f - lower).abs());
142 }
143
144 let p_value = kolmogorov_p(((nf).sqrt() + 0.12 + 0.11 / nf.sqrt()) * d);
145
146 KsResult {
147 statistic: d,
148 p_value,
149 }
150}
151
152pub fn ks_test_discrete(data: &[f64], cdf: impl Fn(u64) -> f64) -> KsResult {
154 let n = data.len();
155 if n == 0 {
156 return KsResult {
157 statistic: 0.0,
158 p_value: 1.0,
159 };
160 }
161 let mut sorted = data.to_vec();
162 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
163
164 let nf = n as f64;
165 let mut d = 0.0_f64;
166 for (i, &x) in sorted.iter().enumerate() {
167 let k = x.round() as u64;
168 let f = cdf(k);
169 let upper = (i + 1) as f64 / nf;
170 let lower = i as f64 / nf;
171 d = d.max((upper - f).abs()).max((f - lower).abs());
172 }
173
174 let p_value = kolmogorov_p(((nf).sqrt() + 0.12 + 0.11 / nf.sqrt()) * d);
175
176 KsResult {
177 statistic: d,
178 p_value,
179 }
180}
181
182fn kolmogorov_p(x: f64) -> f64 {
184 if x <= 0.0 {
185 return 1.0;
186 }
187 let mut sum = 0.0_f64;
189 for j in 1_u32..=100 {
190 let term = (-(2.0 * (j as f64).powi(2) * x * x)).exp();
191 if j % 2 == 1 {
192 sum += term;
193 } else {
194 sum -= term;
195 }
196 if term < 1e-15 {
197 break;
198 }
199 }
200 (2.0 * sum).clamp(0.0, 1.0)
201}
202
203#[derive(Debug, Clone)]
207pub struct FitResult {
208 pub name: String,
210 pub aic: f64,
212 pub bic: f64,
214 pub ks_statistic: f64,
216 pub ks_p_value: f64,
218}
219
220pub fn fit_all(data: &[f64]) -> StatsResult<Vec<FitResult>> {
226 if data.is_empty() {
227 return Err(StatsError::InvalidInput {
228 message: "fit_all: data must not be empty".to_string(),
229 });
230 }
231
232 let mut results: Vec<FitResult> = Vec::new();
233
234 macro_rules! try_fit {
235 ($dist_type:ty, $fit_expr:expr) => {
236 if let Ok(dist) = $fit_expr {
237 if let (Ok(aic), Ok(bic)) = (dist.aic(data), dist.bic(data)) {
238 if aic.is_finite() && bic.is_finite() {
239 let ks = ks_test(data, |x| dist.cdf(x).unwrap_or(0.0));
240 results.push(FitResult {
241 name: dist.name().to_string(),
242 aic,
243 bic,
244 ks_statistic: ks.statistic,
245 ks_p_value: ks.p_value,
246 });
247 }
248 }
249 }
250 };
251 }
252
253 try_fit!(Normal, Normal::fit(data));
254 try_fit!(
255 Exponential,
256 crate::distributions::exponential_distribution::Exponential::fit(data)
257 );
258 try_fit!(Uniform, Uniform::fit(data));
259 try_fit!(Gamma, Gamma::fit(data));
260 try_fit!(LogNormal, LogNormal::fit(data));
261 try_fit!(Weibull, Weibull::fit(data));
262 try_fit!(Beta, Beta::fit(data));
263 try_fit!(StudentT, StudentT::fit(data));
264 try_fit!(FDistribution, FDistribution::fit(data));
265 try_fit!(ChiSquared, ChiSquared::fit(data));
266
267 if results.is_empty() {
268 return Err(StatsError::InvalidInput {
269 message: "fit_all: no distribution could be fitted to the data".to_string(),
270 });
271 }
272
273 results.sort_by(|a, b| {
274 a.aic
275 .partial_cmp(&b.aic)
276 .unwrap_or(std::cmp::Ordering::Equal)
277 });
278 Ok(results)
279}
280
281pub fn fit_best(data: &[f64]) -> StatsResult<FitResult> {
283 let mut all = fit_all(data)?;
284 Ok(all.remove(0))
285}
286
287pub fn fit_all_discrete(data: &[f64]) -> StatsResult<Vec<FitResult>> {
293 if data.is_empty() {
294 return Err(StatsError::InvalidInput {
295 message: "fit_all_discrete: data must not be empty".to_string(),
296 });
297 }
298
299 let int_data: Vec<u64> = data.iter().map(|&x| x.round() as u64).collect();
301
302 let mut results: Vec<FitResult> = Vec::new();
303
304 macro_rules! try_fit_disc {
305 ($fit_expr:expr) => {
306 if let Ok(dist) = $fit_expr {
307 if let (Ok(aic), Ok(bic)) = (dist.aic(&int_data), dist.bic(&int_data)) {
308 if aic.is_finite() && bic.is_finite() {
309 let ks = ks_test_discrete(data, |k| dist.cdf(k).unwrap_or(0.0));
310 results.push(FitResult {
311 name: dist.name().to_string(),
312 aic,
313 bic,
314 ks_statistic: ks.statistic,
315 ks_p_value: ks.p_value,
316 });
317 }
318 }
319 }
320 };
321 }
322
323 try_fit_disc!(Poisson::fit(data));
324 try_fit_disc!(Geometric::fit(data));
325 try_fit_disc!(NegativeBinomial::fit(data));
326 try_fit_disc!(Binomial::fit(data));
327
328 if results.is_empty() {
329 return Err(StatsError::InvalidInput {
330 message: "fit_all_discrete: no distribution could be fitted to the data".to_string(),
331 });
332 }
333
334 results.sort_by(|a, b| {
335 a.aic
336 .partial_cmp(&b.aic)
337 .unwrap_or(std::cmp::Ordering::Equal)
338 });
339 Ok(results)
340}
341
342pub fn fit_best_discrete(data: &[f64]) -> StatsResult<FitResult> {
344 let mut all = fit_all_discrete(data)?;
345 Ok(all.remove(0))
346}
347
348pub fn auto_fit(data: &[f64]) -> StatsResult<FitResult> {
362 match detect_data_type(data) {
363 DataKind::Discrete => fit_best_discrete(data),
364 DataKind::Continuous => fit_best(data),
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_detect_data_type_discrete() {
374 assert_eq!(detect_data_type(&[0.0, 1.0, 2.0, 3.0]), DataKind::Discrete);
375 assert_eq!(detect_data_type(&[0.0, 0.0, 1.0]), DataKind::Discrete);
376 }
377
378 #[test]
379 fn test_detect_data_type_continuous() {
380 assert_eq!(detect_data_type(&[0.5, 1.5, 2.3]), DataKind::Continuous);
381 assert_eq!(detect_data_type(&[-1.0, 0.0, 1.0]), DataKind::Continuous);
382 assert_eq!(detect_data_type(&[1.0, 2.5, 3.0]), DataKind::Continuous);
383 }
384
385 #[test]
386 fn test_ks_test_uniform() {
387 let data: Vec<f64> = (0..20).map(|i| i as f64 / 20.0).collect();
389 let ks = ks_test(&data, |x| x.clamp(0.0, 1.0));
390 assert!(ks.statistic < 0.15);
391 }
392
393 #[test]
394 fn test_fit_all_returns_results() {
395 let data: Vec<f64> = (0..50).map(|i| (i as f64) * 0.1 + 0.5).collect();
396 let results = fit_all(&data).unwrap();
397 assert!(!results.is_empty());
398 for i in 1..results.len() {
400 assert!(results[i].aic >= results[i - 1].aic);
401 }
402 }
403
404 #[test]
405 fn test_fit_best_normal_data() {
406 let data = vec![
408 4.1, 5.2, 5.8, 4.7, 5.3, 4.9, 6.1, 4.5, 5.5, 5.0, 4.8, 5.1, 4.3, 5.7, 4.6, 5.4, 4.2,
409 5.9, 5.2, 4.4,
410 ];
411 let best = fit_best(&data).unwrap();
412 assert!(best.aic.is_finite());
414 }
415
416 #[test]
417 fn test_fit_all_discrete() {
418 let data = vec![0.0, 1.0, 2.0, 3.0, 1.0, 0.0, 2.0, 1.0, 0.0, 4.0];
419 let results = fit_all_discrete(&data).unwrap();
420 assert!(!results.is_empty());
421 }
422
423 #[test]
424 fn test_auto_fit_continuous() {
425 let data = vec![1.5, 2.3, 1.8, 2.1, 2.7, 1.9, 2.4, 2.0];
426 let best = auto_fit(&data).unwrap();
427 assert!(!best.name.is_empty());
428 }
429
430 #[test]
431 fn test_auto_fit_discrete() {
432 let data = vec![0.0, 1.0, 2.0, 1.0, 0.0, 3.0, 1.0, 2.0];
433 let best = auto_fit(&data).unwrap();
434 assert!(!best.name.is_empty());
435 }
436
437 #[test]
438 fn test_fit_all_empty_data() {
439 assert!(fit_all(&[]).is_err());
440 }
441}