mod block;
mod block_stream;
mod codes;
mod scoring;
#[cfg(target_os = "linux")]
mod scoring_gpu;
mod stream;
mod update;
#[cfg(test)]
mod tests;
pub use block::{
BlockSparseConfig, BlockSparseFit, block_gates, block_projections_row,
block_sparse_dictionary_transform, fit_block_sparse_dictionary, reconstruct_row,
route_row_blocks, row_loss,
};
pub use block_stream::{
BlockEpochStats, BlockShardStats, BlockSparseStreamArtifact, BlockSparseStreamState,
};
pub use codes::SparseCode;
pub use scoring::{ScoreRoutePath, ScoreRouteResult, ScoreRouteStats, TileScorer, top_s_online};
#[cfg(target_os = "linux")]
pub use scoring_gpu::{
DEVICE_SCORE_BLOCK_MIN_ELEMS, ScoreBlockPath, score_block_cpu, score_block_required,
};
pub use stream::{EpochStats, ShardStats, SparseDictArtifact, SparseDictStreamState};
use ndarray::{Array2, ArrayView2};
#[derive(Clone, Copy, Debug)]
pub struct SparseDictConfig {
pub n_atoms: usize,
pub active: usize,
pub minibatch: usize,
pub max_epochs: usize,
pub score_tile: usize,
pub code_ridge: f32,
pub decoder_ridge: f32,
pub tolerance: f64,
pub score_mode: gam_gpu::GpuMode,
}
impl SparseDictConfig {
pub fn new(n_atoms: usize) -> Self {
Self {
n_atoms,
..Self::default()
}
}
}
impl Default for SparseDictConfig {
fn default() -> Self {
Self {
n_atoms: 1,
active: 1,
minibatch: 512,
max_epochs: 30,
score_tile: 4096,
code_ridge: 1.0e-6,
decoder_ridge: 1.0e-6,
tolerance: 1.0e-6,
score_mode: gam_gpu::GpuMode::Auto,
}
}
}
#[derive(Clone, Debug)]
pub struct SparseDictFit {
pub decoder: Array2<f32>,
pub indices: Array2<u32>,
pub codes: Array2<f32>,
pub explained_variance: f64,
pub epochs: usize,
pub converged: bool,
pub active: usize,
pub score_route_stats: ScoreRouteStats,
}
impl SparseDictFit {
pub fn reconstruct(&self) -> Array2<f32> {
let n = self.indices.nrows();
let p = self.decoder.ncols();
let mut out = Array2::<f32>::zeros((n, p));
for i in 0..n {
for j in 0..self.active {
let atom = self.indices[[i, j]] as usize;
let code = self.codes[[i, j]];
if code == 0.0 {
continue;
}
let row = self.decoder.row(atom);
for c in 0..p {
out[[i, c]] += code * row[c];
}
}
}
out
}
}
#[derive(Clone, Debug)]
pub struct SparseDictTransform {
pub indices: Array2<u32>,
pub codes: Array2<f32>,
pub score_route_stats: ScoreRouteStats,
}
pub fn sparse_dictionary_transform(
x: ArrayView2<'_, f32>,
decoder: ArrayView2<'_, f32>,
active: usize,
score_tile: usize,
code_ridge: f32,
) -> Result<(Array2<u32>, Array2<f32>), String> {
let transform = sparse_dictionary_transform_with_mode(
x,
decoder,
active,
score_tile,
code_ridge,
gam_gpu::gpu_mode(),
)?;
Ok((transform.indices, transform.codes))
}
pub fn sparse_dictionary_transform_with_mode(
x: ArrayView2<'_, f32>,
decoder: ArrayView2<'_, f32>,
active: usize,
score_tile: usize,
code_ridge: f32,
score_mode: gam_gpu::GpuMode,
) -> Result<SparseDictTransform, String> {
let k = decoder.nrows();
if k == 0 {
return Err("sparse_dictionary_transform: dictionary has no atoms".to_string());
}
if x.ncols() != decoder.ncols() {
return Err(format!(
"sparse_dictionary_transform: X has P={} columns but the decoder has P={}",
x.ncols(),
decoder.ncols()
));
}
let s = active.min(k).max(1);
let scorer = TileScorer::new(s, score_tile.max(1));
let routed = scorer.route_minibatch_with_mode(x, decoder, score_mode)?;
let mut score_route_stats = ScoreRouteStats::default();
score_route_stats.record_result(&routed);
let m = x.nrows();
let mut indices = Array2::<u32>::zeros((m, s));
let mut codes = Array2::<f32>::zeros((m, s));
for (row_idx, active_pairs) in routed.selections.iter().enumerate() {
let code = codes::solve_row_codes(x.row(row_idx), decoder, active_pairs, s, code_ridge);
for j in 0..s {
indices[[row_idx, j]] = code.indices[j];
codes[[row_idx, j]] = code.codes[j];
}
}
Ok(SparseDictTransform {
indices,
codes,
score_route_stats,
})
}
pub fn fit_sparse_dictionary(
x: ArrayView2<'_, f32>,
config: &SparseDictConfig,
) -> Result<SparseDictFit, String> {
update::run(x, config)
}