use super::BlockSparseConfig;
use super::block::{gram_schmidt_rows, route_and_code_all, seed_frames, stable_rank_symmetric};
use crate::frames::GrassmannFrame;
use ndarray::{Array2, ArrayView2};
use std::cmp::Ordering;
use std::collections::BinaryHeap;
const DEAD_DENOM: f64 = 1.0e-12;
#[derive(Clone, Copy, Debug)]
pub struct BlockShardStats {
pub rows: usize,
pub rss: f64,
pub alive_blocks: usize,
}
#[derive(Clone, Copy, Debug)]
pub struct BlockEpochStats {
pub explained_variance: f64,
pub revived: usize,
pub dead: usize,
pub gamma: f32,
pub converged: bool,
pub epoch: usize,
}
struct ResidRow {
norm2: f64,
global_index: u64,
residual: Vec<f32>,
}
impl PartialEq for ResidRow {
fn eq(&self, other: &Self) -> bool {
self.norm2 == other.norm2 && self.global_index == other.global_index
}
}
impl Eq for ResidRow {}
impl Ord for ResidRow {
fn cmp(&self, other: &Self) -> Ordering {
match other.norm2.total_cmp(&self.norm2) {
Ordering::Equal => self.global_index.cmp(&other.global_index),
ord => ord,
}
}
}
impl PartialOrd for ResidRow {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
struct ResidualReservoir {
cap: usize,
heap: BinaryHeap<ResidRow>,
}
impl ResidualReservoir {
fn new(cap: usize) -> Self {
Self {
cap: cap.max(1),
heap: BinaryHeap::new(),
}
}
fn offer(&mut self, norm2: f64, global_index: u64, residual: Vec<f32>) {
if norm2 <= DEAD_DENOM {
return;
}
let row = ResidRow {
norm2,
global_index,
residual,
};
if self.heap.len() < self.cap {
self.heap.push(row);
return;
}
if let Some(worst_kept) = self.heap.peek() {
if row.cmp(worst_kept) == Ordering::Less {
self.heap.pop();
self.heap.push(row);
}
}
}
fn clear(&mut self) {
self.heap.clear();
}
fn ranked(&self) -> Vec<&ResidRow> {
let mut rows: Vec<&ResidRow> = self.heap.iter().collect();
rows.sort_by(|a, b| {
b.norm2
.total_cmp(&a.norm2)
.then_with(|| a.global_index.cmp(&b.global_index))
});
rows
}
}
pub struct BlockSparseStreamState {
config: BlockSparseConfig,
g: usize,
b: usize,
k: usize,
p: usize,
decoder: Array2<f32>,
gamma: f32,
cross: Vec<Array2<f64>>, second: Vec<Array2<f64>>, usage: Vec<usize>,
touched: Vec<bool>,
alive_count: usize,
gamma_num: f64,
gamma_den: f64,
col_sum: Vec<f64>,
col_sumsq: Vec<f64>,
rss: f64,
row_count: usize,
reservoir: ResidualReservoir,
prev_ev: f64,
last_ev: f64,
epochs_run: usize,
last_revived: usize,
converged: bool,
last_util: Vec<f32>,
last_stable: Vec<f32>,
}
impl BlockSparseStreamState {
pub fn new(seed: ArrayView2<'_, f32>, config: &BlockSparseConfig) -> Result<Self, String> {
validate_config(config)?;
if seed.nrows() == 0 || seed.ncols() == 0 {
return Err(
"BlockSparseStream requires a non-empty seed sample (N×P) to fix P and the initial \
block frames"
.to_string(),
);
}
if !seed.iter().all(|v| v.is_finite()) {
return Err("BlockSparseStream seed sample must be finite".to_string());
}
let p = seed.ncols();
if config.block_size > p {
return Err(format!(
"BlockSparseStream block_size b={} cannot exceed P={p} (a block's b orthonormal \
rows must fit in ℝ^P)",
config.block_size
));
}
let g = config.n_blocks;
let b = config.block_size;
let k = config.block_topk.min(g).max(1);
let decoder = seed_frames(seed, g, b);
let cap = config.aux_k.saturating_mul(b).max(1);
Ok(Self {
config: *config,
g,
b,
k,
p,
decoder,
gamma: 1.0,
cross: (0..g).map(|_| Array2::<f64>::zeros((p, b))).collect(),
second: (0..g).map(|_| Array2::<f64>::zeros((b, b))).collect(),
usage: vec![0; g],
touched: vec![false; g],
alive_count: 0,
gamma_num: 0.0,
gamma_den: 0.0,
col_sum: vec![0.0; p],
col_sumsq: vec![0.0; p],
rss: 0.0,
row_count: 0,
reservoir: ResidualReservoir::new(cap),
prev_ev: f64::NEG_INFINITY,
last_ev: f64::NEG_INFINITY,
epochs_run: 0,
last_revived: 0,
converged: false,
last_util: vec![0.0; g],
last_stable: vec![0.0; g],
})
}
pub fn partial_fit(&mut self, shard: ArrayView2<'_, f32>) -> Result<BlockShardStats, String> {
if shard.nrows() == 0 {
return Ok(BlockShardStats {
rows: 0,
rss: 0.0,
alive_blocks: self.alive_count,
});
}
if shard.ncols() != self.p {
return Err(format!(
"BlockSparseStream.partial_fit: shard has P={} columns but the fit was begun with \
P={}",
shard.ncols(),
self.p
));
}
if !shard.iter().all(|v| v.is_finite()) {
return Err("BlockSparseStream.partial_fit shard must be finite".to_string());
}
let p = self.p;
let b = self.b;
let gamma = self.gamma;
let aux_on = self.config.aux_k > 0;
let codes = route_and_code_all(
shard,
self.decoder.view(),
gamma,
self.g,
b,
self.k,
self.config.minibatch,
self.config.block_tile,
);
let base_index = self.row_count as u64;
let mut shard_rss = 0.0f64;
for (r, code) in codes.iter().enumerate() {
let xi = shard.row(r);
for c in 0..p {
let v = xi[c] as f64;
self.col_sum[c] += v;
self.col_sumsq[c] += v * v;
}
let mut sel: Vec<(usize, Vec<f32>, Vec<f32>)> = Vec::with_capacity(self.k);
let mut xhat = vec![0.0f32; p];
let mut proj_sum = vec![0.0f32; p];
for j in 0..code.blocks.len() {
if code.gates[j] == 0.0 {
continue;
}
let gg = code.blocks[j] as usize;
let mut w = vec![0.0f32; b];
for (rr, wr) in w.iter_mut().enumerate() {
let atom = self.decoder.row(gg * b + rr);
let mut acc = 0.0f32;
for (xc, ac) in xi.iter().zip(atom.iter()) {
acc += *xc * *ac;
}
*wr = acc;
}
let mut proj = vec![0.0f32; p];
for (rr, &wr) in w.iter().enumerate() {
if wr == 0.0 {
continue;
}
let atom = self.decoder.row(gg * b + rr);
for c in 0..p {
proj[c] += wr * atom[c];
}
}
for c in 0..p {
xhat[c] += gamma * proj[c];
proj_sum[c] += proj[c];
}
let z: Vec<f32> = w.iter().map(|v| gamma * v).collect();
sel.push((gg, z, proj));
}
let mut residual = vec![0.0f32; p];
let mut norm2 = 0.0f64;
for c in 0..p {
residual[c] = xi[c] - xhat[c];
norm2 += residual[c] as f64 * residual[c] as f64;
}
shard_rss += norm2;
if aux_on {
self.reservoir
.offer(norm2, base_index + r as u64, residual.clone());
}
for c in 0..p {
self.gamma_num += xi[c] as f64 * proj_sum[c] as f64;
self.gamma_den += proj_sum[c] as f64 * proj_sum[c] as f64;
}
for (gg, z, proj) in sel.iter() {
let gg = *gg;
if !self.touched[gg] {
self.touched[gg] = true;
self.alive_count += 1;
}
self.usage[gg] += 1;
let mg = &mut self.cross[gg];
for c in 0..p {
let r_ig_c = residual[c] as f64 + (gamma * proj[c]) as f64;
for (rr, &zr) in z.iter().enumerate() {
mg[[c, rr]] += r_ig_c * zr as f64;
}
}
let sg = &mut self.second[gg];
for r1 in 0..b {
for r2 in 0..b {
sg[[r1, r2]] += z[r1] as f64 * z[r2] as f64;
}
}
}
}
self.rss += shard_rss;
self.row_count += codes.len();
Ok(BlockShardStats {
rows: codes.len(),
rss: shard_rss,
alive_blocks: self.alive_count,
})
}
pub fn end_epoch(&mut self) -> Result<BlockEpochStats, String> {
if self.row_count == 0 {
return Err(
"BlockSparseStream.end_epoch: no rows were streamed this epoch (call partial_fit \
with at least one shard first)"
.to_string(),
);
}
let p = self.p;
let b = self.b;
let n = self.row_count as f64;
let mut tss = 0.0f64;
for c in 0..p {
tss += self.col_sumsq[c] - self.col_sum[c] * self.col_sum[c] / n;
}
let ev = if tss <= 1.0e-24 {
if self.rss <= 1.0e-24 { 1.0 } else { 0.0 }
} else {
1.0 - self.rss / tss
};
let gamma_new = if self.gamma_den <= 1.0e-24 {
self.gamma
} else {
(self.gamma_num / self.gamma_den) as f32
};
let ridge = self.config.frame_ridge;
for gg in 0..self.g {
if !self.touched[gg] {
continue;
}
if ridge > 0.0 {
for rr in 0..b {
for c in 0..p {
self.cross[gg][[c, rr]] += ridge * self.decoder[[gg * b + rr, c]] as f64;
}
}
}
if let Ok(frame) = GrassmannFrame::polar_update(self.cross[gg].view()) {
let u = frame.frame(); let sv = frame.gauge_singular_values();
let full_rank = sv.len() == b && sv.iter().all(|&s| s > 1.0e-9);
if full_rank && u.ncols() == b {
for rr in 0..b {
for c in 0..p {
self.decoder[[gg * b + rr, c]] = u[[c, rr]] as f32;
}
}
}
}
}
let dead: usize = self.usage.iter().filter(|&&u| u == 0).count();
let revived = self.revive();
for gg in 0..self.g {
self.last_util[gg] = self.usage[gg] as f32 / self.row_count.max(1) as f32;
self.last_stable[gg] = stable_rank_symmetric(self.second[gg].view());
}
self.gamma = gamma_new;
let improve = ev - self.prev_ev;
let converged =
revived == 0 && improve.abs() <= self.config.tolerance && self.epochs_run > 0;
self.prev_ev = ev;
self.last_ev = ev;
self.last_revived = revived;
self.converged = converged;
self.epochs_run += 1;
let epoch = self.epochs_run;
self.reset_epoch();
Ok(BlockEpochStats {
explained_variance: ev,
revived,
dead,
gamma: self.gamma,
converged,
epoch,
})
}
fn revive(&mut self) -> usize {
if self.config.aux_k == 0 {
return 0;
}
let ranked = self.reservoir.ranked();
if ranked.is_empty() {
return 0;
}
let b = self.b;
let p = self.p;
let dead_blocks: Vec<usize> = (0..self.g).filter(|&gg| self.usage[gg] == 0).collect();
let mut revived = 0usize;
let mut cursor = 0usize;
for &gg in dead_blocks.iter().take(self.config.aux_k) {
if cursor + b > ranked.len() {
break; }
if ranked[cursor].norm2 <= DEAD_DENOM {
break; }
let mut seed = Array2::<f32>::zeros((b, p));
for rr in 0..b {
let src = &ranked[cursor + rr].residual;
for c in 0..p {
seed[[rr, c]] = src[c];
}
}
cursor += b;
gram_schmidt_rows(&mut seed);
for rr in 0..b {
for c in 0..p {
self.decoder[[gg * b + rr, c]] = seed[[rr, c]];
}
}
revived += 1;
}
revived
}
fn reset_epoch(&mut self) {
for mg in self.cross.iter_mut() {
mg.fill(0.0);
}
for sg in self.second.iter_mut() {
sg.fill(0.0);
}
for u in self.usage.iter_mut() {
*u = 0;
}
for t in self.touched.iter_mut() {
*t = false;
}
self.alive_count = 0;
self.gamma_num = 0.0;
self.gamma_den = 0.0;
for c in 0..self.p {
self.col_sum[c] = 0.0;
self.col_sumsq[c] = 0.0;
}
self.rss = 0.0;
self.row_count = 0;
self.reservoir.clear();
}
pub fn finalize(&self) -> BlockSparseStreamArtifact {
BlockSparseStreamArtifact {
decoder: self.decoder.clone(),
gamma: self.gamma,
block_topk: self.k,
block_size: self.b,
block_utilization: self.last_util.clone(),
block_stable_rank: self.last_stable.clone(),
epochs: self.epochs_run,
explained_variance: self.last_ev,
converged: self.converged,
}
}
pub fn decoder(&self) -> ArrayView2<'_, f32> {
self.decoder.view()
}
pub fn gamma(&self) -> f32 {
self.gamma
}
pub fn block_topk(&self) -> usize {
self.k
}
pub fn block_size(&self) -> usize {
self.b
}
pub fn epochs_run(&self) -> usize {
self.epochs_run
}
}
#[derive(Clone, Debug)]
pub struct BlockSparseStreamArtifact {
pub decoder: Array2<f32>,
pub gamma: f32,
pub block_topk: usize,
pub block_size: usize,
pub block_utilization: Vec<f32>,
pub block_stable_rank: Vec<f32>,
pub epochs: usize,
pub explained_variance: f64,
pub converged: bool,
}
fn validate_config(config: &BlockSparseConfig) -> Result<(), String> {
if config.n_blocks == 0 {
return Err("BlockSparseStream requires n_blocks >= 1".to_string());
}
if config.block_size == 0 {
return Err("BlockSparseStream requires block_size >= 1".to_string());
}
if config.block_topk == 0 {
return Err("BlockSparseStream requires block_topk >= 1".to_string());
}
if config.max_epochs == 0 {
return Err("BlockSparseStream requires max_epochs >= 1".to_string());
}
if !(config.frame_ridge.is_finite() && config.frame_ridge >= 0.0) {
return Err("BlockSparseStream frame_ridge must be finite and non-negative".to_string());
}
if !config.tolerance.is_finite() {
return Err("BlockSparseStream tolerance must be finite".to_string());
}
Ok(())
}
#[cfg(test)]
#[path = "block_stream_tests.rs"]
mod block_stream_tests;