use super::NnResult;
use crate::error::NumRs2Error;
use scirs2_core::ndarray::{
Array, Array1, Array2, ArrayView, ArrayView1, ArrayView2, Axis, ScalarOperand,
};
use scirs2_core::numeric::Float;
use scirs2_core::random::*;
use scirs2_core::simd_ops::SimdUnifiedOps;
pub fn batch_norm_1d<T>(
x: &ArrayView2<T>,
gamma: &ArrayView1<T>,
beta: &ArrayView1<T>,
epsilon: T,
) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
if x.ncols() != gamma.len() || x.ncols() != beta.len() {
return Err(NumRs2Error::DimensionMismatch(
"Gamma and beta dimensions must match input features".to_string(),
));
}
let n = T::from(x.nrows())
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert batch size".to_string()))?;
let mut result = Array2::zeros(x.raw_dim());
for j in 0..x.ncols() {
let col = x.column(j);
let mean = col.sum() / n;
let var = col.mapv(|v| (v - mean) * (v - mean)).sum() / n;
let std = (var + epsilon).sqrt();
let g = gamma[j];
let b = beta[j];
for i in 0..x.nrows() {
result[[i, j]] = (x[[i, j]] - mean) / std * g + b;
}
}
Ok(result)
}
pub fn layer_norm<T>(
x: &ArrayView2<T>,
gamma: &ArrayView1<T>,
beta: &ArrayView1<T>,
epsilon: T,
) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
if x.ncols() != gamma.len() || x.ncols() != beta.len() {
return Err(NumRs2Error::DimensionMismatch(
"Gamma and beta dimensions must match input features".to_string(),
));
}
let n_features = T::from(x.ncols()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert feature count".to_string())
})?;
let mut result = Array2::zeros(x.raw_dim());
for i in 0..x.nrows() {
let row = x.row(i);
let mean = row.sum() / n_features;
let var = row.mapv(|v| (v - mean) * (v - mean)).sum() / n_features;
let std = (var + epsilon).sqrt();
for j in 0..x.ncols() {
result[[i, j]] = (x[[i, j]] - mean) / std * gamma[j] + beta[j];
}
}
Ok(result)
}
pub fn rms_norm<T>(x: &ArrayView2<T>, gamma: &ArrayView1<T>, epsilon: T) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
if x.ncols() != gamma.len() {
return Err(NumRs2Error::DimensionMismatch(
"Gamma dimension must match input features".to_string(),
));
}
let n_features = T::from(x.ncols()).ok_or_else(|| {
NumRs2Error::ConversionError("Failed to convert feature count".to_string())
})?;
let mut result = Array2::zeros(x.raw_dim());
for i in 0..x.nrows() {
let row = x.row(i);
let rms = (row.mapv(|v| v * v).sum() / n_features + epsilon).sqrt();
for j in 0..x.ncols() {
result[[i, j]] = x[[i, j]] / rms * gamma[j];
}
}
Ok(result)
}
pub fn dropout<T>(x: &ArrayView1<T>, p: T, training: bool) -> NnResult<Array1<T>>
where
T: Float + SimdUnifiedOps,
{
if p < T::zero() || p >= T::one() {
return Err(NumRs2Error::InvalidOperation(
"Dropout probability must be in [0, 1)".to_string(),
));
}
if !training || p == T::zero() {
return Ok(x.to_owned());
}
let mut rng = thread_rng();
let threshold = p
.to_f64()
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert probability".to_string()))?;
let scale = T::one() / (T::one() - p);
let mask: Array1<T> = Array1::from_shape_fn(x.len(), |_| {
if rng.gen::<f64>() > threshold {
scale
} else {
T::zero()
}
});
Ok(x * &mask)
}
pub fn dropout_2d<T>(x: &ArrayView2<T>, p: T, training: bool) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps,
{
if p < T::zero() || p >= T::one() {
return Err(NumRs2Error::InvalidOperation(
"Dropout probability must be in [0, 1)".to_string(),
));
}
if !training || p == T::zero() {
return Ok(x.to_owned());
}
let mut rng = thread_rng();
let threshold = p
.to_f64()
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert probability".to_string()))?;
let scale = T::one() / (T::one() - p);
let mask: Array2<T> = Array2::from_shape_fn(x.raw_dim(), |_| {
if rng.gen::<f64>() > threshold {
scale
} else {
T::zero()
}
});
Ok(x * &mask)
}
pub fn spatial_dropout<T>(x: &ArrayView2<T>, p: T, training: bool) -> NnResult<Array2<T>>
where
T: Float + SimdUnifiedOps + ScalarOperand,
{
if p < T::zero() || p >= T::one() {
return Err(NumRs2Error::InvalidOperation(
"Dropout probability must be in [0, 1)".to_string(),
));
}
if !training || p == T::zero() {
return Ok(x.to_owned());
}
let mut rng = thread_rng();
let threshold = p
.to_f64()
.ok_or_else(|| NumRs2Error::ConversionError("Failed to convert probability".to_string()))?;
let scale = T::one() / (T::one() - p);
let mut result = x.to_owned();
for j in 0..x.ncols() {
if rng.gen::<f64>() <= threshold {
for i in 0..x.nrows() {
result[[i, j]] = T::zero();
}
} else {
for i in 0..x.nrows() {
result[[i, j]] = result[[i, j]] * scale;
}
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::{array, Array2};
#[test]
fn test_layer_norm() {
let x = Array2::from_shape_fn((2, 3), |(i, j)| (i * 3 + j) as f64);
let gamma = Array1::ones(3);
let beta = Array1::zeros(3);
let result = layer_norm(&x.view(), &gamma.view(), &beta.view(), 1e-5).unwrap();
for i in 0..result.nrows() {
let row = result.row(i);
let mean = row.sum() / row.len() as f64;
assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-5);
}
}
#[test]
fn test_dropout_inference() {
let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
let result = dropout(&x.view(), 0.5, false).unwrap();
assert_eq!(result, x);
}
#[test]
fn test_dropout_training() {
let x = Array1::from_vec((1..=1000).map(|i| i as f64).collect());
let result = dropout(&x.view(), 0.5, true).unwrap();
let num_zeros = result.iter().filter(|&&v| v == 0.0).count();
assert!(num_zeros > 0, "Expected some zeros in dropout, got none");
let non_zero_count = result.iter().filter(|&&v| v != 0.0).count();
assert!(
non_zero_count > 0,
"Expected some non-zero values in dropout result"
);
let dropout_rate = (num_zeros as f64) / (x.len() as f64);
assert!(
(dropout_rate - 0.5).abs() < 0.1,
"Dropout rate {:.2}% should be close to 50%",
dropout_rate * 100.0
);
}
}