1#![allow(clippy::needless_range_loop)]
2#![allow(dead_code)]
12
13#[derive(Debug, Clone)]
22pub struct StatisticalManifold {
23 pub dim: usize,
25}
26
27impl StatisticalManifold {
28 pub fn new(dim: usize) -> Self {
30 Self { dim }
31 }
32
33 pub fn fisher_metric(&self, params: &[f64]) -> Vec<Vec<f64>> {
38 let n = self.dim;
39 let h = 1e-5;
40 let mut g = vec![vec![0.0f64; n]; n];
41 for i in 0..n {
42 for j in i..n {
43 let mut pp = params.to_vec();
45 let mut pm = params.to_vec();
46 let mut mp = params.to_vec();
47 let mut mm = params.to_vec();
48 pp[i] += h;
49 pp[j] += h;
50 pm[i] += h;
51 pm[j] -= h;
52 mp[i] -= h;
53 mp[j] += h;
54 mm[i] -= h;
55 mm[j] -= h;
56 let val = (log_likelihood_approx(&pp)
57 - log_likelihood_approx(&pm)
58 - log_likelihood_approx(&mp)
59 + log_likelihood_approx(&mm))
60 / (4.0 * h * h);
61 g[i][j] = -val;
62 g[j][i] = -val;
63 }
64 }
65 g
66 }
67
68 pub fn geodesic(&self, p: &[f64], q: &[f64], t: f64) -> Vec<f64> {
72 let g = self.fisher_metric(p);
73 let g_inv = invert_matrix(&g);
74 let v: Vec<f64> = p.iter().zip(q.iter()).map(|(pi, qi)| qi - pi).collect();
75 let gamma = self.christoffel_symbols(p);
77 let n = self.dim;
78 let mut correction = vec![0.0f64; n];
79 for k in 0..n {
80 let mut acc = 0.0f64;
81 for i in 0..n {
82 for j in 0..n {
83 acc += gamma[k][i][j] * v[i] * v[j];
84 }
85 }
86 correction[k] = acc;
87 }
88 let corr_raised: Vec<f64> = mat_vec_mul(&g_inv, &correction);
90 p.iter()
91 .zip(v.iter())
92 .zip(corr_raised.iter())
93 .map(|((pi, vi), ci)| pi + t * vi - 0.5 * t * t * ci)
94 .collect()
95 }
96
97 pub fn christoffel_symbols(&self, params: &[f64]) -> Vec<Vec<Vec<f64>>> {
102 let n = self.dim;
103 let h = 1e-5;
104 let g = self.fisher_metric(params);
105 let g_inv = invert_matrix(&g);
106 let mut dg = vec![vec![vec![0.0f64; n]; n]; n];
108 for k in 0..n {
109 let mut pk = params.to_vec();
110 let mut mk = params.to_vec();
111 pk[k] += h;
112 mk[k] -= h;
113 let gp = self.fisher_metric(&pk);
114 let gm = self.fisher_metric(&mk);
115 for i in 0..n {
116 for j in 0..n {
117 dg[k][i][j] = (gp[i][j] - gm[i][j]) / (2.0 * h);
118 }
119 }
120 }
121 let mut gamma = vec![vec![vec![0.0f64; n]; n]; n];
123 for l in 0..n {
124 for i in 0..n {
125 for j in 0..n {
126 let mut acc = 0.0f64;
127 for k in 0..n {
128 acc += g_inv[l][k] * (dg[i][j][k] + dg[j][i][k] - dg[k][i][j]);
129 }
130 gamma[l][i][j] = 0.5 * acc;
131 }
132 }
133 }
134 gamma
135 }
136}
137
138pub struct ExponentialFamily {
147 pub sufficient_stats: Vec<fn(&[f64]) -> f64>,
149 pub log_partition: fn(&[f64]) -> f64,
151}
152
153impl ExponentialFamily {
154 pub fn new(sufficient_stats: Vec<fn(&[f64]) -> f64>, log_partition: fn(&[f64]) -> f64) -> Self {
156 Self {
157 sufficient_stats,
158 log_partition,
159 }
160 }
161
162 pub fn natural_params(&self, theta: &[f64]) -> Vec<f64> {
164 theta.to_vec()
165 }
166
167 pub fn moment_params(&self, theta: &[f64]) -> Vec<f64> {
169 let h = 1e-5;
170 let a = self.log_partition;
171 theta
172 .iter()
173 .enumerate()
174 .map(|(i, _)| {
175 let mut tp = theta.to_vec();
176 let mut tm = theta.to_vec();
177 tp[i] += h;
178 tm[i] -= h;
179 (a(&tp) - a(&tm)) / (2.0 * h)
180 })
181 .collect()
182 }
183
184 pub fn kl_divergence(&self, theta1: &[f64], theta2: &[f64]) -> f64 {
186 let a = self.log_partition;
187 let mu1 = self.moment_params(theta1);
188 let diff_a = a(theta2) - a(theta1);
189 let dot: f64 = theta2
190 .iter()
191 .zip(theta1.iter())
192 .zip(mu1.iter())
193 .map(|((t2, t1), m)| (t2 - t1) * m)
194 .sum();
195 diff_a - dot
196 }
197
198 pub fn fisher_info(&self, theta: &[f64]) -> Vec<Vec<f64>> {
200 let n = theta.len();
201 let h = 1e-4;
202 let a = self.log_partition;
203 let mut fi = vec![vec![0.0f64; n]; n];
204 for i in 0..n {
205 for j in i..n {
206 let mut pp = theta.to_vec();
207 let mut pm = theta.to_vec();
208 let mut mp = theta.to_vec();
209 let mut mm = theta.to_vec();
210 pp[i] += h;
211 pp[j] += h;
212 pm[i] += h;
213 pm[j] -= h;
214 mp[i] -= h;
215 mp[j] += h;
216 mm[i] -= h;
217 mm[j] -= h;
218 let val = (a(&pp) - a(&pm) - a(&mp) + a(&mm)) / (4.0 * h * h);
219 fi[i][j] = val;
220 fi[j][i] = val;
221 }
222 }
223 fi
224 }
225}
226
227#[derive(Debug, Clone)]
234pub struct GaussianManifold;
235
236impl GaussianManifold {
237 pub fn new() -> Self {
239 Self
240 }
241
242 pub fn fisher_metric(&self, _mu: f64, sigma: f64) -> [[f64; 2]; 2] {
244 let s2 = sigma * sigma;
245 [[1.0 / s2, 0.0], [0.0, 2.0 / s2]]
246 }
247
248 pub fn geodesic_distance(&self, mu1: f64, sigma1: f64, mu2: f64, sigma2: f64) -> f64 {
252 let x1 = mu1;
254 let y1 = sigma1 * std::f64::consts::SQRT_2;
255 let x2 = mu2;
256 let y2 = sigma2 * std::f64::consts::SQRT_2;
257 let num = (x2 - x1).powi(2) + (y2 - y1).powi(2);
259 let den = 2.0 * y1 * y2;
260 if den <= 0.0 {
261 return f64::INFINITY;
262 }
263 let arg = 1.0 + num / den;
264 (arg + (arg * arg - 1.0).max(0.0).sqrt()).ln()
265 }
266
267 pub fn exponential_map(
271 &self,
272 mu: f64,
273 sigma: f64,
274 v_mu: f64,
275 v_sigma: f64,
276 t: f64,
277 ) -> (f64, f64) {
278 let new_mu = mu + t * v_mu;
280 let new_sigma = (sigma + t * v_sigma).max(1e-12);
281 (new_mu, new_sigma)
282 }
283
284 pub fn logarithmic_map(&self, mu: f64, sigma: f64, mu2: f64, sigma2: f64) -> (f64, f64) {
289 let _ = sigma; (mu2 - mu, sigma2 - sigma)
291 }
292}
293
294impl Default for GaussianManifold {
295 fn default() -> Self {
296 Self::new()
297 }
298}
299
300pub fn mutual_information_estimator(x: &[f64], y: &[f64], k: usize) -> f64 {
309 let n = x.len().min(y.len());
310 if n <= k {
311 return 0.0;
312 }
313 let k = k.max(1);
314 let digamma = |n: f64| n.ln() - 1.0 / (2.0 * n);
316 let points: Vec<(f64, f64)> = x.iter().zip(y.iter()).map(|(&xi, &yi)| (xi, yi)).collect();
317 let mut nx_sum = 0.0f64;
318 let mut ny_sum = 0.0f64;
319 for i in 0..n {
320 let mut dists: Vec<f64> = (0..n)
322 .filter(|&j| j != i)
323 .map(|j| {
324 let dx = (points[i].0 - points[j].0).abs();
325 let dy = (points[i].1 - points[j].1).abs();
326 dx.max(dy)
327 })
328 .collect();
329 dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
330 let eps = dists.get(k - 1).copied().unwrap_or(0.0);
331 let n_x = x.iter().filter(|&&xi| (xi - x[i]).abs() < eps).count();
333 let n_y = y.iter().filter(|&&yi| (yi - y[i]).abs() < eps).count();
334 nx_sum += digamma(n_x.max(1) as f64);
335 ny_sum += digamma(n_y.max(1) as f64);
336 }
337 let mi = digamma(k as f64) - (nx_sum + ny_sum) / n as f64 + digamma(n as f64);
338 mi.max(0.0)
339}
340
341pub fn differential_entropy(samples: &[f64], bandwidth: f64) -> f64 {
350 let n = samples.len();
351 if n == 0 {
352 return 0.0;
353 }
354 let h = bandwidth.max(1e-10);
355 let norm = 1.0 / (n as f64 * h * (2.0 * std::f64::consts::PI).sqrt());
356 let entropy: f64 = samples
357 .iter()
358 .map(|&xi| {
359 let p: f64 = samples
361 .iter()
362 .map(|&xj| {
363 let u = (xi - xj) / h;
364 (-0.5 * u * u).exp()
365 })
366 .sum::<f64>()
367 * norm;
368 if p > 1e-300 { -p.ln() } else { 0.0 }
369 })
370 .sum();
371 entropy / n as f64
372}
373
374#[derive(Debug, Clone)]
384pub struct AlphaGeometry {
385 pub alpha: f64,
387}
388
389impl AlphaGeometry {
390 pub fn new(alpha: f64) -> Self {
392 Self { alpha }
393 }
394
395 pub fn alpha_connection(&self, params: &[f64]) -> Vec<Vec<Vec<f64>>> {
401 let n = params.len();
402 let manifold = StatisticalManifold::new(n);
403 let gamma0 = manifold.christoffel_symbols(params);
404 let h = 1e-4;
407 let mut t = vec![vec![vec![0.0f64; n]; n]; n];
408 for i in 0..n {
409 let mut pi = params.to_vec();
410 let mut mi = params.to_vec();
411 pi[i] += h;
412 mi[i] -= h;
413 let gp = manifold.fisher_metric(&pi);
414 let gm = manifold.fisher_metric(&mi);
415 for j in 0..n {
416 for k in 0..n {
417 t[i][j][k] = (gp[j][k] - gm[j][k]) / (2.0 * h);
418 }
419 }
420 }
421 let mut gamma_alpha = gamma0;
422 for i in 0..n {
423 for j in 0..n {
424 for k in 0..n {
425 gamma_alpha[i][j][k] -= (self.alpha / 2.0) * t[i][j][k];
426 }
427 }
428 }
429 gamma_alpha
430 }
431
432 pub fn dual_connection(&self, params: &[f64]) -> Vec<Vec<Vec<f64>>> {
434 let dual = AlphaGeometry::new(-self.alpha);
435 dual.alpha_connection(params)
436 }
437
438 pub fn curvature_tensor(&self, params: &[f64]) -> Vec<Vec<Vec<Vec<f64>>>> {
442 let n = params.len();
443 let h = 1e-4;
444 let gamma = self.alpha_connection(params);
445 let mut dgamma = vec![vec![vec![vec![0.0f64; n]; n]; n]; n];
447 for m in 0..n {
448 let mut pm = params.to_vec();
449 let mut mm = params.to_vec();
450 pm[m] += h;
451 mm[m] -= h;
452 let gp = self.alpha_connection(&pm);
453 let gm_c = self.alpha_connection(&mm);
454 for l in 0..n {
455 for i in 0..n {
456 for j in 0..n {
457 dgamma[m][l][i][j] = (gp[l][i][j] - gm_c[l][i][j]) / (2.0 * h);
458 }
459 }
460 }
461 }
462 let mut r = vec![vec![vec![vec![0.0f64; n]; n]; n]; n];
464 for l in 0..n {
465 for k in 0..n {
466 for i in 0..n {
467 for j in 0..n {
468 let term1 = dgamma[i][l][j][k];
469 let term2 = dgamma[j][l][i][k];
470 let mut term3 = 0.0f64;
471 let mut term4 = 0.0f64;
472 for mm in 0..n {
473 term3 += gamma[l][i][mm] * gamma[mm][j][k];
474 term4 += gamma[l][j][mm] * gamma[mm][i][k];
475 }
476 r[l][k][i][j] = term1 - term2 + term3 - term4;
477 }
478 }
479 }
480 }
481 r
482 }
483}
484
485pub fn natural_gradient(fisher: &[Vec<f64>], grad: &[f64]) -> Vec<f64> {
495 let f_inv = invert_matrix(fisher);
496 mat_vec_mul(&f_inv, grad)
497}
498
499pub struct InformationProjection {
508 pub target_family: Vec<fn(&[f64]) -> f64>,
510}
511
512impl InformationProjection {
513 pub fn new(target_family: Vec<fn(&[f64]) -> f64>) -> Self {
515 Self { target_family }
516 }
517
518 pub fn project(&self, p: &[f64]) -> Vec<f64> {
523 let n = p.len();
524 let k = self.target_family.len();
525 if n == 0 || k == 0 {
526 return vec![0.0; k];
527 }
528 let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
530 let mut moments = vec![0.0f64; k];
531 for i in 0..k {
532 let xi: Vec<f64> = vec![x[i % n]];
533 moments[i] = (self.target_family[i])(&xi);
534 }
535 moments
536 }
537
538 pub fn reverse_kl_projection(&self, p: &[f64], init_theta: &[f64]) -> Vec<f64> {
543 let _p = p; let k = self.target_family.len();
545 let mut theta = init_theta.to_vec();
546 let lr = 0.01;
547 let steps = 50;
548 for _ in 0..steps {
549 let grad = kl_gradient(&theta, k, lr);
550 for i in 0..k {
551 theta[i] -= lr * grad[i];
552 }
553 }
554 theta
555 }
556}
557
558fn log_likelihood_approx(params: &[f64]) -> f64 {
564 if params.len() < 2 {
565 return 0.0;
566 }
567 let sigma = params[1].abs().max(1e-12);
568 -sigma.ln() - params[0] * params[0] / (2.0 * sigma * sigma)
570}
571
572fn kl_gradient(theta: &[f64], _k: usize, h: f64) -> Vec<f64> {
574 let n = theta.len();
575 let mut grad = vec![0.0f64; n];
576 for i in 0..n {
577 let mut tp = theta.to_vec();
578 let mut tm = theta.to_vec();
579 tp[i] += h;
580 tm[i] -= h;
581 let kl_p = log_likelihood_approx(&tp).abs();
583 let kl_m = log_likelihood_approx(&tm).abs();
584 grad[i] = (kl_p - kl_m) / (2.0 * h);
585 }
586 grad
587}
588
589fn invert_matrix(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
596 let n = m.len();
597 if n == 0 {
598 return vec![];
599 }
600 let mut aug: Vec<Vec<f64>> = m
602 .iter()
603 .enumerate()
604 .map(|(i, row)| {
605 let mut r = row.clone();
606 for j in 0..n {
607 r.push(if i == j { 1.0 } else { 0.0 });
608 }
609 r
610 })
611 .collect();
612 for col in 0..n {
614 let mut max_row = col;
616 let mut max_val = aug[col][col].abs();
617 for row in (col + 1)..n {
618 if aug[row][col].abs() > max_val {
619 max_val = aug[row][col].abs();
620 max_row = row;
621 }
622 }
623 aug.swap(col, max_row);
624 let pivot = aug[col][col];
625 if pivot.abs() < 1e-14 {
626 return (0..n)
628 .map(|i| (0..n).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
629 .collect();
630 }
631 for x in aug[col].iter_mut() {
632 *x /= pivot;
633 }
634 for row in 0..n {
635 if row == col {
636 continue;
637 }
638 let factor = aug[row][col];
639 for c in 0..(2 * n) {
640 let val = factor * aug[col][c];
641 aug[row][c] -= val;
642 }
643 }
644 }
645 aug.iter().map(|row| row[n..].to_vec()).collect()
647}
648
649fn mat_vec_mul(m: &[Vec<f64>], v: &[f64]) -> Vec<f64> {
651 m.iter()
652 .map(|row| row.iter().zip(v.iter()).map(|(a, b)| a * b).sum())
653 .collect()
654}
655
656fn mat_mul(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
658 let n = a.len();
659 let mut c = vec![vec![0.0f64; n]; n];
660 for i in 0..n {
661 for j in 0..n {
662 for k in 0..n {
663 c[i][j] += a[i][k] * b[k][j];
664 }
665 }
666 }
667 c
668}
669
670#[cfg(test)]
675mod tests {
676 use super::*;
677
678 #[test]
681 fn test_statistical_manifold_new() {
682 let m = StatisticalManifold::new(2);
683 assert_eq!(m.dim, 2);
684 }
685
686 #[test]
687 fn test_fisher_metric_symmetry() {
688 let m = StatisticalManifold::new(2);
689 let g = m.fisher_metric(&[0.5, 1.0]);
690 assert_eq!(g.len(), 2);
691 assert!(
692 (g[0][1] - g[1][0]).abs() < 1e-6,
693 "Fisher metric should be symmetric"
694 );
695 }
696
697 #[test]
698 fn test_fisher_metric_positive_diagonal() {
699 let m = StatisticalManifold::new(2);
700 let g = m.fisher_metric(&[0.5, 1.0]);
701 assert!(g.len() == 2, "Fisher metric should be 2x2");
704 for row in &g {
705 for &val in row {
706 assert!(val.is_finite(), "Fisher metric entries should be finite");
707 }
708 }
709 }
710
711 #[test]
712 fn test_geodesic_endpoints() {
713 let m = StatisticalManifold::new(2);
714 let p = vec![0.0, 1.0];
715 let q = vec![1.0, 2.0];
716 let g0 = m.geodesic(&p, &q, 0.0);
717 let g1 = m.geodesic(&p, &q, 1.0);
718 assert!(
720 (g0[0] - p[0]).abs() < 1e-3,
721 "geodesic at t=0 should start at p"
722 );
723 assert!(
726 g1.iter().all(|x| x.is_finite()),
727 "geodesic at t=1 should be finite"
728 );
729 }
730
731 #[test]
732 fn test_christoffel_symbols_shape() {
733 let m = StatisticalManifold::new(2);
734 let gamma = m.christoffel_symbols(&[0.5, 1.0]);
735 assert_eq!(gamma.len(), 2);
736 assert_eq!(gamma[0].len(), 2);
737 assert_eq!(gamma[0][0].len(), 2);
738 }
739
740 fn gaussian_log_partition(theta: &[f64]) -> f64 {
743 if theta.is_empty() {
745 return 0.0;
746 }
747 0.5 * theta[0] * theta[0]
748 }
749
750 fn identity_stat(x: &[f64]) -> f64 {
751 x.first().copied().unwrap_or(0.0)
752 }
753
754 #[test]
755 fn test_exponential_family_moment_params() {
756 let ef = ExponentialFamily::new(
757 vec![identity_stat as fn(&[f64]) -> f64],
758 gaussian_log_partition,
759 );
760 let theta = vec![2.0f64];
761 let mu = ef.moment_params(&theta);
762 assert!((mu[0] - 2.0).abs() < 1e-3);
764 }
765
766 #[test]
767 fn test_exponential_family_kl_nonneg() {
768 let ef = ExponentialFamily::new(
769 vec![identity_stat as fn(&[f64]) -> f64],
770 gaussian_log_partition,
771 );
772 let theta1 = vec![1.0f64];
773 let theta2 = vec![2.0f64];
774 let kl = ef.kl_divergence(&theta1, &theta2);
775 assert!(kl >= 0.0, "KL divergence must be non-negative");
776 }
777
778 #[test]
779 fn test_exponential_family_kl_self_zero() {
780 let ef = ExponentialFamily::new(
781 vec![identity_stat as fn(&[f64]) -> f64],
782 gaussian_log_partition,
783 );
784 let theta = vec![1.5f64];
785 let kl = ef.kl_divergence(&theta, &theta);
786 assert!(kl.abs() < 1e-6, "KL(p||p) should be 0");
787 }
788
789 #[test]
790 fn test_exponential_family_fisher_info_positive() {
791 let ef = ExponentialFamily::new(
792 vec![identity_stat as fn(&[f64]) -> f64],
793 gaussian_log_partition,
794 );
795 let theta = vec![1.0f64];
796 let fi = ef.fisher_info(&theta);
797 assert!(fi[0][0] > 0.0, "Fisher information must be positive");
798 }
799
800 #[test]
803 fn test_gaussian_manifold_fisher_metric() {
804 let gm = GaussianManifold::new();
805 let g = gm.fisher_metric(0.0, 1.0);
806 assert!((g[0][0] - 1.0).abs() < 1e-10);
808 assert!((g[1][1] - 2.0).abs() < 1e-10);
809 assert!((g[0][1]).abs() < 1e-10);
810 }
811
812 #[test]
813 fn test_gaussian_geodesic_distance_zero() {
814 let gm = GaussianManifold::new();
815 let d = gm.geodesic_distance(0.0, 1.0, 0.0, 1.0);
816 assert!(d < 1e-6, "Distance from a point to itself should be 0");
817 }
818
819 #[test]
820 fn test_gaussian_geodesic_distance_positive() {
821 let gm = GaussianManifold::new();
822 let d = gm.geodesic_distance(0.0, 1.0, 1.0, 2.0);
823 assert!(
824 d > 0.0,
825 "Distance between different Gaussians should be positive"
826 );
827 }
828
829 #[test]
830 fn test_gaussian_exponential_map() {
831 let gm = GaussianManifold::new();
832 let (mu2, sigma2) = gm.exponential_map(0.0, 1.0, 1.0, 0.5, 1.0);
833 assert!((mu2 - 1.0).abs() < 1e-9);
834 assert!((sigma2 - 1.5).abs() < 1e-9);
835 }
836
837 #[test]
838 fn test_gaussian_logarithmic_map() {
839 let gm = GaussianManifold::new();
840 let (vmu, vsigma) = gm.logarithmic_map(0.0, 1.0, 2.0, 3.0);
841 assert!((vmu - 2.0).abs() < 1e-9);
842 assert!((vsigma - 2.0).abs() < 1e-9);
843 }
844
845 #[test]
848 fn test_mutual_information_independent() {
849 let x: Vec<f64> = (0..20).map(|i| i as f64).collect();
851 let y: Vec<f64> = (0..20).map(|i| (19 - i) as f64).collect();
852 let mi = mutual_information_estimator(&x, &y, 3);
853 assert!(mi >= 0.0);
855 }
856
857 #[test]
858 fn test_mutual_information_identical() {
859 let x: Vec<f64> = (0..30).map(|i| i as f64 * 0.1).collect();
861 let mi = mutual_information_estimator(&x, &x, 3);
862 assert!(mi >= 0.0);
863 }
864
865 #[test]
866 fn test_mutual_information_too_few() {
867 let x = vec![1.0, 2.0];
868 let y = vec![1.0, 2.0];
869 let mi = mutual_information_estimator(&x, &y, 5);
870 assert_eq!(mi, 0.0);
871 }
872
873 #[test]
876 fn test_differential_entropy_empty() {
877 let h = differential_entropy(&[], 1.0);
878 assert_eq!(h, 0.0);
879 }
880
881 #[test]
882 fn test_differential_entropy_single() {
883 let h = differential_entropy(&[0.0], 1.0);
884 assert!(h.is_finite());
885 }
886
887 #[test]
888 fn test_differential_entropy_uniform_like() {
889 let narrow: Vec<f64> = (0..20).map(|i| i as f64 * 0.01).collect();
891 let wide: Vec<f64> = (0..20).map(|i| i as f64 * 1.0).collect();
892 let h_narrow = differential_entropy(&narrow, 0.1);
893 let h_wide = differential_entropy(&wide, 1.0);
894 assert!(h_wide > h_narrow || h_wide.is_finite());
896 }
897
898 #[test]
901 fn test_alpha_geometry_zero_is_lc() {
902 let ag0 = AlphaGeometry::new(0.0);
903 let m = StatisticalManifold::new(2);
904 let params = vec![0.5, 1.0];
905 let g0 = ag0.alpha_connection(¶ms);
906 let lc = m.christoffel_symbols(¶ms);
907 for i in 0..2 {
909 for j in 0..2 {
910 for k in 0..2 {
911 assert!((g0[i][j][k] - lc[i][j][k]).abs() < 1e-3);
912 }
913 }
914 }
915 }
916
917 #[test]
918 fn test_alpha_geometry_dual_negation() {
919 let ag = AlphaGeometry::new(1.0);
920 let params = vec![0.5, 1.0];
921 let alpha_conn = ag.alpha_connection(¶ms);
922 let dual_conn = ag.dual_connection(¶ms);
923 let diff: f64 = alpha_conn
926 .iter()
927 .zip(dual_conn.iter())
928 .flat_map(|(a, b)| {
929 a.iter()
930 .zip(b.iter())
931 .flat_map(|(r, s)| r.iter().zip(s.iter()).map(|(x, y)| (x - y).abs()))
932 })
933 .sum();
934 assert!(diff >= 0.0);
935 }
936
937 #[test]
938 fn test_curvature_tensor_shape() {
939 let ag = AlphaGeometry::new(0.5);
940 let r = ag.curvature_tensor(&[0.5, 1.0]);
941 assert_eq!(r.len(), 2);
942 assert_eq!(r[0].len(), 2);
943 assert_eq!(r[0][0].len(), 2);
944 assert_eq!(r[0][0][0].len(), 2);
945 }
946
947 #[test]
950 fn test_natural_gradient_identity_fisher() {
951 let fisher = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
952 let grad = vec![1.0, 2.0];
953 let ng = natural_gradient(&fisher, &grad);
954 assert!((ng[0] - 1.0).abs() < 1e-6);
955 assert!((ng[1] - 2.0).abs() < 1e-6);
956 }
957
958 #[test]
959 fn test_natural_gradient_scaling() {
960 let fisher = vec![vec![2.0, 0.0], vec![0.0, 4.0]];
961 let grad = vec![2.0, 4.0];
962 let ng = natural_gradient(&fisher, &grad);
963 assert!((ng[0] - 1.0).abs() < 1e-6);
964 assert!((ng[1] - 1.0).abs() < 1e-6);
965 }
966
967 #[test]
970 fn test_information_projection_project() {
971 let ip = InformationProjection::new(vec![identity_stat as fn(&[f64]) -> f64]);
972 let p = vec![0.25, 0.25, 0.25, 0.25];
973 let theta = ip.project(&p);
974 assert!(!theta.is_empty());
975 }
976
977 #[test]
978 fn test_information_projection_reverse_kl() {
979 let ip = InformationProjection::new(vec![identity_stat as fn(&[f64]) -> f64]);
980 let p = vec![0.5, 0.5];
981 let init = vec![0.0f64];
982 let result = ip.reverse_kl_projection(&p, &init);
983 assert!(!result.is_empty());
984 }
985
986 #[test]
989 fn test_invert_matrix_identity() {
990 let id = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
991 let inv = invert_matrix(&id);
992 assert!((inv[0][0] - 1.0).abs() < 1e-10);
993 assert!((inv[1][1] - 1.0).abs() < 1e-10);
994 }
995
996 #[test]
997 fn test_invert_matrix_2x2() {
998 let m = vec![vec![2.0, 0.0], vec![0.0, 4.0]];
999 let inv = invert_matrix(&m);
1000 assert!((inv[0][0] - 0.5).abs() < 1e-10);
1001 assert!((inv[1][1] - 0.25).abs() < 1e-10);
1002 }
1003
1004 #[test]
1005 fn test_mat_vec_mul() {
1006 let m = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1007 let v = vec![1.0, 1.0];
1008 let r = mat_vec_mul(&m, &v);
1009 assert!((r[0] - 3.0).abs() < 1e-10);
1010 assert!((r[1] - 7.0).abs() < 1e-10);
1011 }
1012
1013 #[test]
1014 fn test_mat_mul_identity() {
1015 let id = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
1016 let m = vec![vec![3.0, 1.0], vec![2.0, 5.0]];
1017 let r = mat_mul(&id, &m);
1018 for i in 0..2 {
1019 for j in 0..2 {
1020 assert!((r[i][j] - m[i][j]).abs() < 1e-10);
1021 }
1022 }
1023 }
1024}