Skip to main content

oxilean_std/information_geometry/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
6
7use super::types::{
8    AlphaDivMid, AlphaDivergence, BayesianEstimation, BeliefPropagation, BregmanDivergence,
9    ConstantCurvatureManifold, DualConnection, ExpectationPropagation, ExponentialFamily,
10    ExponentialFamilyDistrib, FisherInformationMetric, GaussianProcess, GeodesicOfDistributions,
11    JeffreysPrior, LegendreTransform, MirrorDescent, MomentParameter, NatGradExt, NatGradMid,
12    NaturalParameter, QuantumInfoGeometry, ReferenceAnalysis, SchroedingerBridge,
13    SlicedWasserstein, StatManiExt, StatManiMid, StatisticalManifold, WassersteinGeometry,
14};
15
16pub fn app(f: Expr, a: Expr) -> Expr {
17    Expr::App(Box::new(f), Box::new(a))
18}
19pub fn app2(f: Expr, a: Expr, b: Expr) -> Expr {
20    app(app(f, a), b)
21}
22pub fn app3(f: Expr, a: Expr, b: Expr, c: Expr) -> Expr {
23    app(app2(f, a, b), c)
24}
25pub fn cst(s: &str) -> Expr {
26    Expr::Const(Name::str(s), vec![])
27}
28pub fn prop() -> Expr {
29    Expr::Sort(Level::zero())
30}
31pub fn type0() -> Expr {
32    Expr::Sort(Level::succ(Level::zero()))
33}
34pub fn pi(bi: BinderInfo, name: &str, dom: Expr, body: Expr) -> Expr {
35    Expr::Pi(bi, Name::str(name), Box::new(dom), Box::new(body))
36}
37pub fn arrow(a: Expr, b: Expr) -> Expr {
38    pi(BinderInfo::Default, "_", a, b)
39}
40pub fn bvar(n: u32) -> Expr {
41    Expr::BVar(n)
42}
43pub fn nat_ty() -> Expr {
44    cst("Nat")
45}
46pub fn real_ty() -> Expr {
47    cst("Real")
48}
49pub fn list_ty(elem: Expr) -> Expr {
50    app(cst("List"), elem)
51}
52/// `StatisticalManifold`: smooth manifold of probability distributions parametrized by θ ∈ Θ ⊆ ℝ^n
53/// Type: Nat → Type (dimension n → manifold)
54pub fn statistical_manifold_ty() -> Expr {
55    arrow(nat_ty(), type0())
56}
57/// `FisherInformationMetric`: g_{ij}(θ) = E[∂_i log p · ∂_j log p]
58/// Type: Nat → Type (dim n → n×n metric tensor field)
59pub fn fisher_information_metric_ty() -> Expr {
60    arrow(nat_ty(), type0())
61}
62/// `RiemannianMetric`: general Riemannian metric on the probability simplex
63/// Type: Nat → Type
64pub fn riemannian_metric_ty() -> Expr {
65    arrow(nat_ty(), type0())
66}
67/// `GeodesicOfDistributions`: shortest path between two distributions on the manifold
68/// Type: Type → Type → Type (start → end → geodesic path)
69pub fn geodesic_of_distributions_ty() -> Expr {
70    arrow(type0(), arrow(type0(), type0()))
71}
72/// Chentsov's theorem: the Fisher information metric is the unique (up to scale)
73/// Riemannian metric invariant under sufficient statistics
74/// Type: Prop
75pub fn chentsov_theorem_ty() -> Expr {
76    prop()
77}
78/// Geodesic distance formula: d(p,q) = 2 arccos(∫ √(p q) dμ) (Bhattacharyya arc length)
79/// Type: ∀ (n : Nat), Prop
80pub fn geodesic_distance_formula_ty() -> Expr {
81    pi(BinderInfo::Default, "n", nat_ty(), prop())
82}
83/// Sectional curvature of the statistical manifold under Fisher metric
84/// Type: ∀ (n : Nat), Real (returns curvature)
85pub fn sectional_curvature_ty() -> Expr {
86    pi(BinderInfo::Default, "n", nat_ty(), real_ty())
87}
88/// Christoffel symbols Γ^k_{ij} for the Fisher information metric
89/// Type: Nat → Nat → Type
90pub fn christoffel_symbols_ty() -> Expr {
91    arrow(nat_ty(), arrow(nat_ty(), type0()))
92}
93/// `ExponentialFamily`: p(x|θ) = exp(⟨θ, T(x)⟩ - A(θ)) h(x)
94/// Type: Nat → Type (sufficient statistic dimension → family)
95pub fn exponential_family_ty() -> Expr {
96    arrow(nat_ty(), type0())
97}
98/// `NaturalParameter`: θ ∈ Θ ⊆ ℝ^d (canonical/natural parameters)
99/// Type: Nat → Type
100pub fn natural_parameter_ty() -> Expr {
101    arrow(nat_ty(), type0())
102}
103/// `MomentParameter`: η = E_θ[T(x)] ∈ ℝ^d (mean/moment parameters)
104/// Type: Nat → Type
105pub fn moment_parameter_ty() -> Expr {
106    arrow(nat_ty(), type0())
107}
108/// `LegendreTransform`: A*(η) = sup_θ {⟨θ,η⟩ - A(θ)} (convex conjugate of log-partition)
109/// Type: (List Real → Real) → List Real → Real
110pub fn legendre_transform_ty() -> Expr {
111    arrow(
112        arrow(list_ty(real_ty()), real_ty()),
113        arrow(list_ty(real_ty()), real_ty()),
114    )
115}
116/// `LogPartitionFunction`: A(θ) = log ∫ exp(⟨θ, T(x)⟩) h(x) dμ(x)
117/// Type: List Real → Real
118pub fn log_partition_function_ty() -> Expr {
119    arrow(list_ty(real_ty()), real_ty())
120}
121/// Natural-to-moment parameter conversion: η = ∇A(θ)
122/// Type: ∀ (d : Nat), Prop
123pub fn natural_to_moment_ty() -> Expr {
124    pi(BinderInfo::Default, "d", nat_ty(), prop())
125}
126/// Bregman divergence from log-partition: D_A(η ‖ η') = A*(η) - A*(η') - ⟨∇A*(η'), η - η'⟩
127/// Type: ∀ (d : Nat), Prop
128pub fn bregman_divergence_ty() -> Expr {
129    pi(BinderInfo::Default, "d", nat_ty(), prop())
130}
131/// Fisher information as Hessian of log-partition: I(θ) = ∇²A(θ)
132/// Type: ∀ (d : Nat), Prop
133pub fn fisher_as_hessian_ty() -> Expr {
134    pi(BinderInfo::Default, "d", nat_ty(), prop())
135}
136/// KL divergence equals Bregman divergence for exponential families:
137/// D_KL(p_θ ‖ p_θ') = D_A(η ‖ η')
138/// Type: Prop
139pub fn kl_equals_bregman_ty() -> Expr {
140    prop()
141}
142/// `AlphaConnection`: Γ^(α)_{ij,k} = Γ^(0)_{ij,k} - (α/2) T_{ijk}
143/// (mixture of e-connection and m-connection)
144/// Type: Real → Nat → Type (α parameter → dimension → connection)
145pub fn alpha_connection_ty() -> Expr {
146    arrow(real_ty(), arrow(nat_ty(), type0()))
147}
148/// `AlphaDivergence`: D^(α)(P‖Q) = 4/(1-α²)(1 - ∫p^{(1+α)/2} q^{(1-α)/2} dμ)
149/// Type: Real → List Real → List Real → Real (α, P-dist, Q-dist → divergence)
150pub fn alpha_divergence_ty() -> Expr {
151    arrow(
152        real_ty(),
153        arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty())),
154    )
155}
156/// `DualConnection`: (∇, ∇*) dual affine connections satisfying X⟨Y,Z⟩ = ⟨∇_X Y, Z⟩ + ⟨Y, ∇*_X Z⟩
157/// Type: Nat → Type (dimension → dual connection pair)
158pub fn dual_connection_ty() -> Expr {
159    arrow(nat_ty(), type0())
160}
161/// `ConstantCurvatureManifold`: statistical manifold with constant α-curvature
162/// (α = ±1 gives exponential/mixture families)
163/// Type: Real → Nat → Type (curvature α → dimension → manifold)
164pub fn constant_curvature_manifold_ty() -> Expr {
165    arrow(real_ty(), arrow(nat_ty(), type0()))
166}
167/// Duality theorem: (∇^(α))* = ∇^(-α)
168/// Type: ∀ (α : Real) (n : Nat), Prop
169pub fn alpha_duality_theorem_ty() -> Expr {
170    pi(
171        BinderInfo::Default,
172        "alpha",
173        real_ty(),
174        pi(BinderInfo::Default, "n", nat_ty(), prop()),
175    )
176}
177/// Generalized Pythagorean theorem for α-divergences on flat manifolds
178/// Type: ∀ (n : Nat), Prop
179pub fn generalized_pythagoras_ty() -> Expr {
180    pi(BinderInfo::Default, "n", nat_ty(), prop())
181}
182/// α-divergence reduction: α=1 gives KL, α=-1 gives reverse KL, α=0 gives Hellinger
183/// Type: Prop
184pub fn alpha_divergence_limits_ty() -> Expr {
185    prop()
186}
187/// Curvature formula: constant curvature = -1/4 for e/m-families
188/// Type: ∀ (α : Real), Real
189pub fn curvature_formula_ty() -> Expr {
190    pi(BinderInfo::Default, "alpha", real_ty(), real_ty())
191}
192/// `BayesianEstimation`: posterior p(θ|x) ∝ L(θ|x) · π(θ)
193/// Type: (Real → Real) → (Real → Real) → Real → Real (likelihood, prior, x → posterior)
194pub fn bayesian_estimation_ty() -> Expr {
195    arrow(
196        arrow(real_ty(), real_ty()),
197        arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty())),
198    )
199}
200/// `JeffreysPrior`: π(θ) ∝ √det(I(θ)) — invariant under reparametrization
201/// Type: (Real → Real) → Real → Real (log-density → θ → prior density)
202pub fn jeffreys_prior_ty() -> Expr {
203    arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty()))
204}
205/// `ReferenceAnalysis`: Bernardo's reference prior maximizing expected KL divergence
206/// Type: (Real → Real) → Real → Real
207pub fn reference_analysis_ty() -> Expr {
208    arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty()))
209}
210/// `ExpectationPropagation`: EP approximation — project tilted distribution onto exponential family
211/// Type: Nat → Type (number of factors → EP state)
212pub fn expectation_propagation_ty() -> Expr {
213    arrow(nat_ty(), type0())
214}
215/// Jeffreys prior invariance: π̃(φ) = π(θ)|dθ/dφ| for reparametrization φ = g(θ)
216/// Type: Prop
217pub fn jeffreys_invariance_ty() -> Expr {
218    prop()
219}
220/// Bernstein-von Mises theorem: posterior concentrates at MLE as n → ∞
221/// Type: ∀ (n : Nat), Prop
222pub fn bernstein_von_mises_ty() -> Expr {
223    pi(BinderInfo::Default, "n", nat_ty(), prop())
224}
225/// EP fixed point: at convergence, q(θ) is the closest exponential family member to p(θ|x)
226/// Type: Prop
227pub fn ep_fixed_point_ty() -> Expr {
228    prop()
229}
230/// Laplace approximation: posterior ≈ N(θ_MAP, I(θ_MAP)^{-1}/n)
231/// Type: ∀ (n : Nat), Prop
232pub fn laplace_approximation_ty() -> Expr {
233    pi(BinderInfo::Default, "n", nat_ty(), prop())
234}
235/// `FisherRaoMetric`: Riemannian metric on the probability simplex induced by
236/// the Fisher information: ds² = Σ_{ij} g_{ij}(θ) dθ^i dθ^j
237/// Type: Nat → Type (dimension → metric)
238pub fn fisher_rao_metric_ty() -> Expr {
239    arrow(nat_ty(), type0())
240}
241/// `EConnection`: the (-1)-connection (exponential connection ∇^{(-1)}) on a
242/// statistical manifold; flat in exponential coordinates
243/// Type: Nat → Type
244pub fn e_connection_ty() -> Expr {
245    arrow(nat_ty(), type0())
246}
247/// `MConnection`: the (+1)-connection (mixture connection ∇^{(+1)}) on a
248/// statistical manifold; flat in mixture coordinates
249/// Type: Nat → Type
250pub fn m_connection_ty() -> Expr {
251    arrow(nat_ty(), type0())
252}
253/// `EProjection`: projection of a distribution onto an e-flat (exponential family) submanifold
254/// minimizing KL divergence: π_e(p) = argmin_{q ∈ E} D_KL(q ‖ p)
255/// Type: Nat → Type → Type (dim → family → projected dist)
256pub fn e_projection_ty() -> Expr {
257    arrow(nat_ty(), arrow(type0(), type0()))
258}
259/// `MProjection`: projection of a distribution onto an m-flat (mixture family) submanifold
260/// minimizing KL divergence: π_m(p) = argmin_{q ∈ M} D_KL(p ‖ q)
261/// Type: Nat → Type → Type (dim → family → projected dist)
262pub fn m_projection_ty() -> Expr {
263    arrow(nat_ty(), arrow(type0(), type0()))
264}
265/// Pythagorean theorem in information geometry:
266/// for e-geodesic p,r with m-projection q onto e-flat family:
267/// D_KL(p ‖ r) = D_KL(p ‖ q) + D_KL(q ‖ r)
268/// Type: ∀ (n : Nat), Prop
269pub fn pythagorean_theorem_ig_ty() -> Expr {
270    pi(BinderInfo::Default, "n", nat_ty(), prop())
271}
272/// e-geodesic closure: exponential families are e-flat (e-geodesically complete)
273/// Type: ∀ (d : Nat), Prop
274pub fn e_flat_exponential_family_ty() -> Expr {
275    pi(BinderInfo::Default, "d", nat_ty(), prop())
276}
277/// m-geodesic closure: mixture families are m-flat (m-geodesically complete)
278/// Type: ∀ (d : Nat), Prop
279pub fn m_flat_mixture_family_ty() -> Expr {
280    pi(BinderInfo::Default, "d", nat_ty(), prop())
281}
282/// Legendre duality: θ ↦ η is a bijection, and A**(θ) = A(θ) (double Legendre)
283/// Type: ∀ (d : Nat), Prop
284pub fn legendre_duality_ty() -> Expr {
285    pi(BinderInfo::Default, "d", nat_ty(), prop())
286}
287/// `FDivergence`: D_f(P ‖ Q) = ∫ f(dP/dQ) dQ for a convex f with f(1)=0
288/// Type: (Real → Real) → List Real → List Real → Real (generator f, P, Q → divergence)
289pub fn f_divergence_ty() -> Expr {
290    arrow(
291        arrow(real_ty(), real_ty()),
292        arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty())),
293    )
294}
295/// `BregmanDivergenceGen`: generalized Bregman divergence D_φ(x ‖ y) = φ(x) - φ(y) - ⟨∇φ(y), x-y⟩
296/// for a strictly convex differentiable φ
297/// Type: (List Real → Real) → List Real → List Real → Real
298pub fn bregman_divergence_gen_ty() -> Expr {
299    arrow(
300        arrow(list_ty(real_ty()), real_ty()),
301        arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty())),
302    )
303}
304/// `WassersteinMetric`: optimal transport distance W_p(μ,ν) = (inf_γ ∫|x-y|^p dγ)^{1/p}
305/// Type: Real → Nat → Type (p-parameter → dim → metric)
306pub fn wasserstein_metric_ty() -> Expr {
307    arrow(real_ty(), arrow(nat_ty(), type0()))
308}
309/// Every f-divergence is a Bregman divergence on exponential families
310/// Type: Prop
311pub fn f_div_is_bregman_on_exp_ty() -> Expr {
312    prop()
313}
314/// Chentsov's uniqueness theorem for f-divergences:
315/// Up to scaling, KL is the unique f-divergence invariant under sufficient statistics
316/// Type: Prop
317pub fn chentsov_uniqueness_f_div_ty() -> Expr {
318    prop()
319}
320/// Wasserstein vs Fisher-Rao: they induce different geodesics;
321/// Fisher-Rao is intrinsic, Wasserstein is extrinsic/optimal-transport
322/// Type: ∀ (n : Nat), Prop
323pub fn wasserstein_vs_fisher_rao_ty() -> Expr {
324    pi(BinderInfo::Default, "n", nat_ty(), prop())
325}
326/// Pinsker's inequality: D_KL(P ‖ Q) ≥ (1/2) ‖P - Q‖²_TV
327/// Type: Prop
328pub fn pinsker_inequality_ty() -> Expr {
329    prop()
330}
331/// `NaturalGradientDescent`: update rule θ ← θ - ε · G(θ)^{-1} ∇L(θ)
332/// where G(θ) is the Fisher information matrix
333/// Type: Nat → Type (dim → optimizer state)
334pub fn natural_gradient_descent_ty() -> Expr {
335    arrow(nat_ty(), type0())
336}
337/// `MirrorDescent`: generalized gradient descent using Bregman divergence:
338/// θ_{t+1} = argmin_{θ} {⟨∇L(θ_t), θ⟩ + (1/ε) D_φ(θ ‖ θ_t)}
339/// Type: Nat → Type
340pub fn mirror_descent_ty() -> Expr {
341    arrow(nat_ty(), type0())
342}
343/// `EMAlgorithm`: Expectation-Maximization as alternating m/e-projections:
344/// E-step: m-project posterior onto simplex; M-step: e-project onto exponential family
345/// Type: Nat → Nat → Type (latent dim → obs dim → EM state)
346pub fn em_algorithm_ty() -> Expr {
347    arrow(nat_ty(), arrow(nat_ty(), type0()))
348}
349/// Natural gradient descent converges to Fisher-efficient estimator:
350/// θ_t → θ_MLE at rate O(1/t) in Fisher metric
351/// Type: ∀ (d : Nat), Prop
352pub fn natural_gradient_convergence_ty() -> Expr {
353    pi(BinderInfo::Default, "d", nat_ty(), prop())
354}
355/// Mirror descent equals natural gradient for exponential family loss:
356/// Bregman mirror descent with φ=A* is equivalent to natural gradient on exp family
357/// Type: Prop
358pub fn mirror_descent_eq_natural_gradient_ty() -> Expr {
359    prop()
360}
361/// EM monotone convergence: log-likelihood L(θ^{(t+1)}) ≥ L(θ^{(t)})
362/// Type: ∀ (n : Nat), Prop
363pub fn em_monotone_convergence_ty() -> Expr {
364    pi(BinderInfo::Default, "n", nat_ty(), prop())
365}
366/// EM as alternating projection: E-step is m-projection, M-step is e-projection
367/// Type: Prop
368pub fn em_as_alternating_projection_ty() -> Expr {
369    prop()
370}
371/// `BeliefPropagation`: sum-product message passing on a factor graph
372/// corresponds to iterated e-projections onto local exponential families
373/// Type: Nat → Nat → Type (nodes → factors → BP state)
374pub fn belief_propagation_ty() -> Expr {
375    arrow(nat_ty(), arrow(nat_ty(), type0()))
376}
377/// `TreeReweightedBP`: TRW-BP minimizes a variational Bethe free energy
378/// Type: Nat → Type
379pub fn tree_reweighted_bp_ty() -> Expr {
380    arrow(nat_ty(), type0())
381}
382/// Belief propagation fixed point: BP fixed points are stationary points of Bethe free energy
383/// Type: ∀ (n : Nat), Prop
384pub fn bp_fixed_point_bethe_ty() -> Expr {
385    pi(BinderInfo::Default, "n", nat_ty(), prop())
386}
387/// On a tree, BP converges to exact marginals (equals e-projection)
388/// Type: ∀ (n : Nat), Prop
389pub fn bp_exact_on_tree_ty() -> Expr {
390    pi(BinderInfo::Default, "n", nat_ty(), prop())
391}
392/// `SanovTheorem`: rate function for empirical distribution is the KL divergence
393/// P(L_n ∈ E) ≈ exp(-n · inf_{q ∈ E} D_KL(q ‖ p))
394/// Type: Nat → Type (sample size → large-deviation event)
395pub fn sanov_theorem_ty() -> Expr {
396    arrow(nat_ty(), type0())
397}
398/// `RateFunction`: I(q) = D_KL(q ‖ p₀) for the Sanov rate
399/// Type: List Real → Real
400pub fn rate_function_ty() -> Expr {
401    arrow(list_ty(real_ty()), real_ty())
402}
403/// Sanov's theorem: D_KL is the unique rate function for empirical distributions
404/// Type: ∀ (n : Nat), Prop
405pub fn sanov_kl_rate_function_ty() -> Expr {
406    pi(BinderInfo::Default, "n", nat_ty(), prop())
407}
408/// Contraction principle: rate function of a smooth map φ is I ∘ φ^{-1}
409/// Type: ∀ (n : Nat), Prop
410pub fn contraction_principle_ty() -> Expr {
411    pi(BinderInfo::Default, "n", nat_ty(), prop())
412}
413/// `QuantumStatisticalManifold`: manifold of density matrices ρ(θ) on a Hilbert space H
414/// Type: Nat → Nat → Type (dim-H → param-dim → manifold)
415pub fn quantum_statistical_manifold_ty() -> Expr {
416    arrow(nat_ty(), arrow(nat_ty(), type0()))
417}
418/// `SLDMetric`: Symmetric Logarithmic Derivative (SLD) Fisher metric on quantum states;
419/// the quantum analogue of Fisher-Rao: g_{ij}^{SLD} = (1/2) Tr[ρ {L_i, L_j}]
420/// Type: Nat → Nat → Type (Hilbert-dim → param-dim → metric)
421pub fn sld_metric_ty() -> Expr {
422    arrow(nat_ty(), arrow(nat_ty(), type0()))
423}
424/// `RLDMetric`: Right Logarithmic Derivative metric on quantum states
425/// Type: Nat → Nat → Type
426pub fn rld_metric_ty() -> Expr {
427    arrow(nat_ty(), arrow(nat_ty(), type0()))
428}
429/// `QuantumRelativeEntropy`: S(ρ ‖ σ) = Tr[ρ (log ρ - log σ)] (von Neumann relative entropy)
430/// Type: Nat → Type (dim → relative-entropy operator)
431pub fn quantum_relative_entropy_ty() -> Expr {
432    arrow(nat_ty(), type0())
433}
434/// Quantum Cramér-Rao bound: Var(θ̂) ≥ 1 / (n · g^{SLD}(θ))
435/// Type: ∀ (d : Nat), Prop
436pub fn quantum_cramer_rao_ty() -> Expr {
437    pi(BinderInfo::Default, "d", nat_ty(), prop())
438}
439/// SLD metric contracts under quantum channels (monotonicity under CPTP maps)
440/// Type: Prop
441pub fn sld_monotonicity_ty() -> Expr {
442    prop()
443}
444/// Uhlmann's theorem: geometric phase = arc cos of fidelity F(ρ,σ) = Tr[√(√ρ σ √ρ)]
445/// Type: ∀ (n : Nat), Prop
446pub fn uhlmann_theorem_ty() -> Expr {
447    pi(BinderInfo::Default, "n", nat_ty(), prop())
448}
449/// Quantum Stein's lemma: optimal exponent for quantum hypothesis testing is D_KL(ρ ‖ σ)
450/// Type: Prop
451pub fn quantum_stein_lemma_ty() -> Expr {
452    prop()
453}
454/// `ItoGirsanovIG`: Girsanov's theorem viewed as a change of measure in IG:
455/// the Radon-Nikodym derivative exp(∫ h dW - (1/2) ∫ h² dt) is a path-space exponential family
456/// Type: Nat → Type (dim → process)
457pub fn ito_girsanov_ig_ty() -> Expr {
458    arrow(nat_ty(), type0())
459}
460/// `FokkerPlanckIG`: Fokker-Planck equation as a gradient flow on the manifold of densities
461/// under the Fisher-Rao metric
462/// Type: Nat → Type
463pub fn fokker_planck_ig_ty() -> Expr {
464    arrow(nat_ty(), type0())
465}
466/// Girsanov change-of-measure as e-geodesic in path space:
467/// p^h(x) = exp(∫ h dx - A(h)) p^0(x) is an e-family parametrized by h
468/// Type: ∀ (d : Nat), Prop
469pub fn girsanov_e_geodesic_ty() -> Expr {
470    pi(BinderInfo::Default, "d", nat_ty(), prop())
471}
472/// Otto calculus: Fokker-Planck is gradient flow of KL divergence in Wasserstein geometry
473/// Type: ∀ (d : Nat), Prop
474pub fn otto_calculus_gradient_flow_ty() -> Expr {
475    pi(BinderInfo::Default, "d", nat_ty(), prop())
476}
477/// Register all information geometry axioms and theorems in the kernel environment.
478pub fn build_env(env: &mut Environment) -> Result<(), String> {
479    let axioms: &[(&str, Expr)] = &[
480        ("StatisticalManifold", statistical_manifold_ty()),
481        ("FisherInformationMetric", fisher_information_metric_ty()),
482        ("RiemannianMetric", riemannian_metric_ty()),
483        ("GeodesicOfDistributions", geodesic_of_distributions_ty()),
484        ("chentsov_theorem", chentsov_theorem_ty()),
485        ("geodesic_distance_formula", geodesic_distance_formula_ty()),
486        ("sectional_curvature", sectional_curvature_ty()),
487        ("christoffel_symbols", christoffel_symbols_ty()),
488        ("ExponentialFamily", exponential_family_ty()),
489        ("NaturalParameter", natural_parameter_ty()),
490        ("MomentParameter", moment_parameter_ty()),
491        ("LegendreTransform", legendre_transform_ty()),
492        ("LogPartitionFunction", log_partition_function_ty()),
493        ("natural_to_moment", natural_to_moment_ty()),
494        ("bregman_divergence", bregman_divergence_ty()),
495        ("fisher_as_hessian", fisher_as_hessian_ty()),
496        ("kl_equals_bregman", kl_equals_bregman_ty()),
497        ("AlphaConnection", alpha_connection_ty()),
498        ("AlphaDivergence", alpha_divergence_ty()),
499        ("DualConnection", dual_connection_ty()),
500        (
501            "ConstantCurvatureManifold",
502            constant_curvature_manifold_ty(),
503        ),
504        ("alpha_duality_theorem", alpha_duality_theorem_ty()),
505        ("generalized_pythagoras", generalized_pythagoras_ty()),
506        ("alpha_divergence_limits", alpha_divergence_limits_ty()),
507        ("curvature_formula", curvature_formula_ty()),
508        ("BayesianEstimation", bayesian_estimation_ty()),
509        ("JeffreysPrior", jeffreys_prior_ty()),
510        ("ReferenceAnalysis", reference_analysis_ty()),
511        ("ExpectationPropagation", expectation_propagation_ty()),
512        ("jeffreys_invariance", jeffreys_invariance_ty()),
513        ("bernstein_von_mises", bernstein_von_mises_ty()),
514        ("ep_fixed_point", ep_fixed_point_ty()),
515        ("laplace_approximation", laplace_approximation_ty()),
516        ("FisherRaoMetric", fisher_rao_metric_ty()),
517        ("EConnection", e_connection_ty()),
518        ("MConnection", m_connection_ty()),
519        ("EProjection", e_projection_ty()),
520        ("MProjection", m_projection_ty()),
521        ("pythagorean_theorem_ig", pythagorean_theorem_ig_ty()),
522        ("e_flat_exponential_family", e_flat_exponential_family_ty()),
523        ("m_flat_mixture_family", m_flat_mixture_family_ty()),
524        ("legendre_duality", legendre_duality_ty()),
525        ("FDivergence", f_divergence_ty()),
526        ("BregmanDivergenceGen", bregman_divergence_gen_ty()),
527        ("WassersteinMetric", wasserstein_metric_ty()),
528        ("f_div_is_bregman_on_exp", f_div_is_bregman_on_exp_ty()),
529        ("chentsov_uniqueness_f_div", chentsov_uniqueness_f_div_ty()),
530        ("wasserstein_vs_fisher_rao", wasserstein_vs_fisher_rao_ty()),
531        ("pinsker_inequality", pinsker_inequality_ty()),
532        ("NaturalGradientDescent", natural_gradient_descent_ty()),
533        ("MirrorDescent", mirror_descent_ty()),
534        ("EMAlgorithm", em_algorithm_ty()),
535        (
536            "natural_gradient_convergence",
537            natural_gradient_convergence_ty(),
538        ),
539        (
540            "mirror_descent_eq_natural_gradient",
541            mirror_descent_eq_natural_gradient_ty(),
542        ),
543        ("em_monotone_convergence", em_monotone_convergence_ty()),
544        (
545            "em_as_alternating_projection",
546            em_as_alternating_projection_ty(),
547        ),
548        ("BeliefPropagation", belief_propagation_ty()),
549        ("TreeReweightedBP", tree_reweighted_bp_ty()),
550        ("bp_fixed_point_bethe", bp_fixed_point_bethe_ty()),
551        ("bp_exact_on_tree", bp_exact_on_tree_ty()),
552        ("SanovTheorem", sanov_theorem_ty()),
553        ("RateFunction", rate_function_ty()),
554        ("sanov_kl_rate_function", sanov_kl_rate_function_ty()),
555        ("contraction_principle", contraction_principle_ty()),
556        (
557            "QuantumStatisticalManifold",
558            quantum_statistical_manifold_ty(),
559        ),
560        ("SLDMetric", sld_metric_ty()),
561        ("RLDMetric", rld_metric_ty()),
562        ("QuantumRelativeEntropy", quantum_relative_entropy_ty()),
563        ("quantum_cramer_rao", quantum_cramer_rao_ty()),
564        ("sld_monotonicity", sld_monotonicity_ty()),
565        ("uhlmann_theorem", uhlmann_theorem_ty()),
566        ("quantum_stein_lemma", quantum_stein_lemma_ty()),
567        ("ItoGirsanovIG", ito_girsanov_ig_ty()),
568        ("FokkerPlanckIG", fokker_planck_ig_ty()),
569        ("girsanov_e_geodesic", girsanov_e_geodesic_ty()),
570        (
571            "otto_calculus_gradient_flow",
572            otto_calculus_gradient_flow_ty(),
573        ),
574    ];
575    for (name, ty) in axioms {
576        env.add(Declaration::Axiom {
577            name: Name::str(*name),
578            univ_params: vec![],
579            ty: ty.clone(),
580        })
581        .ok();
582    }
583    Ok(())
584}
585/// Dot product of two equal-length slices.
586pub fn dot_product(a: &[f64], b: &[f64]) -> f64 {
587    a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
588}
589/// Matrix-vector product: returns A * v where A is d×d (row-major) and v is d.
590pub fn mat_vec(a: &[Vec<f64>], v: &[f64]) -> Vec<f64> {
591    a.iter().map(|row| dot_product(row, v)).collect()
592}
593/// Solve a d×d linear system Ax = b using Gaussian elimination with partial pivoting.
594pub fn solve_linear_system(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
595    let d = b.len();
596    let mut mat: Vec<Vec<f64>> = a.to_vec();
597    let mut rhs: Vec<f64> = b.to_vec();
598    for col in 0..d {
599        let pivot = (col..d)
600            .max_by(|&i, &j| {
601                mat[i][col]
602                    .abs()
603                    .partial_cmp(&mat[j][col].abs())
604                    .unwrap_or(std::cmp::Ordering::Equal)
605            })
606            .unwrap_or(col);
607        mat.swap(col, pivot);
608        rhs.swap(col, pivot);
609        let diag = mat[col][col];
610        if diag.abs() < 1e-14 {
611            continue;
612        }
613        for row in (col + 1)..d {
614            let factor = mat[row][col] / diag;
615            for k in col..d {
616                let val = mat[col][k];
617                mat[row][k] -= factor * val;
618            }
619            rhs[row] -= factor * rhs[col];
620        }
621    }
622    let mut x = vec![0.0f64; d];
623    for i in (0..d).rev() {
624        let mut s = rhs[i];
625        for j in (i + 1)..d {
626            s -= mat[i][j] * x[j];
627        }
628        x[i] = if mat[i][i].abs() < 1e-14 {
629            0.0
630        } else {
631            s / mat[i][i]
632        };
633    }
634    x
635}
636#[cfg(test)]
637mod ig_ext_tests {
638    use super::*;
639    #[test]
640    fn test_statistical_manifold() {
641        let exp = StatManiMid::exponential_family("Normal", 2);
642        assert!(exp.is_dually_flat());
643        assert!(!exp.alpha_divergence_description().is_empty());
644    }
645    #[test]
646    fn test_natural_gradient() {
647        let ng = NatGradMid::new(10, 0.01);
648        assert!(!ng.update_rule().is_empty());
649        assert!(!ng.invariance_property().is_empty());
650    }
651    #[test]
652    fn test_alpha_divergence() {
653        let kl = AlphaDivMid::kl_divergence("p", "q");
654        assert!(kl.is_kl());
655    }
656    #[test]
657    fn test_bregman_divergence() {
658        let bd = BregmanDivergence::squared_euclidean();
659        assert!(!bd.definition().is_empty());
660        assert!(!bd.three_point_property().is_empty());
661    }
662    #[test]
663    fn test_wasserstein() {
664        let w = WassersteinGeometry::new(2, "R^d");
665        assert!(!w.w2_distance_description().is_empty());
666        assert!(!w.benamou_brenier_description().is_empty());
667    }
668}
669#[cfg(test)]
670mod gp_expfam_tests {
671    use super::*;
672    #[test]
673    fn test_gaussian_process() {
674        let gp = GaussianProcess::rbf(1.0);
675        assert!(gp.is_stationary);
676        assert!(!gp.posterior_description().is_empty());
677    }
678    #[test]
679    fn test_exponential_family() {
680        let gauss = ExponentialFamilyDistrib::gaussian(2);
681        assert!(gauss.mle_equals_moment_matching());
682        assert!(!gauss.natural_to_moment_params().is_empty());
683    }
684}
685#[cfg(test)]
686mod tests_info_geom_ext {
687    use super::*;
688    #[test]
689    fn test_natural_gradient() {
690        let ng = NatGradExt::new(10);
691        let update = ng.update_rule(0.01);
692        assert!(update.contains("Natural gradient"));
693        let fr = ng.fisher_rao_distance();
694        assert!(fr.contains("Fisher-Rao"));
695        let amari = ng.amari_dual_connection();
696        assert!(amari.contains("α-connection"));
697        let inv = ng.invariance_property();
698        assert!(inv.contains("Fisher-Rao"));
699    }
700    #[test]
701    fn test_statistical_manifold() {
702        let gauss = StatManiExt::gaussian_family();
703        assert!(gauss.is_dually_flat);
704        assert_eq!(gauss.dimension, 2);
705        let pyth = gauss.pythagorean_theorem();
706        assert!(pyth.contains("Pythagoras"));
707        let bregman = gauss.bregman_divergence_connection();
708        assert!(bregman.contains("Bregman"));
709    }
710    #[test]
711    fn test_sliced_wasserstein() {
712        let sw = SlicedWasserstein::new(10, 100);
713        let desc = sw.complexity_description();
714        assert!(desc.contains("Sliced"));
715        let bonneel = sw.bonneel_et_al_description();
716        assert!(bonneel.contains("sliced Wasserstein"));
717    }
718    #[test]
719    fn test_schroedinger_bridge() {
720        let sb = SchroedingerBridge::new("P", "Q", "BM", 0.01);
721        let sink = sb.sinkhorn_algorithm();
722        assert!(sink.contains("Sinkhorn"));
723        let ipfp = sb.ipfp_iteration();
724        assert!(ipfp.contains("IPFP"));
725        let diff = sb.connection_to_diffusion_models();
726        assert!(diff.contains("diffusion"));
727    }
728    #[test]
729    fn test_quantum_info_geom() {
730        let bures = QuantumInfoGeometry::bures_metric(4);
731        assert!(bures.is_monotone_metric);
732        let petz = bures.petz_classification();
733        assert!(petz.contains("Petz"));
734        let qcr = bures.quantum_cramer_rao();
735        assert!(qcr.contains("Cramér-Rao"));
736        let holevo = bures.holevo_bound();
737        assert!(holevo.contains("Holevo"));
738        let bures_dist = bures.bures_distance(1.0);
739        assert!((bures_dist - 0.0).abs() < 1e-10);
740        let bures_dist2 = bures.bures_distance(0.0);
741        assert!((bures_dist2 - 2.0_f64.sqrt()).abs() < 1e-10);
742    }
743}