use scirs2_core::ndarray::{
Array, Array1, Array2, ArrayView, ArrayView1, ArrayView2, Axis, Dimension,
};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::{NeuralError, Result};
type BatchNormForwardReturn<F> = (
Array2<F>,
Option<(Array2<F>, Array1<F>, Array1<F>, Array1<F>, F)>,
);
type LayerNormForwardReturn<F> = (Array2<F>, (Array2<F>, Array1<F>, Array1<F>, Array1<F>, F));
#[allow(clippy::too_many_arguments)]
pub fn batch_norm_forward<F>(
x: &ArrayView2<F>,
gamma: &ArrayView1<F>,
beta: &ArrayView1<F>,
eps: F,
momentum: F,
running_mean: &mut Array1<F>,
running_var: &mut Array1<F>,
training: bool,
) -> Result<BatchNormForwardReturn<F>>
where
F: Float + Debug + FromPrimitive,
{
let batch_size = x.shape()[0];
let num_features = x.shape()[1];
if gamma.shape()[0] != num_features || beta.shape()[0] != num_features {
return Err(NeuralError::ShapeMismatch(
format!("Parameter shape mismatch in batch_norm_forward: x shape {:?}, gamma shape {:?}, beta shape {:?}",
x.shape(), gamma.shape(), beta.shape())
));
}
if running_mean.shape()[0] != num_features || running_var.shape()[0] != num_features {
return Err(NeuralError::ShapeMismatch(
format!("Running stats shape mismatch in batch_norm_forward: x shape {:?}, running_mean shape {:?}, running_var shape {:?}",
x.shape(), running_mean.shape(), running_var.shape())
));
}
let mut out = Array2::<F>::zeros(x.raw_dim());
if training {
let batch_mean = x.mean_axis(Axis(0)).expect("Operation failed");
let mut batch_var = Array1::<F>::zeros(num_features);
for i in 0..batch_size {
for j in 0..num_features {
let diff = x[[i, j]] - batch_mean[j];
batch_var[j] = batch_var[j] + diff * diff;
}
}
batch_var.mapv_inplace(|v| v / F::from(batch_size).expect("Failed to convert to float"));
for j in 0..num_features {
running_mean[j] = momentum * running_mean[j] + (F::one() - momentum) * batch_mean[j];
running_var[j] = momentum * running_var[j] + (F::one() - momentum) * batch_var[j];
}
let mut x_hat = Array2::<F>::zeros(x.raw_dim());
for i in 0..batch_size {
for j in 0..num_features {
x_hat[[i, j]] = (x[[i, j]] - batch_mean[j]) / (batch_var[j] + eps).sqrt();
out[[i, j]] = gamma[j] * x_hat[[i, j]] + beta[j];
}
}
let cache = (x_hat, batch_mean, batch_var, gamma.to_owned(), eps);
Ok((out, Some(cache)))
} else {
for i in 0..batch_size {
for j in 0..num_features {
let x_hat = (x[[i, j]] - running_mean[j]) / (running_var[j] + eps).sqrt();
out[[i, j]] = gamma[j] * x_hat + beta[j];
}
}
Ok((out, None))
}
}
pub fn batch_norm_backward<F>(
dout: &ArrayView2<F>,
cache: &(Array2<F>, Array1<F>, Array1<F>, Array1<F>, F),
) -> Result<(Array2<F>, Array1<F>, Array1<F>)>
where
F: Float + Debug + FromPrimitive,
{
let (x_hat, _batch_mean, batch_var, gamma, eps) = cache;
let batch_size = dout.shape()[0];
let num_features = dout.shape()[1];
if x_hat.shape() != dout.shape() {
return Err(NeuralError::ShapeMismatch(format!(
"Shape mismatch in batch_norm_backward: dout shape {:?}, x_hat shape {:?}",
dout.shape(),
x_hat.shape()
)));
}
let mut dx = Array2::<F>::zeros(dout.raw_dim());
let mut dgamma = Array1::<F>::zeros(gamma.raw_dim());
let mut dbeta = Array1::<F>::zeros(gamma.raw_dim());
for j in 0..num_features {
for i in 0..batch_size {
dbeta[j] = dbeta[j] + dout[[i, j]];
}
}
for j in 0..num_features {
for i in 0..batch_size {
dgamma[j] = dgamma[j] + dout[[i, j]] * x_hat[[i, j]];
}
}
let mut dx_hat = Array2::<F>::zeros(dout.raw_dim());
for i in 0..batch_size {
for j in 0..num_features {
dx_hat[[i, j]] = dout[[i, j]] * gamma[j];
}
}
let batch_size_f = F::from(batch_size).expect("Failed to convert to float");
for j in 0..num_features {
let std_inv = F::one() / (batch_var[j] + *eps).sqrt();
let mut sum_dx_hat = F::zero();
let mut sum_dx_hat_x_hat = F::zero();
for i in 0..batch_size {
sum_dx_hat = sum_dx_hat + dx_hat[[i, j]];
sum_dx_hat_x_hat = sum_dx_hat_x_hat + dx_hat[[i, j]] * x_hat[[i, j]];
}
for i in 0..batch_size {
dx[[i, j]] =
dx_hat[[i, j]] - (sum_dx_hat + x_hat[[i, j]] * sum_dx_hat_x_hat) / batch_size_f;
dx[[i, j]] = dx[[i, j]] * std_inv;
}
}
Ok((dx, dgamma, dbeta))
}
pub fn layer_norm<F>(
x: &ArrayView2<F>,
gamma: &ArrayView1<F>,
beta: &ArrayView1<F>,
eps: F,
) -> Result<LayerNormForwardReturn<F>>
where
F: Float + Debug + FromPrimitive,
{
let batch_size = x.shape()[0];
let num_features = x.shape()[1];
if gamma.shape()[0] != num_features || beta.shape()[0] != num_features {
return Err(NeuralError::ShapeMismatch(
format!("Parameter shape mismatch in layer_norm: x shape {:?}, gamma shape {:?}, beta shape {:?}",
x.shape(), gamma.shape(), beta.shape())
));
}
let mut out = Array2::<F>::zeros(x.raw_dim());
let mut x_hat = Array2::<F>::zeros(x.raw_dim());
let mut mean = Array1::<F>::zeros(batch_size);
let mut var = Array1::<F>::zeros(batch_size);
for i in 0..batch_size {
mean[i] = F::zero();
for j in 0..num_features {
mean[i] = mean[i] + x[[i, j]];
}
mean[i] = mean[i] / F::from(num_features).expect("Failed to convert to float");
var[i] = F::zero();
for j in 0..num_features {
let diff = x[[i, j]] - mean[i];
var[i] = var[i] + diff * diff;
}
var[i] = var[i] / F::from(num_features).expect("Failed to convert to float");
for j in 0..num_features {
x_hat[[i, j]] = (x[[i, j]] - mean[i]) / (var[i] + eps).sqrt();
out[[i, j]] = gamma[j] * x_hat[[i, j]] + beta[j];
}
}
let cache = (x_hat, mean, var, gamma.to_owned(), eps);
Ok((out, cache))
}
pub fn layer_norm_backward<F>(
dout: &ArrayView2<F>,
cache: &(Array2<F>, Array1<F>, Array1<F>, Array1<F>, F),
) -> Result<(Array2<F>, Array1<F>, Array1<F>)>
where
F: Float + Debug + FromPrimitive,
{
let (x_hat, _mean, var, gamma, eps) = cache;
let batch_size = dout.shape()[0];
let num_features = dout.shape()[1];
if x_hat.shape() != dout.shape() {
return Err(NeuralError::ShapeMismatch(format!(
"Shape mismatch in layer_norm_backward: dout shape {:?}, x_hat shape {:?}",
dout.shape(),
x_hat.shape()
)));
}
let mut dx = Array2::<F>::zeros(dout.raw_dim());
let mut dgamma = Array1::<F>::zeros(gamma.raw_dim());
let mut dbeta = Array1::<F>::zeros(gamma.raw_dim());
for j in 0..num_features {
for i in 0..batch_size {
dbeta[j] = dbeta[j] + dout[[i, j]];
}
}
for j in 0..num_features {
for i in 0..batch_size {
dgamma[j] = dgamma[j] + dout[[i, j]] * x_hat[[i, j]];
}
}
let mut dx_hat = Array2::<F>::zeros(dout.raw_dim());
for i in 0..batch_size {
for j in 0..num_features {
dx_hat[[i, j]] = dout[[i, j]] * gamma[j];
}
}
let num_features_f = F::from(num_features).expect("Failed to convert to float");
for i in 0..batch_size {
let std_inv = F::one() / (var[i] + *eps).sqrt();
let mut sum_dx_hat = F::zero();
let mut sum_dx_hat_x_hat = F::zero();
for j in 0..num_features {
sum_dx_hat = sum_dx_hat + dx_hat[[i, j]];
sum_dx_hat_x_hat = sum_dx_hat_x_hat + dx_hat[[i, j]] * x_hat[[i, j]];
}
for j in 0..num_features {
dx[[i, j]] =
dx_hat[[i, j]] - (sum_dx_hat + x_hat[[i, j]] * sum_dx_hat_x_hat) / num_features_f;
dx[[i, j]] = dx[[i, j]] * std_inv;
}
}
Ok((dx, dgamma, dbeta))
}
pub fn clip_grad_norm<F, D>(grad: &ArrayView<F, D>, max_norm: F) -> Result<Array<F, D>>
where
F: Float + Debug + FromPrimitive,
D: Dimension,
{
let mut grad_squared_sum = F::zero();
for &g in grad.iter() {
grad_squared_sum = grad_squared_sum + g * g;
}
let grad_norm = grad_squared_sum.sqrt();
if grad_norm <= max_norm {
return Ok(grad.to_owned());
}
let scale = max_norm / grad_norm;
let mut clipped_grad = grad.to_owned();
for g in clipped_grad.iter_mut() {
*g = *g * scale;
}
Ok(clipped_grad)
}