1use ndarray::prelude::*;
2use statrs::function::erf::erf;
3
4pub fn relu(x: ArrayView2<f32>) -> Array2<f32> {
5 x.mapv(|x| if x < 0.0 { 0.0 } else { x })
6}
7
8pub fn clip(x: ArrayView2<f32>, min: f32, max: f32) -> Array2<f32> {
9 x.mapv(|x| {
10 if x < min {
11 min
12 } else if x > max {
13 max
14 } else {
15 x
16 }
17 })
18}
19
20pub fn gelu(x: ArrayView2<f32>) -> Array2<f32> {
21 x.mapv(|x| 0.5 * x * (1.0 + erf((x / std::f32::consts::SQRT_2) as f64)) as f32)
22}
23
24pub fn softmax(logits: &Array2<f32>) -> Array2<f32> {
25 let mut softmax = logits.to_owned();
26 let max = softmax.fold_axis(Axis(1), 0.0, |x, y| if *x > *y { *x } else { *y });
28 for ((b, _), x) in softmax.indexed_iter_mut() {
29 *x = (*x - max[b]).exp();
30 }
31 let sum = softmax.sum_axis(Axis(1));
32 for ((b, _), x) in softmax.indexed_iter_mut() {
33 *x /= sum[b];
34 }
35 softmax
36}