use super::loop_driver::max_symmetric_asymmetry;
use super::{
FIXED_STABILIZATION_RIDGE, PirlsPenalty, PirlsWorkspace, SparseXtWxCache, StablePLSResult,
WorkingReparamTransform, calculate_edf_from_sparse_factor,
calculate_edfwithworkspace_from_factor, ensure_sparse_positive_definitewithridge,
solve_sparse_spd,
};
use crate::estimate::EstimationError;
use crate::faer_ndarray::{FaerLinalgError, array1_to_col_matmut};
use crate::linalg::utils::{StableSolver, array_is_finite};
use crate::matrix::{DesignMatrix, LinearOperator, SymmetricMatrix};
use crate::types::{Coefficients, LinkFunction};
use faer::sparse::SparseColMat;
use ndarray::{Array1, Array2, ArrayView1, ShapeBuilder};
use std::sync::Arc;
#[derive(Debug)]
pub struct GaussianFixedCache {
pub xtwx_orig: Array2<f64>,
pub xtwy_orig: Array1<f64>,
pub centered_weighted_y_sq: f64,
pub xtwx_sparse_orig: Option<Arc<SparseXtwxPrecomputed>>,
}
#[derive(Debug, Clone)]
pub struct SparseXtwxPrecomputed {
pub xtwx_symbolic_col_ptr: Vec<usize>,
pub xtwx_symbolic_row_idx: Vec<usize>,
pub xtwxvalues: Vec<f64>,
}
impl SparseXtwxPrecomputed {
pub fn build(
x: &SparseColMat<usize, f64>,
weights: &Array1<f64>,
) -> Result<Self, EstimationError> {
let mut cache = SparseXtWxCache::new(x)?;
cache.compute_numeric(x, weights)?;
Ok(Self {
xtwx_symbolic_col_ptr: cache.xtwx_symbolic.col_ptr().to_vec(),
xtwx_symbolic_row_idx: cache.xtwx_symbolic.row_idx().to_vec(),
xtwxvalues: cache.xtwxvalues,
})
}
}
pub(super) fn solve_penalized_least_squares_implicit(
x_original: &DesignMatrix,
transform: Option<&WorkingReparamTransform>,
z: ArrayView1<f64>,
weights: ArrayView1<f64>,
offset: ArrayView1<f64>,
penalty: &PirlsPenalty,
workspace: &mut PirlsWorkspace,
y: ArrayView1<f64>,
link_function: LinkFunction,
gaussian_fixed_cache: Option<&GaussianFixedCache>,
) -> Result<(StablePLSResult, usize), EstimationError> {
let p_dim = penalty.dim();
if transform.is_none()
&& let Some(x_sparse) = x_original.as_sparse()
{
let PirlsPenalty::Dense { s_transformed, .. } = penalty else {
crate::bail_invalid_estim!(
"sparse-native PIRLS requires a dense transformed penalty matrix"
);
};
let weights_owned = weights.to_owned();
let precomputed_xtwx =
gaussian_fixed_cache.and_then(|c| c.xtwx_sparse_orig.as_ref().map(|arc| arc.as_ref()));
let (h_sparse, factor, ridge_used) = ensure_sparse_positive_definitewithridge(|ridge| {
let ridge = if ridge == 0.0 {
FIXED_STABILIZATION_RIDGE
} else {
ridge
};
workspace.assemble_sparse_penalized_hessian(
x_sparse,
&weights_owned,
s_transformed,
ridge,
precomputed_xtwx,
)
})?;
let mut wz = z.to_owned();
wz -= &offset;
wz *= &weights_owned;
let mut rhs = x_original.transpose_vector_multiply(&wz);
rhs += penalty.linear_shift();
if ridge_used > 0.0 {
let prior_mean_target = penalty.prior_mean_target();
if prior_mean_target.len() == rhs.len() {
rhs.scaled_add(ridge_used, prior_mean_target);
}
}
let betavec = solve_sparse_spd(&factor, &rhs)?;
let h_sym = SymmetricMatrix::Sparse(h_sparse);
let edf = calculate_edf_from_sparse_factor(&factor, penalty)?;
let fitted_vals = {
let xb = x_original.apply(&betavec);
let mut f = xb;
f += &offset;
f
};
let standard_deviation = match link_function {
LinkFunction::Identity => {
let residuals = &y - &fitted_vals;
let weighted_rss: f64 = weights
.iter()
.zip(residuals.iter())
.map(|(&w, &r)| w * r * r)
.sum();
let effective_n = y.len() as f64;
(weighted_rss / (effective_n - edf).max(1.0)).sqrt()
}
_ => 1.0,
};
return Ok((
StablePLSResult {
beta: Coefficients::new(betavec),
penalized_hessian: h_sym,
edf,
standard_deviation,
ridge_used,
},
p_dim,
));
}
if workspace.wz.len() != z.len() {
workspace.wz = Array1::zeros(z.len());
}
workspace.wz.assign(&z);
workspace.wz -= &offset;
workspace.wz *= &weights;
let weights_owned = weights.to_owned();
let xtwx_orig = if let Some(cache) = gaussian_fixed_cache {
let p = x_original.ncols();
if cache.xtwx_orig.nrows() != p || cache.xtwx_orig.ncols() != p {
return Err(EstimationError::InvalidInput(format!(
"GaussianFixedCache XᵀWX shape {}×{} does not match design p={}",
cache.xtwx_orig.nrows(),
cache.xtwx_orig.ncols(),
p,
)));
}
cache.xtwx_orig.clone()
} else {
match x_original {
DesignMatrix::Dense(x_dense) if x_dense.is_materialized_dense() => {
let p = x_dense.ncols();
let x_dense = x_dense.to_dense_arc();
if workspace.hessian_buf.nrows() != p || workspace.hessian_buf.ncols() != p {
workspace.hessian_buf = Array2::zeros((p, p).f());
} else {
workspace.hessian_buf.fill(0.0);
}
PirlsWorkspace::add_dense_xtwx_signed(
&weights_owned,
&mut workspace.weighted_x_chunk,
x_dense.as_ref(),
&mut workspace.hessian_buf,
);
std::mem::take(&mut workspace.hessian_buf)
}
_ => {
crate::matrix::xt_diag_x_signed(
x_original,
crate::matrix::SignedWeightsView::from_array(&weights_owned),
)
.map(|h| h.to_dense())
.map_err(EstimationError::InvalidInput)?
}
}
};
let xtwx_orig_asym = max_symmetric_asymmetry(&xtwx_orig);
let xtwx_transformed = if let Some(transform) = transform {
transform.conjugate_matrix(&xtwx_orig)
} else {
xtwx_orig
};
let mut penalized_hessian = xtwx_transformed.clone();
penalty.add_to_hessian(&mut penalized_hessian);
let xtwy_orig = if let Some(cache) = gaussian_fixed_cache {
assert_eq!(
cache.xtwy_orig.len(),
x_original.ncols(),
"GaussianFixedCache XᵀW(y−offset) length must match design p"
);
cache.xtwy_orig.clone()
} else {
x_original.transpose_vector_multiply(&workspace.wz)
};
if workspace.vec_buf_p.len() != p_dim {
workspace.vec_buf_p = Array1::zeros(p_dim);
}
if let Some(transform) = transform {
workspace
.vec_buf_p
.assign(&transform.apply_transpose(&xtwy_orig));
} else {
workspace.vec_buf_p.assign(&xtwy_orig);
}
workspace.vec_buf_p += penalty.linear_shift();
{
let xtwx_asym = max_symmetric_asymmetry(&xtwx_transformed);
let penalty_asym = match penalty {
PirlsPenalty::Dense { s_transformed, .. } => max_symmetric_asymmetry(s_transformed),
PirlsPenalty::Diagonal { .. } => 0.0,
};
let total_asym = max_symmetric_asymmetry(&penalized_hessian);
assert!(
total_asym <= 1e-8,
"implicit PLS penalized Hessian asymmetry too large: total={total_asym:.3e}, xtwx_orig={xtwx_orig_asym:.3e}, xtwx={xtwx_asym:.3e}, penalty={penalty_asym:.3e}, tol={:.3e}",
1e-8
);
}
let nugget = FIXED_STABILIZATION_RIDGE;
let mut regularizedhessian = penalized_hessian.clone();
if nugget > 0.0 {
for i in 0..p_dim {
regularizedhessian[[i, i]] += nugget;
}
}
let ridge_used = nugget;
if workspace.rhs_full.len() != p_dim {
workspace.rhs_full = Array1::zeros(p_dim);
}
workspace.rhs_full.assign(&workspace.vec_buf_p);
if nugget > 0.0 {
let prior_mean_target = penalty.prior_mean_target();
if prior_mean_target.len() == p_dim {
workspace.rhs_full.scaled_add(nugget, prior_mean_target);
}
}
let factor = StableSolver::new("pirls implicit pls")
.factorize(®ularizedhessian)
.map_err(EstimationError::LinearSystemSolveFailed)?;
let mut rhsview = array1_to_col_matmut(&mut workspace.rhs_full);
factor.solve_in_place(rhsview.as_mut());
if !array_is_finite(&workspace.rhs_full) {
return Err(EstimationError::LinearSystemSolveFailed(
FaerLinalgError::FactorizationFailed {
context: "PIRLS implicit PLS non-finite solve",
},
));
}
let betavec = workspace.rhs_full.clone();
let edf = calculate_edfwithworkspace_from_factor(&factor, penalty, workspace)?;
let qbeta = if let Some(transform) = transform {
transform.apply(&betavec)
} else {
betavec.clone()
};
let xqbeta = x_original.apply(&qbeta);
let mut fitted = xqbeta;
fitted += &offset;
let standard_deviation = match link_function {
LinkFunction::Identity => {
let residuals = &y - &fitted;
let weighted_rss: f64 = weights
.iter()
.zip(residuals.iter())
.map(|(&w, &r)| w * r * r)
.sum();
let effective_n = y.len() as f64;
(weighted_rss / (effective_n - edf).max(1.0)).sqrt()
}
_ => 1.0,
};
Ok((
StablePLSResult {
beta: Coefficients::new(betavec),
penalized_hessian: SymmetricMatrix::Dense(penalized_hessian),
edf,
standard_deviation,
ridge_used,
},
p_dim,
))
}