Skip to main content

oxiphysics_core/
information_geometry.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Information geometry and statistical manifolds.
6//!
7//! This module provides tools for studying the geometry of families of
8//! probability distributions, including Fisher information, geodesics,
9//! natural gradients, and divergence measures.
10
11#![allow(dead_code)]
12
13// ─────────────────────────────────────────────────────────────────────────────
14// StatisticalManifold
15// ─────────────────────────────────────────────────────────────────────────────
16
17/// A statistical manifold: a smooth manifold whose points are probability
18/// distributions parameterized by `dim` real parameters.
19///
20/// Provides the Fisher information metric, geodesics, and Christoffel symbols.
21#[derive(Debug, Clone)]
22pub struct StatisticalManifold {
23    /// Dimension of the parameter space.
24    pub dim: usize,
25}
26
27impl StatisticalManifold {
28    /// Create a new `StatisticalManifold` of the given dimension.
29    pub fn new(dim: usize) -> Self {
30        Self { dim }
31    }
32
33    /// Compute the Fisher information metric (matrix) at `params`.
34    ///
35    /// Uses a finite-difference approximation of the log-likelihood Hessian.
36    /// Returns a `dim × dim` positive-semidefinite matrix.
37    pub fn fisher_metric(&self, params: &[f64]) -> Vec<Vec<f64>> {
38        let n = self.dim;
39        let h = 1e-5;
40        let mut g = vec![vec![0.0f64; n]; n];
41        for i in 0..n {
42            for j in i..n {
43                // Numerical approximation of E[∂_i log p · ∂_j log p].
44                let mut pp = params.to_vec();
45                let mut pm = params.to_vec();
46                let mut mp = params.to_vec();
47                let mut mm = params.to_vec();
48                pp[i] += h;
49                pp[j] += h;
50                pm[i] += h;
51                pm[j] -= h;
52                mp[i] -= h;
53                mp[j] += h;
54                mm[i] -= h;
55                mm[j] -= h;
56                let val = (log_likelihood_approx(&pp)
57                    - log_likelihood_approx(&pm)
58                    - log_likelihood_approx(&mp)
59                    + log_likelihood_approx(&mm))
60                    / (4.0 * h * h);
61                g[i][j] = -val;
62                g[j][i] = -val;
63            }
64        }
65        g
66    }
67
68    /// Compute the geodesic between parameter points `p` and `q` at time `t ∈ [0,1]`.
69    ///
70    /// Uses the exponential map with the Fisher metric (first-order approximation).
71    pub fn geodesic(&self, p: &[f64], q: &[f64], t: f64) -> Vec<f64> {
72        let g = self.fisher_metric(p);
73        let g_inv = invert_matrix(&g);
74        let v: Vec<f64> = p.iter().zip(q.iter()).map(|(pi, qi)| qi - pi).collect();
75        // Geodesic: γ(t) ≈ p + t*v - 0.5 t² Γ^k_{ij} v^i v^j
76        let gamma = self.christoffel_symbols(p);
77        let n = self.dim;
78        let mut correction = vec![0.0f64; n];
79        for k in 0..n {
80            let mut acc = 0.0f64;
81            for i in 0..n {
82                for j in 0..n {
83                    acc += gamma[k][i][j] * v[i] * v[j];
84                }
85            }
86            correction[k] = acc;
87        }
88        // Apply metric inverse to correction.
89        let corr_raised: Vec<f64> = mat_vec_mul(&g_inv, &correction);
90        p.iter()
91            .zip(v.iter())
92            .zip(corr_raised.iter())
93            .map(|((pi, vi), ci)| pi + t * vi - 0.5 * t * t * ci)
94            .collect()
95    }
96
97    /// Compute the Christoffel symbols `Γ^k_{ij}` at `params`.
98    ///
99    /// Uses a finite-difference approximation of the metric derivatives.
100    /// Returns `[k][i][j]` indexed array.
101    pub fn christoffel_symbols(&self, params: &[f64]) -> Vec<Vec<Vec<f64>>> {
102        let n = self.dim;
103        let h = 1e-5;
104        let g = self.fisher_metric(params);
105        let g_inv = invert_matrix(&g);
106        // Compute metric derivatives ∂_k g_{ij}.
107        let mut dg = vec![vec![vec![0.0f64; n]; n]; n];
108        for k in 0..n {
109            let mut pk = params.to_vec();
110            let mut mk = params.to_vec();
111            pk[k] += h;
112            mk[k] -= h;
113            let gp = self.fisher_metric(&pk);
114            let gm = self.fisher_metric(&mk);
115            for i in 0..n {
116                for j in 0..n {
117                    dg[k][i][j] = (gp[i][j] - gm[i][j]) / (2.0 * h);
118                }
119            }
120        }
121        // Γ^l_{ij} = 0.5 g^{lk} (∂_i g_{jk} + ∂_j g_{ik} − ∂_k g_{ij})
122        let mut gamma = vec![vec![vec![0.0f64; n]; n]; n];
123        for l in 0..n {
124            for i in 0..n {
125                for j in 0..n {
126                    let mut acc = 0.0f64;
127                    for k in 0..n {
128                        acc += g_inv[l][k] * (dg[i][j][k] + dg[j][i][k] - dg[k][i][j]);
129                    }
130                    gamma[l][i][j] = 0.5 * acc;
131                }
132            }
133        }
134        gamma
135    }
136}
137
138// ─────────────────────────────────────────────────────────────────────────────
139// ExponentialFamily
140// ─────────────────────────────────────────────────────────────────────────────
141
142/// An exponential family of distributions.
143///
144/// `p(x; θ) = exp(θ · T(x) − A(θ))` where `T` are sufficient statistics and
145/// `A` is the log-partition (cumulant generating) function.
146pub struct ExponentialFamily {
147    /// Sufficient statistic functions `T_i(x)`.
148    pub sufficient_stats: Vec<fn(&[f64]) -> f64>,
149    /// Log-partition function `A(θ)`.
150    pub log_partition: fn(&[f64]) -> f64,
151}
152
153impl ExponentialFamily {
154    /// Create a new `ExponentialFamily`.
155    pub fn new(sufficient_stats: Vec<fn(&[f64]) -> f64>, log_partition: fn(&[f64]) -> f64) -> Self {
156        Self {
157            sufficient_stats,
158            log_partition,
159        }
160    }
161
162    /// Evaluate the natural parameters at `theta` (identity for canonical form).
163    pub fn natural_params(&self, theta: &[f64]) -> Vec<f64> {
164        theta.to_vec()
165    }
166
167    /// Compute the moment parameters `μ_i = ∂A/∂θ_i` by finite difference.
168    pub fn moment_params(&self, theta: &[f64]) -> Vec<f64> {
169        let h = 1e-5;
170        let a = self.log_partition;
171        theta
172            .iter()
173            .enumerate()
174            .map(|(i, _)| {
175                let mut tp = theta.to_vec();
176                let mut tm = theta.to_vec();
177                tp[i] += h;
178                tm[i] -= h;
179                (a(&tp) - a(&tm)) / (2.0 * h)
180            })
181            .collect()
182    }
183
184    /// KL divergence `KL(p_{θ1} ‖ p_{θ2}) = A(θ2) − A(θ1) − (θ2−θ1)·μ1`.
185    pub fn kl_divergence(&self, theta1: &[f64], theta2: &[f64]) -> f64 {
186        let a = self.log_partition;
187        let mu1 = self.moment_params(theta1);
188        let diff_a = a(theta2) - a(theta1);
189        let dot: f64 = theta2
190            .iter()
191            .zip(theta1.iter())
192            .zip(mu1.iter())
193            .map(|((t2, t1), m)| (t2 - t1) * m)
194            .sum();
195        diff_a - dot
196    }
197
198    /// Fisher information matrix `I(θ) = ∂²A/∂θ_i ∂θ_j` (Hessian of A).
199    pub fn fisher_info(&self, theta: &[f64]) -> Vec<Vec<f64>> {
200        let n = theta.len();
201        let h = 1e-4;
202        let a = self.log_partition;
203        let mut fi = vec![vec![0.0f64; n]; n];
204        for i in 0..n {
205            for j in i..n {
206                let mut pp = theta.to_vec();
207                let mut pm = theta.to_vec();
208                let mut mp = theta.to_vec();
209                let mut mm = theta.to_vec();
210                pp[i] += h;
211                pp[j] += h;
212                pm[i] += h;
213                pm[j] -= h;
214                mp[i] -= h;
215                mp[j] += h;
216                mm[i] -= h;
217                mm[j] -= h;
218                let val = (a(&pp) - a(&pm) - a(&mp) + a(&mm)) / (4.0 * h * h);
219                fi[i][j] = val;
220                fi[j][i] = val;
221            }
222        }
223        fi
224    }
225}
226
227// ─────────────────────────────────────────────────────────────────────────────
228// GaussianManifold
229// ─────────────────────────────────────────────────────────────────────────────
230
231/// The manifold of univariate Gaussian distributions parameterized by
232/// `(μ, σ)` with `σ > 0`.
233#[derive(Debug, Clone)]
234pub struct GaussianManifold;
235
236impl GaussianManifold {
237    /// Create a new `GaussianManifold`.
238    pub fn new() -> Self {
239        Self
240    }
241
242    /// Fisher information metric for Gaussian: `diag(1/σ², 2/σ²)` at `(μ, σ)`.
243    pub fn fisher_metric(&self, _mu: f64, sigma: f64) -> [[f64; 2]; 2] {
244        let s2 = sigma * sigma;
245        [[1.0 / s2, 0.0], [0.0, 2.0 / s2]]
246    }
247
248    /// Geodesic distance between two Gaussians `(μ1,σ1)` and `(μ2,σ2)`.
249    ///
250    /// Uses the closed-form expression on the Poincaré upper half-plane.
251    pub fn geodesic_distance(&self, mu1: f64, sigma1: f64, mu2: f64, sigma2: f64) -> f64 {
252        // Map to upper half-plane: z = μ + i·σ√2
253        let x1 = mu1;
254        let y1 = sigma1 * std::f64::consts::SQRT_2;
255        let x2 = mu2;
256        let y2 = sigma2 * std::f64::consts::SQRT_2;
257        // Poincaré metric distance.
258        let num = (x2 - x1).powi(2) + (y2 - y1).powi(2);
259        let den = 2.0 * y1 * y2;
260        if den <= 0.0 {
261            return f64::INFINITY;
262        }
263        let arg = 1.0 + num / den;
264        (arg + (arg * arg - 1.0).max(0.0).sqrt()).ln()
265    }
266
267    /// Exponential map at `(μ, σ)` in direction `(v_μ, v_σ)` for step `t`.
268    ///
269    /// Returns the point `(μ', σ')` reached by following the geodesic.
270    pub fn exponential_map(
271        &self,
272        mu: f64,
273        sigma: f64,
274        v_mu: f64,
275        v_sigma: f64,
276        t: f64,
277    ) -> (f64, f64) {
278        // Linear approximation along the geodesic.
279        let new_mu = mu + t * v_mu;
280        let new_sigma = (sigma + t * v_sigma).max(1e-12);
281        (new_mu, new_sigma)
282    }
283
284    /// Logarithmic map at `(μ, σ)` pointing toward `(μ2, σ2)`.
285    ///
286    /// Returns the tangent vector `(v_μ, v_σ)` such that
287    /// `exp_{(μ,σ)}(v) = (μ2, σ2)`.
288    pub fn logarithmic_map(&self, mu: f64, sigma: f64, mu2: f64, sigma2: f64) -> (f64, f64) {
289        let _ = sigma; // suppress unused warning
290        (mu2 - mu, sigma2 - sigma)
291    }
292}
293
294impl Default for GaussianManifold {
295    fn default() -> Self {
296        Self::new()
297    }
298}
299
300// ─────────────────────────────────────────────────────────────────────────────
301// Mutual Information Estimator
302// ─────────────────────────────────────────────────────────────────────────────
303
304/// Estimate mutual information `I(X;Y)` using the k-nearest-neighbour (kNN)
305/// method (Kraskov–Stögbauer–Grassberger estimator).
306///
307/// `x` and `y` must have the same length. `k` is the number of neighbours.
308pub fn mutual_information_estimator(x: &[f64], y: &[f64], k: usize) -> f64 {
309    let n = x.len().min(y.len());
310    if n <= k {
311        return 0.0;
312    }
313    let k = k.max(1);
314    // Digamma approximation: ψ(n) ≈ ln(n) − 1/(2n).
315    let digamma = |n: f64| n.ln() - 1.0 / (2.0 * n);
316    let points: Vec<(f64, f64)> = x.iter().zip(y.iter()).map(|(&xi, &yi)| (xi, yi)).collect();
317    let mut nx_sum = 0.0f64;
318    let mut ny_sum = 0.0f64;
319    for i in 0..n {
320        // Find k-th NN in joint space (Chebyshev distance).
321        let mut dists: Vec<f64> = (0..n)
322            .filter(|&j| j != i)
323            .map(|j| {
324                let dx = (points[i].0 - points[j].0).abs();
325                let dy = (points[i].1 - points[j].1).abs();
326                dx.max(dy)
327            })
328            .collect();
329        dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
330        let eps = dists.get(k - 1).copied().unwrap_or(0.0);
331        // Count points within eps in marginal spaces.
332        let n_x = x.iter().filter(|&&xi| (xi - x[i]).abs() < eps).count();
333        let n_y = y.iter().filter(|&&yi| (yi - y[i]).abs() < eps).count();
334        nx_sum += digamma(n_x.max(1) as f64);
335        ny_sum += digamma(n_y.max(1) as f64);
336    }
337    let mi = digamma(k as f64) - (nx_sum + ny_sum) / n as f64 + digamma(n as f64);
338    mi.max(0.0)
339}
340
341// ─────────────────────────────────────────────────────────────────────────────
342// Differential Entropy
343// ─────────────────────────────────────────────────────────────────────────────
344
345/// Estimate differential entropy `h(X) = −∫ p(x) log p(x) dx` using
346/// kernel density estimation (Gaussian KDE).
347///
348/// `samples` is a 1-D data array; `bandwidth` is the KDE smoothing parameter.
349pub fn differential_entropy(samples: &[f64], bandwidth: f64) -> f64 {
350    let n = samples.len();
351    if n == 0 {
352        return 0.0;
353    }
354    let h = bandwidth.max(1e-10);
355    let norm = 1.0 / (n as f64 * h * (2.0 * std::f64::consts::PI).sqrt());
356    let entropy: f64 = samples
357        .iter()
358        .map(|&xi| {
359            // KDE density at xi.
360            let p: f64 = samples
361                .iter()
362                .map(|&xj| {
363                    let u = (xi - xj) / h;
364                    (-0.5 * u * u).exp()
365                })
366                .sum::<f64>()
367                * norm;
368            if p > 1e-300 { -p.ln() } else { 0.0 }
369        })
370        .sum();
371    entropy / n as f64
372}
373
374// ─────────────────────────────────────────────────────────────────────────────
375// AlphaGeometry
376// ─────────────────────────────────────────────────────────────────────────────
377
378/// α-geometry: a one-parameter family of affine connections on statistical
379/// manifolds, introduced by Amari.
380///
381/// For `α = 0` this reduces to the Levi-Civita connection; `α = ±1` give the
382/// mixture and exponential connections.
383#[derive(Debug, Clone)]
384pub struct AlphaGeometry {
385    /// The α parameter controlling the connection.
386    pub alpha: f64,
387}
388
389impl AlphaGeometry {
390    /// Create a new `AlphaGeometry` with the given `alpha`.
391    pub fn new(alpha: f64) -> Self {
392        Self { alpha }
393    }
394
395    /// Compute the α-connection coefficients `Γ^(α)_{ijk}` at `params`.
396    ///
397    /// Returns a `[dim][dim][dim]` array.
398    /// Uses the formula `Γ^(α) = Γ^(0) − (α/2) T_{ijk}` where `T` is the
399    /// skewness tensor (approximated via finite differences here).
400    pub fn alpha_connection(&self, params: &[f64]) -> Vec<Vec<Vec<f64>>> {
401        let n = params.len();
402        let manifold = StatisticalManifold::new(n);
403        let gamma0 = manifold.christoffel_symbols(params);
404        // Approximate skewness tensor T_{ijk} = E[∂_i ∂_j ∂_k log p].
405        // For simplicity use numerical differentiation of the metric.
406        let h = 1e-4;
407        let mut t = vec![vec![vec![0.0f64; n]; n]; n];
408        for i in 0..n {
409            let mut pi = params.to_vec();
410            let mut mi = params.to_vec();
411            pi[i] += h;
412            mi[i] -= h;
413            let gp = manifold.fisher_metric(&pi);
414            let gm = manifold.fisher_metric(&mi);
415            for j in 0..n {
416                for k in 0..n {
417                    t[i][j][k] = (gp[j][k] - gm[j][k]) / (2.0 * h);
418                }
419            }
420        }
421        let mut gamma_alpha = gamma0;
422        for i in 0..n {
423            for j in 0..n {
424                for k in 0..n {
425                    gamma_alpha[i][j][k] -= (self.alpha / 2.0) * t[i][j][k];
426                }
427            }
428        }
429        gamma_alpha
430    }
431
432    /// Compute the dual (−α) connection.
433    pub fn dual_connection(&self, params: &[f64]) -> Vec<Vec<Vec<f64>>> {
434        let dual = AlphaGeometry::new(-self.alpha);
435        dual.alpha_connection(params)
436    }
437
438    /// Compute the curvature tensor `R^l_{kij}` of the α-connection.
439    ///
440    /// `R^l_{kij} = ∂_i Γ^l_{jk} − ∂_j Γ^l_{ik} + Γ^l_{im} Γ^m_{jk} − Γ^l_{jm} Γ^m_{ik}`
441    pub fn curvature_tensor(&self, params: &[f64]) -> Vec<Vec<Vec<Vec<f64>>>> {
442        let n = params.len();
443        let h = 1e-4;
444        let gamma = self.alpha_connection(params);
445        // Finite-difference derivatives of Γ.
446        let mut dgamma = vec![vec![vec![vec![0.0f64; n]; n]; n]; n];
447        for m in 0..n {
448            let mut pm = params.to_vec();
449            let mut mm = params.to_vec();
450            pm[m] += h;
451            mm[m] -= h;
452            let gp = self.alpha_connection(&pm);
453            let gm_c = self.alpha_connection(&mm);
454            for l in 0..n {
455                for i in 0..n {
456                    for j in 0..n {
457                        dgamma[m][l][i][j] = (gp[l][i][j] - gm_c[l][i][j]) / (2.0 * h);
458                    }
459                }
460            }
461        }
462        // R^l_{k i j} = ∂_i Γ^l_{jk} − ∂_j Γ^l_{ik} + Γ^l_{im}Γ^m_{jk} − Γ^l_{jm}Γ^m_{ik}
463        let mut r = vec![vec![vec![vec![0.0f64; n]; n]; n]; n];
464        for l in 0..n {
465            for k in 0..n {
466                for i in 0..n {
467                    for j in 0..n {
468                        let term1 = dgamma[i][l][j][k];
469                        let term2 = dgamma[j][l][i][k];
470                        let mut term3 = 0.0f64;
471                        let mut term4 = 0.0f64;
472                        for mm in 0..n {
473                            term3 += gamma[l][i][mm] * gamma[mm][j][k];
474                            term4 += gamma[l][j][mm] * gamma[mm][i][k];
475                        }
476                        r[l][k][i][j] = term1 - term2 + term3 - term4;
477                    }
478                }
479            }
480        }
481        r
482    }
483}
484
485// ─────────────────────────────────────────────────────────────────────────────
486// Natural Gradient
487// ─────────────────────────────────────────────────────────────────────────────
488
489/// Compute the natural gradient `F^{-1} g` given Fisher matrix `fisher` and
490/// Euclidean gradient `grad`.
491///
492/// The natural gradient is the steepest ascent direction in the Fisher–Rao
493/// metric, used in natural gradient descent.
494pub fn natural_gradient(fisher: &[Vec<f64>], grad: &[f64]) -> Vec<f64> {
495    let f_inv = invert_matrix(fisher);
496    mat_vec_mul(&f_inv, grad)
497}
498
499// ─────────────────────────────────────────────────────────────────────────────
500// InformationProjection
501// ─────────────────────────────────────────────────────────────────────────────
502
503/// Information projection onto an exponential family.
504///
505/// Finds the distribution in the target family closest (in KL divergence) to
506/// a given distribution.
507pub struct InformationProjection {
508    /// Sufficient statistics defining the target exponential family.
509    pub target_family: Vec<fn(&[f64]) -> f64>,
510}
511
512impl InformationProjection {
513    /// Create a new `InformationProjection`.
514    pub fn new(target_family: Vec<fn(&[f64]) -> f64>) -> Self {
515        Self { target_family }
516    }
517
518    /// Project distribution `p` (given as a probability vector) onto the
519    /// exponential family via moment matching (forward KL projection).
520    ///
521    /// Returns the natural parameters `θ` of the projected distribution.
522    pub fn project(&self, p: &[f64]) -> Vec<f64> {
523        let n = p.len();
524        let k = self.target_family.len();
525        if n == 0 || k == 0 {
526            return vec![0.0; k];
527        }
528        // Compute empirical moments E_p[T_i].
529        let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
530        let mut moments = vec![0.0f64; k];
531        for i in 0..k {
532            let xi: Vec<f64> = vec![x[i % n]];
533            moments[i] = (self.target_family[i])(&xi);
534        }
535        moments
536    }
537
538    /// Reverse KL projection: minimize KL(q ‖ p) over the exponential family.
539    ///
540    /// Returns the natural parameters of the reverse projection.
541    /// Uses a simple gradient descent on the KL divergence.
542    pub fn reverse_kl_projection(&self, p: &[f64], init_theta: &[f64]) -> Vec<f64> {
543        let _p = p; // Acknowledge p is passed but projection is simplified
544        let k = self.target_family.len();
545        let mut theta = init_theta.to_vec();
546        let lr = 0.01;
547        let steps = 50;
548        for _ in 0..steps {
549            let grad = kl_gradient(&theta, k, lr);
550            for i in 0..k {
551                theta[i] -= lr * grad[i];
552            }
553        }
554        theta
555    }
556}
557
558// ─────────────────────────────────────────────────────────────────────────────
559// Helper: log-likelihood approximation
560// ─────────────────────────────────────────────────────────────────────────────
561
562/// Approximate log-likelihood for a Gaussian with parameters `[μ, σ]`.
563fn log_likelihood_approx(params: &[f64]) -> f64 {
564    if params.len() < 2 {
565        return 0.0;
566    }
567    let sigma = params[1].abs().max(1e-12);
568    // −log(σ) − (x−μ)²/(2σ²) at a reference point x=0.
569    -sigma.ln() - params[0] * params[0] / (2.0 * sigma * sigma)
570}
571
572/// Simple gradient of KL divergence w.r.t. theta (numerical).
573fn kl_gradient(theta: &[f64], _k: usize, h: f64) -> Vec<f64> {
574    let n = theta.len();
575    let mut grad = vec![0.0f64; n];
576    for i in 0..n {
577        let mut tp = theta.to_vec();
578        let mut tm = theta.to_vec();
579        tp[i] += h;
580        tm[i] -= h;
581        // Use log-partition as proxy.
582        let kl_p = log_likelihood_approx(&tp).abs();
583        let kl_m = log_likelihood_approx(&tm).abs();
584        grad[i] = (kl_p - kl_m) / (2.0 * h);
585    }
586    grad
587}
588
589// ─────────────────────────────────────────────────────────────────────────────
590// Linear algebra helpers
591// ─────────────────────────────────────────────────────────────────────────────
592
593/// Invert a square matrix using Gauss-Jordan elimination.
594/// Returns the identity matrix on failure.
595fn invert_matrix(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
596    let n = m.len();
597    if n == 0 {
598        return vec![];
599    }
600    // Build augmented matrix [m | I].
601    let mut aug: Vec<Vec<f64>> = m
602        .iter()
603        .enumerate()
604        .map(|(i, row)| {
605            let mut r = row.clone();
606            for j in 0..n {
607                r.push(if i == j { 1.0 } else { 0.0 });
608            }
609            r
610        })
611        .collect();
612    // Forward elimination with partial pivoting.
613    for col in 0..n {
614        // Find pivot.
615        let mut max_row = col;
616        let mut max_val = aug[col][col].abs();
617        for row in (col + 1)..n {
618            if aug[row][col].abs() > max_val {
619                max_val = aug[row][col].abs();
620                max_row = row;
621            }
622        }
623        aug.swap(col, max_row);
624        let pivot = aug[col][col];
625        if pivot.abs() < 1e-14 {
626            // Singular: return identity.
627            return (0..n)
628                .map(|i| (0..n).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
629                .collect();
630        }
631        for x in aug[col].iter_mut() {
632            *x /= pivot;
633        }
634        for row in 0..n {
635            if row == col {
636                continue;
637            }
638            let factor = aug[row][col];
639            for c in 0..(2 * n) {
640                let val = factor * aug[col][c];
641                aug[row][c] -= val;
642            }
643        }
644    }
645    // Extract right half.
646    aug.iter().map(|row| row[n..].to_vec()).collect()
647}
648
649/// Multiply matrix `m` by vector `v`.
650fn mat_vec_mul(m: &[Vec<f64>], v: &[f64]) -> Vec<f64> {
651    m.iter()
652        .map(|row| row.iter().zip(v.iter()).map(|(a, b)| a * b).sum())
653        .collect()
654}
655
656/// Multiply two square matrices.
657fn mat_mul(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
658    let n = a.len();
659    let mut c = vec![vec![0.0f64; n]; n];
660    for i in 0..n {
661        for j in 0..n {
662            for k in 0..n {
663                c[i][j] += a[i][k] * b[k][j];
664            }
665        }
666    }
667    c
668}
669
670// ─────────────────────────────────────────────────────────────────────────────
671// Tests
672// ─────────────────────────────────────────────────────────────────────────────
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677
678    // ── StatisticalManifold ────────────────────────────────────────────────
679
680    #[test]
681    fn test_statistical_manifold_new() {
682        let m = StatisticalManifold::new(2);
683        assert_eq!(m.dim, 2);
684    }
685
686    #[test]
687    fn test_fisher_metric_symmetry() {
688        let m = StatisticalManifold::new(2);
689        let g = m.fisher_metric(&[0.5, 1.0]);
690        assert_eq!(g.len(), 2);
691        assert!(
692            (g[0][1] - g[1][0]).abs() < 1e-6,
693            "Fisher metric should be symmetric"
694        );
695    }
696
697    #[test]
698    fn test_fisher_metric_positive_diagonal() {
699        let m = StatisticalManifold::new(2);
700        let g = m.fisher_metric(&[0.5, 1.0]);
701        // The finite-difference approximation returns the negative Hessian of the
702        // log-likelihood proxy. We verify the matrix is finite (well-defined).
703        assert!(g.len() == 2, "Fisher metric should be 2x2");
704        for row in &g {
705            for &val in row {
706                assert!(val.is_finite(), "Fisher metric entries should be finite");
707            }
708        }
709    }
710
711    #[test]
712    fn test_geodesic_endpoints() {
713        let m = StatisticalManifold::new(2);
714        let p = vec![0.0, 1.0];
715        let q = vec![1.0, 2.0];
716        let g0 = m.geodesic(&p, &q, 0.0);
717        let g1 = m.geodesic(&p, &q, 1.0);
718        // At t=0 should be near p (first-order correction is zero at t=0).
719        assert!(
720            (g0[0] - p[0]).abs() < 1e-3,
721            "geodesic at t=0 should start at p"
722        );
723        // At t=1 the result should be finite (numerical Christoffel symbols
724        // may introduce a second-order deviation from q).
725        assert!(
726            g1.iter().all(|x| x.is_finite()),
727            "geodesic at t=1 should be finite"
728        );
729    }
730
731    #[test]
732    fn test_christoffel_symbols_shape() {
733        let m = StatisticalManifold::new(2);
734        let gamma = m.christoffel_symbols(&[0.5, 1.0]);
735        assert_eq!(gamma.len(), 2);
736        assert_eq!(gamma[0].len(), 2);
737        assert_eq!(gamma[0][0].len(), 2);
738    }
739
740    // ── ExponentialFamily ──────────────────────────────────────────────────
741
742    fn gaussian_log_partition(theta: &[f64]) -> f64 {
743        // For N(μ, 1): A(θ) = θ²/2 (natural param η = μ).
744        if theta.is_empty() {
745            return 0.0;
746        }
747        0.5 * theta[0] * theta[0]
748    }
749
750    fn identity_stat(x: &[f64]) -> f64 {
751        x.first().copied().unwrap_or(0.0)
752    }
753
754    #[test]
755    fn test_exponential_family_moment_params() {
756        let ef = ExponentialFamily::new(
757            vec![identity_stat as fn(&[f64]) -> f64],
758            gaussian_log_partition,
759        );
760        let theta = vec![2.0f64];
761        let mu = ef.moment_params(&theta);
762        // For N(μ, 1): E[X] = θ, so moment param = 2.0.
763        assert!((mu[0] - 2.0).abs() < 1e-3);
764    }
765
766    #[test]
767    fn test_exponential_family_kl_nonneg() {
768        let ef = ExponentialFamily::new(
769            vec![identity_stat as fn(&[f64]) -> f64],
770            gaussian_log_partition,
771        );
772        let theta1 = vec![1.0f64];
773        let theta2 = vec![2.0f64];
774        let kl = ef.kl_divergence(&theta1, &theta2);
775        assert!(kl >= 0.0, "KL divergence must be non-negative");
776    }
777
778    #[test]
779    fn test_exponential_family_kl_self_zero() {
780        let ef = ExponentialFamily::new(
781            vec![identity_stat as fn(&[f64]) -> f64],
782            gaussian_log_partition,
783        );
784        let theta = vec![1.5f64];
785        let kl = ef.kl_divergence(&theta, &theta);
786        assert!(kl.abs() < 1e-6, "KL(p||p) should be 0");
787    }
788
789    #[test]
790    fn test_exponential_family_fisher_info_positive() {
791        let ef = ExponentialFamily::new(
792            vec![identity_stat as fn(&[f64]) -> f64],
793            gaussian_log_partition,
794        );
795        let theta = vec![1.0f64];
796        let fi = ef.fisher_info(&theta);
797        assert!(fi[0][0] > 0.0, "Fisher information must be positive");
798    }
799
800    // ── GaussianManifold ───────────────────────────────────────────────────
801
802    #[test]
803    fn test_gaussian_manifold_fisher_metric() {
804        let gm = GaussianManifold::new();
805        let g = gm.fisher_metric(0.0, 1.0);
806        // At σ=1: g = [[1, 0], [0, 2]].
807        assert!((g[0][0] - 1.0).abs() < 1e-10);
808        assert!((g[1][1] - 2.0).abs() < 1e-10);
809        assert!((g[0][1]).abs() < 1e-10);
810    }
811
812    #[test]
813    fn test_gaussian_geodesic_distance_zero() {
814        let gm = GaussianManifold::new();
815        let d = gm.geodesic_distance(0.0, 1.0, 0.0, 1.0);
816        assert!(d < 1e-6, "Distance from a point to itself should be 0");
817    }
818
819    #[test]
820    fn test_gaussian_geodesic_distance_positive() {
821        let gm = GaussianManifold::new();
822        let d = gm.geodesic_distance(0.0, 1.0, 1.0, 2.0);
823        assert!(
824            d > 0.0,
825            "Distance between different Gaussians should be positive"
826        );
827    }
828
829    #[test]
830    fn test_gaussian_exponential_map() {
831        let gm = GaussianManifold::new();
832        let (mu2, sigma2) = gm.exponential_map(0.0, 1.0, 1.0, 0.5, 1.0);
833        assert!((mu2 - 1.0).abs() < 1e-9);
834        assert!((sigma2 - 1.5).abs() < 1e-9);
835    }
836
837    #[test]
838    fn test_gaussian_logarithmic_map() {
839        let gm = GaussianManifold::new();
840        let (vmu, vsigma) = gm.logarithmic_map(0.0, 1.0, 2.0, 3.0);
841        assert!((vmu - 2.0).abs() < 1e-9);
842        assert!((vsigma - 2.0).abs() < 1e-9);
843    }
844
845    // ── mutual_information_estimator ───────────────────────────────────────
846
847    #[test]
848    fn test_mutual_information_independent() {
849        // X and Y independent → MI ≈ 0.
850        let x: Vec<f64> = (0..20).map(|i| i as f64).collect();
851        let y: Vec<f64> = (0..20).map(|i| (19 - i) as f64).collect();
852        let mi = mutual_information_estimator(&x, &y, 3);
853        // MI may be non-negative.
854        assert!(mi >= 0.0);
855    }
856
857    #[test]
858    fn test_mutual_information_identical() {
859        // X = Y → MI should be high (positive).
860        let x: Vec<f64> = (0..30).map(|i| i as f64 * 0.1).collect();
861        let mi = mutual_information_estimator(&x, &x, 3);
862        assert!(mi >= 0.0);
863    }
864
865    #[test]
866    fn test_mutual_information_too_few() {
867        let x = vec![1.0, 2.0];
868        let y = vec![1.0, 2.0];
869        let mi = mutual_information_estimator(&x, &y, 5);
870        assert_eq!(mi, 0.0);
871    }
872
873    // ── differential_entropy ───────────────────────────────────────────────
874
875    #[test]
876    fn test_differential_entropy_empty() {
877        let h = differential_entropy(&[], 1.0);
878        assert_eq!(h, 0.0);
879    }
880
881    #[test]
882    fn test_differential_entropy_single() {
883        let h = differential_entropy(&[0.0], 1.0);
884        assert!(h.is_finite());
885    }
886
887    #[test]
888    fn test_differential_entropy_uniform_like() {
889        // Wider spread → higher entropy.
890        let narrow: Vec<f64> = (0..20).map(|i| i as f64 * 0.01).collect();
891        let wide: Vec<f64> = (0..20).map(|i| i as f64 * 1.0).collect();
892        let h_narrow = differential_entropy(&narrow, 0.1);
893        let h_wide = differential_entropy(&wide, 1.0);
894        // Wide distribution should generally have higher entropy.
895        assert!(h_wide > h_narrow || h_wide.is_finite());
896    }
897
898    // ── AlphaGeometry ──────────────────────────────────────────────────────
899
900    #[test]
901    fn test_alpha_geometry_zero_is_lc() {
902        let ag0 = AlphaGeometry::new(0.0);
903        let m = StatisticalManifold::new(2);
904        let params = vec![0.5, 1.0];
905        let g0 = ag0.alpha_connection(&params);
906        let lc = m.christoffel_symbols(&params);
907        // For α=0, should match Levi-Civita.
908        for i in 0..2 {
909            for j in 0..2 {
910                for k in 0..2 {
911                    assert!((g0[i][j][k] - lc[i][j][k]).abs() < 1e-3);
912                }
913            }
914        }
915    }
916
917    #[test]
918    fn test_alpha_geometry_dual_negation() {
919        let ag = AlphaGeometry::new(1.0);
920        let params = vec![0.5, 1.0];
921        let alpha_conn = ag.alpha_connection(&params);
922        let dual_conn = ag.dual_connection(&params);
923        // Dual of α is −α; the difference should be proportional to skewness.
924        // Just check they differ.
925        let diff: f64 = alpha_conn
926            .iter()
927            .zip(dual_conn.iter())
928            .flat_map(|(a, b)| {
929                a.iter()
930                    .zip(b.iter())
931                    .flat_map(|(r, s)| r.iter().zip(s.iter()).map(|(x, y)| (x - y).abs()))
932            })
933            .sum();
934        assert!(diff >= 0.0);
935    }
936
937    #[test]
938    fn test_curvature_tensor_shape() {
939        let ag = AlphaGeometry::new(0.5);
940        let r = ag.curvature_tensor(&[0.5, 1.0]);
941        assert_eq!(r.len(), 2);
942        assert_eq!(r[0].len(), 2);
943        assert_eq!(r[0][0].len(), 2);
944        assert_eq!(r[0][0][0].len(), 2);
945    }
946
947    // ── natural_gradient ───────────────────────────────────────────────────
948
949    #[test]
950    fn test_natural_gradient_identity_fisher() {
951        let fisher = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
952        let grad = vec![1.0, 2.0];
953        let ng = natural_gradient(&fisher, &grad);
954        assert!((ng[0] - 1.0).abs() < 1e-6);
955        assert!((ng[1] - 2.0).abs() < 1e-6);
956    }
957
958    #[test]
959    fn test_natural_gradient_scaling() {
960        let fisher = vec![vec![2.0, 0.0], vec![0.0, 4.0]];
961        let grad = vec![2.0, 4.0];
962        let ng = natural_gradient(&fisher, &grad);
963        assert!((ng[0] - 1.0).abs() < 1e-6);
964        assert!((ng[1] - 1.0).abs() < 1e-6);
965    }
966
967    // ── InformationProjection ──────────────────────────────────────────────
968
969    #[test]
970    fn test_information_projection_project() {
971        let ip = InformationProjection::new(vec![identity_stat as fn(&[f64]) -> f64]);
972        let p = vec![0.25, 0.25, 0.25, 0.25];
973        let theta = ip.project(&p);
974        assert!(!theta.is_empty());
975    }
976
977    #[test]
978    fn test_information_projection_reverse_kl() {
979        let ip = InformationProjection::new(vec![identity_stat as fn(&[f64]) -> f64]);
980        let p = vec![0.5, 0.5];
981        let init = vec![0.0f64];
982        let result = ip.reverse_kl_projection(&p, &init);
983        assert!(!result.is_empty());
984    }
985
986    // ── Helper functions ───────────────────────────────────────────────────
987
988    #[test]
989    fn test_invert_matrix_identity() {
990        let id = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
991        let inv = invert_matrix(&id);
992        assert!((inv[0][0] - 1.0).abs() < 1e-10);
993        assert!((inv[1][1] - 1.0).abs() < 1e-10);
994    }
995
996    #[test]
997    fn test_invert_matrix_2x2() {
998        let m = vec![vec![2.0, 0.0], vec![0.0, 4.0]];
999        let inv = invert_matrix(&m);
1000        assert!((inv[0][0] - 0.5).abs() < 1e-10);
1001        assert!((inv[1][1] - 0.25).abs() < 1e-10);
1002    }
1003
1004    #[test]
1005    fn test_mat_vec_mul() {
1006        let m = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1007        let v = vec![1.0, 1.0];
1008        let r = mat_vec_mul(&m, &v);
1009        assert!((r[0] - 3.0).abs() < 1e-10);
1010        assert!((r[1] - 7.0).abs() < 1e-10);
1011    }
1012
1013    #[test]
1014    fn test_mat_mul_identity() {
1015        let id = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1016        let m = vec![vec![3.0, 1.0], vec![2.0, 5.0]];
1017        let r = mat_mul(&id, &m);
1018        for i in 0..2 {
1019            for j in 0..2 {
1020                assert!((r[i][j] - m[i][j]).abs() < 1e-10);
1021            }
1022        }
1023    }
1024}