use rsomics_common::{Result, RsomicsError};
pub struct Fit {
pub weights: Vec<f64>,
pub iters: usize,
pub converged: bool,
}
pub fn array_weights(
y: &[Vec<f64>],
x: &[Vec<f64>],
prior_n: f64,
maxiter: usize,
tol: f64,
) -> Result<Fit> {
let ng = y.len();
let n = x.len();
let p = x[0].len();
if n <= p {
return Err(RsomicsError::InvalidInput(format!(
"design is saturated: {n} arrays, {p} coefficients leave no residual df"
)));
}
let mut gamma = vec![0.0f64; n];
let mut converged = false;
let mut iters = 0;
for it in 1..=maxiter {
iters = it;
let w: Vec<f64> = gamma.iter().map(|g| (-g).exp()).collect();
let hat = weighted_hat(x, &w)?;
let h: Vec<f64> = (0..n).map(|i| hat[i][i]).collect();
let sw: Vec<f64> = w.iter().map(|wj| wj.sqrt()).collect();
let mut a = vec![vec![0.0f64; n]; n];
for i in 0..n {
for k in 0..n {
let proj = if i == k { 1.0 } else { 0.0 } - hat[i][k];
a[i][k] = sw[i] * proj / sw[k];
}
}
let mut std_sumsq = vec![0.0f64; n];
for row in y {
let r = residual(row, &hat);
let mut rss = 0.0;
let mut rw2 = vec![0.0f64; n];
for j in 0..n {
let v = w[j] * r[j] * r[j];
rw2[j] = v;
rss += v;
}
let s2 = rss / (n - p) as f64;
for j in 0..n {
std_sumsq[j] += rw2[j] / s2;
}
}
let mut score = vec![0.0f64; n];
let mut info = vec![vec![0.0f64; n]; n];
for j in 0..n {
for (info_jk, &a_jk) in info[j].iter_mut().zip(&a[j]) {
*info_jk = 0.5 * ng as f64 * a_jk * a_jk;
}
let e = (-gamma[j]).exp();
score[j] = 0.5 * (std_sumsq[j] - ng as f64 * (1.0 - h[j])) + 0.5 * prior_n * (e - 1.0);
info[j][j] += 0.5 * prior_n * e;
}
let ridge = mean_diag(&info) * 1e-4;
for row in &mut info {
for v in row.iter_mut() {
*v += ridge;
}
}
let mut step = solve(&info, &score)?;
let mean_step = step.iter().sum::<f64>() / n as f64;
for s in &mut step {
*s = (*s - mean_step).clamp(-1.0, 1.0);
}
let mut max_abs = 0.0f64;
for j in 0..n {
gamma[j] += step[j];
max_abs = max_abs.max(step[j].abs());
}
let mean_gamma = gamma.iter().sum::<f64>() / n as f64;
for g in &mut gamma {
*g -= mean_gamma;
}
if max_abs < tol {
converged = true;
break;
}
}
let mut weights: Vec<f64> = gamma.iter().map(|g| (-g).exp()).collect();
let log_mean = weights.iter().map(|w| w.ln()).sum::<f64>() / n as f64;
let scale = log_mean.exp();
for w in &mut weights {
*w /= scale;
}
Ok(Fit {
weights,
iters,
converged,
})
}
fn weighted_hat(x: &[Vec<f64>], w: &[f64]) -> Result<Vec<Vec<f64>>> {
let n = x.len();
let p = x[0].len();
let mut xtwx = vec![vec![0.0f64; p]; p];
for i in 0..n {
for a in 0..p {
let xa = x[i][a] * w[i];
for b in 0..p {
xtwx[a][b] += xa * x[i][b];
}
}
}
let inv = invert(&xtwx)?;
let mut hat = vec![vec![0.0f64; n]; n];
for i in 0..n {
let mut xi_inv = vec![0.0f64; p];
for a in 0..p {
let mut acc = 0.0;
for b in 0..p {
acc += x[i][b] * inv[b][a];
}
xi_inv[a] = acc;
}
for k in 0..n {
let mut acc = 0.0;
for a in 0..p {
acc += xi_inv[a] * x[k][a];
}
hat[i][k] = acc * w[k];
}
}
Ok(hat)
}
fn residual(y: &[f64], hat: &[Vec<f64>]) -> Vec<f64> {
let n = y.len();
let mut r = vec![0.0f64; n];
for i in 0..n {
let mut fitted = 0.0;
for k in 0..n {
fitted += hat[i][k] * y[k];
}
r[i] = y[i] - fitted;
}
r
}
fn mean_diag(m: &[Vec<f64>]) -> f64 {
let n = m.len();
(0..n).map(|i| m[i][i]).sum::<f64>() / n as f64
}
fn invert(m: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let n = m.len();
let mut a: Vec<Vec<f64>> = m.to_vec();
let mut inv = vec![vec![0.0f64; n]; n];
for (i, row) in inv.iter_mut().enumerate() {
row[i] = 1.0;
}
for col in 0..n {
let mut pivot = col;
for r in (col + 1)..n {
if a[r][col].abs() > a[pivot][col].abs() {
pivot = r;
}
}
if a[pivot][col].abs() < 1e-12 {
return Err(RsomicsError::InvalidInput(
"design matrix is rank deficient".into(),
));
}
a.swap(col, pivot);
inv.swap(col, pivot);
let d = a[col][col];
for j in 0..n {
a[col][j] /= d;
inv[col][j] /= d;
}
for r in 0..n {
if r == col {
continue;
}
let f = a[r][col];
if f == 0.0 {
continue;
}
for j in 0..n {
a[r][j] -= f * a[col][j];
inv[r][j] -= f * inv[col][j];
}
}
}
Ok(inv)
}
fn solve(m: &[Vec<f64>], b: &[f64]) -> Result<Vec<f64>> {
let inv = invert(m)?;
let n = b.len();
let mut x = vec![0.0f64; n];
for i in 0..n {
let mut acc = 0.0;
for j in 0..n {
acc += inv[i][j] * b[j];
}
x[i] = acc;
}
Ok(x)
}