use std::sync::Arc;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use crate::inference::row_metric::RowMetric;
use crate::linalg::faer_ndarray::{FaerCholesky, FaerEigh};
use faer::Side;
const ALTERNATION_SWEEPS: usize = 8;
const ACTIVITY_SCALE_BINS: usize = 8;
const DIAGONAL_REL_FLOOR: f64 = 1e-6;
const SCALE_REL_FLOOR: f64 = 1e-4;
#[derive(Clone, Debug)]
pub struct StructuredResidualModel {
p: usize,
factor_rank: usize,
lambda: Array2<f64>,
diagonal: Array1<f64>,
row_scale: Array1<f64>,
log_evidence: f64,
}
pub struct ResidualFactorInput<'a> {
pub residuals: ArrayView2<'a, f64>,
pub activity: ArrayView1<'a, f64>,
pub max_factor_rank: usize,
}
impl StructuredResidualModel {
pub fn fit(input: ResidualFactorInput<'_>) -> Result<Self, String> {
let r = input.residuals;
let z = input.activity;
let n = r.nrows();
let p = r.ncols();
if n == 0 || p == 0 {
return Err(format!(
"StructuredResidualModel::fit: residuals must be non-empty; got ({n}, {p})"
));
}
if z.len() != n {
return Err(format!(
"StructuredResidualModel::fit: activity length {} != residual rows {n}",
z.len()
));
}
if !r.iter().all(|v| v.is_finite()) {
return Err("StructuredResidualModel::fit: residuals must be finite".to_string());
}
if !z.iter().all(|v| v.is_finite()) {
return Err("StructuredResidualModel::fit: activity must be finite".to_string());
}
let bins = ACTIVITY_SCALE_BINS.max(1);
let z_min = z.iter().copied().fold(f64::INFINITY, f64::min);
let z_max = z.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let z_span = z_max - z_min;
let row_bin: Vec<usize> = (0..n)
.map(|i| {
if z_span <= 0.0 {
0
} else {
let frac = (z[i] - z_min) / z_span;
let idx = (frac * bins as f64).floor() as isize;
idx.clamp(0, bins as isize - 1) as usize
}
})
.collect();
let max_rank = input.max_factor_rank.min(p.saturating_sub(1));
let mut best: Option<StructuredResidualModel> = None;
for rank in 0..=max_rank {
let model = Self::fit_fixed_rank(r, &row_bin, bins, rank)?;
let take = match &best {
None => true,
Some(b) => model.log_evidence > b.log_evidence,
};
if take {
best = Some(model);
}
}
best.ok_or_else(|| "StructuredResidualModel::fit: evidence ladder empty".to_string())
}
fn fit_fixed_rank(
r: ArrayView2<'_, f64>,
row_bin: &[usize],
bins: usize,
rank: usize,
) -> Result<Self, String> {
let n = r.nrows();
let p = r.ncols();
let mut total_var = 0.0_f64;
for i in 0..n {
for j in 0..p {
total_var += r[[i, j]] * r[[i, j]];
}
}
let mean_var = (total_var / (n as f64 * p as f64)).max(f64::MIN_POSITIVE);
let diag_floor = DIAGONAL_REL_FLOOR * mean_var;
let mut row_scale = Array1::<f64>::ones(n);
let mut bin_scale = Array1::<f64>::ones(bins);
let raw_diag = column_variances(r);
let mut diagonal = raw_diag.mapv(|v| v.max(diag_floor));
let mut lambda = Array2::<f64>::zeros((p, rank));
for _sweep in 0..ALTERNATION_SWEEPS {
let s = scaled_second_moment(r, &row_scale);
let (evals, evecs) = symmetric_eig_ascending(&s)?;
if rank > 0 {
for k in 0..rank {
let col = p - 1 - k;
let mean_diag = diagonal.iter().copied().sum::<f64>() / p as f64;
let energy = (evals[col] - mean_diag).max(0.0);
let amp = energy.sqrt();
for row in 0..p {
lambda[[row, k]] = amp * evecs[[row, col]];
}
}
}
for j in 0..p {
let mut factor_var = 0.0_f64;
for k in 0..rank {
factor_var += lambda[[j, k]] * lambda[[j, k]];
}
diagonal[j] = (raw_diag[j] - factor_var).max(diag_floor);
}
if rank > 0 {
let mut bin_num = Array1::<f64>::zeros(bins);
let mut bin_den = Array1::<f64>::zeros(bins);
let coords = factor_coordinates(&lambda, &diagonal, r)?;
for i in 0..n {
let mut energy = 0.0_f64;
for k in 0..rank {
energy += coords[[i, k]] * coords[[i, k]];
}
let b = row_bin[i];
bin_num[b] += energy;
bin_den[b] += rank as f64;
}
let global = {
let num: f64 = bin_num.iter().sum();
let den: f64 = bin_den.iter().sum();
if den > 0.0 { num / den } else { 1.0 }
};
for b in 0..bins {
bin_scale[b] = if bin_den[b] > 0.0 {
bin_num[b] / bin_den[b]
} else {
global
};
}
let scale_floor = SCALE_REL_FLOOR * global.max(f64::MIN_POSITIVE);
let smoothed = moving_average_3(&bin_scale);
for b in 0..bins {
bin_scale[b] = smoothed[b].max(scale_floor);
}
let mean_scale = bin_scale.iter().copied().sum::<f64>() / bins as f64;
if mean_scale > 0.0 {
bin_scale.mapv_inplace(|v| v / mean_scale);
}
for i in 0..n {
row_scale[i] = bin_scale[row_bin[i]].max(scale_floor);
}
}
}
let log_evidence = penalized_log_evidence(r, &lambda, &diagonal, &row_scale, rank);
let mut model = Self {
p,
factor_rank: rank,
lambda,
diagonal,
row_scale,
log_evidence,
};
if !model.is_finite() {
model.lambda = Array2::<f64>::zeros((p, rank));
model.row_scale = Array1::<f64>::ones(n);
}
Ok(model)
}
fn is_finite(&self) -> bool {
self.lambda.iter().all(|v| v.is_finite())
&& self.diagonal.iter().all(|v| v.is_finite() && *v > 0.0)
&& self.row_scale.iter().all(|v| v.is_finite() && *v > 0.0)
&& self.log_evidence.is_finite()
}
pub fn factor_rank(&self) -> usize {
self.factor_rank
}
pub fn factor(&self) -> ArrayView2<'_, f64> {
self.lambda.view()
}
pub fn diagonal(&self) -> ArrayView1<'_, f64> {
self.diagonal.view()
}
pub fn row_scale(&self) -> ArrayView1<'_, f64> {
self.row_scale.view()
}
pub fn log_evidence(&self) -> f64 {
self.log_evidence
}
pub fn row_metric(&self, n_rows: usize) -> Result<RowMetric, String> {
if n_rows != self.row_scale.len() {
return Err(format!(
"StructuredResidualModel::row_metric: requested {n_rows} rows but model has {}",
self.row_scale.len()
));
}
let p = self.p;
let mut u = Array2::<f64>::zeros((n_rows, p * p));
for row in 0..n_rows {
let precision = self.row_precision(row)?;
let factor = lower_cholesky_psd(&precision)?;
for i in 0..p {
for k in 0..p {
u[[row, i * p + k]] = factor[[i, k]];
}
}
}
RowMetric::whitened_structured(Arc::new(u), p, p)
}
fn row_precision(&self, row: usize) -> Result<Array2<f64>, String> {
let p = self.p;
let r = self.factor_rank;
let c = self.row_scale[row].max(f64::MIN_POSITIVE);
let d_inv: Vec<f64> = (0..p).map(|i| 1.0 / self.diagonal[i]).collect();
let mut precision = Array2::<f64>::zeros((p, p));
for i in 0..p {
precision[[i, i]] = d_inv[i];
}
if r == 0 {
return Ok(precision);
}
let mut b = Array2::<f64>::zeros((p, r));
for i in 0..p {
for k in 0..r {
b[[i, k]] = d_inv[i] * self.lambda[[i, k]];
}
}
let mut cap = Array2::<f64>::zeros((r, r));
for a in 0..r {
for bk in 0..r {
let mut acc = 0.0_f64;
for i in 0..p {
acc += self.lambda[[i, a]] * b[[i, bk]];
}
cap[[a, bk]] = acc;
}
cap[[a, a]] += 1.0 / c;
}
let chol = cap
.cholesky(Side::Lower)
.map_err(|e| format!("StructuredResidualModel::row_precision capacitance: {e:?}"))?;
let mut bt = Array2::<f64>::zeros((r, p));
for k in 0..r {
for i in 0..p {
bt[[k, i]] = b[[i, k]];
}
}
let x = chol.solve_mat(&bt); for i in 0..p {
for j in 0..p {
let mut acc = 0.0_f64;
for k in 0..r {
acc += b[[i, k]] * x[[k, j]];
}
precision[[i, j]] -= acc;
}
}
for i in 0..p {
for j in (i + 1)..p {
let avg = 0.5 * (precision[[i, j]] + precision[[j, i]]);
precision[[i, j]] = avg;
precision[[j, i]] = avg;
}
}
Ok(precision)
}
}
fn column_variances(r: ArrayView2<'_, f64>) -> Array1<f64> {
let n = r.nrows();
let p = r.ncols();
let mut v = Array1::<f64>::zeros(p);
for j in 0..p {
let mut acc = 0.0_f64;
for i in 0..n {
acc += r[[i, j]] * r[[i, j]];
}
v[j] = acc / n as f64;
}
v
}
fn scaled_second_moment(r: ArrayView2<'_, f64>, row_scale: &Array1<f64>) -> Array2<f64> {
let n = r.nrows();
let p = r.ncols();
let mut s = Array2::<f64>::zeros((p, p));
for i in 0..n {
let w = 1.0 / row_scale[i].max(f64::MIN_POSITIVE);
for a in 0..p {
let ra = r[[i, a]];
for b in 0..p {
s[[a, b]] += w * ra * r[[i, b]];
}
}
}
s.mapv_inplace(|v| v / n as f64);
for a in 0..p {
for b in (a + 1)..p {
let avg = 0.5 * (s[[a, b]] + s[[b, a]]);
s[[a, b]] = avg;
s[[b, a]] = avg;
}
}
s
}
fn factor_coordinates(
lambda: &Array2<f64>,
diagonal: &Array1<f64>,
r: ArrayView2<'_, f64>,
) -> Result<Array2<f64>, String> {
let p = lambda.nrows();
let rank = lambda.ncols();
let n = r.nrows();
let d_inv: Vec<f64> = (0..p).map(|i| 1.0 / diagonal[i]).collect();
let mut normal = Array2::<f64>::zeros((rank, rank));
for a in 0..rank {
for b in 0..rank {
let mut acc = 0.0_f64;
for i in 0..p {
acc += lambda[[i, a]] * d_inv[i] * lambda[[i, b]];
}
normal[[a, b]] = acc;
}
}
let trace = (0..rank).map(|k| normal[[k, k]]).sum::<f64>().max(1.0);
let ridge = 1e-10 * trace / rank.max(1) as f64;
for k in 0..rank {
normal[[k, k]] += ridge;
}
let chol = normal
.cholesky(Side::Lower)
.map_err(|e| format!("factor_coordinates normal solve: {e:?}"))?;
let mut coords = Array2::<f64>::zeros((n, rank));
let mut rhs = Array1::<f64>::zeros(rank);
for i in 0..n {
for a in 0..rank {
let mut acc = 0.0_f64;
for j in 0..p {
acc += lambda[[j, a]] * d_inv[j] * r[[i, j]];
}
rhs[a] = acc;
}
let gamma = chol.solvevec(&rhs);
for a in 0..rank {
coords[[i, a]] = gamma[a];
}
}
Ok(coords)
}
fn moving_average_3(v: &Array1<f64>) -> Array1<f64> {
let m = v.len();
let mut out = Array1::<f64>::zeros(m);
for i in 0..m {
let lo = i.saturating_sub(1);
let hi = (i + 1).min(m - 1);
let mut acc = 0.0_f64;
let mut cnt = 0.0_f64;
for j in lo..=hi {
acc += v[j];
cnt += 1.0;
}
out[i] = acc / cnt;
}
out
}
fn symmetric_eig_ascending(m: &Array2<f64>) -> Result<(Array1<f64>, Array2<f64>), String> {
m.eigh(Side::Lower)
.map_err(|e| format!("symmetric_eig: {e:?}"))
}
fn lower_cholesky_psd(a: &Array2<f64>) -> Result<Array2<f64>, String> {
if let Ok(chol) = a.cholesky(Side::Lower) {
return Ok(chol.lower_triangular());
}
let (evals, evecs) = symmetric_eig_ascending(a)?;
let max_ev = evals.iter().copied().fold(0.0_f64, f64::max).max(1.0);
let floor = 1e-10 * max_ev;
let p = a.nrows();
let mut sqrt = Array2::<f64>::zeros((p, p));
for i in 0..p {
for j in 0..p {
let mut acc = 0.0_f64;
for k in 0..p {
let ev = evals[k].max(floor);
acc += evecs[[i, k]] * ev.sqrt() * evecs[[j, k]];
}
sqrt[[i, j]] = acc;
}
}
sqrt.cholesky(Side::Lower)
.map(|c| c.lower_triangular())
.map_err(|e| format!("lower_cholesky_psd eigen-repair: {e:?}"))
}
fn penalized_log_evidence(
r: ArrayView2<'_, f64>,
lambda: &Array2<f64>,
diagonal: &Array1<f64>,
row_scale: &Array1<f64>,
rank: usize,
) -> f64 {
let n = r.nrows();
let p = r.ncols();
let d_inv: Vec<f64> = (0..p).map(|i| 1.0 / diagonal[i]).collect();
let log_det_d: f64 = diagonal.iter().map(|&d| d.ln()).sum();
let two_pi_ln = (2.0 * std::f64::consts::PI).ln();
let mut log_lik = 0.0_f64;
for i in 0..n {
let c = row_scale[i].max(f64::MIN_POSITIVE);
let mut quad = 0.0_f64;
for j in 0..p {
quad += r[[i, j]] * d_inv[j] * r[[i, j]];
}
let mut log_det = log_det_d;
if rank > 0 {
let mut m = Array2::<f64>::zeros((rank, rank));
let mut w = Array1::<f64>::zeros(rank);
for a in 0..rank {
let mut wa = 0.0_f64;
for j in 0..p {
wa += lambda[[j, a]] * d_inv[j] * r[[i, j]];
}
w[a] = wa;
for b in 0..rank {
let mut acc = 0.0_f64;
for j in 0..p {
acc += lambda[[j, a]] * d_inv[j] * lambda[[j, b]];
}
m[[a, b]] = acc;
}
m[[a, a]] += 1.0 / c;
}
match m.cholesky(Side::Lower) {
Ok(chol) => {
let y = chol.solvevec(&w);
let mut wy = 0.0_f64;
for a in 0..rank {
wy += w[a] * y[a];
}
quad -= wy;
let diag = chol.diag();
let log_det_m: f64 = diag.iter().map(|&l| (l * l).ln()).sum();
log_det = log_det_d + log_det_m + rank as f64 * c.ln();
}
Err(_) => {
log_det = log_det_d;
}
}
}
log_lik += -0.5 * (log_det + quad + p as f64 * two_pi_ln);
}
let k_params = (p * rank + p + ACTIVITY_SCALE_BINS) as f64;
log_lik - 0.5 * k_params * (n.max(2) as f64).ln()
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{Array1, Array2};
fn lcg_uniform(state: &mut u64) -> f64 {
*state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((*state >> 11) as f64) / ((1u64 << 53) as f64)
}
fn lcg_normal(state: &mut u64) -> f64 {
let u1 = lcg_uniform(state).max(1e-12);
let u2 = lcg_uniform(state);
(-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()
}
#[test]
fn evidence_ladder_prefers_planted_rank_one() {
let n = 5000usize;
let p = 4usize;
let lambda0 = ndarray::array![[1.5], [1.2], [-0.4], [0.3]];
let sigma_eps = 0.2_f64;
let slope = 1.3_f64;
let mut seed = 0xD1B54A32D192ED03_u64;
let mut residuals = Array2::<f64>::zeros((n, p));
let mut activity = Array1::<f64>::zeros(n);
for row in 0..n {
let z = (row as f64) / (n as f64 - 1.0);
activity[row] = z;
let amp = (slope * z).exp().sqrt();
let f = lcg_normal(&mut seed);
for i in 0..p {
residuals[[row, i]] = amp * lambda0[[i, 0]] * f + sigma_eps * lcg_normal(&mut seed);
}
}
let bins = ACTIVITY_SCALE_BINS.max(1);
let row_bin: Vec<usize> = (0..n)
.map(|i| {
let frac = activity[i];
(frac * bins as f64).floor().clamp(0.0, bins as f64 - 1.0) as usize
})
.collect();
let mut report = String::new();
let mut ev = Vec::new();
for rank in 0..=2usize {
let m = StructuredResidualModel::fit_fixed_rank(residuals.view(), &row_bin, bins, rank)
.expect("fixed-rank fit");
let k_params = (p * rank + p + ACTIVITY_SCALE_BINS) as f64;
let log_lik = m.log_evidence() + 0.5 * k_params * (n as f64).ln();
let col_norms: Vec<f64> = (0..rank)
.map(|k| {
m.factor()
.column(k)
.iter()
.map(|v| v * v)
.sum::<f64>()
.sqrt()
})
.collect();
report.push_str(&format!(
"rank {rank}: evidence={:.3} loglik={:.3} penalty={:.3} col_norms={:?} diag={:?}\n",
m.log_evidence(),
log_lik,
0.5 * k_params * (n as f64).ln(),
col_norms,
m.diagonal()
.iter()
.map(|v| (v * 1e4).round() / 1e4)
.collect::<Vec<_>>()
));
ev.push(m.log_evidence());
}
assert!(
ev[1] > ev[0] && ev[1] > ev[2],
"evidence ladder must prefer the planted rank 1; breakdown:\n{report}"
);
}
fn orthonormal_columns(m: ArrayView2<'_, f64>) -> Vec<Array1<f64>> {
let mut basis: Vec<Array1<f64>> = Vec::new();
for k in 0..m.ncols() {
let mut v = m.column(k).to_owned();
for q in &basis {
let c = v.dot(q);
v = &v - &(q * c);
}
let norm = v.dot(&v).sqrt();
if norm > 1e-10 {
basis.push(v / norm);
}
}
basis
}
fn projection_energy(v: &Array1<f64>, basis: &[Array1<f64>]) -> f64 {
basis.iter().map(|q| v.dot(q).powi(2)).sum()
}
#[test]
fn factor_recovers_planted_interference_subspace() {
let n = 6000usize;
let p = 6usize;
let raw1: Array1<f64> = ndarray::array![1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let raw2: Array1<f64> = ndarray::array![1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
let v1 = &raw1 / raw1.dot(&raw1).sqrt();
let v2 = &raw2 / raw2.dot(&raw2).sqrt();
let (amp1, amp2) = (1.4_f64, 0.9_f64);
let sigma_eps = 0.15_f64;
let mut seed = 0x9E3779B97F4A7C15_u64;
let mut residuals = Array2::<f64>::zeros((n, p));
let activity = Array1::<f64>::zeros(n); for row in 0..n {
let f1 = amp1 * lcg_normal(&mut seed);
let f2 = amp2 * lcg_normal(&mut seed);
for i in 0..p {
residuals[[row, i]] = f1 * v1[i] + f2 * v2[i] + sigma_eps * lcg_normal(&mut seed);
}
}
let model = StructuredResidualModel::fit(ResidualFactorInput {
residuals: residuals.view(),
activity: activity.view(),
max_factor_rank: 4,
})
.expect("fit");
assert_eq!(
model.factor_rank(),
2,
"ladder must select the planted rank 2 (got {}, evidence {:.3})",
model.factor_rank(),
model.log_evidence()
);
let basis = orthonormal_columns(model.factor());
assert_eq!(basis.len(), 2, "fitted factor must span 2 directions");
let e1 = projection_energy(&v1, &basis);
let e2 = projection_energy(&v2, &basis);
assert!(
e1 > 0.95 && e2 > 0.95,
"planted directions must lie in range(Λ̂): cos² = ({e1:.4}, {e2:.4})"
);
}
#[test]
fn fitted_scale_recovers_planted_activity_law() {
let n = 6000usize;
let p = 4usize;
let lambda0 = ndarray::array![1.5, 1.2, -0.4, 0.3];
let sigma_eps = 0.2_f64;
let slope = 1.3_f64;
let mut seed = 0xD1B54A32D192ED03_u64;
let mut residuals = Array2::<f64>::zeros((n, p));
let mut activity = Array1::<f64>::zeros(n);
for row in 0..n {
let z = (row as f64) / (n as f64 - 1.0);
activity[row] = z;
let amp = (slope * z).exp().sqrt();
let f = lcg_normal(&mut seed);
for i in 0..p {
residuals[[row, i]] = amp * lambda0[i] * f + sigma_eps * lcg_normal(&mut seed);
}
}
let model = StructuredResidualModel::fit(ResidualFactorInput {
residuals: residuals.view(),
activity: activity.view(),
max_factor_rank: 2,
})
.expect("fit");
assert_eq!(model.factor_rank(), 1, "planted rank is 1");
let fitted_log: Vec<f64> = model.row_scale().iter().map(|c| c.ln()).collect();
let planted_log: Vec<f64> = activity.iter().map(|z| slope * z).collect();
let mean_f = fitted_log.iter().sum::<f64>() / n as f64;
let mean_p = planted_log.iter().sum::<f64>() / n as f64;
let mut cov = 0.0_f64;
let mut var_f = 0.0_f64;
let mut var_p = 0.0_f64;
for i in 0..n {
let df = fitted_log[i] - mean_f;
let dp = planted_log[i] - mean_p;
cov += df * dp;
var_f += df * df;
var_p += dp * dp;
}
let corr = cov / (var_f.sqrt() * var_p.sqrt());
assert!(
corr > 0.9,
"fitted activity law must track the planted exp({slope}·z): corr = {corr:.4}"
);
let lo = model.row_scale()[n / 16]; let hi = model.row_scale()[n - 1 - n / 16]; let ratio = hi / lo;
assert!(
ratio > 1.8 && ratio < 5.5,
"fitted dynamic range {ratio:.3} must bracket the planted ≈3.1"
);
}
}