use std::collections::HashMap;
use crate::error::{MIError, Result};
use crate::stoicheia::config::StoicheiaConfig;
use crate::stoicheia::fast::{self, RnnWeights};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ActivationPattern {
bits: [u64; 5],
hidden_size: usize,
seq_len: usize,
}
impl ActivationPattern {
#[must_use]
pub fn from_pre_activations(pre_acts: &[f32], hidden_size: usize, seq_len: usize) -> Self {
assert!(
hidden_size * seq_len <= 320,
"hidden_size ({hidden_size}) * seq_len ({seq_len}) = {} exceeds 320-bit capacity",
hidden_size * seq_len
);
let mut bits = [0_u64; 5];
for t in 0..seq_len {
for j in 0..hidden_size {
let idx = t * hidden_size + j;
#[allow(clippy::indexing_slicing)]
if pre_acts[idx] >= 0.0 {
#[allow(clippy::indexing_slicing)]
{
bits[idx / 64] |= 1_u64 << (idx % 64);
}
}
}
}
Self {
bits,
hidden_size,
seq_len,
}
}
#[must_use]
#[allow(clippy::indexing_slicing)]
pub const fn is_active(&self, t: usize, j: usize) -> bool {
let idx = t * self.hidden_size + j;
(self.bits[idx / 64] >> (idx % 64)) & 1 == 1
}
#[must_use]
pub fn count_active(&self) -> u32 {
self.bits.iter().map(|w| w.count_ones()).sum()
}
#[must_use]
pub const fn total_slots(&self) -> usize {
self.hidden_size * self.seq_len
}
#[must_use]
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
pub fn active_per_timestep(&self) -> Vec<u32> {
(0..self.seq_len)
.map(|t| {
(0..self.hidden_size)
.filter(|&j| self.is_active(t, j))
.count() as u32
})
.collect()
}
}
pub struct RegionInfo {
pub count: usize,
pub representative: Vec<f32>,
pub pattern: ActivationPattern,
}
pub struct RegionMap {
pub regions: Vec<RegionInfo>,
pub total_inputs: usize,
}
pub fn classify_regions(
weights: &RnnWeights,
inputs: &[f32],
n_inputs: usize,
config: &StoicheiaConfig,
) -> Result<RegionMap> {
let h = weights.hidden_size;
let seq_len = config.seq_len;
let out_size = weights.output_size;
let mut pre_acts = vec![0.0_f32; seq_len * h];
let mut output = vec![0.0_f32; out_size];
let mut by_pattern: HashMap<ActivationPattern, (usize, Vec<f32>)> = HashMap::new();
for i in 0..n_inputs {
#[allow(clippy::indexing_slicing)]
let input_slice = &inputs[i * seq_len..(i + 1) * seq_len];
fast::forward_fast_traced(weights, input_slice, &mut pre_acts, &mut output, config)?;
let pattern = ActivationPattern::from_pre_activations(&pre_acts, h, seq_len);
let entry = by_pattern
.entry(pattern)
.or_insert_with(|| (0, input_slice.to_vec()));
entry.0 += 1;
}
let mut regions: Vec<RegionInfo> = by_pattern
.into_iter()
.map(|(pattern, (count, representative))| RegionInfo {
count,
representative,
pattern,
})
.collect();
regions.sort_by_key(|r| std::cmp::Reverse(r.count));
Ok(RegionMap {
regions,
total_inputs: n_inputs,
})
}
#[allow(clippy::needless_range_loop)]
pub fn region_linear_map(
weights: &RnnWeights,
pattern: &ActivationPattern,
config: &StoicheiaConfig,
) -> Result<Vec<f32>> {
let h = weights.hidden_size;
let seq_len = config.seq_len;
let out_size = weights.output_size;
let mut a_mat = vec![0.0_f32; out_size * seq_len];
for t in 0..seq_len {
let mut vec_h = [0.0_f32; 32];
for j in 0..h {
#[allow(clippy::indexing_slicing)]
if pattern.is_active(t, j) {
vec_h[j] = weights.weight_ih[j];
}
}
for s in (t + 1)..seq_len {
let mut new_vec = [0.0_f32; 32];
for j in 0..h {
if pattern.is_active(s, j) {
let mut acc = 0.0_f32;
for k in 0..h {
#[allow(clippy::indexing_slicing)]
{
acc = weights.weight_hh[j * h + k].mul_add(vec_h[k], acc);
}
}
#[allow(clippy::indexing_slicing)]
{
new_vec[j] = acc;
}
}
}
vec_h = new_vec;
#[allow(clippy::indexing_slicing)]
if vec_h[..h].iter().any(|v| v.is_infinite()) {
return Err(MIError::Model(candle_core::Error::Msg(format!(
"region_linear_map: overflow at timestep {s} \
(W_hh likely has spectral radius > 1)"
))));
}
}
for o in 0..out_size {
let mut acc = 0.0_f32;
for j in 0..h {
#[allow(clippy::indexing_slicing)]
{
acc = weights.weight_oh[o * h + j].mul_add(vec_h[j], acc);
}
}
#[allow(clippy::indexing_slicing)]
{
a_mat[o * seq_len + t] = acc;
}
}
}
Ok(a_mat)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stoicheia::config::{StoicheiaConfig, StoicheiaTask};
use crate::stoicheia::fast::RnnWeights;
fn test_weights_2_2() -> RnnWeights {
RnnWeights::new(
vec![1.0, -1.0],
vec![0.0, 0.0, 0.0, 0.0],
vec![1.0, -1.0, -1.0, 1.0],
2,
2,
)
}
fn test_config_2_2() -> StoicheiaConfig {
StoicheiaConfig::from_task(StoicheiaTask::SecondArgmax, 2, 2)
}
#[test]
fn activation_pattern_roundtrip() {
let pre_acts = vec![0.5_f32, -0.3, 1.0, 0.0];
let pattern = ActivationPattern::from_pre_activations(&pre_acts, 2, 2);
assert!(pattern.is_active(0, 0));
assert!(!pattern.is_active(0, 1));
assert!(pattern.is_active(1, 0));
assert!(pattern.is_active(1, 1));
assert_eq!(pattern.count_active(), 3);
assert_eq!(pattern.total_slots(), 4);
assert_eq!(pattern.active_per_timestep(), vec![1, 2]);
}
#[test]
fn m2_2_has_few_regions() {
let weights = test_weights_2_2();
let config = test_config_2_2();
let n = 1000;
let inputs: Vec<f32> = (0..n * 2)
.map(|i| {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let v = ((i as f32) * 0.618_034).sin() * 3.0;
v
})
.collect();
let region_map = classify_regions(&weights, &inputs, n, &config).unwrap();
assert!(
region_map.regions.len() <= 16,
"found {} regions, expected ≤ 16",
region_map.regions.len()
);
assert_eq!(region_map.total_inputs, n);
let total: usize = region_map.regions.iter().map(|r| r.count).sum();
assert_eq!(total, n);
}
#[test]
fn linear_map_matches_forward() {
let weights = test_weights_2_2();
let config = test_config_2_2();
let input = vec![0.5_f32, -0.3];
let mut pre_acts = vec![0.0_f32; 4];
let mut traced_out = vec![0.0_f32; 2];
fast::forward_fast_traced(&weights, &input, &mut pre_acts, &mut traced_out, &config)
.unwrap();
let pattern = ActivationPattern::from_pre_activations(&pre_acts, 2, 2);
let a_mat = region_linear_map(&weights, &pattern, &config).unwrap();
let mut linear_out = [0.0_f32; 2];
for o in 0..2 {
for t in 0..2 {
#[allow(clippy::indexing_slicing)]
{
linear_out[o] += a_mat[o * 2 + t] * input[t];
}
}
}
for (a, b) in linear_out.iter().zip(&traced_out) {
assert!((a - b).abs() < 1e-5, "linear={a}, forward={b}");
}
}
}