use oxilean_kernel::{BinderInfo, Declaration, Environment, Expr, Level, Name};
use super::types::{
AlphaDivMid, AlphaDivergence, BayesianEstimation, BeliefPropagation, BregmanDivergence,
ConstantCurvatureManifold, DualConnection, ExpectationPropagation, ExponentialFamily,
ExponentialFamilyDistrib, FisherInformationMetric, GaussianProcess, GeodesicOfDistributions,
JeffreysPrior, LegendreTransform, MirrorDescent, MomentParameter, NatGradExt, NatGradMid,
NaturalParameter, QuantumInfoGeometry, ReferenceAnalysis, SchroedingerBridge,
SlicedWasserstein, StatManiExt, StatManiMid, StatisticalManifold, WassersteinGeometry,
};
pub fn app(f: Expr, a: Expr) -> Expr {
Expr::App(Box::new(f), Box::new(a))
}
pub fn app2(f: Expr, a: Expr, b: Expr) -> Expr {
app(app(f, a), b)
}
pub fn app3(f: Expr, a: Expr, b: Expr, c: Expr) -> Expr {
app(app2(f, a, b), c)
}
pub fn cst(s: &str) -> Expr {
Expr::Const(Name::str(s), vec![])
}
pub fn prop() -> Expr {
Expr::Sort(Level::zero())
}
pub fn type0() -> Expr {
Expr::Sort(Level::succ(Level::zero()))
}
pub fn pi(bi: BinderInfo, name: &str, dom: Expr, body: Expr) -> Expr {
Expr::Pi(bi, Name::str(name), Box::new(dom), Box::new(body))
}
pub fn arrow(a: Expr, b: Expr) -> Expr {
pi(BinderInfo::Default, "_", a, b)
}
pub fn bvar(n: u32) -> Expr {
Expr::BVar(n)
}
pub fn nat_ty() -> Expr {
cst("Nat")
}
pub fn real_ty() -> Expr {
cst("Real")
}
pub fn list_ty(elem: Expr) -> Expr {
app(cst("List"), elem)
}
pub fn statistical_manifold_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn fisher_information_metric_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn riemannian_metric_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn geodesic_of_distributions_ty() -> Expr {
arrow(type0(), arrow(type0(), type0()))
}
pub fn chentsov_theorem_ty() -> Expr {
prop()
}
pub fn geodesic_distance_formula_ty() -> Expr {
pi(BinderInfo::Default, "n", nat_ty(), prop())
}
pub fn sectional_curvature_ty() -> Expr {
pi(BinderInfo::Default, "n", nat_ty(), real_ty())
}
pub fn christoffel_symbols_ty() -> Expr {
arrow(nat_ty(), arrow(nat_ty(), type0()))
}
pub fn exponential_family_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn natural_parameter_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn moment_parameter_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn legendre_transform_ty() -> Expr {
arrow(
arrow(list_ty(real_ty()), real_ty()),
arrow(list_ty(real_ty()), real_ty()),
)
}
pub fn log_partition_function_ty() -> Expr {
arrow(list_ty(real_ty()), real_ty())
}
pub fn natural_to_moment_ty() -> Expr {
pi(BinderInfo::Default, "d", nat_ty(), prop())
}
pub fn bregman_divergence_ty() -> Expr {
pi(BinderInfo::Default, "d", nat_ty(), prop())
}
pub fn fisher_as_hessian_ty() -> Expr {
pi(BinderInfo::Default, "d", nat_ty(), prop())
}
pub fn kl_equals_bregman_ty() -> Expr {
prop()
}
pub fn alpha_connection_ty() -> Expr {
arrow(real_ty(), arrow(nat_ty(), type0()))
}
pub fn alpha_divergence_ty() -> Expr {
arrow(
real_ty(),
arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty())),
)
}
pub fn dual_connection_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn constant_curvature_manifold_ty() -> Expr {
arrow(real_ty(), arrow(nat_ty(), type0()))
}
pub fn alpha_duality_theorem_ty() -> Expr {
pi(
BinderInfo::Default,
"alpha",
real_ty(),
pi(BinderInfo::Default, "n", nat_ty(), prop()),
)
}
pub fn generalized_pythagoras_ty() -> Expr {
pi(BinderInfo::Default, "n", nat_ty(), prop())
}
pub fn alpha_divergence_limits_ty() -> Expr {
prop()
}
pub fn curvature_formula_ty() -> Expr {
pi(BinderInfo::Default, "alpha", real_ty(), real_ty())
}
pub fn bayesian_estimation_ty() -> Expr {
arrow(
arrow(real_ty(), real_ty()),
arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty())),
)
}
pub fn jeffreys_prior_ty() -> Expr {
arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty()))
}
pub fn reference_analysis_ty() -> Expr {
arrow(arrow(real_ty(), real_ty()), arrow(real_ty(), real_ty()))
}
pub fn expectation_propagation_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn jeffreys_invariance_ty() -> Expr {
prop()
}
pub fn bernstein_von_mises_ty() -> Expr {
pi(BinderInfo::Default, "n", nat_ty(), prop())
}
pub fn ep_fixed_point_ty() -> Expr {
prop()
}
pub fn laplace_approximation_ty() -> Expr {
pi(BinderInfo::Default, "n", nat_ty(), prop())
}
pub fn fisher_rao_metric_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn e_connection_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn m_connection_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn e_projection_ty() -> Expr {
arrow(nat_ty(), arrow(type0(), type0()))
}
pub fn m_projection_ty() -> Expr {
arrow(nat_ty(), arrow(type0(), type0()))
}
pub fn pythagorean_theorem_ig_ty() -> Expr {
pi(BinderInfo::Default, "n", nat_ty(), prop())
}
pub fn e_flat_exponential_family_ty() -> Expr {
pi(BinderInfo::Default, "d", nat_ty(), prop())
}
pub fn m_flat_mixture_family_ty() -> Expr {
pi(BinderInfo::Default, "d", nat_ty(), prop())
}
pub fn legendre_duality_ty() -> Expr {
pi(BinderInfo::Default, "d", nat_ty(), prop())
}
pub fn f_divergence_ty() -> Expr {
arrow(
arrow(real_ty(), real_ty()),
arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty())),
)
}
pub fn bregman_divergence_gen_ty() -> Expr {
arrow(
arrow(list_ty(real_ty()), real_ty()),
arrow(list_ty(real_ty()), arrow(list_ty(real_ty()), real_ty())),
)
}
pub fn wasserstein_metric_ty() -> Expr {
arrow(real_ty(), arrow(nat_ty(), type0()))
}
pub fn f_div_is_bregman_on_exp_ty() -> Expr {
prop()
}
pub fn chentsov_uniqueness_f_div_ty() -> Expr {
prop()
}
pub fn wasserstein_vs_fisher_rao_ty() -> Expr {
pi(BinderInfo::Default, "n", nat_ty(), prop())
}
pub fn pinsker_inequality_ty() -> Expr {
prop()
}
pub fn natural_gradient_descent_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn mirror_descent_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn em_algorithm_ty() -> Expr {
arrow(nat_ty(), arrow(nat_ty(), type0()))
}
pub fn natural_gradient_convergence_ty() -> Expr {
pi(BinderInfo::Default, "d", nat_ty(), prop())
}
pub fn mirror_descent_eq_natural_gradient_ty() -> Expr {
prop()
}
pub fn em_monotone_convergence_ty() -> Expr {
pi(BinderInfo::Default, "n", nat_ty(), prop())
}
pub fn em_as_alternating_projection_ty() -> Expr {
prop()
}
pub fn belief_propagation_ty() -> Expr {
arrow(nat_ty(), arrow(nat_ty(), type0()))
}
pub fn tree_reweighted_bp_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn bp_fixed_point_bethe_ty() -> Expr {
pi(BinderInfo::Default, "n", nat_ty(), prop())
}
pub fn bp_exact_on_tree_ty() -> Expr {
pi(BinderInfo::Default, "n", nat_ty(), prop())
}
pub fn sanov_theorem_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn rate_function_ty() -> Expr {
arrow(list_ty(real_ty()), real_ty())
}
pub fn sanov_kl_rate_function_ty() -> Expr {
pi(BinderInfo::Default, "n", nat_ty(), prop())
}
pub fn contraction_principle_ty() -> Expr {
pi(BinderInfo::Default, "n", nat_ty(), prop())
}
pub fn quantum_statistical_manifold_ty() -> Expr {
arrow(nat_ty(), arrow(nat_ty(), type0()))
}
pub fn sld_metric_ty() -> Expr {
arrow(nat_ty(), arrow(nat_ty(), type0()))
}
pub fn rld_metric_ty() -> Expr {
arrow(nat_ty(), arrow(nat_ty(), type0()))
}
pub fn quantum_relative_entropy_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn quantum_cramer_rao_ty() -> Expr {
pi(BinderInfo::Default, "d", nat_ty(), prop())
}
pub fn sld_monotonicity_ty() -> Expr {
prop()
}
pub fn uhlmann_theorem_ty() -> Expr {
pi(BinderInfo::Default, "n", nat_ty(), prop())
}
pub fn quantum_stein_lemma_ty() -> Expr {
prop()
}
pub fn ito_girsanov_ig_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn fokker_planck_ig_ty() -> Expr {
arrow(nat_ty(), type0())
}
pub fn girsanov_e_geodesic_ty() -> Expr {
pi(BinderInfo::Default, "d", nat_ty(), prop())
}
pub fn otto_calculus_gradient_flow_ty() -> Expr {
pi(BinderInfo::Default, "d", nat_ty(), prop())
}
pub fn build_env(env: &mut Environment) -> Result<(), String> {
let axioms: &[(&str, Expr)] = &[
("StatisticalManifold", statistical_manifold_ty()),
("FisherInformationMetric", fisher_information_metric_ty()),
("RiemannianMetric", riemannian_metric_ty()),
("GeodesicOfDistributions", geodesic_of_distributions_ty()),
("chentsov_theorem", chentsov_theorem_ty()),
("geodesic_distance_formula", geodesic_distance_formula_ty()),
("sectional_curvature", sectional_curvature_ty()),
("christoffel_symbols", christoffel_symbols_ty()),
("ExponentialFamily", exponential_family_ty()),
("NaturalParameter", natural_parameter_ty()),
("MomentParameter", moment_parameter_ty()),
("LegendreTransform", legendre_transform_ty()),
("LogPartitionFunction", log_partition_function_ty()),
("natural_to_moment", natural_to_moment_ty()),
("bregman_divergence", bregman_divergence_ty()),
("fisher_as_hessian", fisher_as_hessian_ty()),
("kl_equals_bregman", kl_equals_bregman_ty()),
("AlphaConnection", alpha_connection_ty()),
("AlphaDivergence", alpha_divergence_ty()),
("DualConnection", dual_connection_ty()),
(
"ConstantCurvatureManifold",
constant_curvature_manifold_ty(),
),
("alpha_duality_theorem", alpha_duality_theorem_ty()),
("generalized_pythagoras", generalized_pythagoras_ty()),
("alpha_divergence_limits", alpha_divergence_limits_ty()),
("curvature_formula", curvature_formula_ty()),
("BayesianEstimation", bayesian_estimation_ty()),
("JeffreysPrior", jeffreys_prior_ty()),
("ReferenceAnalysis", reference_analysis_ty()),
("ExpectationPropagation", expectation_propagation_ty()),
("jeffreys_invariance", jeffreys_invariance_ty()),
("bernstein_von_mises", bernstein_von_mises_ty()),
("ep_fixed_point", ep_fixed_point_ty()),
("laplace_approximation", laplace_approximation_ty()),
("FisherRaoMetric", fisher_rao_metric_ty()),
("EConnection", e_connection_ty()),
("MConnection", m_connection_ty()),
("EProjection", e_projection_ty()),
("MProjection", m_projection_ty()),
("pythagorean_theorem_ig", pythagorean_theorem_ig_ty()),
("e_flat_exponential_family", e_flat_exponential_family_ty()),
("m_flat_mixture_family", m_flat_mixture_family_ty()),
("legendre_duality", legendre_duality_ty()),
("FDivergence", f_divergence_ty()),
("BregmanDivergenceGen", bregman_divergence_gen_ty()),
("WassersteinMetric", wasserstein_metric_ty()),
("f_div_is_bregman_on_exp", f_div_is_bregman_on_exp_ty()),
("chentsov_uniqueness_f_div", chentsov_uniqueness_f_div_ty()),
("wasserstein_vs_fisher_rao", wasserstein_vs_fisher_rao_ty()),
("pinsker_inequality", pinsker_inequality_ty()),
("NaturalGradientDescent", natural_gradient_descent_ty()),
("MirrorDescent", mirror_descent_ty()),
("EMAlgorithm", em_algorithm_ty()),
(
"natural_gradient_convergence",
natural_gradient_convergence_ty(),
),
(
"mirror_descent_eq_natural_gradient",
mirror_descent_eq_natural_gradient_ty(),
),
("em_monotone_convergence", em_monotone_convergence_ty()),
(
"em_as_alternating_projection",
em_as_alternating_projection_ty(),
),
("BeliefPropagation", belief_propagation_ty()),
("TreeReweightedBP", tree_reweighted_bp_ty()),
("bp_fixed_point_bethe", bp_fixed_point_bethe_ty()),
("bp_exact_on_tree", bp_exact_on_tree_ty()),
("SanovTheorem", sanov_theorem_ty()),
("RateFunction", rate_function_ty()),
("sanov_kl_rate_function", sanov_kl_rate_function_ty()),
("contraction_principle", contraction_principle_ty()),
(
"QuantumStatisticalManifold",
quantum_statistical_manifold_ty(),
),
("SLDMetric", sld_metric_ty()),
("RLDMetric", rld_metric_ty()),
("QuantumRelativeEntropy", quantum_relative_entropy_ty()),
("quantum_cramer_rao", quantum_cramer_rao_ty()),
("sld_monotonicity", sld_monotonicity_ty()),
("uhlmann_theorem", uhlmann_theorem_ty()),
("quantum_stein_lemma", quantum_stein_lemma_ty()),
("ItoGirsanovIG", ito_girsanov_ig_ty()),
("FokkerPlanckIG", fokker_planck_ig_ty()),
("girsanov_e_geodesic", girsanov_e_geodesic_ty()),
(
"otto_calculus_gradient_flow",
otto_calculus_gradient_flow_ty(),
),
];
for (name, ty) in axioms {
env.add(Declaration::Axiom {
name: Name::str(*name),
univ_params: vec![],
ty: ty.clone(),
})
.ok();
}
Ok(())
}
pub fn dot_product(a: &[f64], b: &[f64]) -> f64 {
a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
}
pub fn mat_vec(a: &[Vec<f64>], v: &[f64]) -> Vec<f64> {
a.iter().map(|row| dot_product(row, v)).collect()
}
pub fn solve_linear_system(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
let d = b.len();
let mut mat: Vec<Vec<f64>> = a.to_vec();
let mut rhs: Vec<f64> = b.to_vec();
for col in 0..d {
let pivot = (col..d)
.max_by(|&i, &j| {
mat[i][col]
.abs()
.partial_cmp(&mat[j][col].abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(col);
mat.swap(col, pivot);
rhs.swap(col, pivot);
let diag = mat[col][col];
if diag.abs() < 1e-14 {
continue;
}
for row in (col + 1)..d {
let factor = mat[row][col] / diag;
for k in col..d {
let val = mat[col][k];
mat[row][k] -= factor * val;
}
rhs[row] -= factor * rhs[col];
}
}
let mut x = vec![0.0f64; d];
for i in (0..d).rev() {
let mut s = rhs[i];
for j in (i + 1)..d {
s -= mat[i][j] * x[j];
}
x[i] = if mat[i][i].abs() < 1e-14 {
0.0
} else {
s / mat[i][i]
};
}
x
}
#[cfg(test)]
mod ig_ext_tests {
use super::*;
#[test]
fn test_statistical_manifold() {
let exp = StatManiMid::exponential_family("Normal", 2);
assert!(exp.is_dually_flat());
assert!(!exp.alpha_divergence_description().is_empty());
}
#[test]
fn test_natural_gradient() {
let ng = NatGradMid::new(10, 0.01);
assert!(!ng.update_rule().is_empty());
assert!(!ng.invariance_property().is_empty());
}
#[test]
fn test_alpha_divergence() {
let kl = AlphaDivMid::kl_divergence("p", "q");
assert!(kl.is_kl());
}
#[test]
fn test_bregman_divergence() {
let bd = BregmanDivergence::squared_euclidean();
assert!(!bd.definition().is_empty());
assert!(!bd.three_point_property().is_empty());
}
#[test]
fn test_wasserstein() {
let w = WassersteinGeometry::new(2, "R^d");
assert!(!w.w2_distance_description().is_empty());
assert!(!w.benamou_brenier_description().is_empty());
}
}
#[cfg(test)]
mod gp_expfam_tests {
use super::*;
#[test]
fn test_gaussian_process() {
let gp = GaussianProcess::rbf(1.0);
assert!(gp.is_stationary);
assert!(!gp.posterior_description().is_empty());
}
#[test]
fn test_exponential_family() {
let gauss = ExponentialFamilyDistrib::gaussian(2);
assert!(gauss.mle_equals_moment_matching());
assert!(!gauss.natural_to_moment_params().is_empty());
}
}
#[cfg(test)]
mod tests_info_geom_ext {
use super::*;
#[test]
fn test_natural_gradient() {
let ng = NatGradExt::new(10);
let update = ng.update_rule(0.01);
assert!(update.contains("Natural gradient"));
let fr = ng.fisher_rao_distance();
assert!(fr.contains("Fisher-Rao"));
let amari = ng.amari_dual_connection();
assert!(amari.contains("α-connection"));
let inv = ng.invariance_property();
assert!(inv.contains("Fisher-Rao"));
}
#[test]
fn test_statistical_manifold() {
let gauss = StatManiExt::gaussian_family();
assert!(gauss.is_dually_flat);
assert_eq!(gauss.dimension, 2);
let pyth = gauss.pythagorean_theorem();
assert!(pyth.contains("Pythagoras"));
let bregman = gauss.bregman_divergence_connection();
assert!(bregman.contains("Bregman"));
}
#[test]
fn test_sliced_wasserstein() {
let sw = SlicedWasserstein::new(10, 100);
let desc = sw.complexity_description();
assert!(desc.contains("Sliced"));
let bonneel = sw.bonneel_et_al_description();
assert!(bonneel.contains("sliced Wasserstein"));
}
#[test]
fn test_schroedinger_bridge() {
let sb = SchroedingerBridge::new("P", "Q", "BM", 0.01);
let sink = sb.sinkhorn_algorithm();
assert!(sink.contains("Sinkhorn"));
let ipfp = sb.ipfp_iteration();
assert!(ipfp.contains("IPFP"));
let diff = sb.connection_to_diffusion_models();
assert!(diff.contains("diffusion"));
}
#[test]
fn test_quantum_info_geom() {
let bures = QuantumInfoGeometry::bures_metric(4);
assert!(bures.is_monotone_metric);
let petz = bures.petz_classification();
assert!(petz.contains("Petz"));
let qcr = bures.quantum_cramer_rao();
assert!(qcr.contains("Cramér-Rao"));
let holevo = bures.holevo_bound();
assert!(holevo.contains("Holevo"));
let bures_dist = bures.bures_distance(1.0);
assert!((bures_dist - 0.0).abs() < 1e-10);
let bures_dist2 = bures.bures_distance(0.0);
assert!((bures_dist2 - 2.0_f64.sqrt()).abs() < 1e-10);
}
}