use crate::estimate::EstimationError;
use ndarray::{Array1, Array2, ArrayView2};
#[derive(Clone, Debug)]
pub enum VectorNoise {
Isotropic(f64),
Diagonal(Array1<f64>),
LowRank {
diag: Array1<f64>,
factor: Array2<f64>,
},
}
impl VectorNoise {
pub fn diag_precision(&self, m: usize) -> Result<Array1<f64>, EstimationError> {
match self {
Self::Isotropic(sigma) => {
if !sigma.is_finite() || *sigma <= 0.0 {
return Err(EstimationError::InvalidInput(format!(
"VectorNoise::Isotropic: σ must be > 0 and finite (got {sigma})",
)));
}
let p = 1.0 / (sigma * sigma);
Ok(Array1::from_elem(m, p))
}
Self::Diagonal(sigma) => {
if sigma.len() != m {
return Err(EstimationError::InvalidInput(format!(
"VectorNoise::Diagonal: σ length {} ≠ M={m}",
sigma.len()
)));
}
let mut out = Array1::<f64>::zeros(m);
for j in 0..m {
let s = sigma[j];
if !s.is_finite() || s <= 0.0 {
return Err(EstimationError::InvalidInput(format!(
"VectorNoise::Diagonal: σ[{j}] must be > 0 and finite (got {s})",
)));
}
out[j] = 1.0 / (s * s);
}
Ok(out)
}
Self::LowRank { diag, .. } => {
if diag.len() != m {
return Err(EstimationError::InvalidInput(format!(
"VectorNoise::LowRank: diag length {} ≠ M={m}",
diag.len()
)));
}
let mut out = Array1::<f64>::zeros(m);
for j in 0..m {
let d = diag[j];
if !d.is_finite() || d <= 0.0 {
return Err(EstimationError::InvalidInput(format!(
"VectorNoise::LowRank: diag[{j}] must be > 0 (got {d})",
)));
}
out[j] = d;
}
Ok(out)
}
}
}
}
#[derive(Clone, Debug)]
pub struct VectorResponseTarget {
pub y: Array2<f64>,
pub noise: VectorNoise,
pub row_weights: Option<Array1<f64>>,
}
impl VectorResponseTarget {
pub fn new(y: Array2<f64>, noise: VectorNoise) -> Self {
Self {
y,
noise,
row_weights: None,
}
}
pub fn with_row_weights(mut self, w: Array1<f64>) -> Result<Self, EstimationError> {
validate_row_weights(&w, self.y.nrows())?;
self.row_weights = Some(w);
Ok(self)
}
pub fn n(&self) -> usize {
self.y.nrows()
}
pub fn m(&self) -> usize {
self.y.ncols()
}
}
fn validate_row_weights(weights: &Array1<f64>, n: usize) -> Result<(), EstimationError> {
if weights.len() != n {
return Err(EstimationError::InvalidInput(format!(
"row_weights length {} ≠ N={n}",
weights.len()
)));
}
for (idx, weight) in weights.iter().copied().enumerate() {
if !(weight.is_finite() && weight >= 0.0) {
return Err(EstimationError::InvalidInput(format!(
"row_weights[{idx}] must be finite and non-negative (got {weight})"
)));
}
}
Ok(())
}
pub trait VectorLikelihood {
fn log_lik(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> f64;
fn grad_eta(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64>;
fn hess_diag(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64>;
}
#[derive(Clone, Debug)]
pub struct GaussianVectorLikelihood {
pub precision: Array1<f64>,
pub factor: Option<Array2<f64>>,
pub row_weights: Option<Array1<f64>>,
}
impl GaussianVectorLikelihood {
pub fn from_target(target: &VectorResponseTarget) -> Result<Self, EstimationError> {
if let Some(weights) = target.row_weights.as_ref() {
validate_row_weights(weights, target.n())?;
}
let precision = target.noise.diag_precision(target.m())?;
let factor = match &target.noise {
VectorNoise::LowRank { factor, .. } => {
if factor.nrows() != target.m() {
return Err(EstimationError::InvalidInput(format!(
"VectorNoise::LowRank: factor has {} rows but M={}",
factor.nrows(),
target.m()
)));
}
for ((row, col), value) in factor.indexed_iter() {
if !value.is_finite() {
return Err(EstimationError::InvalidInput(format!(
"VectorNoise::LowRank: factor[{row},{col}] must be finite (got {value})"
)));
}
}
Some(factor.clone())
}
_ => None,
};
Ok(Self {
precision,
factor,
row_weights: target.row_weights.clone(),
})
}
#[inline]
fn row_weight(&self, n: usize) -> f64 {
self.row_weights.as_ref().map_or(1.0, |w| w[n])
}
}
impl VectorLikelihood for GaussianVectorLikelihood {
fn log_lik(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> f64 {
assert_eq!(eta.dim(), y.dim());
assert_eq!(eta.ncols(), self.precision.len());
let m = eta.ncols();
let rank = self.factor.as_ref().map_or(0, |f| f.ncols());
let mut acc = 0.0;
let mut ftr = vec![0.0f64; rank];
for n in 0..eta.nrows() {
let w = self.row_weight(n);
let mut row_acc = 0.0;
for j in 0..m {
let r = y[[n, j]] - eta[[n, j]];
row_acc += self.precision[j] * r * r;
}
if let Some(f) = self.factor.as_ref() {
for k in 0..rank {
ftr[k] = 0.0;
}
for j in 0..m {
let r = y[[n, j]] - eta[[n, j]];
for k in 0..rank {
ftr[k] += f[[j, k]] * r;
}
}
for k in 0..rank {
row_acc += ftr[k] * ftr[k];
}
}
acc += w * row_acc;
}
-0.5 * acc
}
fn grad_eta(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64> {
assert_eq!(eta.dim(), y.dim());
let (n_rows, n_cols) = eta.dim();
let rank = self.factor.as_ref().map_or(0, |f| f.ncols());
let mut out = Array2::<f64>::zeros((n_rows, n_cols));
let mut ftr = vec![0.0f64; rank];
for n in 0..n_rows {
let w = self.row_weight(n);
for j in 0..n_cols {
out[[n, j]] = w * self.precision[j] * (y[[n, j]] - eta[[n, j]]);
}
if let Some(f) = self.factor.as_ref() {
for k in 0..rank {
ftr[k] = 0.0;
}
for j in 0..n_cols {
let r = y[[n, j]] - eta[[n, j]];
for k in 0..rank {
ftr[k] += f[[j, k]] * r;
}
}
for j in 0..n_cols {
let mut s = 0.0;
for k in 0..rank {
s += f[[j, k]] * ftr[k];
}
out[[n, j]] += w * s;
}
}
}
out
}
fn hess_diag(&self, eta: ArrayView2<f64>, y: ArrayView2<f64>) -> Array2<f64> {
assert_eq!(eta.dim(), y.dim());
let (n_rows, n_cols) = eta.dim();
let mut out = Array2::<f64>::zeros((n_rows, n_cols));
let f_row_sqsum: Option<Array1<f64>> = self.factor.as_ref().map(|f| {
let m = f.nrows();
let r = f.ncols();
let mut s = Array1::<f64>::zeros(m);
for j in 0..m {
let mut acc = 0.0;
for k in 0..r {
let v = f[[j, k]];
acc += v * v;
}
s[j] = acc;
}
s
});
for n in 0..n_rows {
let w = self.row_weight(n);
for j in 0..n_cols {
let mut d = self.precision[j];
if let Some(s) = f_row_sqsum.as_ref() {
d += s[j];
}
out[[n, j]] = w * d;
}
}
out
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2, array};
macro_rules! expect_invalid_input {
($result:expr, $needle:expr $(,)?) => {{
let needle: &str = $needle;
match $result {
Ok(_) => {
panic!("expected EstimationError::InvalidInput containing `{needle}`, got Ok")
}
Err(EstimationError::InvalidInput(msg)) => {
assert!(
msg.contains(needle),
"InvalidInput message `{msg}` does not contain `{needle}`"
);
msg
}
Err(other) => panic!(
"expected EstimationError::InvalidInput containing `{needle}`, got {other:?}"
),
}
}};
}
fn dummy_target(n: usize, m: usize) -> VectorResponseTarget {
VectorResponseTarget::new(Array2::<f64>::zeros((n, m)), VectorNoise::Isotropic(1.0))
}
#[test]
fn with_row_weights_rejects_wrong_length() {
let target = dummy_target(4, 2);
let weights = Array1::from(vec![1.0, 1.0, 1.0]);
expect_invalid_input!(target.with_row_weights(weights), "row_weights length");
}
#[test]
fn with_row_weights_rejects_negative_entry() {
let target = dummy_target(3, 2);
let weights = Array1::from(vec![1.0, -0.5, 2.0]);
expect_invalid_input!(
target.with_row_weights(weights),
"must be finite and non-negative",
);
}
#[test]
fn with_row_weights_rejects_nan_entry() {
let target = dummy_target(3, 2);
let weights = Array1::from(vec![1.0, f64::NAN, 2.0]);
expect_invalid_input!(
target.with_row_weights(weights),
"must be finite and non-negative",
);
}
#[test]
fn with_row_weights_rejects_infinite_entry() {
let target = dummy_target(3, 2);
let weights = Array1::from(vec![1.0, f64::INFINITY, 2.0]);
expect_invalid_input!(
target.with_row_weights(weights),
"must be finite and non-negative",
);
}
#[test]
fn with_row_weights_accepts_zero_and_positive() {
let target = dummy_target(3, 2);
let weights = Array1::from(vec![0.0, 1.5, 3.0]);
let weighted = target
.with_row_weights(weights)
.expect("zero / positive weights should be accepted");
assert!(weighted.row_weights.is_some());
}
#[test]
fn from_target_rejects_low_rank_factor_with_wrong_row_count() {
let n = 4;
let m = 3;
let factor = Array2::from_shape_vec((2, 2), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
let target = VectorResponseTarget::new(
Array2::<f64>::zeros((n, m)),
VectorNoise::LowRank {
diag: Array1::from(vec![1.0; m]),
factor,
},
);
expect_invalid_input!(GaussianVectorLikelihood::from_target(&target), "factor has",);
}
#[test]
fn from_target_rejects_non_finite_low_rank_factor_entry() {
let n = 4;
let m = 3;
let mut factor = Array2::<f64>::zeros((m, 2));
factor[[1, 0]] = f64::NAN;
let target = VectorResponseTarget::new(
Array2::<f64>::zeros((n, m)),
VectorNoise::LowRank {
diag: Array1::from(vec![1.0; m]),
factor,
},
);
expect_invalid_input!(
GaussianVectorLikelihood::from_target(&target),
"must be finite",
);
}
#[test]
fn from_target_accepts_well_formed_low_rank_factor() {
let n = 2;
let m = 3;
let factor = Array2::from_shape_vec((m, 2), vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6]).unwrap();
let target = VectorResponseTarget::new(
array![[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
VectorNoise::LowRank {
diag: Array1::from(vec![1.0; m]),
factor: factor.clone(),
},
);
let lik = GaussianVectorLikelihood::from_target(&target)
.expect("well-formed low-rank factor should be accepted");
let stored = lik.factor.expect("low-rank factor should be carried");
assert_eq!(stored.dim(), (m, 2));
for ((i, j), v) in stored.indexed_iter() {
assert_eq!(*v, factor[[i, j]]);
}
assert_eq!(n, lik.precision.len().max(n));
}
#[test]
fn from_target_propagates_row_weight_length_mismatch() {
let n = 3;
let m = 2;
let target = VectorResponseTarget {
y: Array2::<f64>::zeros((n, m)),
noise: VectorNoise::Isotropic(1.0),
row_weights: Some(Array1::from(vec![1.0, 1.0])),
};
expect_invalid_input!(
GaussianVectorLikelihood::from_target(&target),
"row_weights length",
);
}
}