use ndarray::Array2;
use crate::layers::gru_cell::{GRUCell, GRUCellGradients, GRUCellCache};
use crate::optimizers::Optimizer;
#[derive(Clone)]
pub struct GRUNetworkCache {
pub caches: Vec<GRUCellCache>,
}
#[derive(Clone)]
pub struct LayerDropoutConfig {
pub input_dropout_rate: f64,
pub input_variational: bool,
pub recurrent_dropout_rate: f64,
pub recurrent_variational: bool,
pub output_dropout_rate: f64,
}
impl LayerDropoutConfig {
pub fn new() -> Self {
LayerDropoutConfig {
input_dropout_rate: 0.0,
input_variational: false,
recurrent_dropout_rate: 0.0,
recurrent_variational: false,
output_dropout_rate: 0.0,
}
}
pub fn with_input_dropout(mut self, rate: f64, variational: bool) -> Self {
self.input_dropout_rate = rate;
self.input_variational = variational;
self
}
pub fn with_recurrent_dropout(mut self, rate: f64, variational: bool) -> Self {
self.recurrent_dropout_rate = rate;
self.recurrent_variational = variational;
self
}
pub fn with_output_dropout(mut self, rate: f64) -> Self {
self.output_dropout_rate = rate;
self
}
}
#[derive(Clone)]
pub struct GRUNetwork {
cells: Vec<GRUCell>,
pub input_size: usize,
pub hidden_size: usize,
pub num_layers: usize,
pub is_training: bool,
}
impl GRUNetwork {
pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
let mut cells = Vec::new();
for i in 0..num_layers {
let layer_input_size = if i == 0 { input_size } else { hidden_size };
cells.push(GRUCell::new(layer_input_size, hidden_size));
}
GRUNetwork {
cells,
input_size,
hidden_size,
num_layers,
is_training: true,
}
}
pub fn with_input_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
for cell in &mut self.cells {
*cell = cell.clone().with_input_dropout(dropout_rate, variational);
}
self
}
pub fn with_recurrent_dropout(mut self, dropout_rate: f64, variational: bool) -> Self {
for cell in &mut self.cells {
*cell = cell.clone().with_recurrent_dropout(dropout_rate, variational);
}
self
}
pub fn with_output_dropout(mut self, dropout_rate: f64) -> Self {
for (i, cell) in self.cells.iter_mut().enumerate() {
if i < self.num_layers - 1 {
*cell = cell.clone().with_output_dropout(dropout_rate);
}
}
self
}
pub fn with_layer_dropout(mut self, configs: Vec<LayerDropoutConfig>) -> Self {
if configs.len() != self.num_layers {
panic!("Number of dropout configs must match number of layers");
}
for (i, config) in configs.into_iter().enumerate() {
if config.input_dropout_rate > 0.0 {
self.cells[i] = self.cells[i].clone()
.with_input_dropout(config.input_dropout_rate, config.input_variational);
}
if config.recurrent_dropout_rate > 0.0 {
self.cells[i] = self.cells[i].clone()
.with_recurrent_dropout(config.recurrent_dropout_rate, config.recurrent_variational);
}
if config.output_dropout_rate > 0.0 && i < self.num_layers - 1 {
self.cells[i] = self.cells[i].clone()
.with_output_dropout(config.output_dropout_rate);
}
}
self
}
pub fn train(&mut self) {
self.is_training = true;
for cell in &mut self.cells {
cell.train();
}
}
pub fn eval(&mut self) {
self.is_training = false;
for cell in &mut self.cells {
cell.eval();
}
}
pub fn forward(&mut self, input: &Array2<f64>, hx: &[Array2<f64>]) -> Vec<Array2<f64>> {
if hx.len() != self.num_layers {
panic!("Number of hidden states must match number of layers");
}
let mut layer_input = input.clone();
let mut outputs = Vec::new();
for (i, cell) in self.cells.iter_mut().enumerate() {
let hy = cell.forward(&layer_input, &hx[i]);
outputs.push(hy.clone());
layer_input = hy;
}
outputs
}
pub fn forward_sequence_with_cache(&mut self, sequence: &[Array2<f64>]) -> (Vec<(Array2<f64>, Vec<Array2<f64>>)>, Vec<GRUNetworkCache>) {
let mut all_outputs = Vec::new();
let mut all_caches = Vec::new();
let mut hidden_states: Vec<Array2<f64>> = (0..self.num_layers)
.map(|_| Array2::zeros((self.hidden_size, 1)))
.collect();
for input in sequence {
let mut layer_input = input.clone();
let mut step_outputs = Vec::new();
let mut step_caches = Vec::new();
for (i, cell) in self.cells.iter_mut().enumerate() {
let (hy, cache) = cell.forward_with_cache(&layer_input, &hidden_states[i]);
hidden_states[i] = hy.clone();
step_outputs.push(hy.clone());
step_caches.push(cache);
layer_input = hy;
}
let final_output = step_outputs.last().unwrap().clone();
all_outputs.push((final_output, step_outputs));
all_caches.push(GRUNetworkCache { caches: step_caches });
}
(all_outputs, all_caches)
}
pub fn backward(&self, dhy: &Array2<f64>, cache: &GRUNetworkCache) -> (Vec<GRUCellGradients>, Array2<f64>) {
let mut gradients = Vec::new();
let mut dhx = dhy.clone();
for (i, cell) in self.cells.iter().enumerate().rev() {
let (cell_gradients, _, dhx_prev) = cell.backward(&dhx, &cache.caches[i]);
gradients.insert(0, cell_gradients);
dhx = dhx_prev;
}
(gradients, dhx)
}
pub fn update_parameters<O: Optimizer>(&mut self, gradients: &[GRUCellGradients], optimizer: &mut O) {
for (i, (cell, grad)) in self.cells.iter_mut().zip(gradients.iter()).enumerate() {
cell.update_parameters(grad, optimizer, &format!("layer_{}", i));
}
}
pub fn zero_gradients(&self) -> Vec<GRUCellGradients> {
self.cells.iter().map(|cell| cell.zero_gradients()).collect()
}
pub fn get_cells(&self) -> &[GRUCell] {
&self.cells
}
pub fn get_cells_mut(&mut self) -> &mut [GRUCell] {
&mut self.cells
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::arr2;
#[test]
fn test_gru_network_creation() {
let network = GRUNetwork::new(3, 5, 2);
assert_eq!(network.input_size, 3);
assert_eq!(network.hidden_size, 5);
assert_eq!(network.num_layers, 2);
assert_eq!(network.cells.len(), 2);
}
#[test]
fn test_gru_network_forward() {
let mut network = GRUNetwork::new(2, 3, 2);
let input = arr2(&[[1.0], [0.5]]);
let hidden_states = vec![
arr2(&[[0.1], [0.2], [0.3]]),
arr2(&[[0.0], [0.1], [0.2]]),
];
let outputs = network.forward(&input, &hidden_states);
assert_eq!(outputs.len(), 2);
assert_eq!(outputs[0].shape(), &[3, 1]);
assert_eq!(outputs[1].shape(), &[3, 1]);
}
#[test]
fn test_gru_network_sequence() {
let mut network = GRUNetwork::new(2, 3, 1);
let sequence = vec![
arr2(&[[1.0], [0.0]]),
arr2(&[[0.0], [1.0]]),
arr2(&[[-1.0], [0.5]]),
];
let (outputs, caches) = network.forward_sequence_with_cache(&sequence);
assert_eq!(outputs.len(), 3);
assert_eq!(caches.len(), 3);
for (output, _) in &outputs {
assert_eq!(output.shape(), &[3, 1]);
}
}
#[test]
fn test_gru_network_with_dropout() {
let mut network = GRUNetwork::new(2, 3, 2)
.with_input_dropout(0.2, true)
.with_recurrent_dropout(0.3, false)
.with_output_dropout(0.1);
let input = arr2(&[[1.0], [0.5]]);
let hidden_states = vec![
arr2(&[[0.1], [0.2], [0.3]]),
arr2(&[[0.0], [0.1], [0.2]]),
];
network.train();
let outputs_train = network.forward(&input, &hidden_states);
network.eval();
let outputs_eval = network.forward(&input, &hidden_states);
assert_eq!(outputs_train.len(), 2);
assert_eq!(outputs_eval.len(), 2);
}
#[test]
fn test_gru_network_layer_dropout() {
let layer_configs = vec![
LayerDropoutConfig::new().with_input_dropout(0.1, false),
LayerDropoutConfig::new().with_recurrent_dropout(0.2, true),
];
let network = GRUNetwork::new(2, 3, 2)
.with_layer_dropout(layer_configs);
assert_eq!(network.cells.len(), 2);
}
}