use ndarray::{Array1, ArrayView1};
use crate::terms::analytic_penalties::{AnalyticPenalty, SparsityPenalty};
use crate::terms::atom_codes::{BitVec, SparseAtomCode, SparseAtomCodes};
use crate::terms::latent_coord::LatentCoordValues;
#[derive(Debug, Clone, Copy)]
pub struct ShapeRef {
pub id: usize,
pub intrinsic_dim: usize,
pub basis_size: usize,
}
#[derive(Debug, Clone)]
pub struct AtomRecord {
pub shape: ShapeRef,
pub coords: LatentCoordValues,
}
impl AtomRecord {
pub fn new(shape: ShapeRef, coords: LatentCoordValues) -> Self {
assert_eq!(
coords.latent_dim(),
shape.intrinsic_dim,
"AtomRecord: coord latent_dim {} != shape.intrinsic_dim {}",
coords.latent_dim(),
shape.intrinsic_dim,
);
Self { shape, coords }
}
pub fn intrinsic_dim(&self) -> usize {
self.shape.intrinsic_dim
}
}
#[derive(Debug, Clone)]
pub struct AtomLibrary {
atoms: Vec<AtomRecord>,
n_obs: usize,
}
impl AtomLibrary {
pub fn new(atoms: Vec<AtomRecord>) -> Result<Self, String> {
if atoms.is_empty() {
return Err("AtomLibrary::new: at least one atom required".into());
}
let n_obs = atoms[0].coords.n_obs();
for (k, a) in atoms.iter().enumerate() {
if a.coords.n_obs() != n_obs {
return Err(format!(
"AtomLibrary::new: atom {k} has n_obs={} but atom 0 has n_obs={n_obs}",
a.coords.n_obs()
));
}
}
Ok(Self { atoms, n_obs })
}
pub fn n_obs(&self) -> usize {
self.n_obs
}
pub fn k_atoms(&self) -> usize {
self.atoms.len()
}
pub fn atom(&self, k: usize) -> &AtomRecord {
&self.atoms[k]
}
pub fn atom_mut(&mut self, k: usize) -> &mut AtomRecord {
&mut self.atoms[k]
}
pub fn iter(&self) -> impl Iterator<Item = &AtomRecord> {
self.atoms.iter()
}
pub fn total_intrinsic_dim(&self) -> usize {
self.atoms.iter().map(|a| a.intrinsic_dim()).sum()
}
pub fn fresh_codes(&self) -> SparseAtomCodes {
SparseAtomCodes::empty(self.n_obs, self.k_atoms())
}
}
pub trait AssignmentSparsityCoupling {
fn penalty_value_and_grad(
&self,
penalty: &SparsityPenalty,
free_amplitudes_row: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> (f64, Array1<f64>);
}
pub trait AtomSelectionStrategy: AssignmentSparsityCoupling {
fn name(&self) -> &'static str;
fn apply(&self, free_amplitudes_row: ArrayView1<'_, f64>) -> SparseAtomCode;
fn backward(
&self,
free_amplitudes_row: ArrayView1<'_, f64>,
code: &SparseAtomCode,
grad_a_row: ArrayView1<'_, f64>,
) -> Array1<f64>;
}
#[derive(Debug, Clone)]
pub struct EntropicSoftmax {
pub temperature: f64,
pub mask_threshold: Option<f64>,
}
impl EntropicSoftmax {
pub fn new(temperature: f64) -> Self {
assert!(
temperature.is_finite() && temperature > 0.0,
"EntropicSoftmax temperature must be finite and positive, got {temperature}"
);
Self {
temperature,
mask_threshold: None,
}
}
pub fn with_mask_threshold(mut self, thr: f64) -> Self {
assert!(
thr.is_finite(),
"EntropicSoftmax mask threshold must be finite, got {thr}"
);
self.mask_threshold = Some(thr);
self
}
fn softmax(&self, logits: ArrayView1<'_, f64>) -> Array1<f64> {
let k = logits.len();
let tau = self.temperature;
let mut m = f64::NEG_INFINITY;
for &l in logits.iter() {
let s = l / tau;
if s > m {
m = s;
}
}
let mut out = Array1::<f64>::zeros(k);
let mut s = 0.0;
for i in 0..k {
let v = (logits[i] / tau - m).exp();
out[i] = v;
s += v;
}
assert!(s > 0.0);
for v in out.iter_mut() {
*v /= s;
}
out
}
pub fn jvp_logits(&self, a: ArrayView1<'_, f64>, g_a: ArrayView1<'_, f64>) -> Array1<f64> {
let k = a.len();
let mut dot = 0.0;
for i in 0..k {
dot += a[i] * g_a[i];
}
let inv_tau = 1.0 / self.temperature;
Array1::<f64>::from_iter((0..k).map(|i| a[i] * (g_a[i] - dot) * inv_tau))
}
}
impl AtomSelectionStrategy for EntropicSoftmax {
fn name(&self) -> &'static str {
"entropic_softmax"
}
fn apply(&self, free_amplitudes_row: ArrayView1<'_, f64>) -> SparseAtomCode {
let a = self.softmax(free_amplitudes_row);
let k = a.len();
let mut mask = BitVec::ones(k);
if let Some(thr) = self.mask_threshold {
for i in 0..k {
if a[i] < thr {
mask.set(i, false);
}
}
}
let mut weights = vec![0.0_f64; k];
for i in 0..k {
if mask.get(i) {
weights[i] = a[i];
}
}
SparseAtomCode {
active_mask: mask,
weights,
}
}
fn backward(
&self,
free_amplitudes_row: ArrayView1<'_, f64>,
atom_code: &SparseAtomCode,
grad_a_row: ArrayView1<'_, f64>,
) -> Array1<f64> {
assert_eq!(
grad_a_row.len(),
free_amplitudes_row.len(),
"EntropicSoftmax backward gradient length mismatch"
);
assert_eq!(
atom_code.k_atoms(),
free_amplitudes_row.len(),
"EntropicSoftmax backward code/free-amplitude length mismatch"
);
assert!(
atom_code.weights.iter().all(|weight| weight.is_finite()),
"EntropicSoftmax backward requires finite assignment weights"
);
let a = self.softmax(free_amplitudes_row);
self.jvp_logits(a.view(), grad_a_row)
}
}
impl AssignmentSparsityCoupling for EntropicSoftmax {
fn penalty_value_and_grad(
&self,
sparsity_penalty: &SparsityPenalty,
free_amplitudes_row: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> (f64, Array1<f64>) {
assert_eq!(
rho.len(),
sparsity_penalty.rho_count(),
"EntropicSoftmax sparsity rho length mismatch"
);
assert!(
rho.iter().all(|value| value.is_finite()),
"EntropicSoftmax sparsity rho must be finite"
);
let k = free_amplitudes_row.len();
(0.0, Array1::<f64>::zeros(k))
}
}
#[derive(Debug, Clone, Copy)]
pub struct TopK {
pub k: usize,
}
impl TopK {
pub fn new(k: usize) -> Self {
assert!(k > 0, "TopK requires k > 0");
Self { k }
}
fn topk_indices(&self, amps: ArrayView1<'_, f64>) -> Vec<usize> {
let n = amps.len();
let k_use = self.k.min(n);
if k_use == 0 {
return Vec::new();
}
let mut idx: Vec<usize> = (0..n).collect();
let pivot = k_use.saturating_sub(1).min(n - 1);
idx.sort_by(|&a, &b| {
amps[b]
.partial_cmp(&s[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
idx.truncate(pivot + 1);
idx
}
pub fn backward_straight_through(
&self,
code: &SparseAtomCode,
grad_a_row: ArrayView1<'_, f64>,
) -> Array1<f64> {
let k = grad_a_row.len();
let mut out = Array1::<f64>::zeros(k);
for i in code.active_mask.iter_ones() {
out[i] = grad_a_row[i];
}
out
}
}
impl AtomSelectionStrategy for TopK {
fn name(&self) -> &'static str {
"topk"
}
fn apply(&self, free_amplitudes_row: ArrayView1<'_, f64>) -> SparseAtomCode {
let k_total = free_amplitudes_row.len();
let mut mask = BitVec::zeros(k_total);
let mut weights = vec![0.0_f64; k_total];
for i in self.topk_indices(free_amplitudes_row) {
mask.set(i, true);
weights[i] = free_amplitudes_row[i];
}
SparseAtomCode {
active_mask: mask,
weights,
}
}
fn backward(
&self,
free_amplitudes_row: ArrayView1<'_, f64>,
code: &SparseAtomCode,
grad_a_row: ArrayView1<'_, f64>,
) -> Array1<f64> {
assert_eq!(
free_amplitudes_row.len(),
grad_a_row.len(),
"TopK backward free-amplitude/gradient length mismatch"
);
assert_eq!(
code.k_atoms(),
grad_a_row.len(),
"TopK backward code/gradient length mismatch"
);
self.backward_straight_through(code, grad_a_row)
}
}
impl AssignmentSparsityCoupling for TopK {
fn penalty_value_and_grad(
&self,
sparsity_penalty: &SparsityPenalty,
free_amplitudes_row: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> (f64, Array1<f64>) {
assert_eq!(
rho.len(),
sparsity_penalty.rho_count(),
"TopK sparsity rho length mismatch"
);
assert!(
rho.iter().all(|value| value.is_finite()),
"TopK sparsity rho must be finite"
);
let k = free_amplitudes_row.len();
(0.0, Array1::<f64>::zeros(k))
}
}
#[derive(Debug, Clone)]
pub struct L1Relaxed {
pub active_threshold: f64,
}
impl L1Relaxed {
pub fn new() -> Self {
Self {
active_threshold: 0.0,
}
}
pub fn with_threshold(thr: f64) -> Self {
Self {
active_threshold: thr,
}
}
}
impl Default for L1Relaxed {
fn default() -> Self {
Self::new()
}
}
impl AtomSelectionStrategy for L1Relaxed {
fn name(&self) -> &'static str {
"l1_relaxed"
}
fn apply(&self, free_amplitudes_row: ArrayView1<'_, f64>) -> SparseAtomCode {
let k = free_amplitudes_row.len();
let mut mask = BitVec::zeros(k);
let mut weights = vec![0.0_f64; k];
for i in 0..k {
let a = free_amplitudes_row[i].max(0.0);
if a > self.active_threshold {
mask.set(i, true);
weights[i] = a;
}
}
SparseAtomCode {
active_mask: mask,
weights,
}
}
fn backward(
&self,
free_amplitudes_row: ArrayView1<'_, f64>,
code: &SparseAtomCode,
grad_a_row: ArrayView1<'_, f64>,
) -> Array1<f64> {
let k = grad_a_row.len();
let mut out = Array1::<f64>::zeros(k);
for i in code.active_mask.iter_ones() {
if free_amplitudes_row[i] > 0.0 {
out[i] = grad_a_row[i];
}
}
out
}
}
impl AssignmentSparsityCoupling for L1Relaxed {
fn penalty_value_and_grad(
&self,
penalty: &SparsityPenalty,
free_amplitudes_row: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> (f64, Array1<f64>) {
let k = free_amplitudes_row.len();
let clipped = Array1::<f64>::from_iter((0..k).map(|i| free_amplitudes_row[i].max(0.0)));
let v = penalty.value(clipped.view(), rho);
let mut g = penalty.grad_target(clipped.view(), rho);
for i in 0..k {
if free_amplitudes_row[i] <= 0.0 {
g[i] = 0.0;
}
}
(v, g)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::terms::latent_coord::{LatentCoordValues, LatentIdMode};
use ndarray::array;
fn lib() -> AtomLibrary {
let c0 = LatentCoordValues::from_matrix(
array![[0.0, 0.0], [0.1, 0.2]].view(),
LatentIdMode::None,
);
let c1 = LatentCoordValues::from_matrix(array![[0.0], [1.0]].view(), LatentIdMode::None);
AtomLibrary::new(vec![
AtomRecord::new(
ShapeRef {
id: 0,
intrinsic_dim: 2,
basis_size: 8,
},
c0,
),
AtomRecord::new(
ShapeRef {
id: 1,
intrinsic_dim: 1,
basis_size: 5,
},
c1,
),
])
.expect("library")
}
#[test]
fn library_construct() {
let l = lib();
assert_eq!(l.k_atoms(), 2);
assert_eq!(l.n_obs(), 2);
assert_eq!(l.total_intrinsic_dim(), 3);
}
#[test]
fn softmax_is_simplex() {
let s = EntropicSoftmax::new(1.0);
let logits = array![1.0_f64, 2.0, 3.0];
let code = s.apply(logits.view());
let sum: f64 = code.weights.iter().sum();
assert!((sum - 1.0).abs() < 1e-12);
assert_eq!(code.active_mask.count_ones(), 3);
}
#[test]
fn topk_keeps_top() {
let t = TopK::new(2);
let amps = array![0.1_f64, 0.9, 0.4, 0.5];
let code = t.apply(amps.view());
assert_eq!(code.active_mask.count_ones(), 2);
assert!(code.active_mask.get(1));
assert!(code.active_mask.get(3));
}
#[test]
fn l1_relaxed_clips_negatives() {
let l = L1Relaxed::new();
let amps = array![-0.5_f64, 0.3, -0.1, 0.7];
let code = l.apply(amps.view());
assert_eq!(code.active_mask.count_ones(), 2);
assert_eq!(code.weights[1], 0.3);
assert_eq!(code.weights[3], 0.7);
}
}