kizzasi_core/
nn.rs

1//! Neural network building blocks: normalization and activation functions
2//!
3//! Provides layer normalization variants and gating mechanisms for SSM architectures.
4
5use scirs2_core::ndarray::Array1;
6
7// ============================================================================
8// Layer Normalization
9// ============================================================================
10
11/// Type of normalization to apply
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum NormType {
14    /// Standard Layer Normalization (Ba et al., 2016)
15    LayerNorm,
16    /// RMS Layer Normalization (Zhang & Sennrich, 2019)
17    #[default]
18    RMSNorm, // RMSNorm is faster and commonly used in modern SSMs
19    /// No normalization
20    None,
21}
22
23/// Layer normalization with learnable parameters
24#[derive(Debug, Clone)]
25pub struct LayerNorm {
26    gamma: Array1<f32>, // scale
27    beta: Array1<f32>,  // shift
28    eps: f32,
29    norm_type: NormType,
30}
31
32impl LayerNorm {
33    /// Create a new LayerNorm
34    pub fn new(dim: usize, norm_type: NormType) -> Self {
35        Self {
36            gamma: Array1::ones(dim),
37            beta: Array1::zeros(dim),
38            eps: 1e-5,
39            norm_type,
40        }
41    }
42
43    /// Create with custom epsilon
44    pub fn with_eps(mut self, eps: f32) -> Self {
45        self.eps = eps;
46        self
47    }
48
49    /// Set gamma (scale) parameters
50    pub fn set_gamma(&mut self, gamma: Array1<f32>) {
51        self.gamma = gamma;
52    }
53
54    /// Set beta (shift) parameters
55    pub fn set_beta(&mut self, beta: Array1<f32>) {
56        self.beta = beta;
57    }
58
59    /// Apply normalization to input
60    pub fn forward(&self, x: &Array1<f32>) -> Array1<f32> {
61        match self.norm_type {
62            NormType::LayerNorm => self.layer_norm(x),
63            NormType::RMSNorm => self.rms_norm(x),
64            NormType::None => x.clone(),
65        }
66    }
67
68    /// Standard layer normalization: (x - mean) / std * gamma + beta
69    fn layer_norm(&self, x: &Array1<f32>) -> Array1<f32> {
70        let n = x.len() as f32;
71        let mean = x.sum() / n;
72        let var = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / n;
73        let std = (var + self.eps).sqrt();
74
75        let mut result = Array1::zeros(x.len());
76        for i in 0..x.len() {
77            result[i] = ((x[i] - mean) / std) * self.gamma[i] + self.beta[i];
78        }
79        result
80    }
81
82    /// RMS layer normalization: x / rms(x) * gamma
83    fn rms_norm(&self, x: &Array1<f32>) -> Array1<f32> {
84        let n = x.len() as f32;
85        let rms = (x.iter().map(|&v| v * v).sum::<f32>() / n + self.eps).sqrt();
86
87        let mut result = Array1::zeros(x.len());
88        for i in 0..x.len() {
89            result[i] = (x[i] / rms) * self.gamma[i];
90        }
91        result
92    }
93
94    /// Get norm type
95    pub fn norm_type(&self) -> NormType {
96        self.norm_type
97    }
98
99    /// Get dimension
100    pub fn dim(&self) -> usize {
101        self.gamma.len()
102    }
103}
104
105/// Standalone layer normalization function
106pub fn layer_norm(x: &Array1<f32>, eps: f32) -> Array1<f32> {
107    let n = x.len() as f32;
108    let mean = x.sum() / n;
109    let var = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / n;
110    let std = (var + eps).sqrt();
111
112    let mut result = Array1::zeros(x.len());
113    for i in 0..x.len() {
114        result[i] = (x[i] - mean) / std;
115    }
116    result
117}
118
119/// Standalone RMS normalization function
120pub fn rms_norm(x: &Array1<f32>, eps: f32) -> Array1<f32> {
121    let n = x.len() as f32;
122    let rms = (x.iter().map(|&v| v * v).sum::<f32>() / n + eps).sqrt();
123
124    let mut result = Array1::zeros(x.len());
125    for i in 0..x.len() {
126        result[i] = x[i] / rms;
127    }
128    result
129}
130
131// ============================================================================
132// Activation Functions (Gating Mechanisms)
133// ============================================================================
134
135/// Type of activation function
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
137pub enum ActivationType {
138    /// Rectified Linear Unit: max(0, x)
139    ReLU,
140    /// Gaussian Error Linear Unit: x * Phi(x)
141    GELU,
142    /// Sigmoid Linear Unit (Swish): x * sigmoid(x)
143    #[default]
144    SiLU, // SiLU is commonly used in Mamba
145    /// Sigmoid: 1 / (1 + exp(-x))
146    Sigmoid,
147    /// Hyperbolic tangent
148    Tanh,
149    /// No activation (identity)
150    None,
151}
152
153/// Gated activation with configurable type
154#[derive(Debug, Clone)]
155pub struct Activation {
156    act_type: ActivationType,
157}
158
159impl Activation {
160    /// Create a new activation
161    pub fn new(act_type: ActivationType) -> Self {
162        Self { act_type }
163    }
164
165    /// Apply activation to input
166    pub fn forward(&self, x: &Array1<f32>) -> Array1<f32> {
167        match self.act_type {
168            ActivationType::ReLU => relu(x),
169            ActivationType::GELU => gelu(x),
170            ActivationType::SiLU => silu(x),
171            ActivationType::Sigmoid => sigmoid(x),
172            ActivationType::Tanh => tanh(x),
173            ActivationType::None => x.clone(),
174        }
175    }
176
177    /// Get activation type
178    pub fn act_type(&self) -> ActivationType {
179        self.act_type
180    }
181}
182
183/// ReLU activation: max(0, x)
184pub fn relu(x: &Array1<f32>) -> Array1<f32> {
185    x.mapv(|v| v.max(0.0))
186}
187
188/// Leaky ReLU: max(alpha * x, x)
189pub fn leaky_relu(x: &Array1<f32>, alpha: f32) -> Array1<f32> {
190    x.mapv(|v| if v >= 0.0 { v } else { alpha * v })
191}
192
193/// Sigmoid activation: 1 / (1 + exp(-x))
194pub fn sigmoid(x: &Array1<f32>) -> Array1<f32> {
195    x.mapv(|v| 1.0 / (1.0 + (-v).exp()))
196}
197
198/// Tanh activation
199pub fn tanh(x: &Array1<f32>) -> Array1<f32> {
200    x.mapv(|v| v.tanh())
201}
202
203/// SiLU (Swish) activation: x * sigmoid(x)
204///
205/// Commonly used in Mamba and modern SSM architectures
206pub fn silu(x: &Array1<f32>) -> Array1<f32> {
207    x.mapv(|v| v / (1.0 + (-v).exp()))
208}
209
210/// GELU activation: x * Phi(x) where Phi is the CDF of standard normal
211///
212/// Approximation: x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
213pub fn gelu(x: &Array1<f32>) -> Array1<f32> {
214    const SQRT_2_OVER_PI: f32 = 0.797_884_6; // sqrt(2/pi)
215    const COEF: f32 = 0.044715;
216
217    x.mapv(|v| {
218        let inner = SQRT_2_OVER_PI * (v + COEF * v.powi(3));
219        0.5 * v * (1.0 + inner.tanh())
220    })
221}
222
223/// Fast GELU approximation using sigmoid
224pub fn gelu_fast(x: &Array1<f32>) -> Array1<f32> {
225    x.mapv(|v| v / (1.0 + (-1.702 * v).exp()))
226}
227
228/// Softmax: exp(x_i) / sum(exp(x))
229pub fn softmax(x: &Array1<f32>) -> Array1<f32> {
230    let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
231    let exp_x: Vec<f32> = x.iter().map(|&v| (v - max_val).exp()).collect();
232    let sum: f32 = exp_x.iter().sum();
233    Array1::from_vec(exp_x.iter().map(|&v| v / sum).collect())
234}
235
236/// Log softmax: log(softmax(x)) - more numerically stable
237pub fn log_softmax(x: &Array1<f32>) -> Array1<f32> {
238    let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
239    let shifted: Array1<f32> = x.mapv(|v| v - max_val);
240    let log_sum_exp = shifted.mapv(|v| v.exp()).sum().ln();
241    shifted.mapv(|v| v - log_sum_exp)
242}
243
244// ============================================================================
245// Gated Linear Unit
246// ============================================================================
247
248/// Gated Linear Unit: splits input in half and applies gate
249///
250/// GLU(x, W, V) = (xW + b) * sigmoid(xV + c)
251#[derive(Debug, Clone)]
252pub struct GatedLinearUnit {
253    /// Activation for the gate
254    gate_activation: ActivationType,
255}
256
257impl GatedLinearUnit {
258    /// Create a new GLU with default sigmoid gate
259    pub fn new() -> Self {
260        Self {
261            gate_activation: ActivationType::Sigmoid,
262        }
263    }
264
265    /// Create with SiLU gate (SwiGLU, commonly used in modern architectures)
266    pub fn swiglu() -> Self {
267        Self {
268            gate_activation: ActivationType::SiLU,
269        }
270    }
271
272    /// Create with GELU gate (GeGLU)
273    pub fn geglu() -> Self {
274        Self {
275            gate_activation: ActivationType::GELU,
276        }
277    }
278
279    /// Apply GLU to input (input should have even dimension)
280    ///
281    /// Splits input into [x, gate] and returns x * activation(gate)
282    pub fn forward(&self, x: &Array1<f32>) -> Array1<f32> {
283        let n = x.len();
284        if n < 2 {
285            return x.clone();
286        }
287
288        let half = n / 2;
289        let x_part: Array1<f32> = Array1::from_vec(x.iter().take(half).cloned().collect());
290        let gate_part: Array1<f32> =
291            Array1::from_vec(x.iter().skip(half).take(half).cloned().collect());
292
293        let gate = match self.gate_activation {
294            ActivationType::Sigmoid => sigmoid(&gate_part),
295            ActivationType::SiLU => silu(&gate_part),
296            ActivationType::GELU => gelu(&gate_part),
297            _ => sigmoid(&gate_part),
298        };
299
300        &x_part * &gate
301    }
302}
303
304impl Default for GatedLinearUnit {
305    fn default() -> Self {
306        Self::new()
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_layer_norm() {
316        let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
317        let norm = LayerNorm::new(4, NormType::LayerNorm);
318        let y = norm.forward(&x);
319
320        // After normalization, mean should be ~0, std should be ~1
321        let mean: f32 = y.sum() / y.len() as f32;
322        assert!(mean.abs() < 0.01);
323    }
324
325    #[test]
326    fn test_rms_norm() {
327        let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
328        let norm = LayerNorm::new(4, NormType::RMSNorm);
329        let y = norm.forward(&x);
330
331        // RMS of output should be ~1 (scaled by gamma=1)
332        let rms = (y.iter().map(|v| v * v).sum::<f32>() / y.len() as f32).sqrt();
333        assert!((rms - 1.0).abs() < 0.1);
334    }
335
336    #[test]
337    fn test_relu() {
338        let x = Array1::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
339        let y = relu(&x);
340        assert_eq!(y[0], 0.0);
341        assert_eq!(y[1], 0.0);
342        assert_eq!(y[2], 0.0);
343        assert_eq!(y[3], 1.0);
344        assert_eq!(y[4], 2.0);
345    }
346
347    #[test]
348    fn test_sigmoid() {
349        let x = Array1::from_vec(vec![-10.0, 0.0, 10.0]);
350        let y = sigmoid(&x);
351        assert!(y[0] < 0.01); // sigmoid(-10) ≈ 0
352        assert!((y[1] - 0.5).abs() < 0.01); // sigmoid(0) = 0.5
353        assert!(y[2] > 0.99); // sigmoid(10) ≈ 1
354    }
355
356    #[test]
357    fn test_silu() {
358        let x = Array1::from_vec(vec![0.0, 1.0, 2.0]);
359        let y = silu(&x);
360        assert!((y[0] - 0.0).abs() < 0.01); // silu(0) = 0 * 0.5 = 0
361        assert!((y[1] - 0.731).abs() < 0.01); // silu(1) ≈ 0.731
362    }
363
364    #[test]
365    fn test_gelu() {
366        let x = Array1::from_vec(vec![-1.0, 0.0, 1.0]);
367        let y = gelu(&x);
368        assert!((y[1] - 0.0).abs() < 0.01); // gelu(0) = 0
369        assert!(y[2] > 0.5); // gelu(1) > 0.5
370        assert!(y[0] < 0.0); // gelu(-1) < 0
371    }
372
373    #[test]
374    fn test_softmax() {
375        let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
376        let y = softmax(&x);
377
378        // Sum should be 1
379        assert!((y.sum() - 1.0).abs() < 0.01);
380        // Values should be ordered
381        assert!(y[2] > y[1] && y[1] > y[0]);
382    }
383
384    #[test]
385    fn test_glu() {
386        let x = Array1::from_vec(vec![1.0, 2.0, 0.0, 0.0]); // [data, gate]
387        let glu = GatedLinearUnit::new();
388        let y = glu.forward(&x);
389
390        assert_eq!(y.len(), 2);
391        // gate = sigmoid([0, 0]) = [0.5, 0.5]
392        // result = [1, 2] * [0.5, 0.5] = [0.5, 1.0]
393        assert!((y[0] - 0.5).abs() < 0.01);
394        assert!((y[1] - 1.0).abs() < 0.01);
395    }
396
397    #[test]
398    fn test_swiglu() {
399        let x = Array1::from_vec(vec![1.0, 2.0, 1.0, 1.0]);
400        let glu = GatedLinearUnit::swiglu();
401        let y = glu.forward(&x);
402
403        assert_eq!(y.len(), 2);
404        // SiLU gate applied
405        assert!(y[0] > 0.0);
406        assert!(y[1] > 0.0);
407    }
408}