use scirs2_core::ndarray::Array1;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum NormType {
LayerNorm,
#[default]
RMSNorm, None,
}
#[derive(Debug, Clone)]
pub struct LayerNorm {
gamma: Array1<f32>, beta: Array1<f32>, eps: f32,
norm_type: NormType,
}
impl LayerNorm {
pub fn new(dim: usize, norm_type: NormType) -> Self {
Self {
gamma: Array1::ones(dim),
beta: Array1::zeros(dim),
eps: 1e-5,
norm_type,
}
}
pub fn with_eps(mut self, eps: f32) -> Self {
self.eps = eps;
self
}
pub fn set_gamma(&mut self, gamma: Array1<f32>) {
self.gamma = gamma;
}
pub fn set_beta(&mut self, beta: Array1<f32>) {
self.beta = beta;
}
pub fn forward(&self, x: &Array1<f32>) -> Array1<f32> {
match self.norm_type {
NormType::LayerNorm => self.layer_norm(x),
NormType::RMSNorm => self.rms_norm(x),
NormType::None => x.clone(),
}
}
fn layer_norm(&self, x: &Array1<f32>) -> Array1<f32> {
let n = x.len() as f32;
let mean = x.sum() / n;
let var = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / n;
let std = (var + self.eps).sqrt();
let mut result = Array1::zeros(x.len());
for i in 0..x.len() {
result[i] = ((x[i] - mean) / std) * self.gamma[i] + self.beta[i];
}
result
}
fn rms_norm(&self, x: &Array1<f32>) -> Array1<f32> {
let n = x.len() as f32;
let rms = (x.iter().map(|&v| v * v).sum::<f32>() / n + self.eps).sqrt();
let mut result = Array1::zeros(x.len());
for i in 0..x.len() {
result[i] = (x[i] / rms) * self.gamma[i];
}
result
}
pub fn norm_type(&self) -> NormType {
self.norm_type
}
pub fn dim(&self) -> usize {
self.gamma.len()
}
}
pub fn layer_norm(x: &Array1<f32>, eps: f32) -> Array1<f32> {
let n = x.len() as f32;
let mean = x.sum() / n;
let var = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / n;
let std = (var + eps).sqrt();
let mut result = Array1::zeros(x.len());
for i in 0..x.len() {
result[i] = (x[i] - mean) / std;
}
result
}
pub fn rms_norm(x: &Array1<f32>, eps: f32) -> Array1<f32> {
let n = x.len() as f32;
let rms = (x.iter().map(|&v| v * v).sum::<f32>() / n + eps).sqrt();
let mut result = Array1::zeros(x.len());
for i in 0..x.len() {
result[i] = x[i] / rms;
}
result
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ActivationType {
ReLU,
GELU,
#[default]
SiLU, Sigmoid,
Tanh,
None,
}
#[derive(Debug, Clone)]
pub struct Activation {
act_type: ActivationType,
}
impl Activation {
pub fn new(act_type: ActivationType) -> Self {
Self { act_type }
}
pub fn forward(&self, x: &Array1<f32>) -> Array1<f32> {
match self.act_type {
ActivationType::ReLU => relu(x),
ActivationType::GELU => gelu(x),
ActivationType::SiLU => silu(x),
ActivationType::Sigmoid => sigmoid(x),
ActivationType::Tanh => tanh(x),
ActivationType::None => x.clone(),
}
}
pub fn act_type(&self) -> ActivationType {
self.act_type
}
}
pub fn relu(x: &Array1<f32>) -> Array1<f32> {
x.mapv(|v| v.max(0.0))
}
pub fn leaky_relu(x: &Array1<f32>, alpha: f32) -> Array1<f32> {
x.mapv(|v| if v >= 0.0 { v } else { alpha * v })
}
pub fn sigmoid(x: &Array1<f32>) -> Array1<f32> {
x.mapv(|v| 1.0 / (1.0 + (-v).exp()))
}
pub fn tanh(x: &Array1<f32>) -> Array1<f32> {
x.mapv(|v| v.tanh())
}
pub fn silu(x: &Array1<f32>) -> Array1<f32> {
x.mapv(|v| v / (1.0 + (-v).exp()))
}
pub fn gelu(x: &Array1<f32>) -> Array1<f32> {
const SQRT_2_OVER_PI: f32 = 0.797_884_6; const COEF: f32 = 0.044715;
x.mapv(|v| {
let inner = SQRT_2_OVER_PI * (v + COEF * v.powi(3));
0.5 * v * (1.0 + inner.tanh())
})
}
pub fn gelu_fast(x: &Array1<f32>) -> Array1<f32> {
x.mapv(|v| v / (1.0 + (-1.702 * v).exp()))
}
pub fn softmax(x: &Array1<f32>) -> Array1<f32> {
let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_x: Vec<f32> = x.iter().map(|&v| (v - max_val).exp()).collect();
let sum: f32 = exp_x.iter().sum();
Array1::from_vec(exp_x.iter().map(|&v| v / sum).collect())
}
pub fn log_softmax(x: &Array1<f32>) -> Array1<f32> {
let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let shifted: Array1<f32> = x.mapv(|v| v - max_val);
let log_sum_exp = shifted.mapv(|v| v.exp()).sum().ln();
shifted.mapv(|v| v - log_sum_exp)
}
#[derive(Debug, Clone)]
pub struct GatedLinearUnit {
gate_activation: ActivationType,
}
impl GatedLinearUnit {
pub fn new() -> Self {
Self {
gate_activation: ActivationType::Sigmoid,
}
}
pub fn swiglu() -> Self {
Self {
gate_activation: ActivationType::SiLU,
}
}
pub fn geglu() -> Self {
Self {
gate_activation: ActivationType::GELU,
}
}
pub fn forward(&self, x: &Array1<f32>) -> Array1<f32> {
let n = x.len();
if n < 2 {
return x.clone();
}
let half = n / 2;
let x_part: Array1<f32> = Array1::from_vec(x.iter().take(half).cloned().collect());
let gate_part: Array1<f32> =
Array1::from_vec(x.iter().skip(half).take(half).cloned().collect());
let gate = match self.gate_activation {
ActivationType::Sigmoid => sigmoid(&gate_part),
ActivationType::SiLU => silu(&gate_part),
ActivationType::GELU => gelu(&gate_part),
_ => sigmoid(&gate_part),
};
&x_part * &gate
}
}
impl Default for GatedLinearUnit {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_norm() {
let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let norm = LayerNorm::new(4, NormType::LayerNorm);
let y = norm.forward(&x);
let mean: f32 = y.sum() / y.len() as f32;
assert!(mean.abs() < 0.01);
}
#[test]
fn test_rms_norm() {
let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let norm = LayerNorm::new(4, NormType::RMSNorm);
let y = norm.forward(&x);
let rms = (y.iter().map(|v| v * v).sum::<f32>() / y.len() as f32).sqrt();
assert!((rms - 1.0).abs() < 0.1);
}
#[test]
fn test_relu() {
let x = Array1::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
let y = relu(&x);
assert_eq!(y[0], 0.0);
assert_eq!(y[1], 0.0);
assert_eq!(y[2], 0.0);
assert_eq!(y[3], 1.0);
assert_eq!(y[4], 2.0);
}
#[test]
fn test_sigmoid() {
let x = Array1::from_vec(vec![-10.0, 0.0, 10.0]);
let y = sigmoid(&x);
assert!(y[0] < 0.01); assert!((y[1] - 0.5).abs() < 0.01); assert!(y[2] > 0.99); }
#[test]
fn test_silu() {
let x = Array1::from_vec(vec![0.0, 1.0, 2.0]);
let y = silu(&x);
assert!((y[0] - 0.0).abs() < 0.01); assert!((y[1] - 0.731).abs() < 0.01); }
#[test]
fn test_gelu() {
let x = Array1::from_vec(vec![-1.0, 0.0, 1.0]);
let y = gelu(&x);
assert!((y[1] - 0.0).abs() < 0.01); assert!(y[2] > 0.5); assert!(y[0] < 0.0); }
#[test]
fn test_softmax() {
let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let y = softmax(&x);
assert!((y.sum() - 1.0).abs() < 0.01);
assert!(y[2] > y[1] && y[1] > y[0]);
}
#[test]
fn test_glu() {
let x = Array1::from_vec(vec![1.0, 2.0, 0.0, 0.0]); let glu = GatedLinearUnit::new();
let y = glu.forward(&x);
assert_eq!(y.len(), 2);
assert!((y[0] - 0.5).abs() < 0.01);
assert!((y[1] - 1.0).abs() < 0.01);
}
#[test]
fn test_swiglu() {
let x = Array1::from_vec(vec![1.0, 2.0, 1.0, 1.0]);
let glu = GatedLinearUnit::swiglu();
let y = glu.forward(&x);
assert_eq!(y.len(), 2);
assert!(y[0] > 0.0);
assert!(y[1] > 0.0);
}
}