use crate::error::{MIError, Result};
use crate::stoicheia::StoicheiaRnn;
use crate::stoicheia::fast::RnnWeights;
#[derive(Debug, Clone)]
#[allow(clippy::similar_names)]
pub struct StandardizedRnn {
pub scales: Vec<f32>,
pub weight_ih: Vec<f32>,
pub weight_hh: Vec<f32>,
pub weight_oh: Vec<f32>,
pub hidden_size: usize,
pub output_size: usize,
}
impl StandardizedRnn {
#[must_use]
pub fn to_rnn_weights(&self) -> RnnWeights {
RnnWeights::new(
self.weight_ih.clone(),
self.weight_hh.clone(),
self.weight_oh.clone(),
self.hidden_size,
self.output_size,
)
}
}
pub fn standardize_rnn(model: &StoicheiaRnn) -> Result<StandardizedRnn> {
let weights = RnnWeights::from_model(model)?;
standardize_weights(&weights)
}
#[allow(clippy::needless_range_loop, clippy::similar_names)]
pub fn standardize_weights(weights: &RnnWeights) -> Result<StandardizedRnn> {
let h = weights.hidden_size;
let out_size = weights.output_size;
let mut scales = Vec::with_capacity(h);
for j in 0..h {
#[allow(clippy::indexing_slicing)]
let s = weights.weight_ih[j].abs();
if s < f32::EPSILON {
return Err(MIError::Config(format!(
"W_ih[{j}] magnitude {s} is below f32::EPSILON \
(degenerate neuron, standardization would amplify by 1/{s})"
)));
}
scales.push(s);
}
let mut std_ih = Vec::with_capacity(h);
for j in 0..h {
#[allow(clippy::indexing_slicing)]
{
std_ih.push(weights.weight_ih[j] / scales[j]);
}
}
let mut std_hh = vec![0.0_f32; h * h];
for j in 0..h {
for k in 0..h {
#[allow(clippy::indexing_slicing)]
{
std_hh[j * h + k] = weights.weight_hh[j * h + k] * scales[k] / scales[j];
}
}
}
let mut std_oh = vec![0.0_f32; out_size * h];
for o in 0..out_size {
for j in 0..h {
#[allow(clippy::indexing_slicing)]
{
std_oh[o * h + j] = weights.weight_oh[o * h + j] * scales[j];
}
}
}
Ok(StandardizedRnn {
scales,
weight_ih: std_ih,
weight_hh: std_hh,
weight_oh: std_oh,
hidden_size: h,
output_size: out_size,
})
}
#[must_use]
pub fn standardization_quality(std_rnn: &StandardizedRnn) -> f32 {
let mut max_dev = 0.0_f32;
for &w in &std_rnn.weight_ih {
let dev = (w.abs() - 1.0).abs();
if dev > max_dev {
max_dev = dev;
}
}
max_dev
}
#[cfg(test)]
mod tests {
use super::*;
fn test_weights() -> RnnWeights {
RnnWeights::new(
vec![2.0, -0.5], vec![1.0, 0.0, 0.0, 1.0], vec![1.0, -1.0, -1.0, 1.0], 2,
2,
)
}
#[test]
fn standardized_wih_near_one() {
let weights = test_weights();
let std_rnn = standardize_weights(&weights).unwrap();
assert!((std_rnn.weight_ih[0] - 1.0).abs() < 1e-6);
assert!((std_rnn.weight_ih[1] - (-1.0)).abs() < 1e-6);
}
#[test]
fn standardization_quality_is_zero() {
let weights = test_weights();
let std_rnn = standardize_weights(&weights).unwrap();
let quality = standardization_quality(&std_rnn);
assert!(quality < 1e-6, "quality = {quality}");
}
#[test]
fn standardize_preserves_output() {
let weights = test_weights();
let config = crate::stoicheia::config::StoicheiaConfig::from_task(
crate::stoicheia::config::StoicheiaTask::SecondArgmax,
2,
2,
);
let std_rnn = standardize_weights(&weights).unwrap();
let std_weights = std_rnn.to_rnn_weights();
let inputs = vec![0.5_f32, -0.3, 1.0, 2.0, -1.0, 0.7];
let n = 3;
let mut orig_out = vec![0.0_f32; n * 2];
let mut std_out = vec![0.0_f32; n * 2];
crate::stoicheia::fast::forward_fast(&weights, &inputs, &mut orig_out, n, &config).unwrap();
crate::stoicheia::fast::forward_fast(&std_weights, &inputs, &mut std_out, n, &config)
.unwrap();
for (a, b) in orig_out.iter().zip(&std_out) {
assert!((a - b).abs() < 1e-5, "output mismatch: orig={a}, std={b}");
}
}
#[test]
fn degenerate_neuron_errors() {
let weights = RnnWeights::new(
vec![0.0, 1.0], vec![0.0, 0.0, 0.0, 0.0],
vec![1.0, 0.0, 0.0, 1.0],
2,
2,
);
let result = standardize_weights(&weights);
assert!(result.is_err());
}
#[test]
fn to_rnn_weights_roundtrip() {
let weights = test_weights();
let std_rnn = standardize_weights(&weights).unwrap();
let rnn_weights = std_rnn.to_rnn_weights();
assert_eq!(rnn_weights.hidden_size, std_rnn.hidden_size);
assert_eq!(rnn_weights.output_size, std_rnn.output_size);
assert_eq!(rnn_weights.weight_ih, std_rnn.weight_ih);
}
}