Skip to main content

gam_problem/
fisher_rao.rs

1//! Fisher–Rao precision-weight normalization for response-geometry REML.
2//!
3//! A caller may supply the per-observation Fisher–Rao precision metric coupling
4//! the tangent-coordinate residuals as a 1-D vector (a shared isotropic scale
5//! per row), a single 2-D `(dim, dim)` matrix (shared across all rows), or a
6//! full 3-D `(n_rows, dim, dim)` stack. This broadcasts any of those to the
7//! canonical per-row block form and validates finiteness, per-row symmetry, and
8//! positive semidefiniteness. A precision metric induces the squared residual
9//! `rᵀ W r`, which must be non-negative for every residual `r`; that holds iff
10//! each block is PSD (all eigenvalues `≥ 0`). Symmetry plus a non-negative
11//! diagonal is **not** sufficient — e.g. `[[1, 2], [2, 1]]` is symmetric with a
12//! non-negative diagonal yet `z = (1, −1)` gives `zᵀ W z = −2 < 0`. The
13//! Cholesky-whitening REML consumer needs the stronger positive-definite
14//! condition (all eigenvalues `> 0`), exposed via
15//! [`normalize_fisher_rao_blocks_pd`]. Single source of truth shared by the
16//! `response_geometry_normalize_fisher_rao` FFI shim and any core consumer.
17
18use faer::Side;
19use gam_linalg::faer_ndarray::FaerEigh;
20use ndarray::{Array2, Array3, ArrayViewD, IxDyn};
21
22/// Required definiteness of each per-row Fisher–Rao precision block.
23#[derive(Clone, Copy, PartialEq, Eq, Debug)]
24pub enum FisherRaoDefiniteness {
25    /// Metric / semi-metric use: `rᵀ W r ≥ 0` for all `r`. Rank-deficient
26    /// (singular) blocks are accepted; only indefinite blocks are rejected.
27    PositiveSemidefinite,
28    /// Cholesky-whitening use: the factorization `L Lᵀ = W` needs strict
29    /// positive-definiteness, so a zero eigenvalue is rejected too.
30    PositiveDefinite,
31}
32
33/// Broadcast and validate a Fisher–Rao weight array into `(n_rows, dim, dim)`
34/// **positive-semidefinite** precision blocks (the general metric API). Accepts
35/// a 1-D `(n_rows,)` isotropic scale, a 2-D `(dim, dim)` shared matrix, or a 3-D
36/// `(n_rows, dim, dim)` stack. Rank-deficient (PSD-singular) blocks are
37/// accepted; indefinite blocks are rejected. Use
38/// [`normalize_fisher_rao_blocks_pd`] for the Cholesky-whitening path.
39pub fn normalize_fisher_rao_blocks(
40    arr: ArrayViewD<'_, f64>,
41    n_rows: usize,
42    dim: usize,
43) -> Result<Array3<f64>, String> {
44    normalize_fisher_rao_blocks_with(
45        arr,
46        n_rows,
47        dim,
48        FisherRaoDefiniteness::PositiveSemidefinite,
49    )
50}
51
52/// Broadcast and validate Fisher–Rao weight blocks requiring each block to be
53/// **positive-definite**, as the Cholesky-whitening REML path needs (a singular
54/// block has no `L Lᵀ = W` factor). Same broadcasting rules as
55/// [`normalize_fisher_rao_blocks`]; the difference is the per-block spectrum
56/// must be strictly positive rather than merely non-negative.
57pub fn normalize_fisher_rao_blocks_pd(
58    arr: ArrayViewD<'_, f64>,
59    n_rows: usize,
60    dim: usize,
61) -> Result<Array3<f64>, String> {
62    normalize_fisher_rao_blocks_with(arr, n_rows, dim, FisherRaoDefiniteness::PositiveDefinite)
63}
64
65fn normalize_fisher_rao_blocks_with(
66    arr: ArrayViewD<'_, f64>,
67    n_rows: usize,
68    dim: usize,
69    definiteness: FisherRaoDefiniteness,
70) -> Result<Array3<f64>, String> {
71    if !arr.iter().all(|v| v.is_finite()) {
72        return Err("fisher_rao_w must contain only finite values".to_string());
73    }
74    let shape = arr.shape().to_vec();
75    let out: Array3<f64> = match arr.ndim() {
76        1 => {
77            if shape[0] != n_rows {
78                return Err(format!(
79                    "fisher_rao_w vector must have length {n_rows}; got {}",
80                    shape[0]
81                ));
82            }
83            let mut block = Array3::<f64>::zeros((n_rows, dim, dim));
84            for row in 0..n_rows {
85                let value = arr[IxDyn(&[row])];
86                for d in 0..dim {
87                    block[[row, d, d]] = value;
88                }
89            }
90            block
91        }
92        2 => {
93            if shape[0] != dim || shape[1] != dim {
94                return Err(format!(
95                    "fisher_rao_w matrix must have shape ({dim}, {dim}); got ({}, {})",
96                    shape[0], shape[1]
97                ));
98            }
99            let mut block = Array3::<f64>::zeros((n_rows, dim, dim));
100            for row in 0..n_rows {
101                for r in 0..dim {
102                    for c in 0..dim {
103                        block[[row, r, c]] = arr[IxDyn(&[r, c])];
104                    }
105                }
106            }
107            block
108        }
109        3 => {
110            if shape[0] != n_rows || shape[1] != dim || shape[2] != dim {
111                return Err(format!(
112                    "fisher_rao_w must have shape ({n_rows}, {dim}, {dim}); got ({}, {}, {})",
113                    shape[0], shape[1], shape[2]
114                ));
115            }
116            let mut block = Array3::<f64>::zeros((n_rows, dim, dim));
117            for row in 0..n_rows {
118                for r in 0..dim {
119                    for c in 0..dim {
120                        block[[row, r, c]] = arr[IxDyn(&[row, r, c])];
121                    }
122                }
123            }
124            block
125        }
126        _ => return Err("fisher_rao_w must be a 1-D, 2-D, or 3-D numeric array".to_string()),
127    };
128    for row in 0..n_rows {
129        for r in 0..dim {
130            for c in 0..dim {
131                let a = out[[row, r, c]];
132                let b = out[[row, c, r]];
133                if (a - b).abs() > 1.0e-10 * (1.0 + a.abs() + b.abs()) {
134                    return Err("fisher_rao_w must be symmetric in every row block".to_string());
135                }
136            }
137            if out[[row, r, r]] < 0.0 {
138                return Err("fisher_rao_w diagonal entries must be non-negative".to_string());
139            }
140        }
141        validate_block_definiteness(out.index_axis(ndarray::Axis(0), row), row, definiteness)?;
142    }
143    Ok(out)
144}
145
146/// Validate that a single symmetric `(dim, dim)` precision block has the
147/// required definiteness by checking its eigenvalue spectrum. A precision
148/// metric must be PSD so that the induced squared residual `rᵀ W r` is never
149/// negative; the Cholesky-whitening path additionally needs PD. The threshold
150/// is relative to the block's spectral scale (its largest eigenvalue magnitude)
151/// so that the check is invariant to the units of the metric.
152fn validate_block_definiteness(
153    block: ndarray::ArrayView2<'_, f64>,
154    row: usize,
155    definiteness: FisherRaoDefiniteness,
156) -> Result<(), String> {
157    if block.nrows() == 0 {
158        return Ok(());
159    }
160    // Symmetrize before the eigensolve so the spectrum is exactly real; the
161    // off-diagonals already match to within the symmetry tolerance checked above.
162    let mut symmetric = Array2::<f64>::zeros((block.nrows(), block.ncols()));
163    for i in 0..block.nrows() {
164        for j in 0..block.ncols() {
165            symmetric[[i, j]] = 0.5 * (block[[i, j]] + block[[j, i]]);
166        }
167    }
168    let (eigenvalues, _) = symmetric.eigh(Side::Lower).map_err(|err| {
169        format!("fisher_rao_w row {row} eigendecomposition for definiteness check failed: {err}")
170    })?;
171    let spectral_scale = eigenvalues
172        .iter()
173        .fold(0.0_f64, |acc, &value| acc.max(value.abs()))
174        .max(1.0);
175    let min_eigenvalue = eigenvalues.iter().copied().fold(f64::INFINITY, f64::min);
176    // Relative spectral tolerance: a block is treated as PSD when its smallest
177    // eigenvalue is no more negative than this fraction of its spectral scale,
178    // absorbing the rounding of the symmetric eigensolve.
179    let tol = 1.0e-10 * spectral_scale;
180    match definiteness {
181        FisherRaoDefiniteness::PositiveSemidefinite => {
182            if min_eigenvalue < -tol {
183                return Err(format!(
184                    "fisher_rao_w row {row} must be positive semidefinite (a precision metric \
185                     induces the squared residual rᵀ W r ≥ 0); smallest eigenvalue {min_eigenvalue} \
186                     is negative"
187                ));
188            }
189        }
190        FisherRaoDefiniteness::PositiveDefinite => {
191            if min_eigenvalue <= tol {
192                return Err(format!(
193                    "fisher_rao_w row {row} must be positive definite for Cholesky whitening; \
194                     smallest eigenvalue {min_eigenvalue} is not strictly positive"
195                ));
196            }
197        }
198    }
199    Ok(())
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use ndarray::Array2;
206
207    fn block_2x2(values: [[f64; 2]; 2]) -> Array2<f64> {
208        let mut block = Array2::<f64>::zeros((2, 2));
209        for r in 0..2 {
210            for c in 0..2 {
211                block[[r, c]] = values[r][c];
212            }
213        }
214        block
215    }
216
217    #[test]
218    fn indefinite_block_symmetric_nonneg_diagonal_is_rejected_as_not_psd() {
219        // [[1, 2], [2, 1]] is symmetric with a non-negative diagonal but has
220        // eigenvalues {3, -1}; z = (1, -1) gives zᵀ W z = -2 < 0, so it is not a
221        // valid precision metric. The old symmetry + diagonal checks accepted it.
222        let block = block_2x2([[1.0, 2.0], [2.0, 1.0]]);
223        let err = normalize_fisher_rao_blocks(block.view().into_dyn(), 4, 2)
224            .expect_err("indefinite block must be rejected by the PSD metric API");
225        assert!(
226            err.contains("positive semidefinite"),
227            "unexpected error message: {err}"
228        );
229    }
230
231    #[test]
232    fn psd_block_is_accepted_by_metric_api_and_broadcast() {
233        // [[2, 1], [1, 2]] has eigenvalues {1, 3} > 0 (PSD, in fact PD).
234        let block = block_2x2([[2.0, 1.0], [1.0, 2.0]]);
235        let n_rows = 3;
236        let out = normalize_fisher_rao_blocks(block.view().into_dyn(), n_rows, 2)
237            .expect("a genuinely PSD block must be accepted");
238        assert_eq!(out.shape(), &[n_rows, 2, 2]);
239        for row in 0..n_rows {
240            assert_eq!(out[[row, 0, 0]], 2.0);
241            assert_eq!(out[[row, 1, 0]], 1.0);
242            assert_eq!(out[[row, 0, 1]], 1.0);
243            assert_eq!(out[[row, 1, 1]], 2.0);
244        }
245    }
246
247    #[test]
248    fn pd_block_passes_the_cholesky_path() {
249        // Eigenvalues {1, 3}, both strictly positive: valid for Cholesky whitening.
250        let block = block_2x2([[2.0, 1.0], [1.0, 2.0]]);
251        normalize_fisher_rao_blocks_pd(block.view().into_dyn(), 2, 2)
252            .expect("a positive-definite block must pass the Cholesky (PD) path");
253    }
254
255    #[test]
256    fn psd_singular_block_passes_metric_api_but_is_rejected_on_cholesky_path() {
257        // [[1, 1], [1, 1]] has eigenvalues {0, 2}: PSD but singular. The metric
258        // API must accept it (rᵀ W r ≥ 0), while the Cholesky-whitening path
259        // requires strict positive-definiteness and must reject it.
260        let block = block_2x2([[1.0, 1.0], [1.0, 1.0]]);
261        normalize_fisher_rao_blocks(block.view().into_dyn(), 2, 2)
262            .expect("a PSD-singular block must be accepted by the metric API");
263        let err = normalize_fisher_rao_blocks_pd(block.view().into_dyn(), 2, 2)
264            .expect_err("a singular block has no Cholesky factor and must be rejected");
265        assert!(
266            err.contains("positive definite"),
267            "unexpected error message: {err}"
268        );
269    }
270
271    #[test]
272    fn isotropic_scale_vector_remains_accepted() {
273        // A 1-D isotropic scale becomes diag(value) per row: PSD when non-negative.
274        let scales = ndarray::Array1::from(vec![0.5_f64, 2.0, 1.0]);
275        let out = normalize_fisher_rao_blocks(scales.view().into_dyn(), 3, 2)
276            .expect("non-negative isotropic scales are PSD");
277        assert_eq!(out[[1, 0, 0]], 2.0);
278        assert_eq!(out[[1, 1, 1]], 2.0);
279        assert_eq!(out[[1, 0, 1]], 0.0);
280    }
281
282    #[test]
283    fn per_row_indefinite_block_is_rejected_with_its_row_index() {
284        // A 3-D stack whose second row block is indefinite must be rejected and
285        // name that row, confirming the check runs per row block.
286        let mut stack = ndarray::Array3::<f64>::zeros((2, 2, 2));
287        for row in 0..2 {
288            stack[[row, 0, 0]] = 2.0;
289            stack[[row, 1, 1]] = 2.0;
290        }
291        stack[[1, 0, 1]] = 3.0;
292        stack[[1, 1, 0]] = 3.0; // eigenvalues {-1, 5}: indefinite.
293        let err = normalize_fisher_rao_blocks(stack.view().into_dyn(), 2, 2)
294            .expect_err("the indefinite row block must be rejected");
295        assert!(err.contains("row 1"), "unexpected error message: {err}");
296    }
297
298    #[test]
299    fn non_square_dynamic_input_is_still_rejected_by_shape_check() {
300        // Sanity: the eigenvalue addition must not regress the existing shape
301        // validation for a malformed 2-D matrix.
302        let block = Array2::<f64>::zeros((3, 2));
303        let err = normalize_fisher_rao_blocks(block.view().into_dyn(), 4, 2)
304            .expect_err("a (3, 2) matrix is not a valid (2, 2) shared block");
305        assert!(err.contains("shape"), "unexpected error message: {err}");
306    }
307}