use crate::faer_ndarray::FaerEigh;
use faer::Side;
use ndarray::{Array2, Array3, ArrayViewD, IxDyn};
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum FisherRaoDefiniteness {
PositiveSemidefinite,
PositiveDefinite,
}
pub fn normalize_fisher_rao_blocks(
arr: ArrayViewD<'_, f64>,
n_rows: usize,
dim: usize,
) -> Result<Array3<f64>, String> {
normalize_fisher_rao_blocks_with(
arr,
n_rows,
dim,
FisherRaoDefiniteness::PositiveSemidefinite,
)
}
pub fn normalize_fisher_rao_blocks_pd(
arr: ArrayViewD<'_, f64>,
n_rows: usize,
dim: usize,
) -> Result<Array3<f64>, String> {
normalize_fisher_rao_blocks_with(arr, n_rows, dim, FisherRaoDefiniteness::PositiveDefinite)
}
fn normalize_fisher_rao_blocks_with(
arr: ArrayViewD<'_, f64>,
n_rows: usize,
dim: usize,
definiteness: FisherRaoDefiniteness,
) -> Result<Array3<f64>, String> {
if !arr.iter().all(|v| v.is_finite()) {
return Err("fisher_rao_w must contain only finite values".to_string());
}
let shape = arr.shape().to_vec();
let out: Array3<f64> = match arr.ndim() {
1 => {
if shape[0] != n_rows {
return Err(format!(
"fisher_rao_w vector must have length {n_rows}; got {}",
shape[0]
));
}
let mut block = Array3::<f64>::zeros((n_rows, dim, dim));
for row in 0..n_rows {
let value = arr[IxDyn(&[row])];
for d in 0..dim {
block[[row, d, d]] = value;
}
}
block
}
2 => {
if shape[0] != dim || shape[1] != dim {
return Err(format!(
"fisher_rao_w matrix must have shape ({dim}, {dim}); got ({}, {})",
shape[0], shape[1]
));
}
let mut block = Array3::<f64>::zeros((n_rows, dim, dim));
for row in 0..n_rows {
for r in 0..dim {
for c in 0..dim {
block[[row, r, c]] = arr[IxDyn(&[r, c])];
}
}
}
block
}
3 => {
if shape[0] != n_rows || shape[1] != dim || shape[2] != dim {
return Err(format!(
"fisher_rao_w must have shape ({n_rows}, {dim}, {dim}); got ({}, {}, {})",
shape[0], shape[1], shape[2]
));
}
let mut block = Array3::<f64>::zeros((n_rows, dim, dim));
for row in 0..n_rows {
for r in 0..dim {
for c in 0..dim {
block[[row, r, c]] = arr[IxDyn(&[row, r, c])];
}
}
}
block
}
_ => return Err("fisher_rao_w must be a 1-D, 2-D, or 3-D numeric array".to_string()),
};
for row in 0..n_rows {
for r in 0..dim {
for c in 0..dim {
let a = out[[row, r, c]];
let b = out[[row, c, r]];
if (a - b).abs() > 1.0e-10 * (1.0 + a.abs() + b.abs()) {
return Err("fisher_rao_w must be symmetric in every row block".to_string());
}
}
if out[[row, r, r]] < 0.0 {
return Err("fisher_rao_w diagonal entries must be non-negative".to_string());
}
}
validate_block_definiteness(out.index_axis(ndarray::Axis(0), row), row, definiteness)?;
}
Ok(out)
}
fn validate_block_definiteness(
block: ndarray::ArrayView2<'_, f64>,
row: usize,
definiteness: FisherRaoDefiniteness,
) -> Result<(), String> {
if block.nrows() == 0 {
return Ok(());
}
let mut symmetric = Array2::<f64>::zeros((block.nrows(), block.ncols()));
for i in 0..block.nrows() {
for j in 0..block.ncols() {
symmetric[[i, j]] = 0.5 * (block[[i, j]] + block[[j, i]]);
}
}
let (eigenvalues, _) = symmetric.eigh(Side::Lower).map_err(|err| {
format!("fisher_rao_w row {row} eigendecomposition for definiteness check failed: {err}")
})?;
let spectral_scale = eigenvalues
.iter()
.fold(0.0_f64, |acc, &value| acc.max(value.abs()))
.max(1.0);
let min_eigenvalue = eigenvalues.iter().copied().fold(f64::INFINITY, f64::min);
let tol = 1.0e-10 * spectral_scale;
match definiteness {
FisherRaoDefiniteness::PositiveSemidefinite => {
if min_eigenvalue < -tol {
return Err(format!(
"fisher_rao_w row {row} must be positive semidefinite (a precision metric \
induces the squared residual rᵀ W r ≥ 0); smallest eigenvalue {min_eigenvalue} \
is negative"
));
}
}
FisherRaoDefiniteness::PositiveDefinite => {
if min_eigenvalue <= tol {
return Err(format!(
"fisher_rao_w row {row} must be positive definite for Cholesky whitening; \
smallest eigenvalue {min_eigenvalue} is not strictly positive"
));
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn block_2x2(values: [[f64; 2]; 2]) -> Array2<f64> {
let mut block = Array2::<f64>::zeros((2, 2));
for r in 0..2 {
for c in 0..2 {
block[[r, c]] = values[r][c];
}
}
block
}
#[test]
fn indefinite_block_symmetric_nonneg_diagonal_is_rejected_as_not_psd() {
let block = block_2x2([[1.0, 2.0], [2.0, 1.0]]);
let err = normalize_fisher_rao_blocks(block.view().into_dyn(), 4, 2)
.expect_err("indefinite block must be rejected by the PSD metric API");
assert!(
err.contains("positive semidefinite"),
"unexpected error message: {err}"
);
}
#[test]
fn psd_block_is_accepted_by_metric_api_and_broadcast() {
let block = block_2x2([[2.0, 1.0], [1.0, 2.0]]);
let n_rows = 3;
let out = normalize_fisher_rao_blocks(block.view().into_dyn(), n_rows, 2)
.expect("a genuinely PSD block must be accepted");
assert_eq!(out.shape(), &[n_rows, 2, 2]);
for row in 0..n_rows {
assert_eq!(out[[row, 0, 0]], 2.0);
assert_eq!(out[[row, 1, 0]], 1.0);
assert_eq!(out[[row, 0, 1]], 1.0);
assert_eq!(out[[row, 1, 1]], 2.0);
}
}
#[test]
fn pd_block_passes_the_cholesky_path() {
let block = block_2x2([[2.0, 1.0], [1.0, 2.0]]);
normalize_fisher_rao_blocks_pd(block.view().into_dyn(), 2, 2)
.expect("a positive-definite block must pass the Cholesky (PD) path");
}
#[test]
fn psd_singular_block_passes_metric_api_but_is_rejected_on_cholesky_path() {
let block = block_2x2([[1.0, 1.0], [1.0, 1.0]]);
normalize_fisher_rao_blocks(block.view().into_dyn(), 2, 2)
.expect("a PSD-singular block must be accepted by the metric API");
let err = normalize_fisher_rao_blocks_pd(block.view().into_dyn(), 2, 2)
.expect_err("a singular block has no Cholesky factor and must be rejected");
assert!(
err.contains("positive definite"),
"unexpected error message: {err}"
);
}
#[test]
fn isotropic_scale_vector_remains_accepted() {
let scales = ndarray::Array1::from(vec![0.5_f64, 2.0, 1.0]);
let out = normalize_fisher_rao_blocks(scales.view().into_dyn(), 3, 2)
.expect("non-negative isotropic scales are PSD");
assert_eq!(out[[1, 0, 0]], 2.0);
assert_eq!(out[[1, 1, 1]], 2.0);
assert_eq!(out[[1, 0, 1]], 0.0);
}
#[test]
fn per_row_indefinite_block_is_rejected_with_its_row_index() {
let mut stack = ndarray::Array3::<f64>::zeros((2, 2, 2));
for row in 0..2 {
stack[[row, 0, 0]] = 2.0;
stack[[row, 1, 1]] = 2.0;
}
stack[[1, 0, 1]] = 3.0;
stack[[1, 1, 0]] = 3.0; let err = normalize_fisher_rao_blocks(stack.view().into_dyn(), 2, 2)
.expect_err("the indefinite row block must be rejected");
assert!(err.contains("row 1"), "unexpected error message: {err}");
}
#[test]
fn non_square_dynamic_input_is_still_rejected_by_shape_check() {
let block = Array2::<f64>::zeros((3, 2));
let err = normalize_fisher_rao_blocks(block.view().into_dyn(), 4, 2)
.expect_err("a (3, 2) matrix is not a valid (2, 2) shared block");
assert!(err.contains("shape"), "unexpected error message: {err}");
}
}