mod codes;
mod scoring;
#[cfg(target_os = "linux")]
mod scoring_gpu;
mod update;
#[cfg(test)]
mod tests;
pub use codes::SparseCode;
pub use scoring::{TileScorer, top_s_online};
#[cfg(target_os = "linux")]
pub use scoring_gpu::{
DEVICE_SCORE_BLOCK_MIN_ELEMS, ScoreBlockPath, score_block_cpu, score_block_required,
};
use ndarray::{Array2, ArrayView2};
#[derive(Clone, Copy, Debug)]
pub struct SparseDictConfig {
pub n_atoms: usize,
pub active: usize,
pub minibatch: usize,
pub max_epochs: usize,
pub score_tile: usize,
pub code_ridge: f32,
pub decoder_ridge: f32,
pub tolerance: f64,
}
impl SparseDictConfig {
pub fn new(n_atoms: usize) -> Self {
Self {
n_atoms,
..Self::default()
}
}
}
impl Default for SparseDictConfig {
fn default() -> Self {
Self {
n_atoms: 1,
active: 1,
minibatch: 512,
max_epochs: 30,
score_tile: 4096,
code_ridge: 1.0e-6,
decoder_ridge: 1.0e-6,
tolerance: 1.0e-6,
}
}
}
#[derive(Clone, Debug)]
pub struct SparseDictFit {
pub decoder: Array2<f32>,
pub indices: Array2<u32>,
pub codes: Array2<f32>,
pub explained_variance: f64,
pub epochs: usize,
pub converged: bool,
pub active: usize,
}
impl SparseDictFit {
pub fn reconstruct(&self) -> Array2<f32> {
let n = self.indices.nrows();
let p = self.decoder.ncols();
let mut out = Array2::<f32>::zeros((n, p));
for i in 0..n {
for j in 0..self.active {
let atom = self.indices[[i, j]] as usize;
let code = self.codes[[i, j]];
if code == 0.0 {
continue;
}
let row = self.decoder.row(atom);
for c in 0..p {
out[[i, c]] += code * row[c];
}
}
}
out
}
}
pub fn sparse_dictionary_transform(
x: ArrayView2<'_, f32>,
decoder: ArrayView2<'_, f32>,
active: usize,
score_tile: usize,
code_ridge: f32,
) -> Result<(Array2<u32>, Array2<f32>), String> {
let k = decoder.nrows();
if k == 0 {
return Err("sparse_dictionary_transform: dictionary has no atoms".to_string());
}
if x.ncols() != decoder.ncols() {
return Err(format!(
"sparse_dictionary_transform: X has P={} columns but the decoder has P={}",
x.ncols(),
decoder.ncols()
));
}
let s = active.min(k).max(1);
let scorer = TileScorer::new(s, score_tile.max(1));
let routed = scorer.route_minibatch(x, decoder);
let m = x.nrows();
let mut indices = Array2::<u32>::zeros((m, s));
let mut codes = Array2::<f32>::zeros((m, s));
for (row_idx, active_pairs) in routed.iter().enumerate() {
let code = codes::solve_row_codes(x.row(row_idx), decoder, active_pairs, s, code_ridge);
for j in 0..s {
indices[[row_idx, j]] = code.indices[j];
codes[[row_idx, j]] = code.codes[j];
}
}
Ok((indices, codes))
}
pub fn fit_sparse_dictionary(
x: ArrayView2<'_, f32>,
config: &SparseDictConfig,
) -> Result<SparseDictFit, String> {
update::run(x, config)
}