use ndarray::{ArrayView1, ArrayView2};
use crate::inference::row_metric::{MetricProvenance, RowMetric};
use crate::terms::sae::manifold::SaeManifoldTerm;
pub const SAE_TRUST_ACTIVE_MASS_FLOOR: f64 = 1e-6;
#[derive(Clone, Debug, PartialEq)]
pub struct AtomLensEntry {
pub name: String,
pub presence: f64,
pub coupling: Option<f64>,
pub presence_normalized: f64,
pub coupling_normalized: Option<f64>,
pub discrepancy: Option<f64>,
}
impl AtomLensEntry {
pub fn is_represented_not_used(&self) -> bool {
match self.discrepancy {
Some(d) => d >= REPRESENTED_NOT_USED_THRESHOLD,
None => false,
}
}
pub fn is_used(&self) -> bool {
match self.discrepancy {
Some(d) => d <= USED_THRESHOLD,
None => false,
}
}
}
const REPRESENTED_NOT_USED_THRESHOLD: f64 = 0.5;
const USED_THRESHOLD: f64 = 0.0;
#[derive(Clone, Debug, PartialEq)]
pub struct AtomTwoLensReport {
pub atoms: Vec<AtomLensEntry>,
pub coupling_provenance: Option<MetricProvenance>,
}
impl AtomTwoLensReport {
pub fn coupling_available(&self) -> bool {
self.coupling_provenance
.is_some_and(metric_carries_behavior)
}
}
fn metric_carries_behavior(p: MetricProvenance) -> bool {
match p {
MetricProvenance::Euclidean => false,
MetricProvenance::OutputFisher { .. }
| MetricProvenance::OutputFisherDownstream { .. }
| MetricProvenance::WhitenedStructured { .. } => true,
}
}
pub fn atom_two_lens(
model: &SaeManifoldTerm,
metric: &RowMetric,
assignments_override: Option<ArrayView2<'_, f64>>,
) -> AtomTwoLensReport {
let n = model.n_obs();
let k = model.k_atoms();
let provenance = metric.provenance();
let coupling_axis_available = metric_carries_behavior(provenance)
&& metric.n_rows() == n
&& metric.p_out() == model.output_dim();
let assignments_owned;
let assignments = match assignments_override {
Some(view) => view,
None => {
assignments_owned = model.assignment.assignments();
assignments_owned.view()
}
};
let mut presence = vec![0.0_f64; k];
let mut coupling_raw = vec![0.0_f64; k];
let mut any_coupling = vec![false; k];
for (atom_idx, atom) in model.atoms.iter().enumerate() {
let decoder_norm = atom
.decoder_coefficients
.iter()
.map(|&b| b * b)
.sum::<f64>()
.sqrt();
let latent_dim = atom.latent_dim;
let mut active_mass_sum = 0.0_f64;
let mut active_row_count = 0.0_f64;
let mut coupling_sum = 0.0_f64;
for row in 0..n {
let mass = assignments[[row, atom_idx]];
if !(mass > SAE_TRUST_ACTIVE_MASS_FLOOR) {
continue;
}
active_mass_sum += mass;
active_row_count += 1.0;
if coupling_axis_available {
let mut row_tangent_mass = 0.0_f64;
for axis in 0..latent_dim {
let dg = atom.decoded_derivative_row(row, axis);
let dg_view: ArrayView1<'_, f64> = dg.view();
row_tangent_mass += metric.fisher_mass(row, dg_view);
}
coupling_sum += mass * row_tangent_mass;
any_coupling[atom_idx] = true;
}
}
let mean_active_mass = if active_row_count > 0.0 {
active_mass_sum / active_row_count
} else {
0.0
};
presence[atom_idx] = mean_active_mass * decoder_norm;
if coupling_axis_available && active_row_count > 0.0 {
coupling_raw[atom_idx] = coupling_sum / active_row_count;
}
}
let presence_max = presence.iter().copied().fold(0.0_f64, f64::max);
let coupling_max = coupling_raw
.iter()
.zip(any_coupling.iter())
.filter(|&(_, &has)| has)
.map(|(&c, _)| c)
.fold(0.0_f64, f64::max);
let mut entries = Vec::with_capacity(k);
for (atom_idx, atom) in model.atoms.iter().enumerate() {
let p = presence[atom_idx];
let presence_normalized = if presence_max > 0.0 {
p / presence_max
} else {
0.0
};
let (coupling, coupling_normalized, discrepancy) =
if coupling_axis_available && any_coupling[atom_idx] {
let c = coupling_raw[atom_idx];
let c_norm = if coupling_max > 0.0 {
c / coupling_max
} else {
0.0
};
(Some(c), Some(c_norm), Some(presence_normalized - c_norm))
} else {
(None, None, None)
};
entries.push(AtomLensEntry {
name: atom.name.clone(),
presence: p,
coupling,
presence_normalized,
coupling_normalized,
discrepancy,
});
}
AtomTwoLensReport {
atoms: entries,
coupling_provenance: Some(provenance),
}
}