use crate::faer_ndarray::{
FaerArrayView, factorize_symmetricwith_fallback, fast_ab, fast_abt, fast_xt_diag_x,
fast_xt_diag_y,
};
use crate::families::transformation_normal::{
TRANSFORMATION_MONOTONICITY_EPS, TransformationNormalFitResult, transformation_normal_pit_score,
};
use crate::inference::model::TRANSFORMATION_SCORE_PIT_CLIP_EPS;
use crate::matrix::FactorizedSystem;
use crate::probability::{normal_cdf, normal_pdf, standard_normal_quantile};
use crate::smooth::build_term_collection_design;
use faer::Side;
use ndarray::{Array1, Array2, ArrayView2};
pub(crate) const INFLUENCE_ABSORBER_FIXED_LOG_LAMBDA: f64 = 0.0;
pub(crate) const MARGINAL_NULLSPACE_RIDGE_FIXED_LOG_LAMBDA: f64 = -4.605_170_185_988_091;
pub(crate) const MARGINAL_LOGSLOPE_OVERLAP_FIXED_LOG_LAMBDA: f64 = 0.0;
pub struct ScoreInfluenceJacobian {
pub columns: Array2<f64>,
pub z: Array1<f64>,
}
pub fn score_influence_jacobian(
fit: &TransformationNormalFitResult,
response: &Array1<f64>,
covariate_data: ArrayView2<f64>,
) -> Result<ScoreInfluenceJacobian, String> {
let family = &fit.family;
let n = response.len();
if covariate_data.nrows() != n {
return Err(format!(
"score_influence_jacobian: covariate rows ({}) != response rows ({n})",
covariate_data.nrows()
));
}
if n == 0 {
return Err("score_influence_jacobian: empty input rows".to_string());
}
let p_resp = family.p_resp();
let p_cov = family.p_cov();
let p1 = p_resp.checked_mul(p_cov).ok_or_else(|| {
format!("score_influence_jacobian: p_resp({p_resp}) * p_cov({p_cov}) overflowed")
})?;
let beta = &fit
.fit
.block_states
.first()
.ok_or_else(|| "score_influence_jacobian: fitted CTN has no block states".to_string())?
.beta;
if beta.len() != p1 {
return Err(format!(
"score_influence_jacobian: beta length {} != p_resp({p_resp}) * p_cov({p_cov})",
beta.len()
));
}
let beta_mat = beta
.view()
.into_shape_with_order((p_resp, p_cov))
.map_err(|e| format!("score_influence_jacobian: beta reshape failed: {e}"))?;
let resp_val = family.evaluate_response_value_basis(response.view())?;
if resp_val.nrows() != n || resp_val.ncols() != p_resp {
return Err(format!(
"score_influence_jacobian: response basis shape {}x{} != {n}x{p_resp}",
resp_val.nrows(),
resp_val.ncols()
));
}
let cov_design = build_term_collection_design(covariate_data, &fit.covariate_spec_resolved)
.map_err(|e| format!("score_influence_jacobian: covariate design build failed: {e}"))?;
if cov_design.design.ncols() != p_cov {
return Err(format!(
"score_influence_jacobian: rebuilt covariate design has {} columns, fitted p_cov is {p_cov}",
cov_design.design.ncols()
));
}
let x_cov = cov_design.design.try_row_chunk(0..n).map_err(|e| {
format!("score_influence_jacobian: covariate design materialization failed: {e}")
})?;
let gamma = fast_abt(&x_cov, &beta_mat);
if gamma.nrows() != n || gamma.ncols() != p_resp {
return Err(format!(
"score_influence_jacobian: gamma shape {}x{} != {n}x{p_resp}",
gamma.nrows(),
gamma.ncols()
));
}
let lower_basis = family.response_lower_basis();
let upper_basis = family.response_upper_basis();
let lower_floor = family.response_lower_floor_offset();
let upper_floor = family.response_upper_floor_offset();
let median = family.response_median();
let pdf_z_floor = normal_pdf(
standard_normal_quantile(TRANSFORMATION_SCORE_PIT_CLIP_EPS)
.map_err(|e| format!("score_influence_jacobian: clip quantile failed: {e}"))?,
);
let mut columns = Array2::<f64>::zeros((n, p1));
let mut z_scores = Array1::<f64>::zeros(n);
for i in 0..n {
let gamma_row = gamma.row(i);
let val_row = resp_val.row(i);
let x_row = x_cov.row(i);
let g0 = gamma_row[0];
let value_floor = TRANSFORMATION_MONOTONICITY_EPS * (response[i] - median);
let mut h = val_row[0] * g0 + value_floor;
let mut l = lower_basis[0] * g0 + lower_floor;
let mut u = upper_basis[0] * g0 + upper_floor;
for k in 1..p_resp {
let gk = gamma_row[k];
let gk_sq = gk * gk;
h += val_row[k] * gk_sq;
l += lower_basis[k] * gk_sq;
u += upper_basis[k] * gk_sq;
}
if !(h.is_finite() && l.is_finite() && u.is_finite()) {
return Err(format!(
"score_influence_jacobian: non-finite transform geometry at row {i}: h={h}, L={l}, U={u}"
));
}
if u <= l {
return Err(format!(
"score_influence_jacobian: support order violated at row {i}: L={l:.6e} >= U={u:.6e}"
));
}
let z = transformation_normal_pit_score(h, l, u, TRANSFORMATION_SCORE_PIT_CLIP_EPS)
.map_err(|e| format!("score_influence_jacobian: PIT score failed at row {i}: {e}"))?;
z_scores[i] = z;
let phi_l = normal_cdf(l);
let phi_u = normal_cdf(u);
let denom_mass = phi_u - phi_l;
if !(denom_mass.is_finite() && denom_mass > 0.0) {
return Err(format!(
"score_influence_jacobian: endpoint mass not resolvable at row {i}: Φ(U)−Φ(L)={denom_mass:.6e}"
));
}
let u_pit = normal_cdf(z);
let pdf_h = normal_pdf(h.clamp(l, u));
let pdf_l = normal_pdf(l);
let pdf_u = normal_pdf(u);
let pdf_z = normal_pdf(z).max(pdf_z_floor);
let mut row = columns.row_mut(i);
for k in 0..p_resp {
let (dh_scalar, dl_scalar, du_scalar) = if k == 0 {
(val_row[0], lower_basis[0], upper_basis[0])
} else {
let two_gk = 2.0 * gamma_row[k];
(
val_row[k] * two_gk,
lower_basis[k] * two_gk,
upper_basis[k] * two_gk,
)
};
let base = k * p_cov;
for j in 0..p_cov {
let xij = x_row[j];
let dh = dh_scalar * xij;
let dl = dl_scalar * xij;
let du = du_scalar * xij;
let du_pit =
(pdf_h * dh - u_pit * (pdf_u * du - pdf_l * dl) - pdf_l * dl) / denom_mass;
row[base + j] = du_pit / pdf_z;
}
}
}
if columns.iter().any(|v| !v.is_finite()) {
return Err("score_influence_jacobian: produced non-finite Jacobian entries".to_string());
}
if z_scores.iter().any(|v| !v.is_finite()) {
return Err("score_influence_jacobian: produced non-finite z scores".to_string());
}
Ok(ScoreInfluenceJacobian {
columns,
z: z_scores,
})
}
pub fn influence_block_design(
jac: &ScoreInfluenceJacobian,
pilot_beta0: &Array1<f64>,
s_f: f64,
) -> Array2<f64> {
let n = jac.columns.nrows();
assert_eq!(
pilot_beta0.len(),
n,
"influence_block_design: pilot_beta0 length must equal Jacobian rows"
);
let mut out = jac.columns.clone();
for (i, mut row) in out.axis_iter_mut(ndarray::Axis(0)).enumerate() {
let scale = s_f * pilot_beta0[i];
row.mapv_inplace(|v| v * scale);
}
out
}
pub(crate) fn residualize_influence_columns(
z_infl: &Array2<f64>,
marginal_design: ArrayView2<f64>,
w_metric: &Array1<f64>,
eps: f64,
) -> Array2<f64> {
let n = marginal_design.nrows();
assert_eq!(
z_infl.nrows(),
n,
"residualize_influence_columns: Z_infl rows must equal marginal design rows"
);
assert_eq!(
w_metric.len(),
n,
"residualize_influence_columns: row metric length must equal marginal design rows"
);
let p_m = marginal_design.ncols();
if p_m == 0 {
return z_infl.clone();
}
let mut gram = fast_xt_diag_x(&marginal_design, w_metric);
for i in 0..p_m {
gram[[i, i]] += eps;
}
let cross = fast_xt_diag_y(&marginal_design, w_metric, z_infl);
let gram_view = FaerArrayView::new(&gram);
let factor = factorize_symmetricwith_fallback(gram_view.as_ref(), Side::Lower)
.expect("residualize_influence_columns: weighted marginal Gram factorization failed");
let coeffs = factor
.solvemulti(&cross)
.expect("residualize_influence_columns: marginal projection solve failed");
let projection = fast_ab(&marginal_design, &coeffs);
z_infl - &projection
}
pub(crate) const INFLUENCE_PROJECTION_RELATIVE_RIDGE: f64 = 1.0e-10;
pub(crate) const INFLUENCE_PROJECTION_RIDGE_FLOOR: f64 = 1.0e-12;
pub(crate) fn residualized_influence_block(
raw_jac: &Array2<f64>,
oof_z: &Array1<f64>,
pilot_beta0: &Array1<f64>,
s_f: f64,
marginal_design: ArrayView2<f64>,
w_metric: &Array1<f64>,
) -> Result<Array2<f64>, String> {
let jac = ScoreInfluenceJacobian {
columns: raw_jac.clone(),
z: oof_z.clone(),
};
let z_infl = influence_block_design(&jac, pilot_beta0, s_f);
let p_m = marginal_design.ncols();
let gram = fast_xt_diag_x(&marginal_design, w_metric);
let gram_scale = (0..p_m).map(|i| gram[[i, i]]).fold(0.0_f64, f64::max);
let eps =
(gram_scale * INFLUENCE_PROJECTION_RELATIVE_RIDGE).max(INFLUENCE_PROJECTION_RIDGE_FLOOR);
let residualized = residualize_influence_columns(&z_infl, marginal_design, w_metric, eps);
if residualized.iter().any(|v| !v.is_finite()) {
return Err(
"residualized_influence_block: residualized influence columns contain non-finite entries"
.to_string(),
);
}
Ok(residualized)
}