Skip to main content

amari_info_geom/
lib.rs

1//! Information Geometry operations for statistical manifolds
2//!
3//! This crate implements the foundational concepts of Information Geometry,
4//! including Fisher metrics, α-connections, Bregman divergences, and the
5//! Amari-Chentsov tensor structure.
6
7use amari_core::Multivector;
8use num_traits::{Float, Zero};
9use thiserror::Error;
10
11// GPU acceleration exports
12#[cfg(feature = "gpu")]
13pub use gpu::{
14    GpuBregmanData, GpuFisherData, GpuStatisticalManifold, InfoGeomGpuConfig, InfoGeomGpuError,
15    InfoGeomGpuOps, InfoGeomGpuResult,
16};
17
18#[cfg(test)]
19pub mod comprehensive_tests;
20#[cfg(feature = "gpu")]
21pub mod gpu;
22pub mod verified_contracts;
23
24// pub mod fisher;
25// pub mod connections;
26// pub mod divergences;
27// pub mod manifolds;
28
29#[derive(Error, Debug)]
30pub enum InfoGeomError {
31    #[error("Numerical instability in computation")]
32    NumericalInstability,
33
34    #[error("Invalid parameter dimension: expected {expected}, got {actual}")]
35    InvalidDimension { expected: usize, actual: usize },
36
37    #[error("Parameter out of valid range")]
38    ParameterOutOfRange,
39}
40
41/// Trait for objects that can be used as parameters on statistical manifolds
42pub trait Parameter {
43    type Scalar: Float;
44
45    /// Dimension of the parameter space
46    fn dimension(&self) -> usize;
47
48    /// Get parameter component by index
49    fn get_component(&self, index: usize) -> Self::Scalar;
50
51    /// Set parameter component by index
52    fn set_component(&mut self, index: usize, value: Self::Scalar);
53}
54
55impl<const P: usize, const Q: usize, const R: usize> Parameter for Multivector<P, Q, R> {
56    type Scalar = f64;
57
58    fn dimension(&self) -> usize {
59        Self::BASIS_COUNT
60    }
61
62    fn get_component(&self, index: usize) -> f64 {
63        self.get(index)
64    }
65
66    fn set_component(&mut self, index: usize, value: f64) {
67        self.set(index, value);
68    }
69}
70
71/// Fisher Information Metric for statistical manifolds
72pub trait FisherMetric<T: Parameter> {
73    /// Compute the Fisher information matrix at a point
74    fn fisher_matrix(&self, point: &T) -> Result<Vec<Vec<T::Scalar>>, InfoGeomError>;
75
76    /// Compute inner product using Fisher metric
77    fn fisher_inner_product(&self, point: &T, v1: &T, v2: &T) -> Result<T::Scalar, InfoGeomError> {
78        let g = self.fisher_matrix(point)?;
79        let mut result = T::Scalar::zero();
80
81        for i in 0..v1.dimension() {
82            for j in 0..v2.dimension() {
83                if i < g.len() && j < g[i].len() {
84                    result = result + g[i][j] * v1.get_component(i) * v2.get_component(j);
85                }
86            }
87        }
88
89        Ok(result)
90    }
91}
92
93/// α-connection on a statistical manifold
94pub trait AlphaConnection<T: Parameter> {
95    /// The α parameter defining this connection (-1 ≤ α ≤ 1)
96    fn alpha(&self) -> f64;
97
98    /// Christoffel symbols for the α-connection
99    fn christoffel_symbols(&self, point: &T) -> Result<Vec<Vec<Vec<T::Scalar>>>, InfoGeomError>;
100
101    /// Covariant derivative along a curve
102    fn covariant_derivative(
103        &self,
104        point: &T,
105        vector: &T,
106        direction: &T,
107    ) -> Result<T, InfoGeomError>;
108}
109
110/// Dually flat manifold with e-connection and m-connection
111#[derive(Clone, Debug)]
112pub struct DuallyFlatManifold {
113    dimension: usize,
114    #[allow(dead_code)]
115    alpha: f64,
116}
117
118impl DuallyFlatManifold {
119    /// Create new dually flat manifold with given dimension and alpha parameter
120    pub fn new(dimension: usize, alpha: f64) -> Self {
121        Self { dimension, alpha }
122    }
123
124    /// Compute Fisher information metric at a point
125    pub fn fisher_metric_at(&self, point: &[f64]) -> FisherInformationMatrix {
126        // For exponential families, Fisher metric is the Hessian of log partition function
127        let mut matrix = vec![vec![0.0; self.dimension]; self.dimension];
128
129        // Simplified Fisher metric for probability simplex
130        #[allow(clippy::needless_range_loop)]
131        for i in 0..self.dimension {
132            for j in 0..self.dimension {
133                if i == j && i < point.len() {
134                    // Diagonal elements: 1/p_i for probability distributions
135                    matrix[i][j] = if point[i] > 1e-12 {
136                        1.0 / point[i]
137                    } else {
138                        1e12
139                    };
140                } else {
141                    // Off-diagonal elements are zero for independent components
142                    matrix[i][j] = 0.0;
143                }
144            }
145        }
146
147        FisherInformationMatrix { matrix }
148    }
149
150    /// Compute Bregman divergence between two points
151    pub fn bregman_divergence(&self, p: &[f64], q: &[f64]) -> f64 {
152        // KL divergence for probability distributions: D_KL(p||q) = Σ p_i log(p_i/q_i)
153        let mut divergence = 0.0;
154
155        for i in 0..p.len().min(q.len()) {
156            if p[i] > 1e-12 && q[i] > 1e-12 {
157                divergence += p[i] * (p[i] / q[i]).ln();
158            }
159        }
160
161        divergence
162    }
163}
164
165/// Fisher Information Matrix
166#[derive(Clone, Debug)]
167pub struct FisherInformationMatrix {
168    matrix: Vec<Vec<f64>>,
169}
170
171impl FisherInformationMatrix {
172    /// Compute eigenvalues to check positive definiteness
173    pub fn eigenvalues(&self) -> Vec<f64> {
174        // Simplified eigenvalue computation for testing
175        // In practice, would use proper linear algebra library
176        let mut eigenvals = Vec::new();
177
178        // For diagonal matrices, eigenvalues are the diagonal elements
179        for i in 0..self.matrix.len() {
180            if i < self.matrix[i].len() {
181                eigenvals.push(self.matrix[i][i]);
182            }
183        }
184
185        eigenvals
186    }
187}
188
189/// Simplified AlphaConnection implementation for tests
190#[derive(Clone, Debug)]
191pub struct SimpleAlphaConnection {
192    alpha: f64,
193}
194
195impl SimpleAlphaConnection {
196    pub fn new(alpha: f64) -> Self {
197        Self { alpha }
198    }
199
200    pub fn alpha(&self) -> f64 {
201        self.alpha
202    }
203}
204
205// For backwards compatibility, we expose both the trait and struct versions
206
207/// Compute the Bregman divergence between two points
208pub fn bregman_divergence<F>(
209    phi: F,
210    p: &Multivector<3, 0, 0>,
211    q: &Multivector<3, 0, 0>,
212) -> Result<f64, InfoGeomError>
213where
214    F: Fn(&Multivector<3, 0, 0>) -> f64,
215{
216    let phi_p = phi(p);
217    let phi_q = phi(q);
218
219    // Approximate gradient using finite differences
220    let eps = 1e-8;
221    let mut grad_phi_q = Multivector::zero();
222
223    for i in 0..8 {
224        let mut q_plus = q.clone();
225        q_plus.set(i, q.get(i) + eps);
226        let phi_plus = phi(&q_plus);
227
228        let mut q_minus = q.clone();
229        q_minus.set(i, q.get(i) - eps);
230        let phi_minus = phi(&q_minus);
231
232        let derivative = (phi_plus - phi_minus) / (2.0 * eps);
233        grad_phi_q.set(i, derivative);
234    }
235
236    let diff = p - q;
237    let inner_product = diff.scalar_product(&grad_phi_q);
238
239    Ok(phi_p - phi_q - inner_product)
240}
241
242/// Compute KL divergence using natural and expectation parameters
243pub fn kl_divergence(
244    eta_p: &Multivector<3, 0, 0>, // Natural parameters for p
245    eta_q: &Multivector<3, 0, 0>, // Natural parameters for q
246    mu_p: &Multivector<3, 0, 0>,  // Expectation parameters for p
247) -> f64 {
248    // KL(p||q) = <η_p - η_q, μ_p> - ψ(η_p) + ψ(η_q)
249    // where ψ is the log partition function
250
251    let eta_diff = eta_p - eta_q;
252
253    // For simplicity, assume log partition functions cancel in relative computation
254    eta_diff.scalar_product(mu_p)
255}
256
257/// Compute the Amari-Chentsov tensor at a point
258pub fn amari_chentsov_tensor(
259    x: &Multivector<3, 0, 0>,
260    y: &Multivector<3, 0, 0>,
261    z: &Multivector<3, 0, 0>,
262) -> f64 {
263    // The Amari-Chentsov tensor is the unique (up to scaling) symmetric 3-tensor
264    // that is invariant under sufficient statistics transformations.
265    //
266    // T(X,Y,Z) = ∂³ψ/∂θ^i∂θ^j∂θ^k X^i Y^j Z^k
267    // For a proper implementation, we use the symmetrized trilinear form:
268    // T(X,Y,Z) = (1/6)[X·(Y×Z) + Y·(Z×X) + Z·(X×Y) + cyclic permutations]
269
270    // Extract vector components for the computation
271    let x_vec = [
272        x.vector_component(0),
273        x.vector_component(1),
274        x.vector_component(2),
275    ];
276    let y_vec = [
277        y.vector_component(0),
278        y.vector_component(1),
279        y.vector_component(2),
280    ];
281    let z_vec = [
282        z.vector_component(0),
283        z.vector_component(1),
284        z.vector_component(2),
285    ];
286
287    // Compute the symmetric trilinear form
288    // For 3D Euclidean space, this is related to the scalar triple product
289    x_vec[0] * y_vec[1] * z_vec[2] + x_vec[1] * y_vec[2] * z_vec[0] + x_vec[2] * y_vec[0] * z_vec[1]
290        - x_vec[2] * y_vec[1] * z_vec[0]
291        - x_vec[1] * y_vec[0] * z_vec[2]
292        - x_vec[0] * y_vec[2] * z_vec[1]
293}
294
295/// α-connection factory
296pub struct AlphaConnectionFactory;
297
298impl AlphaConnectionFactory {
299    /// Create an α-connection for the given α value
300    pub fn create<T: Parameter + Clone + 'static>(alpha: f64) -> Box<dyn AlphaConnection<T>> {
301        Box::new(StandardAlphaConnection::new(alpha))
302    }
303}
304
305/// Standard implementation of α-connection
306struct StandardAlphaConnection<T: Parameter> {
307    alpha: f64,
308    _phantom: std::marker::PhantomData<T>,
309}
310
311impl<T: Parameter> StandardAlphaConnection<T> {
312    fn new(alpha: f64) -> Self {
313        Self {
314            alpha,
315            _phantom: std::marker::PhantomData,
316        }
317    }
318}
319
320impl<T: Parameter + Clone> AlphaConnection<T> for StandardAlphaConnection<T> {
321    fn alpha(&self) -> f64 {
322        self.alpha
323    }
324
325    fn christoffel_symbols(&self, _point: &T) -> Result<Vec<Vec<Vec<T::Scalar>>>, InfoGeomError> {
326        // Simplified implementation - would need proper computation based on the metric
327        let dim = _point.dimension();
328        let mut symbols = Vec::new();
329        for _ in 0..dim {
330            let mut dim2 = Vec::new();
331            for _ in 0..dim {
332                let mut dim3 = Vec::new();
333                for _ in 0..dim {
334                    dim3.push(T::Scalar::zero());
335                }
336                dim2.push(dim3);
337            }
338            symbols.push(dim2);
339        }
340
341        // For now, return zero symbols (flat connection)
342        // In practice, this would involve computing derivatives of the metric
343
344        Ok(symbols)
345    }
346
347    fn covariant_derivative(
348        &self,
349        _point: &T,
350        vector: &T,
351        _direction: &T,
352    ) -> Result<T, InfoGeomError> {
353        // Simplified: in flat space, covariant derivative equals ordinary derivative
354        Ok(vector.clone())
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use amari_core::basis::MultivectorBuilder;
362
363    #[test]
364    fn test_bregman_divergence() {
365        let p = MultivectorBuilder::<3, 0, 0>::new()
366            .scalar(1.0)
367            .e(1, 2.0)
368            .build();
369
370        let q = MultivectorBuilder::<3, 0, 0>::new()
371            .scalar(1.5)
372            .e(1, 1.5)
373            .build();
374
375        // Simple quadratic potential
376        let phi = |mv: &Multivector<3, 0, 0>| mv.norm_squared();
377
378        let divergence = bregman_divergence(phi, &p, &q).unwrap();
379        assert!(divergence >= 0.0); // Bregman divergences are non-negative
380    }
381
382    #[test]
383    fn test_kl_divergence() {
384        let eta_p = MultivectorBuilder::<3, 0, 0>::new().scalar(1.0).build();
385
386        let eta_q = MultivectorBuilder::<3, 0, 0>::new().scalar(0.5).build();
387
388        let mu_p = MultivectorBuilder::<3, 0, 0>::new().scalar(2.0).build();
389
390        let kl = kl_divergence(&eta_p, &eta_q, &mu_p);
391        assert_eq!(kl, 1.0); // (1.0 - 0.5) * 2.0 = 1.0
392    }
393
394    #[test]
395    fn test_amari_chentsov_tensor() {
396        // Create three linearly independent vectors to ensure non-zero tensor value
397        // Test with e1, e2, e3 which should give determinant = 1
398        let x = MultivectorBuilder::<3, 0, 0>::new()
399            .e(1, 1.0) // e1
400            .build();
401
402        let y = MultivectorBuilder::<3, 0, 0>::new()
403            .e(2, 1.0) // e2
404            .build();
405
406        let z = MultivectorBuilder::<3, 0, 0>::new()
407            .e(3, 1.0) // e3
408            .build();
409
410        let tensor_value = amari_chentsov_tensor(&x, &y, &z);
411
412        // For x = e1, y = e2, z = e3, the scalar triple product should be 1
413        // T(e1, e2, e3) = det([1,0,0; 0,1,0; 0,0,1]) = 1
414        assert!(
415            (tensor_value - 1.0).abs() < 1e-10,
416            "Expected 1.0, got {}",
417            tensor_value
418        );
419
420        // Test with different ordering to verify anti-symmetry
421        let tensor_value_reversed = amari_chentsov_tensor(&y, &x, &z);
422        assert!(
423            (tensor_value_reversed + 1.0).abs() < 1e-10,
424            "Should be -1.0 due to swap"
425        );
426    }
427}