1use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
6
7use super::types::{
8 AlphaDivMid, AlphaDivergence, BayesianEstimation, BeliefPropagation, BregmanDivergence,
9 ConstantCurvatureManifold, DualConnection, ExpectationPropagation, ExponentialFamily,
10 ExponentialFamilyDistrib, FisherInformationMetric, GaussianProcess, GeodesicOfDistributions,
11 JeffreysPrior, LegendreTransform, MirrorDescent, MomentParameter, NatGradExt, NatGradMid,
12 NaturalParameter, QuantumInfoGeometry, ReferenceAnalysis, SchroedingerBridge,
13 SlicedWasserstein, StatManiExt, StatManiMid, StatisticalManifold, WassersteinGeometry,
14};
15
16pub fn app(f: Expr, a: Expr) -> Expr {
17 Expr::App(Box::new(f), Box::new(a))
18}
19pub fn app2(f: Expr, a: Expr, b: Expr) -> Expr {
20 app(app(f, a), b)
21}
22pub fn app3(f: Expr, a: Expr, b: Expr, c: Expr) -> Expr {
23 app(app2(f, a, b), c)
24}
25pub fn cst(s: &str) -> Expr {
26 Expr::Const(Name::str(s), vec![])
27}
28pub fn prop() -> Expr {
29 Expr::Sort(Level::zero())
30}
31pub fn type0() -> Expr {
32 Expr::Sort(Level::succ(Level::zero()))
33}
34pub fn pi(bi: BinderInfo, name: &str, dom: Expr, body: Expr) -> Expr {
35 Expr::Pi(bi, Name::str(name), Box::new(dom), Box::new(body))
36}
37pub fn arrow(a: Expr, b: Expr) -> Expr {
38 pi(BinderInfo::Default, "_", a, b)
39}
40pub fn bvar(n: u32) -> Expr {
41 Expr::BVar(n)
42}
43pub fn nat_ty() -> Expr {
44 cst("Nat")
45}
46pub fn real_ty() -> Expr {
47 cst("Real")
48}
49pub fn list_ty(elem: Expr) -> Expr {
50 app(cst("List"), elem)
51}
52pub fn statistical_manifold_ty() -> Expr {
55 arrow(nat_ty(), type0())
56}
57pub fn fisher_information_metric_ty() -> Expr {
60 arrow(nat_ty(), type0())
61}
62pub fn riemannian_metric_ty() -> Expr {
65 arrow(nat_ty(), type0())
66}
67pub fn geodesic_of_distributions_ty() -> Expr {
70 arrow(type0(), arrow(type0(), type0()))
71}
72pub fn chentsov_theorem_ty() -> Expr {
76 prop()
77}
78pub fn geodesic_distance_formula_ty() -> Expr {
81 pi(BinderInfo::Default, "n", nat_ty(), prop())
82}
83pub fn sectional_curvature_ty() -> Expr {
86 pi(BinderInfo::Default, "n", nat_ty(), real_ty())
87}
88pub fn christoffel_symbols_ty() -> Expr {
91 arrow(nat_ty(), arrow(nat_ty(), type0()))
92}
93pub fn exponential_family_ty() -> Expr {
96 arrow(nat_ty(), type0())
97}
98pub fn natural_parameter_ty() -> Expr {
101 arrow(nat_ty(), type0())
102}
103pub fn moment_parameter_ty() -> Expr {
106 arrow(nat_ty(), type0())
107}
108pub fn legendre_transform_ty() -> Expr {
111 arrow(
112 arrow(list_ty(real_ty()), real_ty()),
113 arrow(list_ty(real_ty()), real_ty()),
114 )
115}
116pub fn log_partition_function_ty() -> Expr {
119 arrow(list_ty(real_ty()), real_ty())
120}
121pub fn natural_to_moment_ty() -> Expr {
124 pi(BinderInfo::Default, "d", nat_ty(), prop())
125}
126pub fn bregman_divergence_ty() -> Expr {
129 pi(BinderInfo::Default, "d", nat_ty(), prop())
130}
131pub fn fisher_as_hessian_ty() -> Expr {
134 pi(BinderInfo::Default, "d", nat_ty(), prop())
135}
136pub fn kl_equals_bregman_ty() -> Expr {
140 prop()
141}
142pub fn alpha_connection_ty() -> Expr {
146 arrow(real_ty(), arrow(nat_ty(), type0()))
147}
148pub fn alpha_divergence_ty() -> Expr {
151 arrow(
152 real_ty(),
153 arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty())),
154 )
155}
156pub fn dual_connection_ty() -> Expr {
159 arrow(nat_ty(), type0())
160}
161pub fn constant_curvature_manifold_ty() -> Expr {
165 arrow(real_ty(), arrow(nat_ty(), type0()))
166}
167pub fn alpha_duality_theorem_ty() -> Expr {
170 pi(
171 BinderInfo::Default,
172 "alpha",
173 real_ty(),
174 pi(BinderInfo::Default, "n", nat_ty(), prop()),
175 )
176}
177pub fn generalized_pythagoras_ty() -> Expr {
180 pi(BinderInfo::Default, "n", nat_ty(), prop())
181}
182pub fn alpha_divergence_limits_ty() -> Expr {
185 prop()
186}
187pub fn curvature_formula_ty() -> Expr {
190 pi(BinderInfo::Default, "alpha", real_ty(), real_ty())
191}
192pub fn bayesian_estimation_ty() -> Expr {
195 arrow(
196 arrow(real_ty(), real_ty()),
197 arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty())),
198 )
199}
200pub fn jeffreys_prior_ty() -> Expr {
203 arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty()))
204}
205pub fn reference_analysis_ty() -> Expr {
208 arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty()))
209}
210pub fn expectation_propagation_ty() -> Expr {
213 arrow(nat_ty(), type0())
214}
215pub fn jeffreys_invariance_ty() -> Expr {
218 prop()
219}
220pub fn bernstein_von_mises_ty() -> Expr {
223 pi(BinderInfo::Default, "n", nat_ty(), prop())
224}
225pub fn ep_fixed_point_ty() -> Expr {
228 prop()
229}
230pub fn laplace_approximation_ty() -> Expr {
233 pi(BinderInfo::Default, "n", nat_ty(), prop())
234}
235pub fn fisher_rao_metric_ty() -> Expr {
239 arrow(nat_ty(), type0())
240}
241pub fn e_connection_ty() -> Expr {
245 arrow(nat_ty(), type0())
246}
247pub fn m_connection_ty() -> Expr {
251 arrow(nat_ty(), type0())
252}
253pub fn e_projection_ty() -> Expr {
257 arrow(nat_ty(), arrow(type0(), type0()))
258}
259pub fn m_projection_ty() -> Expr {
263 arrow(nat_ty(), arrow(type0(), type0()))
264}
265pub fn pythagorean_theorem_ig_ty() -> Expr {
270 pi(BinderInfo::Default, "n", nat_ty(), prop())
271}
272pub fn e_flat_exponential_family_ty() -> Expr {
275 pi(BinderInfo::Default, "d", nat_ty(), prop())
276}
277pub fn m_flat_mixture_family_ty() -> Expr {
280 pi(BinderInfo::Default, "d", nat_ty(), prop())
281}
282pub fn legendre_duality_ty() -> Expr {
285 pi(BinderInfo::Default, "d", nat_ty(), prop())
286}
287pub fn f_divergence_ty() -> Expr {
290 arrow(
291 arrow(real_ty(), real_ty()),
292 arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty())),
293 )
294}
295pub fn bregman_divergence_gen_ty() -> Expr {
299 arrow(
300 arrow(list_ty(real_ty()), real_ty()),
301 arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty())),
302 )
303}
304pub fn wasserstein_metric_ty() -> Expr {
307 arrow(real_ty(), arrow(nat_ty(), type0()))
308}
309pub fn f_div_is_bregman_on_exp_ty() -> Expr {
312 prop()
313}
314pub fn chentsov_uniqueness_f_div_ty() -> Expr {
318 prop()
319}
320pub fn wasserstein_vs_fisher_rao_ty() -> Expr {
324 pi(BinderInfo::Default, "n", nat_ty(), prop())
325}
326pub fn pinsker_inequality_ty() -> Expr {
329 prop()
330}
331pub fn natural_gradient_descent_ty() -> Expr {
335 arrow(nat_ty(), type0())
336}
337pub fn mirror_descent_ty() -> Expr {
341 arrow(nat_ty(), type0())
342}
343pub fn em_algorithm_ty() -> Expr {
347 arrow(nat_ty(), arrow(nat_ty(), type0()))
348}
349pub fn natural_gradient_convergence_ty() -> Expr {
353 pi(BinderInfo::Default, "d", nat_ty(), prop())
354}
355pub fn mirror_descent_eq_natural_gradient_ty() -> Expr {
359 prop()
360}
361pub fn em_monotone_convergence_ty() -> Expr {
364 pi(BinderInfo::Default, "n", nat_ty(), prop())
365}
366pub fn em_as_alternating_projection_ty() -> Expr {
369 prop()
370}
371pub fn belief_propagation_ty() -> Expr {
375 arrow(nat_ty(), arrow(nat_ty(), type0()))
376}
377pub fn tree_reweighted_bp_ty() -> Expr {
380 arrow(nat_ty(), type0())
381}
382pub fn bp_fixed_point_bethe_ty() -> Expr {
385 pi(BinderInfo::Default, "n", nat_ty(), prop())
386}
387pub fn bp_exact_on_tree_ty() -> Expr {
390 pi(BinderInfo::Default, "n", nat_ty(), prop())
391}
392pub fn sanov_theorem_ty() -> Expr {
396 arrow(nat_ty(), type0())
397}
398pub fn rate_function_ty() -> Expr {
401 arrow(list_ty(real_ty()), real_ty())
402}
403pub fn sanov_kl_rate_function_ty() -> Expr {
406 pi(BinderInfo::Default, "n", nat_ty(), prop())
407}
408pub fn contraction_principle_ty() -> Expr {
411 pi(BinderInfo::Default, "n", nat_ty(), prop())
412}
413pub fn quantum_statistical_manifold_ty() -> Expr {
416 arrow(nat_ty(), arrow(nat_ty(), type0()))
417}
418pub fn sld_metric_ty() -> Expr {
422 arrow(nat_ty(), arrow(nat_ty(), type0()))
423}
424pub fn rld_metric_ty() -> Expr {
427 arrow(nat_ty(), arrow(nat_ty(), type0()))
428}
429pub fn quantum_relative_entropy_ty() -> Expr {
432 arrow(nat_ty(), type0())
433}
434pub fn quantum_cramer_rao_ty() -> Expr {
437 pi(BinderInfo::Default, "d", nat_ty(), prop())
438}
439pub fn sld_monotonicity_ty() -> Expr {
442 prop()
443}
444pub fn uhlmann_theorem_ty() -> Expr {
447 pi(BinderInfo::Default, "n", nat_ty(), prop())
448}
449pub fn quantum_stein_lemma_ty() -> Expr {
452 prop()
453}
454pub fn ito_girsanov_ig_ty() -> Expr {
458 arrow(nat_ty(), type0())
459}
460pub fn fokker_planck_ig_ty() -> Expr {
464 arrow(nat_ty(), type0())
465}
466pub fn girsanov_e_geodesic_ty() -> Expr {
470 pi(BinderInfo::Default, "d", nat_ty(), prop())
471}
472pub fn otto_calculus_gradient_flow_ty() -> Expr {
475 pi(BinderInfo::Default, "d", nat_ty(), prop())
476}
477pub fn build_env(env: &mut Environment) -> Result<(), String> {
479 let axioms: &[(&str, Expr)] = &[
480 ("StatisticalManifold", statistical_manifold_ty()),
481 ("FisherInformationMetric", fisher_information_metric_ty()),
482 ("RiemannianMetric", riemannian_metric_ty()),
483 ("GeodesicOfDistributions", geodesic_of_distributions_ty()),
484 ("chentsov_theorem", chentsov_theorem_ty()),
485 ("geodesic_distance_formula", geodesic_distance_formula_ty()),
486 ("sectional_curvature", sectional_curvature_ty()),
487 ("christoffel_symbols", christoffel_symbols_ty()),
488 ("ExponentialFamily", exponential_family_ty()),
489 ("NaturalParameter", natural_parameter_ty()),
490 ("MomentParameter", moment_parameter_ty()),
491 ("LegendreTransform", legendre_transform_ty()),
492 ("LogPartitionFunction", log_partition_function_ty()),
493 ("natural_to_moment", natural_to_moment_ty()),
494 ("bregman_divergence", bregman_divergence_ty()),
495 ("fisher_as_hessian", fisher_as_hessian_ty()),
496 ("kl_equals_bregman", kl_equals_bregman_ty()),
497 ("AlphaConnection", alpha_connection_ty()),
498 ("AlphaDivergence", alpha_divergence_ty()),
499 ("DualConnection", dual_connection_ty()),
500 (
501 "ConstantCurvatureManifold",
502 constant_curvature_manifold_ty(),
503 ),
504 ("alpha_duality_theorem", alpha_duality_theorem_ty()),
505 ("generalized_pythagoras", generalized_pythagoras_ty()),
506 ("alpha_divergence_limits", alpha_divergence_limits_ty()),
507 ("curvature_formula", curvature_formula_ty()),
508 ("BayesianEstimation", bayesian_estimation_ty()),
509 ("JeffreysPrior", jeffreys_prior_ty()),
510 ("ReferenceAnalysis", reference_analysis_ty()),
511 ("ExpectationPropagation", expectation_propagation_ty()),
512 ("jeffreys_invariance", jeffreys_invariance_ty()),
513 ("bernstein_von_mises", bernstein_von_mises_ty()),
514 ("ep_fixed_point", ep_fixed_point_ty()),
515 ("laplace_approximation", laplace_approximation_ty()),
516 ("FisherRaoMetric", fisher_rao_metric_ty()),
517 ("EConnection", e_connection_ty()),
518 ("MConnection", m_connection_ty()),
519 ("EProjection", e_projection_ty()),
520 ("MProjection", m_projection_ty()),
521 ("pythagorean_theorem_ig", pythagorean_theorem_ig_ty()),
522 ("e_flat_exponential_family", e_flat_exponential_family_ty()),
523 ("m_flat_mixture_family", m_flat_mixture_family_ty()),
524 ("legendre_duality", legendre_duality_ty()),
525 ("FDivergence", f_divergence_ty()),
526 ("BregmanDivergenceGen", bregman_divergence_gen_ty()),
527 ("WassersteinMetric", wasserstein_metric_ty()),
528 ("f_div_is_bregman_on_exp", f_div_is_bregman_on_exp_ty()),
529 ("chentsov_uniqueness_f_div", chentsov_uniqueness_f_div_ty()),
530 ("wasserstein_vs_fisher_rao", wasserstein_vs_fisher_rao_ty()),
531 ("pinsker_inequality", pinsker_inequality_ty()),
532 ("NaturalGradientDescent", natural_gradient_descent_ty()),
533 ("MirrorDescent", mirror_descent_ty()),
534 ("EMAlgorithm", em_algorithm_ty()),
535 (
536 "natural_gradient_convergence",
537 natural_gradient_convergence_ty(),
538 ),
539 (
540 "mirror_descent_eq_natural_gradient",
541 mirror_descent_eq_natural_gradient_ty(),
542 ),
543 ("em_monotone_convergence", em_monotone_convergence_ty()),
544 (
545 "em_as_alternating_projection",
546 em_as_alternating_projection_ty(),
547 ),
548 ("BeliefPropagation", belief_propagation_ty()),
549 ("TreeReweightedBP", tree_reweighted_bp_ty()),
550 ("bp_fixed_point_bethe", bp_fixed_point_bethe_ty()),
551 ("bp_exact_on_tree", bp_exact_on_tree_ty()),
552 ("SanovTheorem", sanov_theorem_ty()),
553 ("RateFunction", rate_function_ty()),
554 ("sanov_kl_rate_function", sanov_kl_rate_function_ty()),
555 ("contraction_principle", contraction_principle_ty()),
556 (
557 "QuantumStatisticalManifold",
558 quantum_statistical_manifold_ty(),
559 ),
560 ("SLDMetric", sld_metric_ty()),
561 ("RLDMetric", rld_metric_ty()),
562 ("QuantumRelativeEntropy", quantum_relative_entropy_ty()),
563 ("quantum_cramer_rao", quantum_cramer_rao_ty()),
564 ("sld_monotonicity", sld_monotonicity_ty()),
565 ("uhlmann_theorem", uhlmann_theorem_ty()),
566 ("quantum_stein_lemma", quantum_stein_lemma_ty()),
567 ("ItoGirsanovIG", ito_girsanov_ig_ty()),
568 ("FokkerPlanckIG", fokker_planck_ig_ty()),
569 ("girsanov_e_geodesic", girsanov_e_geodesic_ty()),
570 (
571 "otto_calculus_gradient_flow",
572 otto_calculus_gradient_flow_ty(),
573 ),
574 ];
575 for (name, ty) in axioms {
576 env.add(Declaration::Axiom {
577 name: Name::str(*name),
578 univ_params: vec![],
579 ty: ty.clone(),
580 })
581 .ok();
582 }
583 Ok(())
584}
585pub fn dot_product(a: &[f64], b: &[f64]) -> f64 {
587 a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
588}
589pub fn mat_vec(a: &[Vec<f64>], v: &[f64]) -> Vec<f64> {
591 a.iter().map(|row| dot_product(row, v)).collect()
592}
593pub fn solve_linear_system(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
595 let d = b.len();
596 let mut mat: Vec<Vec<f64>> = a.to_vec();
597 let mut rhs: Vec<f64> = b.to_vec();
598 for col in 0..d {
599 let pivot = (col..d)
600 .max_by(|&i, &j| {
601 mat[i][col]
602 .abs()
603 .partial_cmp(&mat[j][col].abs())
604 .unwrap_or(std::cmp::Ordering::Equal)
605 })
606 .unwrap_or(col);
607 mat.swap(col, pivot);
608 rhs.swap(col, pivot);
609 let diag = mat[col][col];
610 if diag.abs() < 1e-14 {
611 continue;
612 }
613 for row in (col + 1)..d {
614 let factor = mat[row][col] / diag;
615 for k in col..d {
616 let val = mat[col][k];
617 mat[row][k] -= factor * val;
618 }
619 rhs[row] -= factor * rhs[col];
620 }
621 }
622 let mut x = vec![0.0f64; d];
623 for i in (0..d).rev() {
624 let mut s = rhs[i];
625 for j in (i + 1)..d {
626 s -= mat[i][j] * x[j];
627 }
628 x[i] = if mat[i][i].abs() < 1e-14 {
629 0.0
630 } else {
631 s / mat[i][i]
632 };
633 }
634 x
635}
636#[cfg(test)]
637mod ig_ext_tests {
638 use super::*;
639 #[test]
640 fn test_statistical_manifold() {
641 let exp = StatManiMid::exponential_family("Normal", 2);
642 assert!(exp.is_dually_flat());
643 assert!(!exp.alpha_divergence_description().is_empty());
644 }
645 #[test]
646 fn test_natural_gradient() {
647 let ng = NatGradMid::new(10, 0.01);
648 assert!(!ng.update_rule().is_empty());
649 assert!(!ng.invariance_property().is_empty());
650 }
651 #[test]
652 fn test_alpha_divergence() {
653 let kl = AlphaDivMid::kl_divergence("p", "q");
654 assert!(kl.is_kl());
655 }
656 #[test]
657 fn test_bregman_divergence() {
658 let bd = BregmanDivergence::squared_euclidean();
659 assert!(!bd.definition().is_empty());
660 assert!(!bd.three_point_property().is_empty());
661 }
662 #[test]
663 fn test_wasserstein() {
664 let w = WassersteinGeometry::new(2, "R^d");
665 assert!(!w.w2_distance_description().is_empty());
666 assert!(!w.benamou_brenier_description().is_empty());
667 }
668}
669#[cfg(test)]
670mod gp_expfam_tests {
671 use super::*;
672 #[test]
673 fn test_gaussian_process() {
674 let gp = GaussianProcess::rbf(1.0);
675 assert!(gp.is_stationary);
676 assert!(!gp.posterior_description().is_empty());
677 }
678 #[test]
679 fn test_exponential_family() {
680 let gauss = ExponentialFamilyDistrib::gaussian(2);
681 assert!(gauss.mle_equals_moment_matching());
682 assert!(!gauss.natural_to_moment_params().is_empty());
683 }
684}
685#[cfg(test)]
686mod tests_info_geom_ext {
687 use super::*;
688 #[test]
689 fn test_natural_gradient() {
690 let ng = NatGradExt::new(10);
691 let update = ng.update_rule(0.01);
692 assert!(update.contains("Natural gradient"));
693 let fr = ng.fisher_rao_distance();
694 assert!(fr.contains("Fisher-Rao"));
695 let amari = ng.amari_dual_connection();
696 assert!(amari.contains("α-connection"));
697 let inv = ng.invariance_property();
698 assert!(inv.contains("Fisher-Rao"));
699 }
700 #[test]
701 fn test_statistical_manifold() {
702 let gauss = StatManiExt::gaussian_family();
703 assert!(gauss.is_dually_flat);
704 assert_eq!(gauss.dimension, 2);
705 let pyth = gauss.pythagorean_theorem();
706 assert!(pyth.contains("Pythagoras"));
707 let bregman = gauss.bregman_divergence_connection();
708 assert!(bregman.contains("Bregman"));
709 }
710 #[test]
711 fn test_sliced_wasserstein() {
712 let sw = SlicedWasserstein::new(10, 100);
713 let desc = sw.complexity_description();
714 assert!(desc.contains("Sliced"));
715 let bonneel = sw.bonneel_et_al_description();
716 assert!(bonneel.contains("sliced Wasserstein"));
717 }
718 #[test]
719 fn test_schroedinger_bridge() {
720 let sb = SchroedingerBridge::new("P", "Q", "BM", 0.01);
721 let sink = sb.sinkhorn_algorithm();
722 assert!(sink.contains("Sinkhorn"));
723 let ipfp = sb.ipfp_iteration();
724 assert!(ipfp.contains("IPFP"));
725 let diff = sb.connection_to_diffusion_models();
726 assert!(diff.contains("diffusion"));
727 }
728 #[test]
729 fn test_quantum_info_geom() {
730 let bures = QuantumInfoGeometry::bures_metric(4);
731 assert!(bures.is_monotone_metric);
732 let petz = bures.petz_classification();
733 assert!(petz.contains("Petz"));
734 let qcr = bures.quantum_cramer_rao();
735 assert!(qcr.contains("Cramér-Rao"));
736 let holevo = bures.holevo_bound();
737 assert!(holevo.contains("Holevo"));
738 let bures_dist = bures.bures_distance(1.0);
739 assert!((bures_dist - 0.0).abs() < 1e-10);
740 let bures_dist2 = bures.bures_distance(0.0);
741 assert!((bures_dist2 - 2.0_f64.sqrt()).abs() < 1e-10);
742 }
743}