use irithyll_core::rng::xorshift64_f64;
use irithyll_core::simd::simd_dot as dot;
#[inline]
fn mat_vec(mat: &[f64], v: &[f64], n: usize) -> Vec<f64> {
debug_assert_eq!(mat.len(), n * n);
debug_assert_eq!(v.len(), n);
let mut result = vec![0.0; n];
for (i, res) in result.iter_mut().enumerate() {
let row_start = i * n;
let mut sum = 0.0;
for (j, &vj) in v.iter().enumerate() {
sum += mat[row_start + j] * vj;
}
*res = sum;
}
result
}
#[inline]
fn outer_subtract_scaled(p: &mut [f64], g: &[f64], h: &[f64], lambda: f64, n: usize) {
debug_assert_eq!(p.len(), n * n);
debug_assert_eq!(g.len(), n);
debug_assert_eq!(h.len(), n);
let inv_lambda = 1.0 / lambda;
for (i, &gi) in g.iter().enumerate() {
let row_start = i * n;
for (j, &hj) in h.iter().enumerate() {
p[row_start + j] = (p[row_start + j] - gi * hj) * inv_lambda;
}
}
}
#[inline]
fn mat_t_vec_col_major(w: &[f64], x: &[f64], d_in: usize, rank: usize) -> Vec<f64> {
debug_assert_eq!(w.len(), d_in * rank);
debug_assert_eq!(x.len(), d_in);
let mut result = vec![0.0; rank];
for (col, res) in result.iter_mut().enumerate() {
let col_start = col * d_in;
let mut sum = 0.0;
for row in 0..d_in {
sum += w[col_start + row] * x[row];
}
*res = sum;
}
result
}
#[inline]
fn mat_vec_col_major(w: &[f64], y: &[f64], d_in: usize, rank: usize) -> Vec<f64> {
debug_assert_eq!(w.len(), d_in * rank);
debug_assert_eq!(y.len(), rank);
let mut result = vec![0.0; d_in];
for (col, &yc) in y.iter().enumerate() {
let col_start = col * d_in;
for row in 0..d_in {
result[row] += w[col_start + row] * yc;
}
}
result
}
pub struct SubspaceTracker {
w: Vec<f64>,
p: Vec<f64>,
lambda: f64,
delta: f64,
d_in: usize,
rank: usize,
n_samples: u64,
rng_state: u64,
seed: u64,
}
impl SubspaceTracker {
pub fn new(d_in: usize, rank: usize, lambda: f64, delta: f64, seed: u64) -> Self {
assert!(d_in > 0, "d_in must be positive");
assert!(rank > 0, "rank must be positive");
assert!(rank <= d_in, "rank must be <= d_in");
assert!(seed != 0, "seed must be non-zero for xorshift64");
let mut rng_state = seed;
let w = xavier_init_col_major(d_in, rank, &mut rng_state);
let p = delta_identity(rank, delta);
Self {
w,
p,
lambda,
delta,
d_in,
rank,
n_samples: 0,
rng_state,
seed,
}
}
pub fn project(&self, x: &[f64]) -> Vec<f64> {
debug_assert_eq!(x.len(), self.d_in, "input dimension mismatch");
mat_t_vec_col_major(&self.w, x, self.d_in, self.rank)
}
pub fn update(&mut self, x: &[f64], _residual: f64) {
debug_assert_eq!(x.len(), self.d_in, "input dimension mismatch");
self.past_update(x);
}
pub fn update_with_error(&mut self, x: &[f64], _error_signal: &[f64]) {
debug_assert_eq!(x.len(), self.d_in, "input dimension mismatch");
self.past_update(x);
}
pub fn reset(&mut self) {
self.rng_state = self.seed;
self.w = xavier_init_col_major(self.d_in, self.rank, &mut self.rng_state);
self.p = delta_identity(self.rank, self.delta);
self.n_samples = 0;
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn d_in(&self) -> usize {
self.d_in
}
pub fn n_samples(&self) -> u64 {
self.n_samples
}
}
impl SubspaceTracker {
fn past_update(&mut self, x: &[f64]) {
let d = self.d_in;
let r = self.rank;
let y = mat_t_vec_col_major(&self.w, x, d, r);
let h = mat_vec(&self.p, &y, r);
let denom = self.lambda + dot(&y, &h);
let g: Vec<f64> = h.iter().map(|&hi| hi / denom).collect();
outer_subtract_scaled(&mut self.p, &g, &h, self.lambda, r);
let wy = mat_vec_col_major(&self.w, &y, d, r);
for (col, &gc) in g.iter().enumerate() {
let col_start = col * d;
for row in 0..d {
let e_row = x[row] - wy[row];
self.w[col_start + row] += e_row * gc;
}
}
self.n_samples += 1;
}
}
impl SubspaceTracker {
pub fn supervised_update(&mut self, x: &[f64], residual: f64, beta: &[f64], lr: f64) {
debug_assert_eq!(x.len(), self.d_in);
debug_assert_eq!(beta.len(), self.rank);
let d = self.d_in;
let r = self.rank;
for (j, &bj) in beta.iter().enumerate().take(r) {
let scale = lr * residual * bj;
let col_start = j * d;
for (i, &xi) in x.iter().enumerate().take(d) {
self.w[col_start + i] += scale * xi;
}
}
self.n_samples += 1;
if self.n_samples % 64 == 0 {
self.reorthogonalize();
}
}
fn reorthogonalize(&mut self) {
let d = self.d_in;
let r = self.rank;
for j in 0..r {
for k in 0..j {
let mut dot_val = 0.0;
for i in 0..d {
dot_val += self.w[j * d + i] * self.w[k * d + i];
}
for i in 0..d {
self.w[j * d + i] -= dot_val * self.w[k * d + i];
}
}
let mut norm_sq = 0.0;
for i in 0..d {
norm_sq += self.w[j * d + i] * self.w[j * d + i];
}
let norm = norm_sq.sqrt();
if norm > 1e-12 {
let inv_norm = 1.0 / norm;
for i in 0..d {
self.w[j * d + i] *= inv_norm;
}
}
}
}
}
fn xavier_init_col_major(d_in: usize, rank: usize, rng: &mut u64) -> Vec<f64> {
let scale = (2.0 / (d_in + rank) as f64).sqrt();
let n = d_in * rank;
let mut w = vec![0.0; n];
for val in w.iter_mut() {
let u = xorshift64_f64(rng);
*val = (2.0 * u - 1.0) * scale;
}
w
}
fn delta_identity(rank: usize, delta: f64) -> Vec<f64> {
let mut p = vec![0.0; rank * rank];
for i in 0..rank {
p[i * rank + i] = delta;
}
p
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn determinism() {
let tracker_a = SubspaceTracker::new(8, 3, 0.998, 100.0, 42);
let tracker_b = SubspaceTracker::new(8, 3, 0.998, 100.0, 42);
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let ya = tracker_a.project(&x);
let yb = tracker_b.project(&x);
assert_eq!(ya, yb, "same seed must produce identical projections");
}
#[test]
fn convergence_on_known_subspace() {
let d_in = 10;
let true_rank = 3;
let mut tracker = SubspaceTracker::new(d_in, true_rank, 0.998, 100.0, 123);
let basis: [[f64; 10]; 3] = [
[1.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.3, 0.0, 0.0, 0.1],
[0.0, 1.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.3, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.3, 0.0],
];
let mut rng: u64 = 777;
let mut initial_error = 0.0;
for _ in 0..50 {
let x = random_subspace_sample(&basis, &mut rng, d_in);
let y = tracker.project(&x);
let recon = mat_vec_col_major(&tracker.w, &y, d_in, true_rank);
let err: f64 = x
.iter()
.zip(recon.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
initial_error += err;
}
initial_error /= 50.0;
rng = 777; for _ in 0..1000 {
let x = random_subspace_sample(&basis, &mut rng, d_in);
tracker.update(&x, 0.0); }
rng = 999; let mut final_error = 0.0;
for _ in 0..50 {
let x = random_subspace_sample(&basis, &mut rng, d_in);
let y = tracker.project(&x);
let recon = mat_vec_col_major(&tracker.w, &y, d_in, true_rank);
let err: f64 = x
.iter()
.zip(recon.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
final_error += err;
}
final_error /= 50.0;
assert!(
final_error < initial_error * 0.5,
"reconstruction error should drop significantly after training: \
initial={initial_error:.4}, final={final_error:.4}"
);
}
#[test]
fn reset_restores_initial_state() {
let fresh = SubspaceTracker::new(6, 2, 0.995, 50.0, 99);
let mut tracker = SubspaceTracker::new(6, 2, 0.995, 50.0, 99);
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
for _ in 0..20 {
tracker.update(&x, 0.5);
}
assert!(tracker.n_samples() > 0);
tracker.reset();
assert_eq!(tracker.n_samples(), 0, "reset should clear sample count");
let y_fresh = fresh.project(&x);
let y_reset = tracker.project(&x);
assert_eq!(
y_fresh, y_reset,
"reset tracker must match a fresh tracker with the same seed"
);
}
#[test]
fn dimensions_correct() {
let d_in = 12;
let rank = 4;
let tracker = SubspaceTracker::new(d_in, rank, 0.999, 100.0, 7);
let x = vec![0.5; d_in];
let y = tracker.project(&x);
assert_eq!(y.len(), rank, "projection output must have rank elements");
assert_eq!(
tracker.w.len(),
d_in * rank,
"W must have d_in * rank elements"
);
assert_eq!(tracker.p.len(), rank * rank, "P must have rank^2 elements");
assert_eq!(tracker.d_in(), d_in);
assert_eq!(tracker.rank(), rank);
}
#[test]
fn update_increments_samples() {
let mut tracker = SubspaceTracker::new(5, 2, 0.99, 10.0, 1);
assert_eq!(tracker.n_samples(), 0);
let x = vec![1.0; 5];
tracker.update(&x, 1.0);
assert_eq!(tracker.n_samples(), 1);
tracker.update(&x, -0.5);
assert_eq!(tracker.n_samples(), 2);
tracker.update_with_error(&x, &[0.1, 0.2]);
assert_eq!(tracker.n_samples(), 3);
}
#[test]
fn supervised_update_finds_signal_direction() {
let d_in = 6;
let rank = 2;
let mut tracker = SubspaceTracker::new(d_in, rank, 0.998, 100.0, 42);
let mut rng: u64 = 1234;
let signal_dir = [
1.0 / 2.0_f64.sqrt(),
1.0 / 2.0_f64.sqrt(),
0.0,
0.0,
0.0,
0.0,
];
let initial_capture = subspace_capture(&tracker, &signal_dir);
let beta = [1.0, 0.0];
for _ in 0..2000 {
let mut x = vec![0.0; d_in];
for xi in x.iter_mut() {
*xi = xorshift64_f64(&mut rng) * 2.0 - 1.0;
}
let y = x[0] + x[1];
let projected = tracker.project(&x);
let pred = projected[0]; let residual = y - pred;
tracker.supervised_update(&x, residual, &beta, 0.01);
}
let final_capture = subspace_capture(&tracker, &signal_dir);
assert!(
final_capture > initial_capture,
"supervised update should improve signal capture: \
initial={:.4}, final={:.4}",
initial_capture,
final_capture
);
assert!(
final_capture > 0.5,
"subspace should capture at least half the signal energy: {:.4}",
final_capture
);
}
#[test]
fn supervised_update_increments_samples() {
let mut tracker = SubspaceTracker::new(4, 2, 0.99, 10.0, 1);
assert_eq!(tracker.n_samples(), 0);
let x = vec![1.0; 4];
let beta = vec![0.5, -0.3];
tracker.supervised_update(&x, 1.0, &beta, 0.01);
assert_eq!(tracker.n_samples(), 1);
tracker.supervised_update(&x, -0.5, &beta, 0.01);
assert_eq!(tracker.n_samples(), 2);
}
#[test]
fn reorthogonalize_preserves_unit_columns() {
let d_in = 8;
let rank = 3;
let mut tracker = SubspaceTracker::new(d_in, rank, 0.998, 100.0, 77);
let mut rng: u64 = 555;
let beta = vec![1.0, 0.5, -0.3];
for _ in 0..128 {
let x: Vec<f64> = (0..d_in)
.map(|_| xorshift64_f64(&mut rng) * 2.0 - 1.0)
.collect();
tracker.supervised_update(&x, 0.5, &beta, 0.01);
}
for j in 0..rank {
let mut norm_sq = 0.0;
for i in 0..d_in {
let val = tracker.w[j * d_in + i];
norm_sq += val * val;
}
let norm = norm_sq.sqrt();
assert!(
(norm - 1.0).abs() < 0.01,
"column {} norm should be ~1.0, got {:.6}",
j,
norm
);
}
for j in 0..rank {
for k in (j + 1)..rank {
let mut dot_val = 0.0;
for i in 0..d_in {
dot_val += tracker.w[j * d_in + i] * tracker.w[k * d_in + i];
}
assert!(
dot_val.abs() < 0.05,
"columns {} and {} should be orthogonal, dot={:.6}",
j,
k,
dot_val
);
}
}
}
fn random_subspace_sample(basis: &[[f64; 10]; 3], rng: &mut u64, d_in: usize) -> Vec<f64> {
let mut x = vec![0.0; d_in];
for b in basis.iter() {
let coeff = xorshift64_f64(rng) * 2.0 - 1.0; for (xi, &bi) in x.iter_mut().zip(b.iter()) {
*xi += coeff * bi;
}
}
x
}
fn subspace_capture(tracker: &SubspaceTracker, signal: &[f64]) -> f64 {
let d = tracker.d_in;
let r = tracker.rank;
let y = mat_t_vec_col_major(&tracker.w, signal, d, r);
let recon = mat_vec_col_major(&tracker.w, &y, d, r);
let recon_sq: f64 = recon.iter().map(|v| v * v).sum();
let signal_sq: f64 = signal.iter().map(|v| v * v).sum();
if signal_sq < 1e-15 {
0.0
} else {
recon_sq / signal_sq
}
}
}