use super::*;
use crate::linalg::low_rank_weight::LowRankWeight;
pub fn compute_xtwx_low_rank(
workspace: &mut PirlsWorkspace,
design: &DesignMatrix,
weight: &LowRankWeight<'_>,
) -> Result<Array2<f64>, EstimationError> {
let diag_owned = weight.diag.to_owned();
let mut xtwx = GamWorkingModel::compute_xtwx_blas(workspace, design, &diag_owned)?;
if weight.is_rank_zero() {
return Ok(xtwx);
}
weight
.add_low_rank_xtwx_correction(design, &mut xtwx)
.map_err(EstimationError::InvalidInput)?;
Ok(xtwx)
}
pub fn compute_xtwy_low_rank(
design: &DesignMatrix,
weight: &LowRankWeight<'_>,
y: &Array1<f64>,
) -> Result<Array1<f64>, EstimationError> {
weight
.xtw_y(design, y.view())
.map_err(EstimationError::InvalidInput)
}
pub fn dense_block_xtwx(
design: ArrayView2<'_, f64>,
fisher_blocks: ArrayView3<'_, f64>,
row_weights: Option<ArrayView1<'_, f64>>,
) -> Result<Array2<f64>, EstimationError> {
let n = design.nrows();
let k = design.ncols();
let shape = fisher_blocks.shape();
if shape.len() != 3 || shape[0] != n || shape[1] != shape[2] {
crate::bail_invalid_estim!(
"dense block Fisher shape mismatch: expected ({n}, p, p), got {shape:?}"
);
}
if let Some(w) = row_weights.as_ref() {
if w.len() != n {
crate::bail_invalid_estim!(
"dense block row weight length mismatch: expected {n}, got {}",
w.len()
);
}
if w.iter().any(|v| !v.is_finite() || *v < 0.0) {
crate::bail_invalid_estim!("dense block row weights must be finite and non-negative");
}
}
let p_out = shape[1];
let dim = k * p_out;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
let nonfinite = (0..n)
.into_par_iter()
.filter_map(|row| {
let rw = row_weights.as_ref().map(|w| w[row]).unwrap_or(1.0);
for a in 0..p_out {
for b in 0..p_out {
if !(rw * fisher_blocks[[row, a, b]]).is_finite() {
return Some((row, a, b));
}
}
}
None
})
.min();
if let Some((row, a, b)) = nonfinite {
crate::bail_invalid_estim!("dense block Fisher entry ({row},{a},{b}) is not finite");
}
let mut out = (0..n)
.into_par_iter()
.fold(
|| Array2::<f64>::zeros((dim, dim)),
|mut acc, row| {
let rw = row_weights.as_ref().map(|w| w[row]).unwrap_or(1.0);
for a in 0..p_out {
for b in 0..p_out {
let wab = rw * fisher_blocks[[row, a, b]];
if wab == 0.0 {
continue;
}
let row_a = a * k;
let row_b = b * k;
for i in 0..k {
let xi = design[[row, i]];
if xi == 0.0 {
continue;
}
let scaled = wab * xi;
for j in 0..k {
acc[[row_a + i, row_b + j]] += scaled * design[[row, j]];
}
}
}
}
acc
},
)
.reduce(
|| Array2::<f64>::zeros((dim, dim)),
|mut a, b| {
a += &b;
a
},
);
for i in 0..dim {
for j in (i + 1)..dim {
let avg = 0.5 * (out[[i, j]] + out[[j, i]]);
out[[i, j]] = avg;
out[[j, i]] = avg;
}
}
Ok(out)
}
pub fn dense_block_xtwy(
design: ArrayView2<'_, f64>,
fisher_blocks: ArrayView3<'_, f64>,
response: ArrayView2<'_, f64>,
row_weights: Option<ArrayView1<'_, f64>>,
) -> Result<Array1<f64>, EstimationError> {
let n = design.nrows();
let k = design.ncols();
let shape = fisher_blocks.shape();
if shape.len() != 3 || shape[0] != n || shape[1] != shape[2] {
crate::bail_invalid_estim!(
"dense block Fisher shape mismatch: expected ({n}, p, p), got {shape:?}"
);
}
let p_out = shape[1];
if response.dim() != (n, p_out) {
crate::bail_invalid_estim!(
"dense block response shape mismatch: expected ({n}, {p_out}), got {}x{}",
response.nrows(),
response.ncols()
);
}
if let Some(w) = row_weights.as_ref()
&& w.len() != n
{
crate::bail_invalid_estim!(
"dense block row weight length mismatch: expected {n}, got {}",
w.len()
);
}
let mut out = Array1::<f64>::zeros(k * p_out);
for row in 0..n {
let rw = row_weights.as_ref().map(|w| w[row]).unwrap_or(1.0);
for a in 0..p_out {
let mut wy = 0.0_f64;
for b in 0..p_out {
let wab = rw * fisher_blocks[[row, a, b]];
if !wab.is_finite() {
crate::bail_invalid_estim!(
"dense block Fisher entry ({row},{a},{b}) is not finite"
);
}
wy += wab * response[[row, b]];
}
for i in 0..k {
out[a * k + i] += design[[row, i]] * wy;
}
}
}
Ok(out)
}
pub fn woodbury_gram_capacitance(
a_inv_uhat: &Array2<f64>,
vhat: &Array2<f64>,
) -> Result<Array2<f64>, EstimationError> {
LowRankWeight::gram_capacitance(a_inv_uhat, vhat).map_err(EstimationError::InvalidInput)
}
#[cfg(test)]
mod low_rank_weight_pirls_tests {
use super::{
DesignMatrix, LowRankWeight, PirlsWorkspace, compute_xtwx_low_rank, compute_xtwy_low_rank,
woodbury_gram_capacitance,
};
use crate::linalg::matrix::{LinearOperator, SignedWeightsView};
use ndarray::{Array2, array};
fn tiny_design() -> DesignMatrix {
let x = array![
[1.0, 0.5, -0.2],
[0.3, 1.2, 0.4],
[-0.1, 0.7, 1.0],
[0.6, -0.3, 0.8],
[0.2, 0.9, -0.5],
];
DesignMatrix::Dense(crate::matrix::DenseDesignMatrix::from(x))
}
#[test]
pub(crate) fn xtwx_low_rank_matches_diagonal_when_rank_zero() {
let design = tiny_design();
let d = array![1.0, 2.0, 0.5, 1.5, 0.8];
let u = Array2::<f64>::zeros((5, 0));
let v = Array2::<f64>::zeros((5, 0));
let weight = LowRankWeight::new(d.view(), u.view(), v.view()).unwrap();
let mut ws = PirlsWorkspace::new(5, 3, 0, 0);
let got = compute_xtwx_low_rank(&mut ws, &design, &weight).unwrap();
let want = design
.xt_diag_x_signed_op(SignedWeightsView::from_array(&d))
.unwrap();
let diff = (&got - &want).mapv(f64::abs).sum();
assert!(diff < 1e-12, "rank-0 path diverged from diagonal: {}", diff);
}
#[test]
pub(crate) fn xtwy_low_rank_matches_dense_reference() {
let design = tiny_design();
let d = array![1.0, 2.0, 0.5, 1.5, 0.8];
let u = array![
[0.1, -0.2],
[0.4, 0.3],
[-0.1, 0.5],
[0.2, 0.1],
[0.0, -0.3]
];
let v = array![[0.2, 0.1], [0.0, 0.4], [0.3, -0.2], [-0.1, 0.6], [0.5, 0.0]];
let weight = LowRankWeight::new(d.view(), u.view(), v.view()).unwrap();
let y = array![0.7, -1.2, 0.3, 0.9, -0.4];
let got = compute_xtwy_low_rank(&design, &weight, &y).unwrap();
let xdense = design.as_dense().unwrap().to_owned();
let mut w = Array2::<f64>::zeros((5, 5));
for i in 0..5 {
w[[i, i]] = d[i];
}
w += &u.dot(&v.t());
let want = xdense.t().dot(&w.dot(&y));
let diff: f64 = got
.iter()
.zip(want.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(diff < 1e-10, "xtwy_low_rank diverged: {}", diff);
}
#[test]
pub(crate) fn woodbury_capacitance_is_well_formed() {
let uhat = array![[0.5, 0.1], [-0.2, 0.7], [0.3, -0.4]];
let vhat = array![[0.1, 0.2], [0.6, -0.1], [-0.3, 0.4]];
let cap = woodbury_gram_capacitance(&uhat, &vhat).unwrap();
let want = {
let mut m = vhat.t().dot(&uhat);
for k in 0..2 {
m[[k, k]] += 1.0;
}
m
};
let diff: f64 = cap
.iter()
.zip(want.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(diff < 1e-12);
}
}