use ndarray::{Array3, ArrayViewD, IxDyn};
pub fn normalize_fisher_rao_blocks(
arr: ArrayViewD<'_, f64>,
n_rows: usize,
dim: usize,
) -> Result<Array3<f64>, String> {
if !arr.iter().all(|v| v.is_finite()) {
return Err("fisher_rao_w must contain only finite values".to_string());
}
let shape = arr.shape().to_vec();
let out: Array3<f64> = match arr.ndim() {
1 => {
if shape[0] != n_rows {
return Err(format!(
"fisher_rao_w vector must have length {n_rows}; got {}",
shape[0]
));
}
let mut block = Array3::<f64>::zeros((n_rows, dim, dim));
for row in 0..n_rows {
let value = arr[IxDyn(&[row])];
for d in 0..dim {
block[[row, d, d]] = value;
}
}
block
}
2 => {
if shape[0] != dim || shape[1] != dim {
return Err(format!(
"fisher_rao_w matrix must have shape ({dim}, {dim}); got ({}, {})",
shape[0], shape[1]
));
}
let mut block = Array3::<f64>::zeros((n_rows, dim, dim));
for row in 0..n_rows {
for r in 0..dim {
for c in 0..dim {
block[[row, r, c]] = arr[IxDyn(&[r, c])];
}
}
}
block
}
3 => {
if shape[0] != n_rows || shape[1] != dim || shape[2] != dim {
return Err(format!(
"fisher_rao_w must have shape ({n_rows}, {dim}, {dim}); got ({}, {}, {})",
shape[0], shape[1], shape[2]
));
}
let mut block = Array3::<f64>::zeros((n_rows, dim, dim));
for row in 0..n_rows {
for r in 0..dim {
for c in 0..dim {
block[[row, r, c]] = arr[IxDyn(&[row, r, c])];
}
}
}
block
}
_ => return Err("fisher_rao_w must be a 1-D, 2-D, or 3-D numeric array".to_string()),
};
for row in 0..n_rows {
for r in 0..dim {
for c in 0..dim {
let a = out[[row, r, c]];
let b = out[[row, c, r]];
if (a - b).abs() > 1.0e-10 * (1.0 + a.abs() + b.abs()) {
return Err("fisher_rao_w must be symmetric in every row block".to_string());
}
}
if out[[row, r, r]] < 0.0 {
return Err("fisher_rao_w diagonal entries must be non-negative".to_string());
}
}
}
Ok(out)
}