use crate::error::Result;
use crate::stoicheia::config::StoicheiaConfig;
use crate::stoicheia::fast::{self, RnnWeights, argmax_f32};
pub struct NeuronAblationResult {
pub neuron: usize,
pub ablated_accuracy: f32,
pub accuracy_delta: f32,
}
pub struct AblationSweep {
pub baseline_accuracy: f32,
pub results: Vec<NeuronAblationResult>,
pub n_inputs: usize,
}
pub struct PairAblationResult {
pub neuron_a: usize,
pub neuron_b: usize,
pub ablated_accuracy: f32,
pub accuracy_delta: f32,
pub interaction_score: f32,
}
pub fn ablate_neurons(
weights: &RnnWeights,
inputs: &[f32],
targets: &[u32],
n_inputs: usize,
config: &StoicheiaConfig,
) -> Result<AblationSweep> {
let h = weights.hidden_size;
let baseline_accuracy = fast::accuracy(weights, inputs, targets, n_inputs, config)?;
let mut results = Vec::with_capacity(h);
let out_size = weights.output_size;
let mut outputs = vec![0.0_f32; n_inputs * out_size];
for neuron in 0..h {
let mut ablated = vec![false; h];
#[allow(clippy::indexing_slicing)]
{
ablated[neuron] = true;
}
fast::forward_fast_ablated(weights, inputs, &mut outputs, n_inputs, config, &ablated)?;
let mut correct = 0_usize;
for (i, target) in targets.iter().enumerate() {
#[allow(clippy::indexing_slicing)]
let row = &outputs[i * out_size..(i + 1) * out_size];
let pred = argmax_f32(row);
if *target == pred {
correct += 1;
}
}
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let ablated_accuracy = correct as f32 / n_inputs as f32;
results.push(NeuronAblationResult {
neuron,
ablated_accuracy,
accuracy_delta: ablated_accuracy - baseline_accuracy,
});
}
results.sort_by(|a, b| {
a.accuracy_delta
.partial_cmp(&b.accuracy_delta)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(AblationSweep {
baseline_accuracy,
results,
n_inputs,
})
}
pub fn ablate_neuron_pairs(
weights: &RnnWeights,
inputs: &[f32],
targets: &[u32],
n_inputs: usize,
config: &StoicheiaConfig,
single_results: &AblationSweep,
) -> Result<Vec<PairAblationResult>> {
let h = weights.hidden_size;
let out_size = weights.output_size;
let mut outputs = vec![0.0_f32; n_inputs * out_size];
let mut single_deltas = vec![0.0_f32; h];
for r in &single_results.results {
#[allow(clippy::indexing_slicing)]
{
single_deltas[r.neuron] = r.accuracy_delta;
}
}
let baseline = single_results.baseline_accuracy;
let mut pair_results = Vec::new();
for a in 0..h {
for b in (a + 1)..h {
let mut ablated = vec![false; h];
#[allow(clippy::indexing_slicing)]
{
ablated[a] = true;
ablated[b] = true;
}
fast::forward_fast_ablated(weights, inputs, &mut outputs, n_inputs, config, &ablated)?;
let mut correct = 0_usize;
for (i, target) in targets.iter().enumerate() {
#[allow(clippy::indexing_slicing)]
let row = &outputs[i * out_size..(i + 1) * out_size];
let pred = argmax_f32(row);
if *target == pred {
correct += 1;
}
}
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let ablated_accuracy = correct as f32 / n_inputs as f32;
let pair_delta = ablated_accuracy - baseline;
#[allow(clippy::indexing_slicing)]
let interaction_score = pair_delta - (single_deltas[a] + single_deltas[b]);
pair_results.push(PairAblationResult {
neuron_a: a,
neuron_b: b,
ablated_accuracy,
accuracy_delta: pair_delta,
interaction_score,
});
}
}
Ok(pair_results)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stoicheia::config::{StoicheiaConfig, StoicheiaTask};
fn test_weights() -> RnnWeights {
RnnWeights::new(
vec![1.0, -1.0],
vec![0.0, 0.0, 0.0, 0.0],
vec![1.0, -1.0, -1.0, 1.0],
2,
2,
)
}
fn test_config() -> StoicheiaConfig {
StoicheiaConfig::from_task(StoicheiaTask::SecondArgmax, 2, 2)
}
#[test]
fn no_ablation_matches_baseline() {
let weights = test_weights();
let config = test_config();
let inputs = vec![0.5_f32, -0.3, 1.0, 2.0, -1.0, 0.7];
let n = 3;
let mut outputs = vec![0.0_f32; n * 2];
fast::forward_fast(&weights, &inputs, &mut outputs, n, &config).unwrap();
let targets: Vec<u32> = (0..n)
.map(|i| argmax_f32(&outputs[i * 2..(i + 1) * 2]))
.collect();
let sweep = ablate_neurons(&weights, &inputs, &targets, n, &config).unwrap();
assert!(
(sweep.baseline_accuracy - 1.0).abs() < 1e-6,
"baseline = {}",
sweep.baseline_accuracy
);
}
#[test]
fn full_ablation_near_chance() {
let weights = test_weights();
let config = test_config();
let inputs: Vec<f32> = (0..200)
.map(|i| {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let v = ((i as f32) * 0.618_034).sin() * 3.0;
v
})
.collect();
let n = 100;
let mut outputs = vec![0.0_f32; n * 2];
fast::forward_fast(&weights, &inputs, &mut outputs, n, &config).unwrap();
let targets: Vec<u32> = (0..n)
.map(|i| argmax_f32(&outputs[i * 2..(i + 1) * 2]))
.collect();
let sweep = ablate_neurons(&weights, &inputs, &targets, n, &config).unwrap();
assert_eq!(sweep.results.len(), 2);
}
#[test]
fn pairwise_ablation_runs() {
let weights = test_weights();
let config = test_config();
let inputs = vec![0.5_f32, -0.3, 1.0, 2.0];
let n = 2;
let mut outputs = vec![0.0_f32; n * 2];
fast::forward_fast(&weights, &inputs, &mut outputs, n, &config).unwrap();
let targets: Vec<u32> = (0..n)
.map(|i| argmax_f32(&outputs[i * 2..(i + 1) * 2]))
.collect();
let sweep = ablate_neurons(&weights, &inputs, &targets, n, &config).unwrap();
let pairs = ablate_neuron_pairs(&weights, &inputs, &targets, n, &config, &sweep).unwrap();
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0].neuron_a, 0);
assert_eq!(pairs[0].neuron_b, 1);
}
}