1use scirs2_core::ndarray::Array1;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum NormType {
14 LayerNorm,
16 #[default]
18 RMSNorm, None,
21}
22
23#[derive(Debug, Clone)]
25pub struct LayerNorm {
26 gamma: Array1<f32>, beta: Array1<f32>, eps: f32,
29 norm_type: NormType,
30}
31
32impl LayerNorm {
33 pub fn new(dim: usize, norm_type: NormType) -> Self {
35 Self {
36 gamma: Array1::ones(dim),
37 beta: Array1::zeros(dim),
38 eps: 1e-5,
39 norm_type,
40 }
41 }
42
43 pub fn with_eps(mut self, eps: f32) -> Self {
45 self.eps = eps;
46 self
47 }
48
49 pub fn set_gamma(&mut self, gamma: Array1<f32>) {
51 self.gamma = gamma;
52 }
53
54 pub fn set_beta(&mut self, beta: Array1<f32>) {
56 self.beta = beta;
57 }
58
59 pub fn forward(&self, x: &Array1<f32>) -> Array1<f32> {
61 match self.norm_type {
62 NormType::LayerNorm => self.layer_norm(x),
63 NormType::RMSNorm => self.rms_norm(x),
64 NormType::None => x.clone(),
65 }
66 }
67
68 fn layer_norm(&self, x: &Array1<f32>) -> Array1<f32> {
70 let n = x.len() as f32;
71 let mean = x.sum() / n;
72 let var = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / n;
73 let std = (var + self.eps).sqrt();
74
75 let mut result = Array1::zeros(x.len());
76 for i in 0..x.len() {
77 result[i] = ((x[i] - mean) / std) * self.gamma[i] + self.beta[i];
78 }
79 result
80 }
81
82 fn rms_norm(&self, x: &Array1<f32>) -> Array1<f32> {
84 let n = x.len() as f32;
85 let rms = (x.iter().map(|&v| v * v).sum::<f32>() / n + self.eps).sqrt();
86
87 let mut result = Array1::zeros(x.len());
88 for i in 0..x.len() {
89 result[i] = (x[i] / rms) * self.gamma[i];
90 }
91 result
92 }
93
94 pub fn norm_type(&self) -> NormType {
96 self.norm_type
97 }
98
99 pub fn dim(&self) -> usize {
101 self.gamma.len()
102 }
103}
104
105pub fn layer_norm(x: &Array1<f32>, eps: f32) -> Array1<f32> {
107 let n = x.len() as f32;
108 let mean = x.sum() / n;
109 let var = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / n;
110 let std = (var + eps).sqrt();
111
112 let mut result = Array1::zeros(x.len());
113 for i in 0..x.len() {
114 result[i] = (x[i] - mean) / std;
115 }
116 result
117}
118
119pub fn rms_norm(x: &Array1<f32>, eps: f32) -> Array1<f32> {
121 let n = x.len() as f32;
122 let rms = (x.iter().map(|&v| v * v).sum::<f32>() / n + eps).sqrt();
123
124 let mut result = Array1::zeros(x.len());
125 for i in 0..x.len() {
126 result[i] = x[i] / rms;
127 }
128 result
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
137pub enum ActivationType {
138 ReLU,
140 GELU,
142 #[default]
144 SiLU, Sigmoid,
147 Tanh,
149 None,
151}
152
153#[derive(Debug, Clone)]
155pub struct Activation {
156 act_type: ActivationType,
157}
158
159impl Activation {
160 pub fn new(act_type: ActivationType) -> Self {
162 Self { act_type }
163 }
164
165 pub fn forward(&self, x: &Array1<f32>) -> Array1<f32> {
167 match self.act_type {
168 ActivationType::ReLU => relu(x),
169 ActivationType::GELU => gelu(x),
170 ActivationType::SiLU => silu(x),
171 ActivationType::Sigmoid => sigmoid(x),
172 ActivationType::Tanh => tanh(x),
173 ActivationType::None => x.clone(),
174 }
175 }
176
177 pub fn act_type(&self) -> ActivationType {
179 self.act_type
180 }
181}
182
183pub fn relu(x: &Array1<f32>) -> Array1<f32> {
185 x.mapv(|v| v.max(0.0))
186}
187
188pub fn leaky_relu(x: &Array1<f32>, alpha: f32) -> Array1<f32> {
190 x.mapv(|v| if v >= 0.0 { v } else { alpha * v })
191}
192
193pub fn sigmoid(x: &Array1<f32>) -> Array1<f32> {
195 x.mapv(|v| 1.0 / (1.0 + (-v).exp()))
196}
197
198pub fn tanh(x: &Array1<f32>) -> Array1<f32> {
200 x.mapv(|v| v.tanh())
201}
202
203pub fn silu(x: &Array1<f32>) -> Array1<f32> {
207 x.mapv(|v| v / (1.0 + (-v).exp()))
208}
209
210pub fn gelu(x: &Array1<f32>) -> Array1<f32> {
214 const SQRT_2_OVER_PI: f32 = 0.797_884_6; const COEF: f32 = 0.044715;
216
217 x.mapv(|v| {
218 let inner = SQRT_2_OVER_PI * (v + COEF * v.powi(3));
219 0.5 * v * (1.0 + inner.tanh())
220 })
221}
222
223pub fn gelu_fast(x: &Array1<f32>) -> Array1<f32> {
225 x.mapv(|v| v / (1.0 + (-1.702 * v).exp()))
226}
227
228pub fn softmax(x: &Array1<f32>) -> Array1<f32> {
230 let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
231 let exp_x: Vec<f32> = x.iter().map(|&v| (v - max_val).exp()).collect();
232 let sum: f32 = exp_x.iter().sum();
233 Array1::from_vec(exp_x.iter().map(|&v| v / sum).collect())
234}
235
236pub fn log_softmax(x: &Array1<f32>) -> Array1<f32> {
238 let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
239 let shifted: Array1<f32> = x.mapv(|v| v - max_val);
240 let log_sum_exp = shifted.mapv(|v| v.exp()).sum().ln();
241 shifted.mapv(|v| v - log_sum_exp)
242}
243
244#[derive(Debug, Clone)]
252pub struct GatedLinearUnit {
253 gate_activation: ActivationType,
255}
256
257impl GatedLinearUnit {
258 pub fn new() -> Self {
260 Self {
261 gate_activation: ActivationType::Sigmoid,
262 }
263 }
264
265 pub fn swiglu() -> Self {
267 Self {
268 gate_activation: ActivationType::SiLU,
269 }
270 }
271
272 pub fn geglu() -> Self {
274 Self {
275 gate_activation: ActivationType::GELU,
276 }
277 }
278
279 pub fn forward(&self, x: &Array1<f32>) -> Array1<f32> {
283 let n = x.len();
284 if n < 2 {
285 return x.clone();
286 }
287
288 let half = n / 2;
289 let x_part: Array1<f32> = Array1::from_vec(x.iter().take(half).cloned().collect());
290 let gate_part: Array1<f32> =
291 Array1::from_vec(x.iter().skip(half).take(half).cloned().collect());
292
293 let gate = match self.gate_activation {
294 ActivationType::Sigmoid => sigmoid(&gate_part),
295 ActivationType::SiLU => silu(&gate_part),
296 ActivationType::GELU => gelu(&gate_part),
297 _ => sigmoid(&gate_part),
298 };
299
300 &x_part * &gate
301 }
302}
303
304impl Default for GatedLinearUnit {
305 fn default() -> Self {
306 Self::new()
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_layer_norm() {
316 let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
317 let norm = LayerNorm::new(4, NormType::LayerNorm);
318 let y = norm.forward(&x);
319
320 let mean: f32 = y.sum() / y.len() as f32;
322 assert!(mean.abs() < 0.01);
323 }
324
325 #[test]
326 fn test_rms_norm() {
327 let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
328 let norm = LayerNorm::new(4, NormType::RMSNorm);
329 let y = norm.forward(&x);
330
331 let rms = (y.iter().map(|v| v * v).sum::<f32>() / y.len() as f32).sqrt();
333 assert!((rms - 1.0).abs() < 0.1);
334 }
335
336 #[test]
337 fn test_relu() {
338 let x = Array1::from_vec(vec![-2.0, -1.0, 0.0, 1.0, 2.0]);
339 let y = relu(&x);
340 assert_eq!(y[0], 0.0);
341 assert_eq!(y[1], 0.0);
342 assert_eq!(y[2], 0.0);
343 assert_eq!(y[3], 1.0);
344 assert_eq!(y[4], 2.0);
345 }
346
347 #[test]
348 fn test_sigmoid() {
349 let x = Array1::from_vec(vec![-10.0, 0.0, 10.0]);
350 let y = sigmoid(&x);
351 assert!(y[0] < 0.01); assert!((y[1] - 0.5).abs() < 0.01); assert!(y[2] > 0.99); }
355
356 #[test]
357 fn test_silu() {
358 let x = Array1::from_vec(vec![0.0, 1.0, 2.0]);
359 let y = silu(&x);
360 assert!((y[0] - 0.0).abs() < 0.01); assert!((y[1] - 0.731).abs() < 0.01); }
363
364 #[test]
365 fn test_gelu() {
366 let x = Array1::from_vec(vec![-1.0, 0.0, 1.0]);
367 let y = gelu(&x);
368 assert!((y[1] - 0.0).abs() < 0.01); assert!(y[2] > 0.5); assert!(y[0] < 0.0); }
372
373 #[test]
374 fn test_softmax() {
375 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
376 let y = softmax(&x);
377
378 assert!((y.sum() - 1.0).abs() < 0.01);
380 assert!(y[2] > y[1] && y[1] > y[0]);
382 }
383
384 #[test]
385 fn test_glu() {
386 let x = Array1::from_vec(vec![1.0, 2.0, 0.0, 0.0]); let glu = GatedLinearUnit::new();
388 let y = glu.forward(&x);
389
390 assert_eq!(y.len(), 2);
391 assert!((y[0] - 0.5).abs() < 0.01);
394 assert!((y[1] - 1.0).abs() < 0.01);
395 }
396
397 #[test]
398 fn test_swiglu() {
399 let x = Array1::from_vec(vec![1.0, 2.0, 1.0, 1.0]);
400 let glu = GatedLinearUnit::swiglu();
401 let y = glu.forward(&x);
402
403 assert_eq!(y.len(), 2);
404 assert!(y[0] > 0.0);
406 assert!(y[1] > 0.0);
407 }
408}