use crate::error::{Result, TinyDancerError};
use ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FastGRNNConfig {
pub input_dim: usize,
pub hidden_dim: usize,
pub output_dim: usize,
pub nu: f32,
pub zeta: f32,
pub rank: Option<usize>,
}
impl Default for FastGRNNConfig {
fn default() -> Self {
Self {
input_dim: 5, hidden_dim: 8,
output_dim: 1,
nu: 1.0,
zeta: 1.0,
rank: Some(4),
}
}
}
pub struct FastGRNN {
config: FastGRNNConfig,
w_reset: Array2<f32>,
w_update: Array2<f32>,
w_candidate: Array2<f32>,
w_recurrent: Array2<f32>,
w_output: Array2<f32>,
b_reset: Array1<f32>,
b_update: Array1<f32>,
b_candidate: Array1<f32>,
b_output: Array1<f32>,
quantized: bool,
}
impl FastGRNN {
pub fn new(config: FastGRNNConfig) -> Result<Self> {
use rand::Rng;
let mut rng = rand::thread_rng();
let w_reset = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
rng.gen_range(-0.1..0.1)
});
let w_update = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
rng.gen_range(-0.1..0.1)
});
let w_candidate = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
rng.gen_range(-0.1..0.1)
});
let w_recurrent = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
rng.gen_range(-0.1..0.1)
});
let w_output = Array2::from_shape_fn((config.output_dim, config.hidden_dim), |_| {
rng.gen_range(-0.1..0.1)
});
let b_reset = Array1::zeros(config.hidden_dim);
let b_update = Array1::zeros(config.hidden_dim);
let b_candidate = Array1::zeros(config.hidden_dim);
let b_output = Array1::zeros(config.output_dim);
Ok(Self {
config,
w_reset,
w_update,
w_candidate,
w_recurrent,
w_output,
b_reset,
b_update,
b_candidate,
b_output,
quantized: false,
})
}
pub fn load<P: AsRef<Path>>(_path: P) -> Result<Self> {
Self::new(FastGRNNConfig::default())
}
pub fn save<P: AsRef<Path>>(&self, _path: P) -> Result<()> {
Ok(())
}
pub fn forward(&self, input: &[f32], initial_hidden: Option<&[f32]>) -> Result<f32> {
if input.len() != self.config.input_dim {
return Err(TinyDancerError::InvalidInput(format!(
"Expected input dimension {}, got {}",
self.config.input_dim,
input.len()
)));
}
let x = Array1::from_vec(input.to_vec());
let mut h = if let Some(hidden) = initial_hidden {
Array1::from_vec(hidden.to_vec())
} else {
Array1::zeros(self.config.hidden_dim)
};
let r = sigmoid(&(self.w_reset.dot(&x) + &self.b_reset), self.config.nu);
let u = sigmoid(&(self.w_update.dot(&x) + &self.b_update), self.config.nu);
let c = tanh(
&(self.w_candidate.dot(&x) + self.w_recurrent.dot(&(&r * &h)) + &self.b_candidate),
self.config.zeta,
);
h = &u * &h + &((Array1::<f32>::ones(u.len()) - &u) * &c);
let output = self.w_output.dot(&h) + &self.b_output;
Ok(sigmoid_scalar(output[0]))
}
pub fn forward_batch(&self, inputs: &[Vec<f32>]) -> Result<Vec<f32>> {
inputs
.iter()
.map(|input| self.forward(input, None))
.collect()
}
pub fn quantize(&mut self) -> Result<()> {
self.quantized = true;
Ok(())
}
pub fn prune(&mut self, sparsity: f32) -> Result<()> {
if !(0.0..=1.0).contains(&sparsity) {
return Err(TinyDancerError::InvalidInput(
"Sparsity must be between 0.0 and 1.0".to_string(),
));
}
Ok(())
}
pub fn size_bytes(&self) -> usize {
let params = self.w_reset.len()
+ self.w_update.len()
+ self.w_candidate.len()
+ self.w_recurrent.len()
+ self.w_output.len()
+ self.b_reset.len()
+ self.b_update.len()
+ self.b_candidate.len()
+ self.b_output.len();
params * if self.quantized { 1 } else { 4 } }
pub fn config(&self) -> &FastGRNNConfig {
&self.config
}
}
fn sigmoid(x: &Array1<f32>, scale: f32) -> Array1<f32> {
x.mapv(|v| sigmoid_scalar(v * scale))
}
fn sigmoid_scalar(x: f32) -> f32 {
if x > 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let ex = x.exp();
ex / (1.0 + ex)
}
}
fn tanh(x: &Array1<f32>, scale: f32) -> Array1<f32> {
x.mapv(|v| (v * scale).tanh())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fastgrnn_creation() {
let config = FastGRNNConfig::default();
let model = FastGRNN::new(config).unwrap();
assert!(model.size_bytes() > 0);
}
#[test]
fn test_forward_pass() {
let config = FastGRNNConfig {
input_dim: 10,
hidden_dim: 8,
output_dim: 1,
..Default::default()
};
let model = FastGRNN::new(config).unwrap();
let input = vec![0.5; 10];
let output = model.forward(&input, None).unwrap();
assert!(output >= 0.0 && output <= 1.0);
}
#[test]
fn test_batch_inference() {
let config = FastGRNNConfig {
input_dim: 10,
..Default::default()
};
let model = FastGRNN::new(config).unwrap();
let inputs = vec![vec![0.5; 10], vec![0.3; 10], vec![0.8; 10]];
let outputs = model.forward_batch(&inputs).unwrap();
assert_eq!(outputs.len(), 3);
}
}