use super::scoring::TileScorer;
use super::update::{
DEAD_DENOM, DecoderNormalEq, route_and_code_all, seed_decoder, solve_decoder, unit_norm_rows,
};
use super::{ScoreRouteStats, SparseDictConfig};
use ndarray::{Array2, ArrayView2};
use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[derive(Clone, Copy, Debug)]
pub struct ShardStats {
pub rows: usize,
pub rss: f64,
pub alive_atoms: usize,
pub score_route_stats: ScoreRouteStats,
}
#[derive(Clone, Copy, Debug)]
pub struct EpochStats {
pub explained_variance: f64,
pub revived: usize,
pub dead: usize,
pub converged: bool,
pub epoch: usize,
}
struct ResidRow {
norm2: f64,
global_index: u64,
residual: Vec<f32>,
}
impl PartialEq for ResidRow {
fn eq(&self, other: &Self) -> bool {
self.norm2 == other.norm2 && self.global_index == other.global_index
}
}
impl Eq for ResidRow {}
impl Ord for ResidRow {
fn cmp(&self, other: &Self) -> Ordering {
match other.norm2.total_cmp(&self.norm2) {
Ordering::Equal => self.global_index.cmp(&other.global_index),
ord => ord,
}
}
}
impl PartialOrd for ResidRow {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
struct ResidualReservoir {
cap: usize,
heap: BinaryHeap<ResidRow>,
}
impl ResidualReservoir {
fn new(cap: usize) -> Self {
Self {
cap: cap.max(1),
heap: BinaryHeap::new(),
}
}
fn offer(&mut self, norm2: f64, global_index: u64, residual: Vec<f32>) {
if norm2 <= DEAD_DENOM {
return;
}
let row = ResidRow {
norm2,
global_index,
residual,
};
if self.heap.len() < self.cap {
self.heap.push(row);
return;
}
if let Some(worst_kept) = self.heap.peek() {
if row.cmp(worst_kept) == Ordering::Less {
self.heap.pop();
self.heap.push(row);
}
}
}
fn clear(&mut self) {
self.heap.clear();
}
fn ranked(&self) -> Vec<&ResidRow> {
let mut rows: Vec<&ResidRow> = self.heap.iter().collect();
rows.sort_by(|a, b| {
b.norm2
.total_cmp(&a.norm2)
.then_with(|| a.global_index.cmp(&b.global_index))
});
rows
}
}
pub struct SparseDictStreamState {
config: SparseDictConfig,
s: usize,
p: usize,
decoder: Array2<f32>,
scorer: TileScorer,
eq: DecoderNormalEq,
alive: Vec<bool>,
alive_count: usize,
col_sum: Vec<f64>,
col_sumsq: Vec<f64>,
rss: f64,
row_count: usize,
reservoir: ResidualReservoir,
prev_ev: f64,
last_ev: f64,
epochs_run: usize,
last_revived: usize,
converged: bool,
score_route_stats: ScoreRouteStats,
}
impl SparseDictStreamState {
pub fn new(seed: ArrayView2<'_, f32>, config: &SparseDictConfig) -> Result<Self, String> {
validate_config(config)?;
if seed.nrows() == 0 || seed.ncols() == 0 {
return Err(
"SparseDictStream requires a non-empty seed sample (N×P) to fix P and the initial \
atom directions"
.to_string(),
);
}
if !seed.iter().all(|v| v.is_finite()) {
return Err("SparseDictStream seed sample must be finite".to_string());
}
let k = config.n_atoms;
let p = seed.ncols();
let s = config.active.min(k).max(1);
let mut decoder = seed_decoder(seed, k);
unit_norm_rows(&mut decoder);
let scorer = TileScorer::new(s, config.score_tile);
Ok(Self {
config: *config,
s,
p,
decoder,
scorer,
eq: DecoderNormalEq::zeros(k, p),
alive: vec![false; k],
alive_count: 0,
col_sum: vec![0.0; p],
col_sumsq: vec![0.0; p],
rss: 0.0,
row_count: 0,
reservoir: ResidualReservoir::new(k),
prev_ev: f64::NEG_INFINITY,
last_ev: f64::NEG_INFINITY,
epochs_run: 0,
last_revived: 0,
converged: false,
score_route_stats: ScoreRouteStats::default(),
})
}
pub fn partial_fit(&mut self, shard: ArrayView2<'_, f32>) -> Result<ShardStats, String> {
if shard.nrows() == 0 {
return Ok(ShardStats {
rows: 0,
rss: 0.0,
alive_atoms: self.alive_count,
score_route_stats: ScoreRouteStats::default(),
});
}
if shard.ncols() != self.p {
return Err(format!(
"SparseDictStream.partial_fit: shard has P={} columns but the fit was begun with \
P={}",
shard.ncols(),
self.p
));
}
if !shard.iter().all(|v| v.is_finite()) {
return Err("SparseDictStream.partial_fit shard must be finite".to_string());
}
let mut shard_route_stats = ScoreRouteStats::default();
let codes = route_and_code_all(
shard,
self.decoder.view(),
&self.scorer,
self.s,
self.config.code_ridge,
self.config.minibatch,
self.config.score_mode,
Some(&mut shard_route_stats),
)?;
self.score_route_stats.absorb(shard_route_stats);
self.eq.accumulate(shard, &codes);
let base_index = self.row_count as u64;
let mut shard_rss = 0.0f64;
for (r, code) in codes.iter().enumerate() {
let xi = shard.row(r);
for c in 0..self.p {
let v = xi[c] as f64;
self.col_sum[c] += v;
self.col_sumsq[c] += v * v;
}
let mut residual = vec![0.0f32; self.p];
for c in 0..self.p {
residual[c] = xi[c];
}
for j in 0..code.indices.len() {
let cj = code.codes[j];
if cj == 0.0 {
continue;
}
self.alive_mark(code.indices[j] as usize);
let drow = self.decoder.row(code.indices[j] as usize);
for c in 0..self.p {
residual[c] -= cj * drow[c];
}
}
let mut norm2 = 0.0f64;
for c in 0..self.p {
norm2 += residual[c] as f64 * residual[c] as f64;
}
shard_rss += norm2;
self.reservoir.offer(norm2, base_index + r as u64, residual);
}
self.rss += shard_rss;
self.row_count += codes.len();
Ok(ShardStats {
rows: codes.len(),
rss: shard_rss,
alive_atoms: self.alive_count,
score_route_stats: shard_route_stats,
})
}
#[inline]
fn alive_mark(&mut self, atom: usize) {
if !self.alive[atom] {
self.alive[atom] = true;
self.alive_count += 1;
}
}
pub fn end_epoch(&mut self) -> Result<EpochStats, String> {
if self.row_count == 0 {
return Err(
"SparseDictStream.end_epoch: no rows were streamed this epoch (call partial_fit \
with at least one shard first)"
.to_string(),
);
}
let n = self.row_count as f64;
let mut tss = 0.0f64;
for c in 0..self.p {
tss += self.col_sumsq[c] - self.col_sum[c] * self.col_sum[c] / n;
}
let ev = if tss <= 1.0e-24 {
if self.rss <= 1.0e-24 { 1.0 } else { 0.0 }
} else {
1.0 - self.rss / tss
};
solve_decoder(
&mut self.decoder,
&self.eq,
self.config.decoder_ridge as f64,
);
unit_norm_rows(&mut self.decoder);
let dead: usize = self.alive.iter().filter(|&&a| !a).count();
let revived = self.revive(dead);
if revived > 0 {
unit_norm_rows(&mut self.decoder);
}
let improve = ev - self.prev_ev;
let converged =
revived == 0 && improve.abs() <= self.config.tolerance && self.epochs_run > 0;
self.prev_ev = ev;
self.last_ev = ev;
self.last_revived = revived;
self.converged = converged;
self.epochs_run += 1;
let epoch = self.epochs_run;
self.reset_epoch();
Ok(EpochStats {
explained_variance: ev,
revived,
dead,
converged,
epoch,
})
}
fn revive(&mut self, dead: usize) -> usize {
if dead == 0 {
return 0;
}
let ranked = self.reservoir.ranked();
if ranked.is_empty() {
return 0;
}
let dead_atoms: Vec<usize> = (0..self.decoder.nrows())
.filter(|&a| !self.alive[a])
.collect();
let mut revived = 0usize;
for (t, &atom) in dead_atoms.iter().enumerate() {
if t >= ranked.len() {
break; }
let src = ranked[t];
if src.norm2 <= DEAD_DENOM {
break; }
let mut dst = self.decoder.row_mut(atom);
for c in 0..self.p {
dst[c] = src.residual[c];
}
revived += 1;
}
revived
}
fn reset_epoch(&mut self) {
let k = self.decoder.nrows();
self.eq = DecoderNormalEq::zeros(k, self.p);
for a in self.alive.iter_mut() {
*a = false;
}
self.alive_count = 0;
for c in 0..self.p {
self.col_sum[c] = 0.0;
self.col_sumsq[c] = 0.0;
}
self.rss = 0.0;
self.row_count = 0;
self.reservoir.clear();
}
pub fn finalize(&self) -> SparseDictArtifact {
SparseDictArtifact {
decoder: self.decoder.clone(),
active: self.s,
epochs: self.epochs_run,
explained_variance: self.last_ev,
converged: self.converged,
score_route_stats: self.score_route_stats,
}
}
pub fn decoder(&self) -> ArrayView2<'_, f32> {
self.decoder.view()
}
pub fn active(&self) -> usize {
self.s
}
pub fn epochs_run(&self) -> usize {
self.epochs_run
}
}
#[derive(Clone, Debug)]
pub struct SparseDictArtifact {
pub decoder: Array2<f32>,
pub active: usize,
pub epochs: usize,
pub explained_variance: f64,
pub converged: bool,
pub score_route_stats: ScoreRouteStats,
}
fn validate_config(config: &SparseDictConfig) -> Result<(), String> {
if config.n_atoms == 0 {
return Err("SparseDictStream requires K >= 1".to_string());
}
if config.active == 0 {
return Err("SparseDictStream requires active (top_s) >= 1".to_string());
}
if config.max_epochs == 0 {
return Err("SparseDictStream requires max_epochs >= 1".to_string());
}
if !(config.code_ridge.is_finite() && config.code_ridge >= 0.0) {
return Err("SparseDictStream code_ridge must be finite and non-negative".to_string());
}
if !(config.decoder_ridge.is_finite() && config.decoder_ridge >= 0.0) {
return Err("SparseDictStream decoder_ridge must be finite and non-negative".to_string());
}
if !config.tolerance.is_finite() {
return Err("SparseDictStream tolerance must be finite".to_string());
}
Ok(())
}
#[cfg(test)]
mod stream_tests {
use super::{SparseDictConfig, SparseDictStreamState, TileScorer, route_and_code_all};
use crate::sparse_dict::fit_sparse_dictionary;
use ndarray::{Array2, ArrayView2};
fn planted(n: usize, k: usize, p: usize) -> Array2<f32> {
let mut x = Array2::<f32>::zeros((n, p));
for row in 0..n {
let primary = row % k;
let secondary = (primary + 1) % k;
let scale = 0.7 + 0.01 * ((row / k) as f32);
x[[row, primary % p]] += scale;
x[[row, secondary % p]] += 0.2 * scale;
}
x
}
fn routed_ev(
x: ArrayView2<'_, f32>,
decoder: &Array2<f32>,
s: usize,
config: &SparseDictConfig,
) -> f64 {
let scorer = TileScorer::new(s, config.score_tile);
let codes = route_and_code_all(
x,
decoder.view(),
&scorer,
s,
config.code_ridge,
config.minibatch,
config.score_mode,
None,
)
.expect("fresh route");
let n = x.nrows();
let p = x.ncols();
let mut means = vec![0.0f64; p];
for i in 0..n {
for c in 0..p {
means[c] += x[[i, c]] as f64;
}
}
for m in means.iter_mut() {
*m /= n as f64;
}
let mut rss = 0.0f64;
let mut tss = 0.0f64;
for (i, code) in codes.iter().enumerate() {
let mut recon = vec![0.0f64; p];
for j in 0..code.indices.len() {
let cj = code.codes[j] as f64;
if cj == 0.0 {
continue;
}
let drow = decoder.row(code.indices[j] as usize);
for c in 0..p {
recon[c] += cj * drow[c] as f64;
}
}
for c in 0..p {
let r = x[[i, c]] as f64 - recon[c];
rss += r * r;
let t = x[[i, c]] as f64 - means[c];
tss += t * t;
}
}
if tss <= 1.0e-24 {
if rss <= 1.0e-24 { 1.0 } else { 0.0 }
} else {
1.0 - rss / tss
}
}
fn stream_fit(
seed: ArrayView2<'_, f32>,
shards: &[ArrayView2<'_, f32>],
config: &SparseDictConfig,
) -> (Array2<f32>, usize) {
let mut state = SparseDictStreamState::new(seed, config).expect("fit_begin");
for _ in 0..config.max_epochs {
for shard in shards {
state.partial_fit(*shard).expect("partial_fit");
}
let stats = state.end_epoch().expect("end_epoch");
if stats.converged {
break;
}
}
let artifact = state.finalize();
(artifact.decoder, artifact.active)
}
#[test]
fn streaming_over_shards_matches_one_shot_on_concatenation() {
let (n, k, p) = (240usize, 6usize, 8usize);
let x = planted(n, k, p);
let config = SparseDictConfig {
n_atoms: k,
active: 1,
minibatch: 32,
max_epochs: 40,
score_tile: 16,
code_ridge: 1.0e-6,
decoder_ridge: 1.0e-6,
tolerance: 1.0e-9,
score_mode: gam_gpu::GpuMode::Off,
};
let one_shot = fit_sparse_dictionary(x.view(), &config).expect("one-shot fit");
let chunk = n / 4;
let shards: Vec<ArrayView2<'_, f32>> = (0..4)
.map(|i| {
let start = i * chunk;
let end = if i == 3 { n } else { start + chunk };
x.slice(ndarray::s![start..end, ..])
})
.collect();
let (stream_decoder, s) = stream_fit(x.view(), &shards, &config);
assert_eq!(
stream_decoder.shape(),
one_shot.decoder.shape(),
"decoder shapes must match"
);
let ev_stream = routed_ev(x.view(), &stream_decoder, s, &config);
assert!(
(ev_stream - one_shot.explained_variance).abs() < 1.0e-3,
"streamed EV {ev_stream} must match one-shot EV {} within 1e-3",
one_shot.explained_variance
);
assert!(
ev_stream > 0.9,
"planted corpus should fit well, got EV {ev_stream}"
);
}
fn pseudo_random(n: usize, p: usize) -> Array2<f32> {
let mut x = Array2::<f32>::zeros((n, p));
for i in 0..n {
for c in 0..p {
let h = (i.wrapping_mul(73_856_093) ^ c.wrapping_mul(19_349_663)) as u64;
let h = h.wrapping_mul(2_654_435_761) % 2_000;
x[[i, c]] = h as f32 / 1_000.0 - 1.0;
}
}
x
}
#[test]
fn warm_start_persists_across_epochs() {
let (n, k, p) = (300usize, 8usize, 12usize);
let x = pseudo_random(n, p);
let config = SparseDictConfig {
n_atoms: k,
active: 1,
minibatch: 64,
max_epochs: 6,
score_tile: 16,
code_ridge: 1.0e-6,
decoder_ridge: 1.0e-6,
tolerance: 1.0e-12,
score_mode: gam_gpu::GpuMode::Off,
};
let mut state = SparseDictStreamState::new(x.view(), &config).expect("fit_begin");
let mut evs = Vec::new();
for _ in 0..config.max_epochs {
state.partial_fit(x.view()).expect("partial_fit");
evs.push(state.end_epoch().expect("end_epoch").explained_variance);
}
assert!(
evs[1] > evs[0] + 1.0e-4,
"second-epoch EV {} must improve on first-epoch EV {} (warm-start persisted)",
evs[1],
evs[0]
);
}
#[test]
fn revival_targets_worst_reconstructed_row_not_pcs() {
let p = 4usize;
let mut seed = Array2::<f32>::zeros((20, p));
for i in 0..10 {
seed[[i, 0]] = 3.0;
seed[[10 + i, 1]] = 3.0;
}
let mut shard = Array2::<f32>::zeros((21, p));
for i in 0..10 {
shard[[i, 0]] = 3.0;
shard[[10 + i, 1]] = 3.0;
}
shard[[20, 2]] = 2.0;
let config = SparseDictConfig {
n_atoms: 3,
active: 1,
minibatch: 64,
max_epochs: 5,
score_tile: 16,
code_ridge: 1.0e-6,
decoder_ridge: 1.0e-6,
tolerance: 0.0,
score_mode: gam_gpu::GpuMode::Off,
};
let mut state = SparseDictStreamState::new(seed.view(), &config).expect("fit_begin");
let pre_cos_e2 = (0..3)
.map(|a| state.decoder()[[a, 2]].abs())
.fold(0.0f32, f32::max);
assert!(
pre_cos_e2 < 1.0e-4,
"seed decoder must not span e2, got |cos|={pre_cos_e2}"
);
state.partial_fit(shard.view()).expect("partial_fit");
let stats = state.end_epoch().expect("end_epoch");
assert!(
stats.dead >= 1 && stats.revived >= 1,
"expected a dead atom revived; dead={} revived={}",
stats.dead,
stats.revived
);
let post_cos_e2 = (0..3)
.map(|a| state.decoder()[[a, 2]].abs())
.fold(0.0f32, f32::max);
assert!(
post_cos_e2 > 0.999,
"a revived atom must equal e2, got |cos|={post_cos_e2}"
);
}
}