use crate::autograd::Variable;
use crate::tensor::Tensor;
use num_traits::Float;
use rand_distr::{Distribution, Normal};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct RecurrentConfig {
pub input_size: usize,
pub hidden_size: usize,
pub num_gates: usize,
pub bias: bool,
pub training: bool,
}
impl RecurrentConfig {
pub fn rnn(input_size: usize, hidden_size: usize, bias: bool) -> Self {
Self {
input_size,
hidden_size,
num_gates: 1,
bias,
training: true,
}
}
pub fn gru(input_size: usize, hidden_size: usize, bias: bool) -> Self {
Self {
input_size,
hidden_size,
num_gates: 3,
bias,
training: true,
}
}
pub fn lstm(input_size: usize, hidden_size: usize, bias: bool) -> Self {
Self {
input_size,
hidden_size,
num_gates: 4,
bias,
training: true,
}
}
}
pub trait RecurrentCell<T: Float + Send + Sync + Debug + 'static> {
fn input_size(&self) -> usize;
fn hidden_size(&self) -> usize;
fn set_training(&mut self, training: bool);
fn is_training(&self) -> bool;
fn config(&self) -> &RecurrentConfig;
}
pub struct RecurrentOps;
impl RecurrentOps {
pub fn init_weights<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
input_size: usize,
hidden_size: usize,
num_gates: usize,
) -> (Variable<T>, Variable<T>) {
let mut rng = rand::thread_rng();
let normal = Normal::new(0.0, 0.1).unwrap();
let weight_ih_data: Vec<T> = (0..num_gates * hidden_size * input_size)
.map(|_| num_traits::cast(normal.sample(&mut rng) as f64).unwrap_or(T::zero()))
.collect();
let weight_ih = Variable::new(
Tensor::from_vec(weight_ih_data, vec![num_gates * hidden_size, input_size]),
true,
);
let weight_hh_data: Vec<T> = (0..num_gates * hidden_size * hidden_size)
.map(|_| num_traits::cast(normal.sample(&mut rng) as f64).unwrap_or(T::zero()))
.collect();
let weight_hh = Variable::new(
Tensor::from_vec(weight_hh_data, vec![num_gates * hidden_size, hidden_size]),
true,
);
(weight_ih, weight_hh)
}
pub fn init_bias<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
hidden_size: usize,
num_gates: usize,
) -> (Option<Variable<T>>, Option<Variable<T>>) {
let mut rng = rand::thread_rng();
let normal = Normal::new(0.0, 0.1).unwrap();
let bias_ih_data: Vec<T> = (0..num_gates * hidden_size)
.map(|_| num_traits::cast(normal.sample(&mut rng) as f64).unwrap_or(T::zero()))
.collect();
let bias_ih = Some(Variable::new(
Tensor::from_vec(bias_ih_data, vec![num_gates * hidden_size]),
true,
));
let bias_hh_data: Vec<T> = (0..num_gates * hidden_size)
.map(|_| num_traits::cast(normal.sample(&mut rng) as f64).unwrap_or(T::zero()))
.collect();
let bias_hh = Some(Variable::new(
Tensor::from_vec(bias_hh_data, vec![num_gates * hidden_size]),
true,
));
(bias_ih, bias_hh)
}
pub fn linear_transform<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
input: &Variable<T>,
weight: &Variable<T>,
bias: Option<&Variable<T>>,
) -> Variable<T> {
let output = Self::matmul_variables(input, &Self::transpose_variable(weight));
match bias {
Some(b) => Self::add_variables(&output, b),
None => output,
}
}
pub fn matmul_variables<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
a: &Variable<T>,
b: &Variable<T>,
) -> Variable<T> {
a.matmul(b)
}
pub fn add_variables<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
a: &Variable<T>,
b: &Variable<T>,
) -> Variable<T> {
a + b
}
pub fn multiply_variables<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
a: &Variable<T>,
b: &Variable<T>,
) -> Variable<T> {
a * b
}
pub fn subtract_from_scalar<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
var: &Variable<T>,
scalar: T,
) -> Variable<T> {
let var_binding = var.data();
let var_data = var_binding.read().unwrap();
let result_data = var_data.map(|x| scalar - x);
Variable::new(result_data, var.requires_grad())
}
pub fn transpose_variable<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
var: &Variable<T>,
) -> Variable<T> {
let var_binding = var.data();
let var_data = var_binding.read().unwrap();
let transposed_data = var_data.transpose().unwrap();
Variable::new(transposed_data, var.requires_grad())
}
pub fn sigmoid<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
var: &Variable<T>,
) -> Variable<T> {
let var_binding = var.data();
let var_data = var_binding.read().unwrap();
let sigmoid_data = var_data.map(|x| T::one() / (T::one() + (-x).exp()));
Variable::new(sigmoid_data, var.requires_grad())
}
pub fn tanh<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
var: &Variable<T>,
) -> Variable<T> {
let var_binding = var.data();
let var_data = var_binding.read().unwrap();
let tanh_data = var_data.map(|x| x.tanh());
Variable::new(tanh_data, var.requires_grad())
}
pub fn slice_gates<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
gates: &Variable<T>,
gate_idx: usize,
hidden_size: usize,
) -> Variable<T> {
let start_idx = gate_idx * hidden_size;
let end_idx = (gate_idx + 1) * hidden_size;
let gates_binding = gates.data();
let gates_data = gates_binding.read().unwrap();
let gate_data: Vec<T> = gates_data.as_slice().unwrap()[start_idx..end_idx].to_vec();
Variable::new(
Tensor::from_vec(gate_data, vec![gates_data.shape()[0], hidden_size]),
gates.requires_grad(),
)
}
pub fn zero_hidden_state<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
batch_size: usize,
hidden_size: usize,
) -> Variable<T> {
Variable::new(Tensor::zeros(&[batch_size, hidden_size]), false)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrainingMode {
Train,
Eval,
}
impl From<bool> for TrainingMode {
fn from(training: bool) -> Self {
if training {
TrainingMode::Train
} else {
TrainingMode::Eval
}
}
}
impl From<TrainingMode> for bool {
fn from(mode: TrainingMode) -> Self {
matches!(mode, TrainingMode::Train)
}
}
pub fn collect_recurrent_parameters<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
weight_ih: &Variable<T>,
weight_hh: &Variable<T>,
bias_ih: &Option<Variable<T>>,
bias_hh: &Option<Variable<T>>,
) -> Vec<Variable<T>> {
let mut params = vec![weight_ih.clone(), weight_hh.clone()];
if let Some(ref bias) = bias_ih {
params.push(bias.clone());
}
if let Some(ref bias) = bias_hh {
params.push(bias.clone());
}
params
}
pub struct MultiLayerUtils;
impl MultiLayerUtils {
pub fn get_timestep_input<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
input: &Variable<T>,
timestep: usize,
) -> Variable<T> {
let input_binding = input.data();
let input_data = input_binding.read().unwrap();
let batch_size = input_data.shape()[0];
let feature_size = input_data.shape()[2];
let timestep_data: Vec<T> = (0..batch_size * feature_size)
.map(|i| {
let batch_idx = i / feature_size;
let feat_idx = i % feature_size;
input_data.as_slice().unwrap()[batch_idx * input_data.shape()[1] * feature_size
+ timestep * feature_size
+ feat_idx]
})
.collect();
Variable::new(
Tensor::from_vec(timestep_data, vec![batch_size, feature_size]),
input.requires_grad(),
)
}
pub fn stack_outputs<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
outputs: &[Variable<T>],
) -> Variable<T> {
let output_binding = outputs[0].data();
let output_data = output_binding.read().unwrap();
let batch_size = output_data.shape()[0];
let hidden_size = output_data.shape()[1];
let seq_len = outputs.len();
let mut stacked_data = Vec::new();
for batch_idx in 0..batch_size {
for t in 0..seq_len {
let output_binding = outputs[t].data();
let output_data = output_binding.read().unwrap();
let output_slice = output_data.as_slice().unwrap();
let start_idx = batch_idx * hidden_size;
let end_idx = start_idx + hidden_size;
stacked_data.extend_from_slice(&output_slice[start_idx..end_idx]);
}
}
Variable::new(
Tensor::from_vec(stacked_data, vec![batch_size, seq_len, hidden_size]),
outputs[0].requires_grad(),
)
}
pub fn stack_hidden_states<
T: Float + Send + Sync + Debug + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
>(
states: &[Variable<T>],
num_layers: usize,
) -> Variable<T> {
let state_binding = states[0].data();
let state_data = state_binding.read().unwrap();
let batch_size = state_data.shape()[0];
let hidden_size = state_data.shape()[1];
let mut stacked_data = Vec::new();
for state in states {
let state_binding = state.data();
let state_data = state_binding.read().unwrap();
stacked_data.extend_from_slice(state_data.as_slice().unwrap());
}
Variable::new(
Tensor::from_vec(stacked_data, vec![num_layers, batch_size, hidden_size]),
states[0].requires_grad(),
)
}
}