mod codes;
mod scoring;
#[cfg(target_os = "linux")]
mod scoring_gpu;
mod update;
#[cfg(test)]
mod tests;
pub use codes::SparseCode;
pub use scoring::{TileScorer, top_s_online};
#[cfg(target_os = "linux")]
pub use scoring_gpu::{
DEVICE_SCORE_BLOCK_MIN_ELEMS, ScoreBlockPath, score_block_cpu, score_block_required,
};
use ndarray::{Array2, ArrayView2};
#[derive(Clone, Copy, Debug)]
pub struct SparseDictConfig {
pub n_atoms: usize,
pub active: usize,
pub minibatch: usize,
pub max_epochs: usize,
pub score_tile: usize,
pub code_ridge: f32,
pub decoder_ridge: f32,
pub tolerance: f64,
}
impl SparseDictConfig {
pub fn new(n_atoms: usize) -> Self {
Self {
n_atoms,
..Self::default()
}
}
}
impl Default for SparseDictConfig {
fn default() -> Self {
Self {
n_atoms: 1,
active: 1,
minibatch: 512,
max_epochs: 30,
score_tile: 4096,
code_ridge: 1.0e-6,
decoder_ridge: 1.0e-6,
tolerance: 1.0e-6,
}
}
}
#[derive(Clone, Debug)]
pub struct SparseDictFit {
pub decoder: Array2<f32>,
pub indices: Array2<u32>,
pub codes: Array2<f32>,
pub explained_variance: f64,
pub epochs: usize,
pub converged: bool,
pub active: usize,
}
impl SparseDictFit {
pub fn reconstruct(&self) -> Array2<f32> {
let n = self.indices.nrows();
let p = self.decoder.ncols();
let mut out = Array2::<f32>::zeros((n, p));
for i in 0..n {
for j in 0..self.active {
let atom = self.indices[[i, j]] as usize;
let code = self.codes[[i, j]];
if code == 0.0 {
continue;
}
let row = self.decoder.row(atom);
for c in 0..p {
out[[i, c]] += code * row[c];
}
}
}
out
}
}
pub fn fit_sparse_dictionary(
x: ArrayView2<'_, f32>,
config: &SparseDictConfig,
) -> Result<SparseDictFit, String> {
update::run(x, config)
}