use std::f32;
#[derive(Clone)]
pub struct FastGRNN {
input_dim: usize,
hidden_dim: usize,
w_gate: Vec<f32>,
u_gate: Vec<f32>,
w_update: Vec<f32>,
u_update: Vec<f32>,
bias_gate: Vec<f32>,
bias_update: Vec<f32>,
zeta: f32,
nu: f32,
}
impl FastGRNN {
pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
let scale = (2.0 / (input_dim + hidden_dim) as f32).sqrt();
Self {
input_dim,
hidden_dim,
w_gate: vec![0.1 * scale; input_dim * hidden_dim],
u_gate: vec![0.1 * scale; hidden_dim * hidden_dim],
w_update: vec![0.1 * scale; input_dim * hidden_dim],
u_update: vec![0.1 * scale; hidden_dim * hidden_dim],
bias_gate: vec![0.0; hidden_dim],
bias_update: vec![0.0; hidden_dim],
zeta: 1.0,
nu: 1.0,
}
}
pub fn from_weights(
input_dim: usize,
hidden_dim: usize,
w_gate: Vec<f32>,
u_gate: Vec<f32>,
w_update: Vec<f32>,
u_update: Vec<f32>,
bias_gate: Vec<f32>,
bias_update: Vec<f32>,
zeta: f32,
nu: f32,
) -> Self {
Self {
input_dim,
hidden_dim,
w_gate,
u_gate,
w_update,
u_update,
bias_gate,
bias_update,
zeta,
nu,
}
}
pub fn step(&self, input: &[f32], hidden: &[f32]) -> Vec<f32> {
assert_eq!(input.len(), self.input_dim, "Input dimension mismatch");
assert_eq!(hidden.len(), self.hidden_dim, "Hidden dimension mismatch");
let mut new_hidden = vec![0.0; self.hidden_dim];
let mut gate = vec![0.0; self.hidden_dim];
self.matmul_add(&self.w_gate, input, &mut gate);
self.matmul_add(&self.u_gate, hidden, &mut gate);
for i in 0..self.hidden_dim {
gate[i] = self.sigmoid(gate[i] + self.bias_gate[i]);
}
let mut update = vec![0.0; self.hidden_dim];
self.matmul_add(&self.w_update, input, &mut update);
self.matmul_add(&self.u_update, hidden, &mut update);
for i in 0..self.hidden_dim {
update[i] = self.tanh(update[i] + self.bias_update[i]);
}
for i in 0..self.hidden_dim {
let gate_factor = self.zeta * gate[i] + self.nu;
let gate_factor = gate_factor.min(1.0).max(0.0); new_hidden[i] = gate_factor * hidden[i] + (1.0 - gate_factor) * update[i];
}
new_hidden
}
pub fn forward_single(&self, input: &[f32]) -> Vec<f32> {
let hidden = vec![0.0; self.hidden_dim];
self.step(input, &hidden)
}
pub fn forward_sequence(&self, inputs: &[Vec<f32>]) -> Vec<Vec<f32>> {
let mut hidden = vec![0.0; self.hidden_dim];
let mut outputs = Vec::with_capacity(inputs.len());
for input in inputs {
hidden = self.step(input, &hidden);
outputs.push(hidden.clone());
}
outputs
}
fn matmul_add(&self, weights: &[f32], input: &[f32], result: &mut [f32]) {
let rows = result.len();
let cols = input.len();
for i in 0..rows {
for j in 0..cols {
result[i] += weights[i * cols + j] * input[j];
}
}
}
fn sigmoid(&self, x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
fn tanh(&self, x: f32) -> f32 {
x.tanh()
}
pub fn input_dim(&self) -> usize {
self.input_dim
}
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fastgrnn_creation() {
let grnn = FastGRNN::new(10, 5);
assert_eq!(grnn.input_dim(), 10);
assert_eq!(grnn.hidden_dim(), 5);
}
#[test]
fn test_fastgrnn_step() {
let grnn = FastGRNN::new(4, 3);
let input = vec![1.0, 0.5, -0.5, 0.0];
let hidden = vec![0.1, 0.2, 0.3];
let new_hidden = grnn.step(&input, &hidden);
assert_eq!(new_hidden.len(), 3);
for &h in &new_hidden {
assert!(h.abs() <= 2.0, "Hidden state should be bounded");
}
}
#[test]
fn test_fastgrnn_forward_single() {
let grnn = FastGRNN::new(4, 3);
let input = vec![1.0, 0.5, -0.5, 0.0];
let output = grnn.forward_single(&input);
assert_eq!(output.len(), 3);
}
#[test]
fn test_fastgrnn_sequence() {
let grnn = FastGRNN::new(4, 3);
let inputs = vec![
vec![1.0, 0.5, -0.5, 0.0],
vec![0.5, 1.0, 0.0, -0.5],
vec![-0.5, 0.0, 1.0, 0.5],
];
let outputs = grnn.forward_sequence(&inputs);
assert_eq!(outputs.len(), 3);
assert_eq!(outputs[0].len(), 3);
}
#[test]
fn test_sigmoid() {
let grnn = FastGRNN::new(1, 1);
assert!((grnn.sigmoid(0.0) - 0.5).abs() < 1e-6);
assert!(grnn.sigmoid(10.0) > 0.99);
assert!(grnn.sigmoid(-10.0) < 0.01);
}
#[test]
fn test_tanh() {
let grnn = FastGRNN::new(1, 1);
assert!(grnn.tanh(0.0).abs() < 1e-6);
assert!(grnn.tanh(10.0) > 0.99);
assert!(grnn.tanh(-10.0) < -0.99);
}
#[test]
#[should_panic(expected = "Input dimension mismatch")]
fn test_wrong_input_dimension() {
let grnn = FastGRNN::new(4, 3);
let input = vec![1.0, 0.5]; let hidden = vec![0.1, 0.2, 0.3];
grnn.step(&input, &hidden);
}
#[test]
#[should_panic(expected = "Hidden dimension mismatch")]
fn test_wrong_hidden_dimension() {
let grnn = FastGRNN::new(4, 3);
let input = vec![1.0, 0.5, -0.5, 0.0];
let hidden = vec![0.1, 0.2]; grnn.step(&input, &hidden);
}
}