use crate::error::{MIError, Result};
use crate::stoicheia::StoicheiaRnn;
use crate::stoicheia::config::StoicheiaConfig;
const MAX_H: usize = 32;
#[derive(Debug, Clone)]
#[allow(clippy::similar_names)]
pub struct RnnWeights {
pub weight_ih: Vec<f32>,
pub weight_hh: Vec<f32>,
pub weight_oh: Vec<f32>,
pub hidden_size: usize,
pub output_size: usize,
}
impl RnnWeights {
#[allow(clippy::similar_names)]
pub fn from_model(model: &StoicheiaRnn) -> Result<Self> {
let config = model.config();
let weight_ih: Vec<f32> = model.weight_ih().flatten_all()?.to_vec1()?;
let weight_hh: Vec<f32> = model.weight_hh().flatten_all()?.to_vec1()?;
let weight_oh: Vec<f32> = model.weight_oh().flatten_all()?.to_vec1()?;
Ok(Self {
weight_ih,
weight_hh,
weight_oh,
hidden_size: config.hidden_size,
output_size: config.output_size(),
})
}
#[must_use]
#[allow(clippy::similar_names)]
pub fn new(
weight_ih: Vec<f32>,
weight_hh: Vec<f32>,
weight_oh: Vec<f32>,
hidden_size: usize,
output_size: usize,
) -> Self {
assert!(
hidden_size <= MAX_H,
"hidden_size {hidden_size} exceeds MAX_H {MAX_H}"
);
assert_eq!(weight_ih.len(), hidden_size);
assert_eq!(weight_hh.len(), hidden_size * hidden_size);
assert_eq!(weight_oh.len(), output_size * hidden_size);
Self {
weight_ih,
weight_hh,
weight_oh,
hidden_size,
output_size,
}
}
}
pub fn forward_fast(
weights: &RnnWeights,
inputs: &[f32],
outputs: &mut [f32],
n_inputs: usize,
config: &StoicheiaConfig,
) -> Result<()> {
validate_fast_args(weights, inputs, outputs, n_inputs, config)?;
let h = weights.hidden_size;
let seq_len = config.seq_len;
let out_size = weights.output_size;
for i in 0..n_inputs {
let mut hidden = [0.0_f32; MAX_H];
let mut pre_act = [0.0_f32; MAX_H];
for t in 0..seq_len {
#[allow(clippy::indexing_slicing)]
let x_t = inputs[i * seq_len + t];
rnn_cell(weights, x_t, &hidden, &mut pre_act, h);
for j in 0..h {
#[allow(clippy::indexing_slicing)]
{
hidden[j] = pre_act[j].max(0.0);
}
}
}
output_projection(weights, &hidden, outputs, i, h, out_size);
}
Ok(())
}
pub fn forward_fast_ablated(
weights: &RnnWeights,
inputs: &[f32],
outputs: &mut [f32],
n_inputs: usize,
config: &StoicheiaConfig,
ablated: &[bool],
) -> Result<()> {
validate_fast_args(weights, inputs, outputs, n_inputs, config)?;
if ablated.len() != weights.hidden_size {
return Err(MIError::Config(format!(
"ablated length {} != hidden_size {}",
ablated.len(),
weights.hidden_size
)));
}
let h = weights.hidden_size;
let seq_len = config.seq_len;
let out_size = weights.output_size;
for i in 0..n_inputs {
let mut hidden = [0.0_f32; MAX_H];
let mut pre_act = [0.0_f32; MAX_H];
for t in 0..seq_len {
#[allow(clippy::indexing_slicing)]
let x_t = inputs[i * seq_len + t];
rnn_cell(weights, x_t, &hidden, &mut pre_act, h);
for j in 0..h {
#[allow(clippy::indexing_slicing)]
{
hidden[j] = if ablated[j] { 0.0 } else { pre_act[j].max(0.0) };
}
}
}
output_projection(weights, &hidden, outputs, i, h, out_size);
}
Ok(())
}
#[allow(clippy::needless_range_loop)]
pub fn forward_fast_traced(
weights: &RnnWeights,
input: &[f32],
pre_activations: &mut [f32],
output: &mut [f32],
config: &StoicheiaConfig,
) -> Result<()> {
let h = weights.hidden_size;
let seq_len = config.seq_len;
let out_size = weights.output_size;
if h > MAX_H {
return Err(MIError::Config(format!(
"hidden_size {h} exceeds MAX_H {MAX_H}"
)));
}
if input.len() != seq_len {
return Err(MIError::Config(format!(
"input length {} != seq_len {seq_len}",
input.len()
)));
}
if pre_activations.len() != seq_len * h {
return Err(MIError::Config(format!(
"pre_activations length {} != seq_len * hidden_size {}",
pre_activations.len(),
seq_len * h
)));
}
if output.len() != out_size {
return Err(MIError::Config(format!(
"output length {} != output_size {out_size}",
output.len()
)));
}
let mut hidden = [0.0_f32; MAX_H];
for t in 0..seq_len {
#[allow(clippy::indexing_slicing)]
let x_t = input[t];
for j in 0..h {
#[allow(clippy::indexing_slicing)]
{
let mut acc = x_t * weights.weight_ih[j];
for k in 0..h {
acc = weights.weight_hh[j * h + k].mul_add(hidden[k], acc);
}
pre_activations[t * h + j] = acc;
}
}
for j in 0..h {
#[allow(clippy::indexing_slicing)]
{
hidden[j] = pre_activations[t * h + j].max(0.0);
}
}
}
for o in 0..out_size {
let mut acc = 0.0_f32;
for j in 0..h {
#[allow(clippy::indexing_slicing)]
{
acc = weights.weight_oh[o * h + j].mul_add(hidden[j], acc);
}
}
#[allow(clippy::indexing_slicing)]
{
output[o] = acc;
}
}
Ok(())
}
pub fn accuracy(
weights: &RnnWeights,
inputs: &[f32],
targets: &[u32],
n_inputs: usize,
config: &StoicheiaConfig,
) -> Result<f32> {
if targets.len() != n_inputs {
return Err(MIError::Config(format!(
"targets length {} != n_inputs {n_inputs}",
targets.len()
)));
}
let out_size = weights.output_size;
let mut outputs = vec![0.0_f32; n_inputs * out_size];
forward_fast(weights, inputs, &mut outputs, n_inputs, config)?;
let mut correct = 0_usize;
for (i, target) in targets.iter().enumerate() {
#[allow(clippy::indexing_slicing)]
let row = &outputs[i * out_size..(i + 1) * out_size];
let pred = argmax_f32(row);
if *target == pred {
correct += 1;
}
}
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let acc = correct as f32 / n_inputs as f32;
Ok(acc)
}
#[inline]
#[allow(clippy::needless_range_loop)]
fn rnn_cell(
weights: &RnnWeights,
x_t: f32,
hidden: &[f32; MAX_H],
pre_act: &mut [f32; MAX_H],
h: usize,
) {
for j in 0..h {
#[allow(clippy::indexing_slicing)]
{
let mut acc = x_t * weights.weight_ih[j];
for k in 0..h {
acc = weights.weight_hh[j * h + k].mul_add(hidden[k], acc);
}
pre_act[j] = acc;
}
}
}
#[inline]
#[allow(clippy::needless_range_loop)]
fn output_projection(
weights: &RnnWeights,
hidden: &[f32; MAX_H],
outputs: &mut [f32],
i: usize,
h: usize,
out_size: usize,
) {
for o in 0..out_size {
let mut acc = 0.0_f32;
for j in 0..h {
#[allow(clippy::indexing_slicing)]
{
acc = weights.weight_oh[o * h + j].mul_add(hidden[j], acc);
}
}
#[allow(clippy::indexing_slicing)]
{
outputs[i * out_size + o] = acc;
}
}
}
#[must_use]
pub fn argmax_f32(slice: &[f32]) -> u32 {
debug_assert!(
!slice.iter().any(|v| v.is_nan()),
"argmax_f32: input contains NaN"
);
let mut best_idx = 0_u32;
let mut best_val = f32::NEG_INFINITY;
for (i, &v) in slice.iter().enumerate() {
if v > best_val {
best_val = v;
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
{
best_idx = i as u32;
}
}
}
best_idx
}
fn validate_fast_args(
weights: &RnnWeights,
inputs: &[f32],
outputs: &[f32],
n_inputs: usize,
config: &StoicheiaConfig,
) -> Result<()> {
let h = weights.hidden_size;
let seq_len = config.seq_len;
let out_size = weights.output_size;
if h > MAX_H {
return Err(MIError::Config(format!(
"hidden_size {h} exceeds MAX_H {MAX_H}"
)));
}
if inputs.len() != n_inputs * seq_len {
return Err(MIError::Config(format!(
"inputs length {} != n_inputs * seq_len {}",
inputs.len(),
n_inputs * seq_len
)));
}
if outputs.len() != n_inputs * out_size {
return Err(MIError::Config(format!(
"outputs length {} != n_inputs * output_size {}",
outputs.len(),
n_inputs * out_size
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn test_weights_2_2() -> RnnWeights {
RnnWeights::new(
vec![1.0, -1.0], vec![0.0, 0.0, 0.0, 0.0], vec![1.0, -1.0, -1.0, 1.0], 2,
2,
)
}
fn test_config_2_2() -> StoicheiaConfig {
StoicheiaConfig::from_task(crate::stoicheia::config::StoicheiaTask::SecondArgmax, 2, 2)
}
#[test]
fn forward_fast_shape_and_output() {
let weights = test_weights_2_2();
let config = test_config_2_2();
let inputs = vec![0.5_f32, -0.3];
let mut outputs = vec![0.0_f32; 2];
forward_fast(&weights, &inputs, &mut outputs, 1, &config).unwrap();
assert!((outputs[0] - (-0.3)).abs() < 1e-6);
assert!((outputs[1] - 0.3).abs() < 1e-6);
}
#[test]
fn forward_fast_batch() {
let weights = test_weights_2_2();
let config = test_config_2_2();
let inputs = vec![0.5, -0.3, 1.0, 2.0];
let mut outputs = vec![0.0_f32; 4];
forward_fast(&weights, &inputs, &mut outputs, 2, &config).unwrap();
assert!((outputs[0] - (-0.3)).abs() < 1e-6);
assert!((outputs[1] - 0.3).abs() < 1e-6);
}
#[test]
fn forward_fast_traced_matches_forward() {
let weights = test_weights_2_2();
let config = test_config_2_2();
let input = vec![0.5_f32, -0.3];
let mut pre_acts = vec![0.0_f32; 4]; let mut traced_out = vec![0.0_f32; 2];
forward_fast_traced(&weights, &input, &mut pre_acts, &mut traced_out, &config).unwrap();
let mut fast_out = vec![0.0_f32; 2];
forward_fast(&weights, &input, &mut fast_out, 1, &config).unwrap();
for (a, b) in traced_out.iter().zip(&fast_out) {
assert!((a - b).abs() < 1e-6, "traced={a}, fast={b}");
}
assert!((pre_acts[0] - 0.5).abs() < 1e-6);
assert!((pre_acts[1] - (-0.5)).abs() < 1e-6);
}
#[test]
fn ablated_all_zeros_gives_zero_output() {
let weights = test_weights_2_2();
let config = test_config_2_2();
let inputs = vec![0.5_f32, -0.3];
let mut outputs = vec![0.0_f32; 2];
let ablated = vec![true, true];
forward_fast_ablated(&weights, &inputs, &mut outputs, 1, &config, &ablated).unwrap();
assert!((outputs[0]).abs() < 1e-6);
assert!((outputs[1]).abs() < 1e-6);
}
#[test]
fn ablated_none_matches_normal() {
let weights = test_weights_2_2();
let config = test_config_2_2();
let inputs = vec![0.5_f32, -0.3];
let mut normal_out = vec![0.0_f32; 2];
let mut ablated_out = vec![0.0_f32; 2];
let ablated = vec![false, false];
forward_fast(&weights, &inputs, &mut normal_out, 1, &config).unwrap();
forward_fast_ablated(&weights, &inputs, &mut ablated_out, 1, &config, &ablated).unwrap();
for (a, b) in normal_out.iter().zip(&ablated_out) {
assert!((a - b).abs() < 1e-6);
}
}
#[test]
fn accuracy_perfect() {
let weights = test_weights_2_2();
let config = test_config_2_2();
let inputs = vec![0.5_f32, -0.3];
let targets = vec![1_u32];
let acc = accuracy(&weights, &inputs, &targets, 1, &config).unwrap();
assert!((acc - 1.0).abs() < 1e-6);
}
#[test]
fn accuracy_wrong() {
let weights = test_weights_2_2();
let config = test_config_2_2();
let inputs = vec![0.5_f32, -0.3];
let targets = vec![0_u32];
let acc = accuracy(&weights, &inputs, &targets, 1, &config).unwrap();
assert!(acc.abs() < 1e-6);
}
#[test]
fn validation_rejects_wrong_input_length() {
let weights = test_weights_2_2();
let config = test_config_2_2();
let inputs = vec![0.5_f32]; let mut outputs = vec![0.0_f32; 2];
let result = forward_fast(&weights, &inputs, &mut outputs, 1, &config);
assert!(result.is_err());
}
}