1use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
6
7use super::types::{
8 CharFunctionData, CharacteristicFunction, ConcentrationBound, Coupling, DirichletProcess,
9 DiscreteDistribution, EmpiricalCdf, ErgodicTheoremData, ExponentialDistribution,
10 GaussianDistribution, GaussianProcess, GaussianProcessRegression, HawkesProcess,
11 KernelDensityEstimator, LargeDeviations, Lcg, MarkovChain, PoissonProcess, RenewalProcess,
12 StoppingTime, WelfordEstimator,
13};
14
15pub fn app(f: Expr, a: Expr) -> Expr {
16 Expr::App(Box::new(f), Box::new(a))
17}
18pub fn app2(f: Expr, a: Expr, b: Expr) -> Expr {
19 app(app(f, a), b)
20}
21pub fn app3(f: Expr, a: Expr, b: Expr, c: Expr) -> Expr {
22 app(app2(f, a, b), c)
23}
24pub fn cst(s: &str) -> Expr {
25 Expr::Const(Name::str(s), vec![])
26}
27pub fn prop() -> Expr {
28 Expr::Sort(Level::zero())
29}
30pub fn type0() -> Expr {
31 Expr::Sort(Level::succ(Level::zero()))
32}
33pub fn pi(bi: BinderInfo, name: &str, dom: Expr, body: Expr) -> Expr {
34 Expr::Pi(bi, Name::str(name), Box::new(dom), Box::new(body))
35}
36pub fn arrow(a: Expr, b: Expr) -> Expr {
37 pi(BinderInfo::Default, "_", a, b)
38}
39pub fn bvar(n: u32) -> Expr {
40 Expr::BVar(n)
41}
42pub fn nat_ty() -> Expr {
43 cst("Nat")
44}
45pub fn real_ty() -> Expr {
46 cst("Real")
47}
48pub fn prob_space_ty() -> Expr {
50 type0()
51}
52pub fn event_ty() -> Expr {
54 arrow(type0(), prop())
55}
56pub fn random_var_ty() -> Expr {
58 arrow(type0(), arrow(type0(), type0()))
59}
60pub fn distribution_ty() -> Expr {
62 arrow(type0(), type0())
63}
64pub fn markov_chain_ty() -> Expr {
66 arrow(type0(), type0())
67}
68pub fn stochastic_process_ty() -> Expr {
70 arrow(nat_ty(), arrow(type0(), type0()))
71}
72pub fn sigma_algebra_ty() -> Expr {
75 arrow(type0(), type0())
76}
77pub fn measurable_ty() -> Expr {
80 arrow(arrow(type0(), type0()), prop())
81}
82pub fn prob_measure_ty() -> Expr {
84 arrow(sigma_algebra_ty(), type0())
85}
86pub fn covariance_ty() -> Expr {
89 arrow(
90 arrow(type0(), real_ty()),
91 arrow(arrow(type0(), real_ty()), real_ty()),
92 )
93}
94pub fn mutual_independence_ty() -> Expr {
97 arrow(app(cst("List"), event_ty()), prop())
98}
99pub fn pairwise_independence_ty() -> Expr {
102 arrow(app(cst("List"), event_ty()), prop())
103}
104pub fn conditional_expectation_ty() -> Expr {
107 arrow(
108 arrow(type0(), real_ty()),
109 arrow(sigma_algebra_ty(), arrow(type0(), real_ty())),
110 )
111}
112pub fn characteristic_fn_ty() -> Expr {
115 arrow(arrow(type0(), real_ty()), arrow(real_ty(), cst("Complex")))
116}
117pub fn quantile_ty() -> Expr {
120 arrow(distribution_ty(), arrow(real_ty(), type0()))
121}
122pub fn entropy_ty() -> Expr {
125 arrow(distribution_ty(), real_ty())
126}
127pub fn kl_divergence_ty() -> Expr {
130 arrow(distribution_ty(), arrow(distribution_ty(), real_ty()))
131}
132pub fn stopping_time_ty() -> Expr {
135 arrow(arrow(nat_ty(), event_ty()), prop())
136}
137pub fn martingale_ty() -> Expr {
140 arrow(stochastic_process_ty(), prop())
141}
142pub fn sub_gaussian_ty() -> Expr {
145 arrow(arrow(type0(), real_ty()), arrow(real_ty(), prop()))
146}
147pub fn sub_exponential_ty() -> Expr {
150 arrow(
151 arrow(type0(), real_ty()),
152 arrow(real_ty(), arrow(real_ty(), prop())),
153 )
154}
155pub fn rate_function_ty() -> Expr {
158 arrow(distribution_ty(), arrow(arrow(type0(), real_ty()), prop()))
159}
160pub fn renewal_process_ty() -> Expr {
163 arrow(arrow(nat_ty(), real_ty()), prop())
164}
165pub fn mixing_time_ty() -> Expr {
168 arrow(markov_chain_ty(), arrow(real_ty(), nat_ty()))
169}
170pub fn total_variation_dist_ty() -> Expr {
173 arrow(distribution_ty(), arrow(distribution_ty(), real_ty()))
174}
175pub fn coupling_ty() -> Expr {
178 arrow(distribution_ty(), arrow(distribution_ty(), type0()))
179}
180pub fn empirical_measure_ty() -> Expr {
183 arrow(nat_ty(), arrow(arrow(nat_ty(), type0()), distribution_ty()))
184}
185pub fn weak_lln_ty() -> Expr {
188 prop()
189}
190pub fn strong_lln_ty() -> Expr {
193 prop()
194}
195pub fn lindeberg_clt_ty() -> Expr {
198 prop()
199}
200pub fn lyapunov_clt_ty() -> Expr {
203 prop()
204}
205pub fn berry_esseen_ty() -> Expr {
208 prop()
209}
210pub fn hoeffding_inequality_ty() -> Expr {
213 prop()
214}
215pub fn bernstein_inequality_ty() -> Expr {
218 prop()
219}
220pub fn chernoff_bound_ty() -> Expr {
223 prop()
224}
225pub fn cramer_ldp_ty() -> Expr {
228 prop()
229}
230pub fn sanov_ldp_ty() -> Expr {
233 prop()
234}
235pub fn doob_optional_sampling_ty() -> Expr {
238 prop()
239}
240pub fn azuma_hoeffding_ty() -> Expr {
243 prop()
244}
245pub fn renewal_reward_ty() -> Expr {
248 prop()
249}
250pub fn coupling_lemma_ty() -> Expr {
253 prop()
254}
255pub fn law_of_large_numbers_ty() -> Expr {
257 prop()
258}
259pub fn central_limit_theorem_ty() -> Expr {
261 prop()
262}
263pub fn bayes_theorem_ty() -> Expr {
265 pi(
266 BinderInfo::Default,
267 "P_A",
268 real_ty(),
269 pi(
270 BinderInfo::Default,
271 "P_B",
272 real_ty(),
273 pi(BinderInfo::Default, "P_B_given_A", real_ty(), real_ty()),
274 ),
275 )
276}
277pub fn markov_inequality_ty() -> Expr {
279 prop()
280}
281pub fn chebyshev_inequality_ty() -> Expr {
283 prop()
284}
285pub fn kolmogorov_axioms_ty() -> Expr {
287 prop()
288}
289pub fn build_probability_env(env: &mut Environment) -> Result<(), String> {
291 let type_decls: &[(&str, Expr)] = &[
292 ("ProbSpace", prob_space_ty()),
293 ("Event", event_ty()),
294 ("RandomVar", random_var_ty()),
295 ("Distribution", distribution_ty()),
296 ("MarkovChain", markov_chain_ty()),
297 ("StochasticProcess", stochastic_process_ty()),
298 ];
299 for (name, ty) in type_decls {
300 env.add(Declaration::Axiom {
301 name: Name::str(*name),
302 univ_params: vec![],
303 ty: ty.clone(),
304 })
305 .ok();
306 }
307 let new_type_decls: &[(&str, Expr)] = &[
308 ("SigmaAlgebra", sigma_algebra_ty()),
309 ("ProbMeasure", prob_measure_ty()),
310 ("Coupling", coupling_ty()),
311 ];
312 for (name, ty) in new_type_decls {
313 env.add(Declaration::Axiom {
314 name: Name::str(*name),
315 univ_params: vec![],
316 ty: ty.clone(),
317 })
318 .ok();
319 }
320 let theorem_decls: &[(&str, Expr)] = &[
321 ("LawOfLargeNumbers", law_of_large_numbers_ty()),
322 ("CentralLimitTheorem", central_limit_theorem_ty()),
323 ("BayesTheorem", bayes_theorem_ty()),
324 ("MarkovInequality", markov_inequality_ty()),
325 ("ChebyshevInequality", chebyshev_inequality_ty()),
326 ("KolmogorovAxioms", kolmogorov_axioms_ty()),
327 ];
328 for (name, ty) in theorem_decls {
329 env.add(Declaration::Axiom {
330 name: Name::str(*name),
331 univ_params: vec![],
332 ty: ty.clone(),
333 })
334 .ok();
335 }
336 let new_theorem_decls: &[(&str, Expr)] = &[
337 ("WeakLawOfLargeNumbers", weak_lln_ty()),
338 ("StrongLawOfLargeNumbers", strong_lln_ty()),
339 ("LindebergCLT", lindeberg_clt_ty()),
340 ("LyapunovCLT", lyapunov_clt_ty()),
341 ("BerryEsseenBound", berry_esseen_ty()),
342 ("HoeffdingInequality", hoeffding_inequality_ty()),
343 ("BernsteinInequality", bernstein_inequality_ty()),
344 ("ChernoffBound", chernoff_bound_ty()),
345 ("CramerLDP", cramer_ldp_ty()),
346 ("SanovLDP", sanov_ldp_ty()),
347 ("DoobOptionalSampling", doob_optional_sampling_ty()),
348 ("AzumaHoeffding", azuma_hoeffding_ty()),
349 ("RenewalReward", renewal_reward_ty()),
350 ("CouplingLemma", coupling_lemma_ty()),
351 ];
352 for (name, ty) in new_theorem_decls {
353 env.add(Declaration::Axiom {
354 name: Name::str(*name),
355 univ_params: vec![],
356 ty: ty.clone(),
357 })
358 .ok();
359 }
360 let extra: &[(&str, Expr)] = &[
361 ("Prob", arrow(event_ty(), real_ty())),
362 ("Expectation", arrow(arrow(type0(), real_ty()), real_ty())),
363 ("Variance", arrow(arrow(type0(), real_ty()), real_ty())),
364 (
365 "Conditional",
366 arrow(event_ty(), arrow(event_ty(), real_ty())),
367 ),
368 ("Independence", arrow(event_ty(), arrow(event_ty(), prop()))),
369 (
370 "StationaryDist",
371 arrow(markov_chain_ty(), distribution_ty()),
372 ),
373 (
374 "MomentGenerating",
375 arrow(arrow(type0(), real_ty()), arrow(real_ty(), real_ty())),
376 ),
377 ("Measurable", measurable_ty()),
378 ("Cov", covariance_ty()),
379 ("MutualIndep", mutual_independence_ty()),
380 ("PairwiseIndep", pairwise_independence_ty()),
381 ("CondExpectation", conditional_expectation_ty()),
382 ("CharFn", characteristic_fn_ty()),
383 ("Quantile", quantile_ty()),
384 ("Entropy", entropy_ty()),
385 ("KLDiv", kl_divergence_ty()),
386 ("StoppingTime", stopping_time_ty()),
387 ("IsMartingale", martingale_ty()),
388 ("IsSubGaussian", sub_gaussian_ty()),
389 ("IsSubExponential", sub_exponential_ty()),
390 ("RateFunction", rate_function_ty()),
391 ("IsRenewalProcess", renewal_process_ty()),
392 ("MixingTime", mixing_time_ty()),
393 ("TVDist", total_variation_dist_ty()),
394 ("EmpiricalMeasure", empirical_measure_ty()),
395 ];
396 for (name, ty) in extra {
397 env.add(Declaration::Axiom {
398 name: Name::str(*name),
399 univ_params: vec![],
400 ty: ty.clone(),
401 })
402 .ok();
403 }
404 build_advanced_probability_env(env)?;
405 Ok(())
406}
407pub fn discrete_uniform(n: usize) -> Vec<f64> {
409 if n == 0 {
410 return vec![];
411 }
412 vec![1.0 / n as f64; n]
413}
414pub fn binomial_pmf(n: u32, k: u32, p: f64) -> f64 {
416 if k > n {
417 return 0.0;
418 }
419 let binom = binomial_coeff(n, k) as f64;
420 binom * p.powi(k as i32) * (1.0 - p).powi((n - k) as i32)
421}
422pub fn poisson_pmf(lambda: f64, k: u32) -> f64 {
424 if lambda < 0.0 {
425 return 0.0;
426 }
427 lambda.powi(k as i32) * (-lambda).exp() / factorial(k) as f64
428}
429pub fn normal_pdf(x: f64, mu: f64, sigma: f64) -> f64 {
431 if sigma <= 0.0 {
432 return 0.0;
433 }
434 let z = (x - mu) / sigma;
435 (-0.5 * z * z).exp() / (sigma * (2.0 * std::f64::consts::PI).sqrt())
436}
437pub fn sample_mean(data: &[f64]) -> f64 {
439 if data.is_empty() {
440 return 0.0;
441 }
442 data.iter().sum::<f64>() / data.len() as f64
443}
444pub fn sample_variance(data: &[f64]) -> f64 {
446 if data.len() < 2 {
447 return 0.0;
448 }
449 let mean = sample_mean(data);
450 let sum_sq: f64 = data.iter().map(|x| (x - mean).powi(2)).sum();
451 sum_sq / (data.len() - 1) as f64
452}
453pub fn covariance(x: &[f64], y: &[f64]) -> f64 {
455 let n = x.len().min(y.len());
456 if n < 2 {
457 return 0.0;
458 }
459 let mx = sample_mean(&x[..n]);
460 let my = sample_mean(&y[..n]);
461 let sum: f64 = x[..n]
462 .iter()
463 .zip(y[..n].iter())
464 .map(|(xi, yi)| (xi - mx) * (yi - my))
465 .sum();
466 sum / (n - 1) as f64
467}
468pub fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
470 let n = x.len().min(y.len());
471 if n < 2 {
472 return 0.0;
473 }
474 let sx = sample_variance(&x[..n]).sqrt();
475 let sy = sample_variance(&y[..n]).sqrt();
476 if sx == 0.0 || sy == 0.0 {
477 return 0.0;
478 }
479 covariance(x, y) / (sx * sy)
480}
481pub fn standard_normal_cdf(z: f64) -> f64 {
485 let sign = if z < 0.0 { -1.0 } else { 1.0 };
486 let z = z.abs();
487 let t = 1.0 / (1.0 + 0.2316419 * z);
488 let poly = t
489 * (0.319_381_530
490 + t * (-0.356_563_782
491 + t * (1.781_477_937 + t * (-1.821_255_978 + t * 1.330_274_429))));
492 let phi_pos = 1.0 - normal_pdf(z, 0.0, 1.0) * poly;
493 if sign > 0.0 {
494 phi_pos
495 } else {
496 1.0 - phi_pos
497 }
498}
499pub fn exponential_pdf(x: f64, lambda: f64) -> f64 {
501 if x < 0.0 || lambda <= 0.0 {
502 return 0.0;
503 }
504 lambda * (-lambda * x).exp()
505}
506pub fn exponential_cdf(x: f64, lambda: f64) -> f64 {
508 if x < 0.0 || lambda <= 0.0 {
509 return 0.0;
510 }
511 1.0 - (-lambda * x).exp()
512}
513pub fn geometric_pmf(k: u32, p: f64) -> f64 {
515 if k == 0 || p <= 0.0 || p > 1.0 {
516 return 0.0;
517 }
518 (1.0 - p).powi((k - 1) as i32) * p
519}
520pub fn negative_binomial_pmf(k: u32, r: u32, p: f64) -> f64 {
524 if p <= 0.0 || p > 1.0 {
525 return 0.0;
526 }
527 let binom = binomial_coeff(k + r - 1, k) as f64;
528 binom * p.powi(r as i32) * (1.0 - p).powi(k as i32)
529}
530pub fn gamma_pdf(x: f64, alpha: f64, beta: f64) -> f64 {
534 if x <= 0.0 || alpha <= 0.0 || beta <= 0.0 {
535 return 0.0;
536 }
537 let log_pdf = (alpha - 1.0) * x.ln() - x / beta - log_gamma(alpha) - alpha * beta.ln();
538 log_pdf.exp()
539}
540pub fn beta_pdf(x: f64, alpha: f64, beta: f64) -> f64 {
542 if x <= 0.0 || x >= 1.0 || alpha <= 0.0 || beta <= 0.0 {
543 return 0.0;
544 }
545 let log_b = log_gamma(alpha) + log_gamma(beta) - log_gamma(alpha + beta);
546 let log_pdf = (alpha - 1.0) * x.ln() + (beta - 1.0) * (1.0 - x).ln() - log_b;
547 log_pdf.exp()
548}
549pub fn log_gamma(x: f64) -> f64 {
553 if x <= 0.0 {
554 return f64::NAN;
555 }
556 let g = 7.0f64;
557 let c = [
558 0.999_999_999_999_809_9,
559 676.520_368_121_885_1,
560 -1_259.139_216_722_403,
561 771.323_428_777_653_1,
562 -176.615_029_162_140_6,
563 12.507_343_278_686_905,
564 -0.138_571_095_265_720_12,
565 9.984_369_578_019_572e-6,
566 1.505_632_735_149_311_6e-7,
567 ];
568 if x < 0.5 {
569 return std::f64::consts::PI.ln()
570 - (std::f64::consts::PI * x).sin().ln()
571 - log_gamma(1.0 - x);
572 }
573 let x = x - 1.0;
574 let mut a = c[0];
575 for i in 1..9 {
576 a += c[i] / (x + i as f64);
577 }
578 let t = x + g + 0.5;
579 0.5 * (2.0 * std::f64::consts::PI).ln() + (x + 0.5) * t.ln() - t + a.ln()
580}
581pub fn bayes_update(prior: &[f64], likelihoods: &[f64]) -> Vec<f64> {
586 let n = prior.len().min(likelihoods.len());
587 let mut posterior: Vec<f64> = prior[..n]
588 .iter()
589 .zip(likelihoods[..n].iter())
590 .map(|(p, l)| p * l)
591 .collect();
592 let total: f64 = posterior.iter().sum();
593 if total > 0.0 {
594 for v in posterior.iter_mut() {
595 *v /= total;
596 }
597 }
598 posterior
599}
600pub fn kl_divergence(p: &[f64], q: &[f64]) -> f64 {
602 let n = p.len().min(q.len());
603 p[..n]
604 .iter()
605 .zip(q[..n].iter())
606 .filter(|(&pi, &qi)| pi > 0.0 && qi > 0.0)
607 .map(|(&pi, &qi)| pi * (pi / qi).ln())
608 .sum()
609}
610pub fn total_variation_distance(p: &[f64], q: &[f64]) -> f64 {
614 let n = p.len().min(q.len());
615 0.5 * p[..n]
616 .iter()
617 .zip(q[..n].iter())
618 .map(|(a, b)| (a - b).abs())
619 .sum::<f64>()
620}
621pub fn empirical_moments(data: &[f64], max_order: u32) -> Vec<f64> {
625 (0..=max_order)
626 .map(|k| {
627 if data.is_empty() {
628 0.0
629 } else {
630 data.iter().map(|x| x.powi(k as i32)).sum::<f64>() / data.len() as f64
631 }
632 })
633 .collect()
634}
635pub fn binomial_coeff(n: u32, k: u32) -> u64 {
636 if k > n {
637 return 0;
638 }
639 let k = k.min(n - k);
640 let mut result: u64 = 1;
641 for i in 0..k {
642 result = result * (n - i) as u64 / (i + 1) as u64;
643 }
644 result
645}
646pub fn factorial(n: u32) -> u64 {
647 (1..=n as u64).product()
648}
649pub fn brownian_motion_ty() -> Expr {
653 arrow(arrow(real_ty(), arrow(type0(), real_ty())), prop())
654}
655pub fn levy_process_ty() -> Expr {
658 arrow(arrow(real_ty(), arrow(type0(), real_ty())), prop())
659}
660pub fn ito_integral_ty() -> Expr {
663 arrow(
664 arrow(real_ty(), arrow(type0(), real_ty())),
665 arrow(
666 arrow(real_ty(), arrow(type0(), real_ty())),
667 arrow(type0(), real_ty()),
668 ),
669 )
670}
671pub fn ito_formula_ty() -> Expr {
674 prop()
675}
676pub fn sde_ty() -> Expr {
679 arrow(
680 arrow(real_ty(), arrow(type0(), real_ty())),
681 arrow(
682 arrow(real_ty(), arrow(real_ty(), real_ty())),
683 arrow(arrow(real_ty(), arrow(real_ty(), real_ty())), prop()),
684 ),
685 )
686}
687pub fn strong_solution_ty() -> Expr {
690 arrow(sde_ty(), prop())
691}
692pub fn weak_solution_ty() -> Expr {
695 arrow(sde_ty(), prop())
696}
697pub fn girsanov_thm_ty() -> Expr {
701 prop()
702}
703pub fn quadratic_variation_ty() -> Expr {
706 arrow(
707 arrow(real_ty(), arrow(type0(), real_ty())),
708 arrow(arrow(type0(), real_ty()), prop()),
709 )
710}
711pub fn mcdiarmid_inequality_ty() -> Expr {
715 prop()
716}
717pub fn azuma_inequality_ty() -> Expr {
720 prop()
721}
722pub fn ldp_ty() -> Expr {
725 arrow(
726 distribution_ty(),
727 arrow(arrow(real_ty(), real_ty()), prop()),
728 )
729}
730pub fn gartner_ellis_ty() -> Expr {
733 prop()
734}
735pub fn log_mgf_ty() -> Expr {
738 arrow(arrow(type0(), real_ty()), arrow(real_ty(), real_ty()))
739}
740pub fn fenchel_legendre_ty() -> Expr {
743 arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty()))
744}
745pub fn random_walk_ty() -> Expr {
748 arrow(arrow(nat_ty(), arrow(type0(), real_ty())), prop())
749}
750pub fn green_function_ty() -> Expr {
753 arrow(
754 markov_chain_ty(),
755 arrow(nat_ty(), arrow(nat_ty(), real_ty())),
756 )
757}
758pub fn hitting_time_ty() -> Expr {
761 arrow(markov_chain_ty(), arrow(nat_ty(), arrow(type0(), nat_ty())))
762}
763pub fn spectral_gap_ty() -> Expr {
766 arrow(markov_chain_ty(), real_ty())
767}
768pub fn reversible_chain_ty() -> Expr {
771 arrow(markov_chain_ty(), arrow(distribution_ty(), prop()))
772}
773pub fn gev_distribution_ty() -> Expr {
776 arrow(real_ty(), arrow(real_ty(), arrow(real_ty(), type0())))
777}
778pub fn gpd_distribution_ty() -> Expr {
781 arrow(real_ty(), arrow(real_ty(), type0()))
782}
783pub fn pickands_balkema_de_haan_ty() -> Expr {
787 prop()
788}
789pub fn fisher_tippett_gnedenko_ty() -> Expr {
792 prop()
793}
794pub fn gaussian_process_ty() -> Expr {
797 arrow(
798 arrow(
799 arrow(type0(), real_ty()),
800 arrow(arrow(type0(), real_ty()), real_ty()),
801 ),
802 prop(),
803 )
804}
805pub fn dirichlet_process_ty() -> Expr {
808 arrow(real_ty(), arrow(distribution_ty(), distribution_ty()))
809}
810pub fn crp_ty() -> Expr {
813 arrow(real_ty(), arrow(nat_ty(), distribution_ty()))
814}
815pub fn donsker_thm_ty() -> Expr {
818 prop()
819}
820pub fn vc_dimension_ty() -> Expr {
823 arrow(arrow(arrow(type0(), prop()), prop()), nat_ty())
824}
825pub fn rademacher_complexity_ty() -> Expr {
828 arrow(
829 arrow(arrow(type0(), real_ty()), prop()),
830 arrow(real_ty(), real_ty()),
831 )
832}
833pub fn markov_blanket_ty() -> Expr {
836 arrow(
837 arrow(nat_ty(), prop()),
838 arrow(nat_ty(), arrow(nat_ty(), prop())),
839 )
840}
841pub fn d_separation_ty() -> Expr {
844 arrow(
845 arrow(nat_ty(), arrow(nat_ty(), prop())),
846 arrow(
847 nat_ty(),
848 arrow(nat_ty(), arrow(arrow(nat_ty(), prop()), prop())),
849 ),
850 )
851}
852pub fn faithfulness_ty() -> Expr {
855 arrow(
856 arrow(nat_ty(), arrow(nat_ty(), prop())),
857 arrow(distribution_ty(), prop()),
858 )
859}
860pub fn free_probability_space_ty() -> Expr {
863 type0()
864}
865pub fn free_independence_ty() -> Expr {
868 arrow(
869 free_probability_space_ty(),
870 arrow(free_probability_space_ty(), prop()),
871 )
872}
873pub fn free_convolution_ty() -> Expr {
876 arrow(
877 distribution_ty(),
878 arrow(distribution_ty(), distribution_ty()),
879 )
880}
881pub fn quantum_prob_space_ty() -> Expr {
884 type0()
885}
886pub fn branching_process_ty() -> Expr {
889 arrow(
890 arrow(nat_ty(), distribution_ty()),
891 arrow(arrow(nat_ty(), nat_ty()), prop()),
892 )
893}
894pub fn extinction_prob_ty() -> Expr {
897 arrow(branching_process_ty(), real_ty())
898}
899pub fn random_tree_ty() -> Expr {
902 arrow(nat_ty(), type0())
903}
904pub fn continuum_random_tree_ty() -> Expr {
907 type0()
908}
909pub fn build_advanced_probability_env(env: &mut Environment) -> Result<(), String> {
912 let advanced_type_decls: &[(&str, Expr)] = &[
913 ("BrownianMotion", brownian_motion_ty()),
914 ("LevyProcess", levy_process_ty()),
915 ("ItoIntegral", ito_integral_ty()),
916 ("QuadraticVariation", quadratic_variation_ty()),
917 ("SDE", sde_ty()),
918 ("GEVDistribution", gev_distribution_ty()),
919 ("GPDDistribution", gpd_distribution_ty()),
920 ("DirichletProcess", dirichlet_process_ty()),
921 ("CRP", crp_ty()),
922 ("GreenFunction", green_function_ty()),
923 ("HittingTime", hitting_time_ty()),
924 ("RandomWalk", random_walk_ty()),
925 ("RandomTree", random_tree_ty()),
926 ("ContinuumRandomTree", continuum_random_tree_ty()),
927 ("FreeProbabilitySpace", free_probability_space_ty()),
928 ("QuantumProbSpace", quantum_prob_space_ty()),
929 ];
930 for (name, ty) in advanced_type_decls {
931 env.add(Declaration::Axiom {
932 name: Name::str(*name),
933 univ_params: vec![],
934 ty: ty.clone(),
935 })
936 .ok();
937 }
938 let advanced_fn_decls: &[(&str, Expr)] = &[
939 ("LogMGF", log_mgf_ty()),
940 ("FenchelLegendre", fenchel_legendre_ty()),
941 ("SpectralGap", spectral_gap_ty()),
942 ("ExtinctionProb", extinction_prob_ty()),
943 ("FreeConvolution", free_convolution_ty()),
944 ("RademacherComplexity", rademacher_complexity_ty()),
945 ("VCDimension", vc_dimension_ty()),
946 ];
947 for (name, ty) in advanced_fn_decls {
948 env.add(Declaration::Axiom {
949 name: Name::str(*name),
950 univ_params: vec![],
951 ty: ty.clone(),
952 })
953 .ok();
954 }
955 let advanced_pred_decls: &[(&str, Expr)] = &[
956 ("ItoFormula", ito_formula_ty()),
957 ("StrongSolution", strong_solution_ty()),
958 ("WeakSolution", weak_solution_ty()),
959 ("GirsanovThm", girsanov_thm_ty()),
960 ("McDiarmidInequality", mcdiarmid_inequality_ty()),
961 ("AzumaInequality", azuma_inequality_ty()),
962 ("LDP", ldp_ty()),
963 ("GartnerEllis", gartner_ellis_ty()),
964 ("PickandsBalkemaDeHaan", pickands_balkema_de_haan_ty()),
965 ("FisherTippettGnedenko", fisher_tippett_gnedenko_ty()),
966 ("GaussianProcess", gaussian_process_ty()),
967 ("DonskerThm", donsker_thm_ty()),
968 ("MarkovBlanket", markov_blanket_ty()),
969 ("DSeparation", d_separation_ty()),
970 ("Faithfulness", faithfulness_ty()),
971 ("FreeIndependence", free_independence_ty()),
972 ("BranchingProcess", branching_process_ty()),
973 ("ReversibleChain", reversible_chain_ty()),
974 ];
975 for (name, ty) in advanced_pred_decls {
976 env.add(Declaration::Axiom {
977 name: Name::str(*name),
978 univ_params: vec![],
979 ty: ty.clone(),
980 })
981 .ok();
982 }
983 Ok(())
984}
985#[cfg(test)]
986mod tests {
987 use super::*;
988 use oxilean_kernel::Environment;
989 const EPS: f64 = 1e-6;
990 fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
991 (a - b).abs() < tol
992 }
993 #[test]
994 fn test_discrete_uniform() {
995 let probs = discrete_uniform(4);
996 assert_eq!(probs.len(), 4);
997 for p in &probs {
998 assert!(approx_eq(*p, 0.25, EPS));
999 }
1000 let sum: f64 = probs.iter().sum();
1001 assert!(approx_eq(sum, 1.0, EPS));
1002 }
1003 #[test]
1004 fn test_binomial_pmf() {
1005 let p = binomial_pmf(10, 5, 0.5);
1006 assert!(approx_eq(p, 0.24609375, 1e-8));
1007 }
1008 #[test]
1009 fn test_poisson_pmf() {
1010 let p = poisson_pmf(2.0, 2);
1011 assert!(approx_eq(p, 2.0 * (-2.0f64).exp(), 1e-9));
1012 assert!(approx_eq(p, 0.27067, 1e-4));
1013 }
1014 #[test]
1015 fn test_normal_pdf() {
1016 let p = normal_pdf(0.0, 0.0, 1.0);
1017 let expected = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
1018 assert!(approx_eq(p, expected, EPS));
1019 assert!(approx_eq(p, 0.3989422804, 1e-9));
1020 }
1021 #[test]
1022 fn test_sample_stats() {
1023 let data = [1.0, 2.0, 3.0, 4.0, 5.0];
1024 let mean = sample_mean(&data);
1025 assert!(approx_eq(mean, 3.0, EPS));
1026 let var = sample_variance(&data);
1027 assert!(approx_eq(var, 2.5, EPS));
1028 }
1029 #[test]
1030 fn test_pearson() {
1031 let x: Vec<f64> = (0..10).map(|i| i as f64).collect();
1032 let y: Vec<f64> = x.iter().map(|xi| 2.0 * xi + 1.0).collect();
1033 let r = pearson_correlation(&x, &y);
1034 assert!(approx_eq(r, 1.0, EPS));
1035 }
1036 #[test]
1037 fn test_markov_chain() {
1038 let transition = vec![vec![0.7, 0.3], vec![0.4, 0.6]];
1039 let chain = MarkovChain::new(transition);
1040 let stat = chain.stationary_distribution();
1041 assert_eq!(stat.len(), 2);
1042 assert!(approx_eq(stat[0], 4.0 / 7.0, 1e-6));
1043 assert!(approx_eq(stat[1], 3.0 / 7.0, 1e-6));
1044 assert!(chain.is_ergodic());
1045 }
1046 #[test]
1047 fn test_bayes_update() {
1048 let prior = [0.5, 0.5];
1049 let likelihoods = [0.8, 0.4];
1050 let posterior = bayes_update(&prior, &likelihoods);
1051 assert_eq!(posterior.len(), 2);
1052 assert!(approx_eq(posterior[0], 2.0 / 3.0, EPS));
1053 assert!(approx_eq(posterior[1], 1.0 / 3.0, EPS));
1054 }
1055 #[test]
1056 fn test_build_env() {
1057 let mut env = Environment::new();
1058 let result = build_probability_env(&mut env);
1059 assert!(result.is_ok());
1060 }
1061 #[test]
1062 fn test_discrete_distribution() {
1063 let weights = [1.0, 2.0, 3.0, 4.0];
1064 let dist = DiscreteDistribution::from_weights(&weights);
1065 assert_eq!(dist.pmf.len(), 4);
1066 let sum: f64 = dist.pmf.iter().sum();
1067 assert!(approx_eq(sum, 1.0, EPS));
1068 assert!(approx_eq(dist.prob(0), 0.1, EPS));
1069 assert!(approx_eq(dist.prob(3), 0.4, EPS));
1070 assert!(approx_eq(dist.mean(), 2.0, EPS));
1071 }
1072 #[test]
1073 fn test_gaussian_cdf() {
1074 let g = GaussianDistribution::new(0.0, 1.0);
1075 assert!(approx_eq(g.cdf(0.0), 0.5, 1e-4));
1076 assert!(approx_eq(g.cdf(1.96), 0.975, 1e-3));
1077 }
1078 #[test]
1079 fn test_concentration_bounds() {
1080 let intervals: Vec<(f64, f64)> = vec![(0.0, 1.0); 10];
1081 let b = ConcentrationBound::hoeffding(1.0, &intervals);
1082 assert!(approx_eq(b, (-0.2f64).exp(), 1e-6));
1083 let m = ConcentrationBound::markov(2.0, 4.0);
1084 assert!(approx_eq(m, 0.5, EPS));
1085 let c = ConcentrationBound::chebyshev(2.0);
1086 assert!(approx_eq(c, 0.25, EPS));
1087 }
1088 #[test]
1089 fn test_characteristic_function() {
1090 let pmf = vec![0.25; 4];
1091 let cf = CharacteristicFunction::new(pmf);
1092 assert!(approx_eq(cf.real_part(0.0), 1.0, EPS));
1093 assert!(approx_eq(cf.imag_part(0.0), 0.0, EPS));
1094 assert!(approx_eq(cf.moment(1), 1.5, EPS));
1095 }
1096 #[test]
1097 fn test_exponential_dist() {
1098 assert!(approx_eq(exponential_pdf(0.0, 1.0), 1.0, EPS));
1099 assert!(approx_eq(
1100 exponential_cdf(1.0, 1.0),
1101 1.0 - (-1.0f64).exp(),
1102 EPS
1103 ));
1104 }
1105 #[test]
1106 fn test_kl_divergence() {
1107 let p = [0.5, 0.5];
1108 assert!(approx_eq(kl_divergence(&p, &p), 0.0, EPS));
1109 let q = [0.5, 0.5];
1110 let p2 = [1.0, 0.0];
1111 let kl = kl_divergence(&p2, &q);
1112 assert!(approx_eq(kl, 2.0f64.ln(), EPS));
1113 }
1114 #[test]
1115 fn test_total_variation() {
1116 let p = [0.5, 0.5];
1117 let q = [0.25, 0.75];
1118 let tv = total_variation_distance(&p, &q);
1119 assert!(approx_eq(tv, 0.25, EPS));
1120 }
1121 #[test]
1122 fn test_geometric_pmf() {
1123 assert!(approx_eq(geometric_pmf(1, 0.5), 0.5, EPS));
1124 assert!(approx_eq(geometric_pmf(2, 0.5), 0.25, EPS));
1125 }
1126 #[test]
1127 fn test_lcg() {
1128 let mut lcg = Lcg::new(42);
1129 for _ in 0..100 {
1130 let v = lcg.next_f64();
1131 assert!(v >= 0.0 && v < 1.0);
1132 }
1133 }
1134 #[test]
1135 fn test_mixing_time() {
1136 let transition = vec![vec![0.5, 0.5], vec![0.5, 0.5]];
1137 let chain = MarkovChain::new(transition);
1138 let t = chain.mixing_time(0.01);
1139 assert!(t <= 5);
1140 }
1141 #[test]
1142 fn test_empirical_moments() {
1143 let data = [1.0, 2.0, 3.0];
1144 let moments = empirical_moments(&data, 2);
1145 assert!(approx_eq(moments[0], 1.0, EPS));
1146 assert!(approx_eq(moments[1], 2.0, EPS));
1147 assert!(approx_eq(moments[2], 14.0 / 3.0, EPS));
1148 }
1149 #[test]
1150 fn test_gaussian_mgf() {
1151 let g = GaussianDistribution::new(0.0, 1.0);
1152 assert!(approx_eq(g.mgf(1.0), (0.5f64).exp(), EPS));
1153 assert!(approx_eq(g.mgf(0.0), 1.0, EPS));
1154 }
1155 #[test]
1156 fn test_gamma_pdf_exponential() {
1157 let g = gamma_pdf(1.0, 1.0, 1.0);
1158 assert!(approx_eq(g, (-1.0f64).exp(), 1e-6));
1159 }
1160 #[test]
1161 fn test_exponential_distribution_struct() {
1162 let exp = ExponentialDistribution::new(2.0);
1163 assert!(approx_eq(exp.pdf(0.0), 2.0, EPS));
1164 assert!(approx_eq(exp.cdf(1.0), 1.0 - (-2.0f64).exp(), EPS));
1165 assert!(approx_eq(exp.mean(), 0.5, EPS));
1166 assert!(approx_eq(exp.variance(), 0.25, EPS));
1167 assert!(approx_eq(exp.quantile(0.0), 0.0, EPS));
1168 assert!(approx_eq(exp.mgf(1.0), 2.0, EPS));
1169 }
1170 #[test]
1171 fn test_welford_estimator() {
1172 let mut est = WelfordEstimator::new();
1173 for x in [1.0, 2.0, 3.0, 4.0, 5.0] {
1174 est.update(x);
1175 }
1176 assert_eq!(est.count(), 5);
1177 assert!(approx_eq(est.mean(), 3.0, EPS));
1178 assert!(approx_eq(est.variance(), 2.5, EPS));
1179 }
1180 #[test]
1181 fn test_welford_merge() {
1182 let mut est1 = WelfordEstimator::new();
1183 let mut est2 = WelfordEstimator::new();
1184 for x in [1.0, 2.0, 3.0] {
1185 est1.update(x);
1186 }
1187 for x in [4.0, 5.0] {
1188 est2.update(x);
1189 }
1190 est1.merge(&est2);
1191 assert_eq!(est1.count(), 5);
1192 assert!(approx_eq(est1.mean(), 3.0, 1e-10));
1193 }
1194 #[test]
1195 fn test_kde_density() {
1196 let kde = KernelDensityEstimator::with_bandwidth(vec![0.0], 1.0);
1197 let d = kde.density(0.0);
1198 let expected = 1.0 / (2.0 * std::f64::consts::PI).sqrt();
1199 assert!(approx_eq(d, expected, 1e-9));
1200 assert!(kde.density(100.0) < 1e-10);
1201 }
1202 #[test]
1203 fn test_empirical_cdf() {
1204 let ecdf = EmpiricalCdf::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1205 assert_eq!(ecdf.len(), 5);
1206 assert!(approx_eq(ecdf.eval(0.0), 0.0, EPS));
1207 assert!(approx_eq(ecdf.eval(3.0), 0.6, EPS));
1208 assert!(approx_eq(ecdf.eval(10.0), 1.0, EPS));
1209 assert!(approx_eq(ecdf.quantile(0.5), 3.0, EPS));
1210 }
1211 #[test]
1212 fn test_poisson_process() {
1213 let pp = PoissonProcess::new(3.0);
1214 assert!(approx_eq(pp.expected_count(1.0), 3.0, EPS));
1215 assert!(approx_eq(pp.variance_count(2.0), 6.0, EPS));
1216 assert!(approx_eq(pp.count_pmf(1.0, 0), (-3.0f64).exp(), 1e-9));
1217 assert!(approx_eq(pp.compound_expected(2.0, 4.0), 24.0, EPS));
1218 }
1219 #[test]
1220 fn test_poisson_process_simulation() {
1221 let pp = PoissonProcess::new(10.0);
1222 let mut lcg = Lcg::new(12345);
1223 let arrivals = pp.simulate_arrivals(1.0, &mut lcg);
1224 assert!(!arrivals.is_empty() || arrivals.is_empty());
1225 for &t in &arrivals {
1226 assert!(t > 0.0 && t <= 1.0);
1227 }
1228 }
1229 #[test]
1230 fn test_build_advanced_env() {
1231 let mut env = Environment::new();
1232 let result = build_probability_env(&mut env);
1233 assert!(result.is_ok());
1234 }
1235 #[test]
1236 fn test_ks_statistic() {
1237 let data: Vec<f64> = (1..=10).map(|i| i as f64 / 10.0).collect();
1238 let ecdf = EmpiricalCdf::new(data);
1239 let ks = ecdf.ks_statistic(|x| x.clamp(0.0, 1.0));
1240 assert!(ks <= 0.1 + EPS);
1241 }
1242}
1243#[cfg(test)]
1244mod extended_prob_tests {
1245 use super::*;
1246 #[test]
1247 fn test_characteristic_function() {
1248 let cf = CharFunctionData::gaussian(0.0, 1.0);
1249 assert!(cf.is_integrable);
1250 assert!(cf.levy_cramer_applies());
1251 assert!(cf.formula.contains("exp"));
1252 }
1253 #[test]
1254 fn test_large_deviations() {
1255 let ld = LargeDeviations::cramer("X");
1256 assert!(ld.is_good);
1257 assert!(ld.ldp_description().contains("LDP"));
1258 let sanov = LargeDeviations::sanov();
1259 assert!(sanov.rate_function.contains("KL"));
1260 }
1261 #[test]
1262 fn test_ergodic_theorem() {
1263 let birk = ErgodicTheoremData::birkhoff("T");
1264 assert_eq!(birk.theorem_name, "Birkhoff");
1265 assert!(birk.convergence_type.contains("L1"));
1266 }
1267 #[test]
1268 fn test_stopping_time() {
1269 let tau = StoppingTime::first_hitting("A", "F_t");
1270 assert!(tau.optional_stopping_description().contains("tau"));
1271 }
1272 #[test]
1273 fn test_coupling() {
1274 let c = Coupling::maximal("mu", "nu", 0.2);
1275 assert!(c.maximal_coupling_property().contains("P(X != Y)"));
1276 let ot = Coupling::optimal_transport("mu", "nu");
1277 assert!(ot.tv_bound.is_none());
1278 }
1279}
1280#[cfg(test)]
1281mod tests_prob_ext {
1282 use super::*;
1283 #[test]
1284 fn test_gaussian_process_sq_exp() {
1285 let gp = GaussianProcess::with_sq_exp(1.0, 1.0, 2);
1286 assert!(gp.is_stationary);
1287 let k = gp.kernel_value(0.0);
1288 assert!((k - 1.0).abs() < 1e-10);
1289 let k2 = gp.kernel_value(1.0);
1290 assert!(k2 < 1.0 && k2 > 0.0);
1291 let post = gp.posterior_description(5);
1292 assert!(post.contains("GP posterior"));
1293 let mercer = gp.mercer_representation();
1294 assert!(mercer.contains("Mercer"));
1295 }
1296 #[test]
1297 fn test_gaussian_process_matern() {
1298 let gp = GaussianProcess::with_matern(1.5, 1.0, 3);
1299 let k = gp.kernel_value(0.0);
1300 assert!((k - 1.0).abs() < 1e-10);
1301 }
1302 #[test]
1303 fn test_gp_regression() {
1304 let gp = GaussianProcess::with_sq_exp(1.0, 1.0, 2);
1305 let mut gpr = GaussianProcessRegression::new(gp, 0.1);
1306 gpr.n_training = 100;
1307 let cplx = gpr.complexity_exact();
1308 assert!(cplx.contains("O(nΒ³)"));
1309 let sparse = gpr.sparse_gp_complexity(10);
1310 assert!(sparse.contains("inducing"));
1311 let lml = gpr.log_marginal_likelihood();
1312 assert!(lml.contains("log p"));
1313 }
1314 #[test]
1315 fn test_renewal_process() {
1316 let rp = RenewalProcess::poisson_process(2.0);
1317 assert!((rp.rate - 2.0).abs() < 1e-10);
1318 let ert = rp.elementary_renewal_theorem();
1319 assert!(ert.contains("Elementary renewal"));
1320 let rrt = rp.renewal_reward_theorem(1.0);
1321 assert!((rrt - 2.0).abs() < 1e-10);
1322 let blackwell = rp.blackwell_renewal_theorem();
1323 assert!(blackwell.contains("Blackwell"));
1324 }
1325 #[test]
1326 fn test_hawkes_process() {
1327 let hawkes = HawkesProcess::new(1.0, 0.5, 1.0);
1328 assert!(hawkes.is_stationary);
1329 assert!((hawkes.branching_ratio() - 0.5).abs() < 1e-10);
1330 let mean = hawkes.mean_intensity();
1331 assert!(mean > hawkes.base_intensity);
1332 let ci = hawkes.conditional_intensity(1.0, 0.5);
1333 assert!(ci > hawkes.base_intensity);
1334 }
1335 #[test]
1336 fn test_dirichlet_process() {
1337 let dp = DirichletProcess::new(2.0, "N(0,1)");
1338 assert!(dp.is_discrete);
1339 let ec = dp.expected_clusters_for_n(100);
1340 assert!(ec > 0.0);
1341 let stick = dp.stick_breaking_construction();
1342 assert!(stick.contains("Stick-breaking"));
1343 let crp = dp.chinese_restaurant_process(100);
1344 assert!(crp.contains("CRP"));
1345 let post = dp.posterior_update(50);
1346 assert!((post.concentration - 52.0).abs() < 1e-10);
1347 }
1348}
1349#[allow(dead_code)]
1351pub(super) fn lgamma_approx(x: f64) -> f64 {
1352 0.5 * std::f64::consts::TAU.ln() + (x - 0.5) * x.ln() - x
1353}