rogue_net/
fun.rs

1use ndarray::prelude::*;
2use statrs::function::erf::erf;
3
4pub fn relu(x: ArrayView2<f32>) -> Array2<f32> {
5    x.mapv(|x| if x < 0.0 { 0.0 } else { x })
6}
7
8pub fn clip(x: ArrayView2<f32>, min: f32, max: f32) -> Array2<f32> {
9    x.mapv(|x| {
10        if x < min {
11            min
12        } else if x > max {
13            max
14        } else {
15            x
16        }
17    })
18}
19
20pub fn gelu(x: ArrayView2<f32>) -> Array2<f32> {
21    x.mapv(|x| 0.5 * x * (1.0 + erf((x / std::f32::consts::SQRT_2) as f64)) as f32)
22}
23
24pub fn softmax(logits: &Array2<f32>) -> Array2<f32> {
25    let mut softmax = logits.to_owned();
26    // Calculate softmax
27    let max = softmax.fold_axis(Axis(1), 0.0, |x, y| if *x > *y { *x } else { *y });
28    for ((b, _), x) in softmax.indexed_iter_mut() {
29        *x = (*x - max[b]).exp();
30    }
31    let sum = softmax.sum_axis(Axis(1));
32    for ((b, _), x) in softmax.indexed_iter_mut() {
33        *x /= sum[b];
34    }
35    softmax
36}