use crate::error::Result;
use crate::stoicheia::config::{StoicheiaConfig, StoicheiaOutput, StoicheiaTask};
use crate::stoicheia::fast::{self, RnnWeights, argmax_f32};
pub trait MechanisticEstimator {
fn predict(&self, input: &[f32]) -> u32;
fn description(&self) -> &'static str;
}
pub struct OracleEstimator {
task: StoicheiaTask,
seq_len: usize,
}
impl OracleEstimator {
#[must_use]
pub const fn new(task: StoicheiaTask, seq_len: usize) -> Self {
Self { task, seq_len }
}
}
impl MechanisticEstimator for OracleEstimator {
fn predict(&self, input: &[f32]) -> u32 {
match self.task {
StoicheiaTask::SecondArgmax => {
let (mut max_pos, mut second_pos) = (0, 0);
let mut max_val = f32::NEG_INFINITY;
let mut second_val = f32::NEG_INFINITY;
for (i, &x) in input.iter().enumerate() {
if x > max_val {
second_val = max_val;
second_pos = max_pos;
max_val = x;
max_pos = i;
} else if x > second_val {
second_val = x;
second_pos = i;
}
}
if second_val == f32::NEG_INFINITY && input.len() > 1 {
second_pos = usize::from(max_pos == 0);
}
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
{
second_pos as u32
}
}
StoicheiaTask::Argmedian => {
let mut indexed: Vec<(usize, f32)> = input.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let median_rank = self.seq_len / 2;
#[allow(clippy::indexing_slicing)]
let pos = indexed[median_rank].0;
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
{
pos as u32
}
}
StoicheiaTask::Median | StoicheiaTask::LongestCycle => 0,
}
}
fn description(&self) -> &'static str {
"oracle (ground-truth task function)"
}
}
pub struct SurpriseReport {
pub model_accuracy: f32,
pub estimate_accuracy: f32,
pub disagreement_rate: f32,
pub chance_accuracy: f32,
pub param_count: usize,
pub n_samples: usize,
pub agreement_rate: f32,
pub estimator_description: String,
}
pub fn surprise_accounting(
weights: &RnnWeights,
estimator: &dyn MechanisticEstimator,
config: &StoicheiaConfig,
n_samples: usize,
) -> Result<SurpriseReport> {
let seq_len = config.seq_len;
let out_size = weights.output_size;
let inputs = generate_inputs(n_samples, seq_len);
let flat_inputs: Vec<f32> = inputs.iter().flatten().copied().collect();
let mut model_outputs = vec![0.0_f32; n_samples * out_size];
fast::forward_fast(weights, &flat_inputs, &mut model_outputs, n_samples, config)?;
let mut model_correct = 0_usize;
let mut estimate_correct = 0_usize;
let mut agree = 0_usize;
let mut mse_sum = 0.0_f32;
for (i, input) in inputs.iter().enumerate() {
#[allow(clippy::indexing_slicing)]
let model_row = &model_outputs[i * out_size..(i + 1) * out_size];
let model_pred = argmax_f32(model_row);
let est_pred = estimator.predict(input);
let oracle = OracleEstimator::new(config.task, seq_len);
let truth = oracle.predict(input);
if model_pred == truth {
model_correct += 1;
}
if est_pred == truth {
estimate_correct += 1;
}
if model_pred == est_pred {
agree += 1;
}
let err = if model_pred == est_pred { 0.0_f32 } else { 1.0 };
mse_sum += err;
}
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let n_f = n_samples as f32;
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let model_accuracy = model_correct as f32 / n_f;
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let estimate_accuracy = estimate_correct as f32 / n_f;
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let agreement_rate = agree as f32 / n_f;
let disagreement_rate = mse_sum / n_f;
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let chance_accuracy = if config.output == StoicheiaOutput::Distribution {
1.0 / out_size as f32
} else {
0.0 };
Ok(SurpriseReport {
model_accuracy,
estimate_accuracy,
disagreement_rate,
chance_accuracy,
param_count: config.param_count(),
n_samples,
agreement_rate,
estimator_description: estimator.description().to_string(),
})
}
fn generate_inputs(n_samples: usize, seq_len: usize) -> Vec<Vec<f32>> {
let mut inputs = Vec::with_capacity(n_samples);
let mut state = 123_456_789_u64;
for _ in 0..n_samples {
let mut input = Vec::with_capacity(seq_len);
for _ in 0..seq_len {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let uniform = (state >> 33) as f32 / (1_u64 << 31) as f32;
let value = (uniform - 0.5) * 6.0; input.push(value);
}
inputs.push(input);
}
inputs
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stoicheia::config::StoicheiaConfig;
#[test]
fn oracle_matches_task_second_argmax() {
let oracle = OracleEstimator::new(StoicheiaTask::SecondArgmax, 3);
let pred = oracle.predict(&[1.0, 3.0, 2.0]);
assert_eq!(pred, 2);
let pred = oracle.predict(&[5.0, 1.0, 3.0]);
assert_eq!(pred, 2);
}
#[test]
fn oracle_matches_task_argmedian() {
let oracle = OracleEstimator::new(StoicheiaTask::Argmedian, 3);
let pred = oracle.predict(&[1.0, 3.0, 2.0]);
assert_eq!(pred, 2);
}
#[test]
fn surprise_accounting_runs() {
let weights = RnnWeights::new(
vec![1.0, -1.0],
vec![0.0, 0.0, 0.0, 0.0],
vec![1.0, -1.0, -1.0, 1.0],
2,
2,
);
let config = StoicheiaConfig::from_task(StoicheiaTask::SecondArgmax, 2, 2);
let oracle = OracleEstimator::new(StoicheiaTask::SecondArgmax, 2);
let report = surprise_accounting(&weights, &oracle, &config, 100).unwrap();
assert_eq!(report.n_samples, 100);
assert!(report.model_accuracy >= 0.0);
assert!(report.model_accuracy <= 1.0);
assert!(report.estimate_accuracy >= 0.0);
assert!(report.agreement_rate >= 0.0);
assert!((report.chance_accuracy - 0.5).abs() < 1e-6);
assert_eq!(report.param_count, 10);
}
#[test]
fn perfect_model_high_agreement() {
let weights = RnnWeights::new(
vec![1.0, -1.0],
vec![0.0, 0.0, 0.0, 0.0],
vec![1.0, -1.0, -1.0, 1.0],
2,
2,
);
let config = StoicheiaConfig::from_task(StoicheiaTask::SecondArgmax, 2, 2);
let oracle = OracleEstimator::new(StoicheiaTask::SecondArgmax, 2);
let report = surprise_accounting(&weights, &oracle, &config, 500).unwrap();
assert!(
report.estimate_accuracy > report.chance_accuracy,
"estimate accuracy {} should exceed chance {}",
report.estimate_accuracy,
report.chance_accuracy
);
}
}