1use amari_core::Multivector;
8use num_traits::{Float, Zero};
9use thiserror::Error;
10
11#[cfg(feature = "gpu")]
13pub use gpu::{
14 GpuBregmanData, GpuFisherData, GpuStatisticalManifold, InfoGeomGpuConfig, InfoGeomGpuError,
15 InfoGeomGpuOps, InfoGeomGpuResult,
16};
17
18#[cfg(test)]
19pub mod comprehensive_tests;
20#[cfg(feature = "gpu")]
21pub mod gpu;
22pub mod verified_contracts;
23
24#[derive(Error, Debug)]
30pub enum InfoGeomError {
31 #[error("Numerical instability in computation")]
32 NumericalInstability,
33
34 #[error("Invalid parameter dimension: expected {expected}, got {actual}")]
35 InvalidDimension { expected: usize, actual: usize },
36
37 #[error("Parameter out of valid range")]
38 ParameterOutOfRange,
39}
40
41pub trait Parameter {
43 type Scalar: Float;
44
45 fn dimension(&self) -> usize;
47
48 fn get_component(&self, index: usize) -> Self::Scalar;
50
51 fn set_component(&mut self, index: usize, value: Self::Scalar);
53}
54
55impl<const P: usize, const Q: usize, const R: usize> Parameter for Multivector<P, Q, R> {
56 type Scalar = f64;
57
58 fn dimension(&self) -> usize {
59 Self::BASIS_COUNT
60 }
61
62 fn get_component(&self, index: usize) -> f64 {
63 self.get(index)
64 }
65
66 fn set_component(&mut self, index: usize, value: f64) {
67 self.set(index, value);
68 }
69}
70
71pub trait FisherMetric<T: Parameter> {
73 fn fisher_matrix(&self, point: &T) -> Result<Vec<Vec<T::Scalar>>, InfoGeomError>;
75
76 fn fisher_inner_product(&self, point: &T, v1: &T, v2: &T) -> Result<T::Scalar, InfoGeomError> {
78 let g = self.fisher_matrix(point)?;
79 let mut result = T::Scalar::zero();
80
81 for i in 0..v1.dimension() {
82 for j in 0..v2.dimension() {
83 if i < g.len() && j < g[i].len() {
84 result = result + g[i][j] * v1.get_component(i) * v2.get_component(j);
85 }
86 }
87 }
88
89 Ok(result)
90 }
91}
92
93pub trait AlphaConnection<T: Parameter> {
95 fn alpha(&self) -> f64;
97
98 fn christoffel_symbols(&self, point: &T) -> Result<Vec<Vec<Vec<T::Scalar>>>, InfoGeomError>;
100
101 fn covariant_derivative(
103 &self,
104 point: &T,
105 vector: &T,
106 direction: &T,
107 ) -> Result<T, InfoGeomError>;
108}
109
110#[derive(Clone, Debug)]
112pub struct DuallyFlatManifold {
113 dimension: usize,
114 #[allow(dead_code)]
115 alpha: f64,
116}
117
118impl DuallyFlatManifold {
119 pub fn new(dimension: usize, alpha: f64) -> Self {
121 Self { dimension, alpha }
122 }
123
124 pub fn fisher_metric_at(&self, point: &[f64]) -> FisherInformationMatrix {
126 let mut matrix = vec![vec![0.0; self.dimension]; self.dimension];
128
129 #[allow(clippy::needless_range_loop)]
131 for i in 0..self.dimension {
132 for j in 0..self.dimension {
133 if i == j && i < point.len() {
134 matrix[i][j] = if point[i] > 1e-12 {
136 1.0 / point[i]
137 } else {
138 1e12
139 };
140 } else {
141 matrix[i][j] = 0.0;
143 }
144 }
145 }
146
147 FisherInformationMatrix { matrix }
148 }
149
150 pub fn bregman_divergence(&self, p: &[f64], q: &[f64]) -> f64 {
152 let mut divergence = 0.0;
154
155 for i in 0..p.len().min(q.len()) {
156 if p[i] > 1e-12 && q[i] > 1e-12 {
157 divergence += p[i] * (p[i] / q[i]).ln();
158 }
159 }
160
161 divergence
162 }
163}
164
165#[derive(Clone, Debug)]
167pub struct FisherInformationMatrix {
168 matrix: Vec<Vec<f64>>,
169}
170
171impl FisherInformationMatrix {
172 pub fn eigenvalues(&self) -> Vec<f64> {
174 let mut eigenvals = Vec::new();
177
178 for i in 0..self.matrix.len() {
180 if i < self.matrix[i].len() {
181 eigenvals.push(self.matrix[i][i]);
182 }
183 }
184
185 eigenvals
186 }
187}
188
189#[derive(Clone, Debug)]
191pub struct SimpleAlphaConnection {
192 alpha: f64,
193}
194
195impl SimpleAlphaConnection {
196 pub fn new(alpha: f64) -> Self {
197 Self { alpha }
198 }
199
200 pub fn alpha(&self) -> f64 {
201 self.alpha
202 }
203}
204
205pub fn bregman_divergence<F>(
209 phi: F,
210 p: &Multivector<3, 0, 0>,
211 q: &Multivector<3, 0, 0>,
212) -> Result<f64, InfoGeomError>
213where
214 F: Fn(&Multivector<3, 0, 0>) -> f64,
215{
216 let phi_p = phi(p);
217 let phi_q = phi(q);
218
219 let eps = 1e-8;
221 let mut grad_phi_q = Multivector::zero();
222
223 for i in 0..8 {
224 let mut q_plus = q.clone();
225 q_plus.set(i, q.get(i) + eps);
226 let phi_plus = phi(&q_plus);
227
228 let mut q_minus = q.clone();
229 q_minus.set(i, q.get(i) - eps);
230 let phi_minus = phi(&q_minus);
231
232 let derivative = (phi_plus - phi_minus) / (2.0 * eps);
233 grad_phi_q.set(i, derivative);
234 }
235
236 let diff = p - q;
237 let inner_product = diff.scalar_product(&grad_phi_q);
238
239 Ok(phi_p - phi_q - inner_product)
240}
241
242pub fn kl_divergence(
244 eta_p: &Multivector<3, 0, 0>, eta_q: &Multivector<3, 0, 0>, mu_p: &Multivector<3, 0, 0>, ) -> f64 {
248 let eta_diff = eta_p - eta_q;
252
253 eta_diff.scalar_product(mu_p)
255}
256
257pub fn amari_chentsov_tensor(
259 x: &Multivector<3, 0, 0>,
260 y: &Multivector<3, 0, 0>,
261 z: &Multivector<3, 0, 0>,
262) -> f64 {
263 let x_vec = [
272 x.vector_component(0),
273 x.vector_component(1),
274 x.vector_component(2),
275 ];
276 let y_vec = [
277 y.vector_component(0),
278 y.vector_component(1),
279 y.vector_component(2),
280 ];
281 let z_vec = [
282 z.vector_component(0),
283 z.vector_component(1),
284 z.vector_component(2),
285 ];
286
287 x_vec[0] * y_vec[1] * z_vec[2] + x_vec[1] * y_vec[2] * z_vec[0] + x_vec[2] * y_vec[0] * z_vec[1]
290 - x_vec[2] * y_vec[1] * z_vec[0]
291 - x_vec[1] * y_vec[0] * z_vec[2]
292 - x_vec[0] * y_vec[2] * z_vec[1]
293}
294
295pub struct AlphaConnectionFactory;
297
298impl AlphaConnectionFactory {
299 pub fn create<T: Parameter + Clone + 'static>(alpha: f64) -> Box<dyn AlphaConnection<T>> {
301 Box::new(StandardAlphaConnection::new(alpha))
302 }
303}
304
305struct StandardAlphaConnection<T: Parameter> {
307 alpha: f64,
308 _phantom: std::marker::PhantomData<T>,
309}
310
311impl<T: Parameter> StandardAlphaConnection<T> {
312 fn new(alpha: f64) -> Self {
313 Self {
314 alpha,
315 _phantom: std::marker::PhantomData,
316 }
317 }
318}
319
320impl<T: Parameter + Clone> AlphaConnection<T> for StandardAlphaConnection<T> {
321 fn alpha(&self) -> f64 {
322 self.alpha
323 }
324
325 fn christoffel_symbols(&self, _point: &T) -> Result<Vec<Vec<Vec<T::Scalar>>>, InfoGeomError> {
326 let dim = _point.dimension();
328 let mut symbols = Vec::new();
329 for _ in 0..dim {
330 let mut dim2 = Vec::new();
331 for _ in 0..dim {
332 let mut dim3 = Vec::new();
333 for _ in 0..dim {
334 dim3.push(T::Scalar::zero());
335 }
336 dim2.push(dim3);
337 }
338 symbols.push(dim2);
339 }
340
341 Ok(symbols)
345 }
346
347 fn covariant_derivative(
348 &self,
349 _point: &T,
350 vector: &T,
351 _direction: &T,
352 ) -> Result<T, InfoGeomError> {
353 Ok(vector.clone())
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use amari_core::basis::MultivectorBuilder;
362
363 #[test]
364 fn test_bregman_divergence() {
365 let p = MultivectorBuilder::<3, 0, 0>::new()
366 .scalar(1.0)
367 .e(1, 2.0)
368 .build();
369
370 let q = MultivectorBuilder::<3, 0, 0>::new()
371 .scalar(1.5)
372 .e(1, 1.5)
373 .build();
374
375 let phi = |mv: &Multivector<3, 0, 0>| mv.norm_squared();
377
378 let divergence = bregman_divergence(phi, &p, &q).unwrap();
379 assert!(divergence >= 0.0); }
381
382 #[test]
383 fn test_kl_divergence() {
384 let eta_p = MultivectorBuilder::<3, 0, 0>::new().scalar(1.0).build();
385
386 let eta_q = MultivectorBuilder::<3, 0, 0>::new().scalar(0.5).build();
387
388 let mu_p = MultivectorBuilder::<3, 0, 0>::new().scalar(2.0).build();
389
390 let kl = kl_divergence(&eta_p, &eta_q, &mu_p);
391 assert_eq!(kl, 1.0); }
393
394 #[test]
395 fn test_amari_chentsov_tensor() {
396 let x = MultivectorBuilder::<3, 0, 0>::new()
399 .e(1, 1.0) .build();
401
402 let y = MultivectorBuilder::<3, 0, 0>::new()
403 .e(2, 1.0) .build();
405
406 let z = MultivectorBuilder::<3, 0, 0>::new()
407 .e(3, 1.0) .build();
409
410 let tensor_value = amari_chentsov_tensor(&x, &y, &z);
411
412 assert!(
415 (tensor_value - 1.0).abs() < 1e-10,
416 "Expected 1.0, got {}",
417 tensor_value
418 );
419
420 let tensor_value_reversed = amari_chentsov_tensor(&y, &x, &z);
422 assert!(
423 (tensor_value_reversed + 1.0).abs() < 1e-10,
424 "Should be -1.0 due to swap"
425 );
426 }
427}