use crate::inference::smooth_test::{SmoothTestInput, SmoothTestScale, wood_smooth_test};
use crate::inference::structure_evidence::{e_benjamini_hochberg, log_e_from_p_calibrator};
use crate::linalg::faer_ndarray::FaerSvd;
use crate::solver::row_measure::RowSubsampleMask;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuxOutcomeFamily {
Binomial,
Multinomial { n_classes: usize },
}
impl AuxOutcomeFamily {
pub fn n_eta_channels(&self) -> usize {
match self {
AuxOutcomeFamily::Binomial => 1,
AuxOutcomeFamily::Multinomial { n_classes } => n_classes.saturating_sub(1),
}
}
pub fn behavioral_subspace_dim(&self) -> usize {
self.n_eta_channels()
}
}
#[derive(Debug, Clone)]
pub struct BehavioralHead {
family: AuxOutcomeFamily,
y: Array1<f64>,
w_row: Array1<f64>,
}
impl BehavioralHead {
pub fn new(
family: AuxOutcomeFamily,
y: Array1<f64>,
w_row: Array1<f64>,
) -> Result<Self, String> {
let n = y.len();
if w_row.len() != n {
return Err(format!(
"BehavioralHead: w_row length {} != labels length {n}",
w_row.len()
));
}
for &v in w_row.iter() {
if !(v.is_finite() && v >= 0.0) {
return Err(format!(
"BehavioralHead: row weights must be finite and ≥ 0, got {v}"
));
}
}
match family {
AuxOutcomeFamily::Binomial => {
for (i, &label) in y.iter().enumerate() {
if label != 0.0 && label != 1.0 {
return Err(format!(
"BehavioralHead(Binomial): label[{i}] = {label} is not 0/1"
));
}
}
}
AuxOutcomeFamily::Multinomial { n_classes } => {
if n_classes < 2 {
return Err(format!(
"BehavioralHead(Multinomial): need ≥ 2 classes, got {n_classes}"
));
}
for (i, &label) in y.iter().enumerate() {
let k = label as usize;
if k as f64 != label || k >= n_classes {
return Err(format!(
"BehavioralHead(Multinomial): label[{i}] = {label} not an \
integer class index in 0..{n_classes}"
));
}
}
}
}
Ok(Self { family, y, w_row })
}
pub fn fully_supervised(family: AuxOutcomeFamily, y: Array1<f64>) -> Result<Self, String> {
let n = y.len();
Self::new(family, y, Array1::from_elem(n, 1.0))
}
pub fn with_row_measure(
family: AuxOutcomeFamily,
y: Array1<f64>,
measure: &RowSubsampleMask,
) -> Result<Self, String> {
let n = y.len();
let (indices, weights) = measure.indices_and_weights(n);
let mut w_row = Array1::<f64>::zeros(n);
for &idx in &indices {
if idx < n {
w_row[idx] = weights[idx];
}
}
Self::new(family, y, w_row)
}
pub fn family(&self) -> AuxOutcomeFamily {
self.family
}
pub fn n_obs(&self) -> usize {
self.y.len()
}
pub fn n_coeffs(&self, latent_dim: usize) -> usize {
self.family.n_eta_channels() * (1 + latent_dim)
}
pub fn effective_labeled_count(&self) -> f64 {
self.w_row.iter().sum()
}
fn eta(&self, t: ArrayView2<'_, f64>, coeffs: ArrayView1<'_, f64>) -> Array2<f64> {
let (n, d) = t.dim();
let n_eta = self.family.n_eta_channels();
let mut eta = Array2::<f64>::zeros((n, n_eta));
for c in 0..n_eta {
let base = c * (1 + d);
let a = coeffs[base];
for row in 0..n {
let mut acc = a;
for axis in 0..d {
acc += t[[row, axis]] * coeffs[base + 1 + axis];
}
eta[[row, c]] = acc;
}
}
eta
}
pub fn neg_loglik_and_grad(
&self,
t: ArrayView2<'_, f64>,
coeffs: ArrayView1<'_, f64>,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
let (n, d) = t.dim();
if n != self.y.len() {
return Err(format!(
"BehavioralHead: latent rows {n} != labels {}",
self.y.len()
));
}
let n_eta = self.family.n_eta_channels();
if coeffs.len() != n_eta * (1 + d) {
return Err(format!(
"BehavioralHead: coeffs length {} != n_eta·(1+d) = {}",
coeffs.len(),
n_eta * (1 + d)
));
}
let eta = self.eta(t, coeffs);
let mut nll = 0.0_f64;
let mut grad_coeffs = Array1::<f64>::zeros(n_eta * (1 + d));
let mut grad_t = Array2::<f64>::zeros((n, d));
match self.family {
AuxOutcomeFamily::Binomial => {
for row in 0..n {
let w = self.w_row[row];
if w == 0.0 {
continue;
}
let e = eta[[row, 0]];
let log1p = if e > 0.0 {
e + (-e).exp().ln_1p()
} else {
e.exp().ln_1p()
};
let y = self.y[row];
nll += w * (log1p - y * e);
let p = 1.0 / (1.0 + (-e).exp());
let r = w * (p - y);
grad_coeffs[0] += r;
for axis in 0..d {
grad_coeffs[1 + axis] += r * t[[row, axis]];
grad_t[[row, axis]] += r * coeffs[1 + axis];
}
}
}
AuxOutcomeFamily::Multinomial { .. } => {
for row in 0..n {
let w = self.w_row[row];
if w == 0.0 {
continue;
}
let mut max_eta = 0.0_f64;
for c in 0..n_eta {
if eta[[row, c]] > max_eta {
max_eta = eta[[row, c]];
}
}
let mut denom = (0.0 - max_eta).exp();
for c in 0..n_eta {
denom += (eta[[row, c]] - max_eta).exp();
}
let lse = max_eta + denom.ln();
let label = self.y[row] as usize;
let eta_y = if label == 0 {
0.0
} else {
eta[[row, label - 1]]
};
nll += w * (lse - eta_y);
for c in 0..n_eta {
let p_c = (eta[[row, c]] - lse).exp();
let indicator = if label == c + 1 { 1.0 } else { 0.0 };
let r = w * (p_c - indicator);
let base = c * (1 + d);
grad_coeffs[base] += r;
for axis in 0..d {
grad_coeffs[base + 1 + axis] += r * t[[row, axis]];
grad_t[[row, axis]] += r * coeffs[base + 1 + axis];
}
}
}
}
}
Ok((nll, grad_coeffs, grad_t))
}
pub fn head_working_weights(
&self,
t: ArrayView2<'_, f64>,
coeffs: ArrayView1<'_, f64>,
) -> Result<Array2<f64>, String> {
let (n, d) = t.dim();
let n_eta = self.family.n_eta_channels();
if coeffs.len() != n_eta * (1 + d) {
return Err("BehavioralHead::head_working_weights: coeff length mismatch".to_string());
}
let eta = self.eta(t, coeffs);
let mut s = Array2::<f64>::zeros((n, n_eta));
match self.family {
AuxOutcomeFamily::Binomial => {
for row in 0..n {
let p = 1.0 / (1.0 + (-eta[[row, 0]]).exp());
s[[row, 0]] = self.w_row[row] * p * (1.0 - p);
}
}
AuxOutcomeFamily::Multinomial { .. } => {
for row in 0..n {
let mut max_eta = 0.0_f64;
for c in 0..n_eta {
if eta[[row, c]] > max_eta {
max_eta = eta[[row, c]];
}
}
let mut denom = (0.0 - max_eta).exp();
for c in 0..n_eta {
denom += (eta[[row, c]] - max_eta).exp();
}
let lse = max_eta + denom.ln();
for c in 0..n_eta {
let p_c = (eta[[row, c]] - lse).exp();
s[[row, c]] = self.w_row[row] * p_c * (1.0 - p_c);
}
}
}
}
Ok(s)
}
}
#[derive(Debug, Clone)]
pub struct LeakageAbsorber {
q: Array2<f64>,
}
impl LeakageAbsorber {
pub fn from_score_influence(
score_influence: ArrayView2<'_, f64>,
latent_dim: usize,
) -> Result<Self, String> {
let (n, cols) = score_influence.dim();
if latent_dim == 0 {
return Ok(Self {
q: Array2::<f64>::zeros((0, 0)),
});
}
if cols % latent_dim != 0 {
return Err(format!(
"LeakageAbsorber: score_influence has {cols} columns, not a multiple of \
latent_dim {latent_dim}"
));
}
let n_eta = cols / latent_dim;
let mut gram = Array2::<f64>::zeros((latent_dim, latent_dim));
for row in 0..n {
for c in 0..n_eta {
let base = c * latent_dim;
for i in 0..latent_dim {
let vi = score_influence[[row, base + i]];
for j in 0..latent_dim {
gram[[i, j]] += vi * score_influence[[row, base + j]];
}
}
}
}
let (u_opt, sv, _vt) = gram
.svd(true, false)
.map_err(|e| format!("LeakageAbsorber: SVD of label-channel Gram failed: {e}"))?;
let u = u_opt.ok_or_else(|| "LeakageAbsorber: SVD did not return U".to_string())?;
let max_sv = sv.iter().cloned().fold(0.0_f64, f64::max);
let tol = max_sv * (latent_dim as f64) * f64::EPSILON;
let rank = sv.iter().filter(|&&s| s > tol).count();
let mut q = Array2::<f64>::zeros((latent_dim, rank));
for col in 0..rank {
for r in 0..latent_dim {
q[[r, col]] = u[[r, col]];
}
}
Ok(Self { q })
}
pub fn rank(&self) -> usize {
self.q.ncols()
}
pub fn basis(&self) -> ArrayView2<'_, f64> {
self.q.view()
}
pub fn orthogonalize_recon_update(&self, delta_t: ArrayView2<'_, f64>) -> Array2<f64> {
let (_, d) = delta_t.dim();
if self.q.ncols() == 0 || self.q.nrows() != d {
return delta_t.to_owned();
}
let proj_coords = delta_t.dot(&self.q); let proj = proj_coords.dot(&self.q.t()); let mut out = delta_t.to_owned();
out -= &proj;
out
}
pub fn absorbed_design_block(&self, t: ArrayView2<'_, f64>) -> Array2<f64> {
if self.q.ncols() == 0 {
return Array2::<f64>::zeros((t.nrows(), 0));
}
t.dot(&self.q)
}
}
#[derive(Debug, Clone)]
pub struct HeadFeatureSignificance {
pub statistic: Vec<f64>,
pub p_value: Vec<f64>,
pub fdr_rejected: Vec<usize>,
pub alpha: f64,
}
pub fn head_feature_significance(
coeffs: ArrayView1<'_, f64>,
covariance: &Array2<f64>,
latent_dim: usize,
n_eta: usize,
residual_df: f64,
alpha: f64,
) -> Result<HeadFeatureSignificance, String> {
let block = 1 + latent_dim;
if coeffs.len() != n_eta * block {
return Err(format!(
"head_feature_significance: coeffs length {} != n_eta·(1+d) = {}",
coeffs.len(),
n_eta * block
));
}
if covariance.nrows() != coeffs.len() || covariance.ncols() != coeffs.len() {
return Err(format!(
"head_feature_significance: covariance must be {0}×{0}, got {1}×{2}",
coeffs.len(),
covariance.nrows(),
covariance.ncols()
));
}
let beta = coeffs.to_owned();
let mut statistic = Vec::with_capacity(latent_dim);
let mut p_value = Vec::with_capacity(latent_dim);
let mut log_e_values = Vec::with_capacity(latent_dim);
for axis in 0..latent_dim {
let mut best_p = 1.0_f64;
let mut best_stat = 0.0_f64;
let mut tested_channels = 0_usize;
for c in 0..n_eta {
let idx = c * block + 1 + axis;
let input = SmoothTestInput {
beta: beta.view(),
covariance,
influence_matrix: None,
coeff_range: idx..idx + 1,
edf: 1.0,
nullspace_dim: 1,
residual_df,
scale: SmoothTestScale::Estimated,
};
if let Some(res) = wood_smooth_test(input) {
tested_channels += 1;
if res.p_value < best_p {
best_p = res.p_value;
best_stat = res.statistic;
}
}
}
let axis_p = if tested_channels > 0 {
(best_p * tested_channels as f64).min(1.0)
} else {
1.0
};
statistic.push(best_stat);
p_value.push(axis_p);
let calibration_p = axis_p.max(f64::MIN_POSITIVE);
log_e_values.push(log_e_from_p_calibrator(calibration_p).map_err(|err| {
format!("head_feature_significance: invalid calibrated axis p-value: {err}")
})?);
}
let fdr_rejected = e_benjamini_hochberg(&log_e_values, alpha);
Ok(HeadFeatureSignificance {
statistic,
p_value,
fdr_rejected,
alpha,
})
}
pub fn orthonormal_span(jacobian: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
let (n, d) = jacobian.dim();
if d == 0 || n == 0 {
return Ok(Array2::<f64>::zeros((d, 0)));
}
let gram = jacobian.t().dot(&jacobian);
let (u_opt, sv, _vt) = gram
.svd(true, false)
.map_err(|e| format!("orthonormal_span: SVD failed: {e}"))?;
let u = u_opt.ok_or_else(|| "orthonormal_span: SVD did not return U".to_string())?;
let max_sv = sv.iter().cloned().fold(0.0_f64, f64::max);
let tol = max_sv * (d as f64) * f64::EPSILON;
let rank = sv.iter().filter(|&&s| s > tol).count();
let mut q = Array2::<f64>::zeros((d, rank));
for col in 0..rank {
for r in 0..d {
q[[r, col]] = u[[r, col]];
}
}
Ok(q)
}