use crate::error::Result;
use crate::stoicheia::config::StoicheiaConfig;
use crate::stoicheia::fast::{self, RnnWeights};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NeuronRole {
RunningMax,
RunningMin,
MaxIncrement,
LeaveOneOutMax,
RecentInput,
Comparator,
Unknown,
}
pub struct NeuronProbeResult {
pub neuron: usize,
pub role: NeuronRole,
pub correlation: f32,
}
pub struct ProbeReport {
pub neurons: Vec<NeuronProbeResult>,
pub n_probes: usize,
}
const CORRELATION_THRESHOLD: f32 = 0.8;
#[allow(clippy::needless_range_loop)]
pub fn probe_neurons(
weights: &RnnWeights,
config: &StoicheiaConfig,
n_probes: usize,
) -> Result<ProbeReport> {
let h = weights.hidden_size;
let seq_len = config.seq_len;
let out_size = weights.output_size;
let inputs = generate_probe_inputs(n_probes, seq_len);
let mut neuron_activations = vec![vec![0.0_f32; n_probes]; h];
let mut pre_acts = vec![0.0_f32; seq_len * h];
let mut output = vec![0.0_f32; out_size];
for (idx, input) in inputs.iter().enumerate() {
fast::forward_fast_traced(weights, input, &mut pre_acts, &mut output, config)?;
for j in 0..h {
#[allow(clippy::indexing_slicing)]
let pre = pre_acts[(seq_len - 1) * h + j];
#[allow(clippy::indexing_slicing)]
{
neuron_activations[j][idx] = pre.max(0.0);
}
}
}
let mut results = Vec::with_capacity(h);
for j in 0..h {
#[allow(clippy::indexing_slicing)]
let activations = &neuron_activations[j];
let (role, corr) = best_probe_match(activations, &inputs, seq_len);
results.push(NeuronProbeResult {
neuron: j,
role,
correlation: corr,
});
}
Ok(ProbeReport {
neurons: results,
n_probes,
})
}
fn generate_probe_inputs(n_probes: usize, seq_len: usize) -> Vec<Vec<f32>> {
let mut inputs = Vec::with_capacity(n_probes);
let mut state = 42_u64;
for _ in 0..n_probes {
let mut input = Vec::with_capacity(seq_len);
for _ in 0..seq_len {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let uniform = (state >> 33) as f32 / (1_u64 << 31) as f32; let value = (uniform - 0.5) * 6.0; input.push(value);
}
inputs.push(input);
}
inputs
}
#[allow(clippy::too_many_lines)]
fn best_probe_match(activations: &[f32], inputs: &[Vec<f32>], seq_len: usize) -> (NeuronRole, f32) {
let mut best_role = NeuronRole::Unknown;
let mut best_corr = 0.0_f32;
let ref_running_max: Vec<f32> = inputs
.iter()
.map(|inp| {
let mut mx = f32::NEG_INFINITY;
for &x in inp {
if x > mx {
mx = x;
}
}
mx.max(0.0) })
.collect();
let c = pearson_abs(activations, &ref_running_max);
if c > best_corr {
best_corr = c;
best_role = NeuronRole::RunningMax;
}
let ref_running_min: Vec<f32> = inputs
.iter()
.map(|inp| {
let mut mn = f32::INFINITY;
for &x in inp {
if x < mn {
mn = x;
}
}
mn.max(0.0) })
.collect();
let c = pearson_abs(activations, &ref_running_min);
if c > best_corr {
best_corr = c;
best_role = NeuronRole::RunningMin;
}
if seq_len >= 2 {
let ref_max_inc: Vec<f32> = inputs
.iter()
.map(|inp| {
let mut prev_max = f32::NEG_INFINITY;
let mut curr_max = f32::NEG_INFINITY;
for (t, &x) in inp.iter().enumerate() {
if t < seq_len - 1 && x > prev_max {
prev_max = x;
}
if x > curr_max {
curr_max = x;
}
}
(curr_max.max(0.0) - prev_max.max(0.0)).max(0.0)
})
.collect();
let c = pearson_abs(activations, &ref_max_inc);
if c > best_corr {
best_corr = c;
best_role = NeuronRole::MaxIncrement;
}
}
let ref_loo_max: Vec<f32> = inputs
.iter()
.map(|inp| {
let overall_max = inp.iter().copied().fold(f32::NEG_INFINITY, f32::max);
#[allow(clippy::indexing_slicing)]
let last = inp[seq_len - 1];
(overall_max.max(0.0) - last).max(0.0)
})
.collect();
let c = pearson_abs(activations, &ref_loo_max);
if c > best_corr {
best_corr = c;
best_role = NeuronRole::LeaveOneOutMax;
}
#[allow(clippy::indexing_slicing)]
let ref_recent: Vec<f32> = inputs.iter().map(|inp| inp[seq_len - 1].max(0.0)).collect();
let c = pearson_abs(activations, &ref_recent);
if c > best_corr {
best_corr = c;
best_role = NeuronRole::RecentInput;
}
if seq_len >= 2 {
for a in 0..seq_len {
for b in 0..seq_len {
if a == b {
continue;
}
#[allow(clippy::indexing_slicing)]
let ref_comp: Vec<f32> = inputs
.iter()
.map(|inp| (inp[a] - inp[b]).max(0.0))
.collect();
let c = pearson_abs(activations, &ref_comp);
if c > best_corr {
best_corr = c;
best_role = NeuronRole::Comparator;
}
}
}
}
if best_corr < CORRELATION_THRESHOLD {
best_role = NeuronRole::Unknown;
}
(best_role, best_corr)
}
fn pearson_abs(a: &[f32], b: &[f32]) -> f32 {
let n = a.len().min(b.len());
if n == 0 {
return 0.0;
}
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let n_f = n as f32;
let sum_a: f32 = a.iter().take(n).sum();
let sum_b: f32 = b.iter().take(n).sum();
let mean_a = sum_a / n_f;
let mean_b = sum_b / n_f;
let mut cov = 0.0_f32;
let mut var_a = 0.0_f32;
let mut var_b = 0.0_f32;
for i in 0..n {
#[allow(clippy::indexing_slicing)]
{
let da = a[i] - mean_a;
let db = b[i] - mean_b;
cov += da * db;
var_a = da.mul_add(da, var_a);
var_b = db.mul_add(db, var_b);
}
}
let denom = (var_a * var_b).sqrt();
if denom.partial_cmp(&1e-12) != Some(std::cmp::Ordering::Greater) {
return 0.0;
}
(cov / denom).abs()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pearson_abs_perfect_correlation() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b = vec![2.0, 4.0, 6.0, 8.0, 10.0];
let c = pearson_abs(&a, &b);
assert!((c - 1.0).abs() < 1e-5, "corr = {c}");
}
#[test]
fn pearson_abs_negative_correlation() {
let a = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let b = vec![10.0, 8.0, 6.0, 4.0, 2.0];
let c = pearson_abs(&a, &b);
assert!((c - 1.0).abs() < 1e-5, "corr = {c}");
}
#[test]
fn pearson_abs_zero_variance() {
let a = vec![1.0, 1.0, 1.0];
let b = vec![1.0, 2.0, 3.0];
let c = pearson_abs(&a, &b);
assert!(c < 1e-6, "corr = {c}");
}
#[test]
fn probe_runs_on_tiny_model() {
let weights = RnnWeights::new(
vec![1.0, -1.0],
vec![0.0, 0.0, 0.0, 0.0],
vec![1.0, -1.0, -1.0, 1.0],
2,
2,
);
let config = crate::stoicheia::config::StoicheiaConfig::from_task(
crate::stoicheia::config::StoicheiaTask::SecondArgmax,
2,
2,
);
let report = probe_neurons(&weights, &config, 100).unwrap();
assert_eq!(report.neurons.len(), 2);
assert_eq!(report.n_probes, 100);
for n in &report.neurons {
assert!(n.correlation >= 0.0);
}
}
}