use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
#[derive(Clone, Debug)]
pub struct SparseCode {
pub indices: Vec<u32>,
pub codes: Vec<f32>,
}
pub fn solve_row_codes(
row: ArrayView1<'_, f32>,
decoder: ArrayView2<'_, f32>,
active: &[(u32, f32)],
s: usize,
ridge: f32,
) -> SparseCode {
let m = active.len();
if m == 0 {
return SparseCode {
indices: vec![0u32; s],
codes: vec![0.0f32; s],
};
}
let p = row.len();
let mut gram = Array2::<f64>::zeros((m, m));
let mut rhs = Array1::<f64>::zeros(m);
for i in 0..m {
let ai = active[i].0 as usize;
let di = decoder.row(ai);
let mut proj = 0.0f64;
for c in 0..p {
proj += di[c] as f64 * row[c] as f64;
}
rhs[i] = proj;
for j in i..m {
let aj = active[j].0 as usize;
let dj = decoder.row(aj);
let mut g = 0.0f64;
for c in 0..p {
g += di[c] as f64 * dj[c] as f64;
}
gram[[i, j]] = g;
gram[[j, i]] = g;
}
gram[[i, i]] += ridge as f64;
}
let solution = solve_spd(&gram, &rhs);
let mut indices = Vec::with_capacity(s);
let mut codes = Vec::with_capacity(s);
for i in 0..m.min(s) {
indices.push(active[i].0);
codes.push(solution[i] as f32);
}
while indices.len() < s {
indices.push(active[0].0);
codes.push(0.0f32);
}
SparseCode { indices, codes }
}
fn solve_spd(gram: &Array2<f64>, rhs: &Array1<f64>) -> Array1<f64> {
use crate::faer_ndarray::FaerCholesky;
use faer::Side;
let m = rhs.len();
let mut a = gram.clone();
let mut bump = 0.0f64;
for _attempt in 0..6 {
if let Ok(factor) = a.cholesky(Side::Lower) {
return factor.solvevec(rhs);
}
bump = if bump == 0.0 { 1.0e-8 } else { bump * 16.0 };
a = gram.clone();
for i in 0..m {
a[[i, i]] += bump;
}
}
let mut out = Array1::<f64>::zeros(m);
for i in 0..m {
let d = gram[[i, i]].max(1.0e-12);
out[i] = rhs[i] / d;
}
out
}