1use faer::Side;
19use gam_linalg::faer_ndarray::FaerEigh;
20use ndarray::{Array2, Array3, ArrayViewD, IxDyn};
21
22#[derive(Clone, Copy, PartialEq, Eq, Debug)]
24pub enum FisherRaoDefiniteness {
25 PositiveSemidefinite,
28 PositiveDefinite,
31}
32
33pub fn normalize_fisher_rao_blocks(
40 arr: ArrayViewD<'_, f64>,
41 n_rows: usize,
42 dim: usize,
43) -> Result<Array3<f64>, String> {
44 normalize_fisher_rao_blocks_with(
45 arr,
46 n_rows,
47 dim,
48 FisherRaoDefiniteness::PositiveSemidefinite,
49 )
50}
51
52pub fn normalize_fisher_rao_blocks_pd(
58 arr: ArrayViewD<'_, f64>,
59 n_rows: usize,
60 dim: usize,
61) -> Result<Array3<f64>, String> {
62 normalize_fisher_rao_blocks_with(arr, n_rows, dim, FisherRaoDefiniteness::PositiveDefinite)
63}
64
65fn normalize_fisher_rao_blocks_with(
66 arr: ArrayViewD<'_, f64>,
67 n_rows: usize,
68 dim: usize,
69 definiteness: FisherRaoDefiniteness,
70) -> Result<Array3<f64>, String> {
71 if !arr.iter().all(|v| v.is_finite()) {
72 return Err("fisher_rao_w must contain only finite values".to_string());
73 }
74 let shape = arr.shape().to_vec();
75 let out: Array3<f64> = match arr.ndim() {
76 1 => {
77 if shape[0] != n_rows {
78 return Err(format!(
79 "fisher_rao_w vector must have length {n_rows}; got {}",
80 shape[0]
81 ));
82 }
83 let mut block = Array3::<f64>::zeros((n_rows, dim, dim));
84 for row in 0..n_rows {
85 let value = arr[IxDyn(&[row])];
86 for d in 0..dim {
87 block[[row, d, d]] = value;
88 }
89 }
90 block
91 }
92 2 => {
93 if shape[0] != dim || shape[1] != dim {
94 return Err(format!(
95 "fisher_rao_w matrix must have shape ({dim}, {dim}); got ({}, {})",
96 shape[0], shape[1]
97 ));
98 }
99 let mut block = Array3::<f64>::zeros((n_rows, dim, dim));
100 for row in 0..n_rows {
101 for r in 0..dim {
102 for c in 0..dim {
103 block[[row, r, c]] = arr[IxDyn(&[r, c])];
104 }
105 }
106 }
107 block
108 }
109 3 => {
110 if shape[0] != n_rows || shape[1] != dim || shape[2] != dim {
111 return Err(format!(
112 "fisher_rao_w must have shape ({n_rows}, {dim}, {dim}); got ({}, {}, {})",
113 shape[0], shape[1], shape[2]
114 ));
115 }
116 let mut block = Array3::<f64>::zeros((n_rows, dim, dim));
117 for row in 0..n_rows {
118 for r in 0..dim {
119 for c in 0..dim {
120 block[[row, r, c]] = arr[IxDyn(&[row, r, c])];
121 }
122 }
123 }
124 block
125 }
126 _ => return Err("fisher_rao_w must be a 1-D, 2-D, or 3-D numeric array".to_string()),
127 };
128 for row in 0..n_rows {
129 for r in 0..dim {
130 for c in 0..dim {
131 let a = out[[row, r, c]];
132 let b = out[[row, c, r]];
133 if (a - b).abs() > 1.0e-10 * (1.0 + a.abs() + b.abs()) {
134 return Err("fisher_rao_w must be symmetric in every row block".to_string());
135 }
136 }
137 if out[[row, r, r]] < 0.0 {
138 return Err("fisher_rao_w diagonal entries must be non-negative".to_string());
139 }
140 }
141 validate_block_definiteness(out.index_axis(ndarray::Axis(0), row), row, definiteness)?;
142 }
143 Ok(out)
144}
145
146fn validate_block_definiteness(
153 block: ndarray::ArrayView2<'_, f64>,
154 row: usize,
155 definiteness: FisherRaoDefiniteness,
156) -> Result<(), String> {
157 if block.nrows() == 0 {
158 return Ok(());
159 }
160 let mut symmetric = Array2::<f64>::zeros((block.nrows(), block.ncols()));
163 for i in 0..block.nrows() {
164 for j in 0..block.ncols() {
165 symmetric[[i, j]] = 0.5 * (block[[i, j]] + block[[j, i]]);
166 }
167 }
168 let (eigenvalues, _) = symmetric.eigh(Side::Lower).map_err(|err| {
169 format!("fisher_rao_w row {row} eigendecomposition for definiteness check failed: {err}")
170 })?;
171 let spectral_scale = eigenvalues
172 .iter()
173 .fold(0.0_f64, |acc, &value| acc.max(value.abs()))
174 .max(1.0);
175 let min_eigenvalue = eigenvalues.iter().copied().fold(f64::INFINITY, f64::min);
176 let tol = 1.0e-10 * spectral_scale;
180 match definiteness {
181 FisherRaoDefiniteness::PositiveSemidefinite => {
182 if min_eigenvalue < -tol {
183 return Err(format!(
184 "fisher_rao_w row {row} must be positive semidefinite (a precision metric \
185 induces the squared residual rᵀ W r ≥ 0); smallest eigenvalue {min_eigenvalue} \
186 is negative"
187 ));
188 }
189 }
190 FisherRaoDefiniteness::PositiveDefinite => {
191 if min_eigenvalue <= tol {
192 return Err(format!(
193 "fisher_rao_w row {row} must be positive definite for Cholesky whitening; \
194 smallest eigenvalue {min_eigenvalue} is not strictly positive"
195 ));
196 }
197 }
198 }
199 Ok(())
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use ndarray::Array2;
206
207 fn block_2x2(values: [[f64; 2]; 2]) -> Array2<f64> {
208 let mut block = Array2::<f64>::zeros((2, 2));
209 for r in 0..2 {
210 for c in 0..2 {
211 block[[r, c]] = values[r][c];
212 }
213 }
214 block
215 }
216
217 #[test]
218 fn indefinite_block_symmetric_nonneg_diagonal_is_rejected_as_not_psd() {
219 let block = block_2x2([[1.0, 2.0], [2.0, 1.0]]);
223 let err = normalize_fisher_rao_blocks(block.view().into_dyn(), 4, 2)
224 .expect_err("indefinite block must be rejected by the PSD metric API");
225 assert!(
226 err.contains("positive semidefinite"),
227 "unexpected error message: {err}"
228 );
229 }
230
231 #[test]
232 fn psd_block_is_accepted_by_metric_api_and_broadcast() {
233 let block = block_2x2([[2.0, 1.0], [1.0, 2.0]]);
235 let n_rows = 3;
236 let out = normalize_fisher_rao_blocks(block.view().into_dyn(), n_rows, 2)
237 .expect("a genuinely PSD block must be accepted");
238 assert_eq!(out.shape(), &[n_rows, 2, 2]);
239 for row in 0..n_rows {
240 assert_eq!(out[[row, 0, 0]], 2.0);
241 assert_eq!(out[[row, 1, 0]], 1.0);
242 assert_eq!(out[[row, 0, 1]], 1.0);
243 assert_eq!(out[[row, 1, 1]], 2.0);
244 }
245 }
246
247 #[test]
248 fn pd_block_passes_the_cholesky_path() {
249 let block = block_2x2([[2.0, 1.0], [1.0, 2.0]]);
251 normalize_fisher_rao_blocks_pd(block.view().into_dyn(), 2, 2)
252 .expect("a positive-definite block must pass the Cholesky (PD) path");
253 }
254
255 #[test]
256 fn psd_singular_block_passes_metric_api_but_is_rejected_on_cholesky_path() {
257 let block = block_2x2([[1.0, 1.0], [1.0, 1.0]]);
261 normalize_fisher_rao_blocks(block.view().into_dyn(), 2, 2)
262 .expect("a PSD-singular block must be accepted by the metric API");
263 let err = normalize_fisher_rao_blocks_pd(block.view().into_dyn(), 2, 2)
264 .expect_err("a singular block has no Cholesky factor and must be rejected");
265 assert!(
266 err.contains("positive definite"),
267 "unexpected error message: {err}"
268 );
269 }
270
271 #[test]
272 fn isotropic_scale_vector_remains_accepted() {
273 let scales = ndarray::Array1::from(vec![0.5_f64, 2.0, 1.0]);
275 let out = normalize_fisher_rao_blocks(scales.view().into_dyn(), 3, 2)
276 .expect("non-negative isotropic scales are PSD");
277 assert_eq!(out[[1, 0, 0]], 2.0);
278 assert_eq!(out[[1, 1, 1]], 2.0);
279 assert_eq!(out[[1, 0, 1]], 0.0);
280 }
281
282 #[test]
283 fn per_row_indefinite_block_is_rejected_with_its_row_index() {
284 let mut stack = ndarray::Array3::<f64>::zeros((2, 2, 2));
287 for row in 0..2 {
288 stack[[row, 0, 0]] = 2.0;
289 stack[[row, 1, 1]] = 2.0;
290 }
291 stack[[1, 0, 1]] = 3.0;
292 stack[[1, 1, 0]] = 3.0; let err = normalize_fisher_rao_blocks(stack.view().into_dyn(), 2, 2)
294 .expect_err("the indefinite row block must be rejected");
295 assert!(err.contains("row 1"), "unexpected error message: {err}");
296 }
297
298 #[test]
299 fn non_square_dynamic_input_is_still_rejected_by_shape_check() {
300 let block = Array2::<f64>::zeros((3, 2));
303 let err = normalize_fisher_rao_blocks(block.view().into_dyn(), 4, 2)
304 .expect_err("a (3, 2) matrix is not a valid (2, 2) shared block");
305 assert!(err.contains("shape"), "unexpected error message: {err}");
306 }
307}