1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

use super::*;

#[derive(Clone, Serialize, Deserialize)]
struct Softmax;

impl Operator for Softmax {
    fn forward(&mut self, node: &Node) -> Result<()> {
        let (y, _) = node.y();
        let (x, _) = node.x(1);

        bmls::softmax(
            &x.read(),
            &mut y.write(),
            x.shape2(),
        )?;

        Ok(())
    }

    fn backward(&mut self, node: &Node) -> Result<()> {
        let (y, gy) = node.y();
        let (_, gx) = node.x(1);

        bmls::softmax_wrt_x(
            &y.read(),
            &gy.read(),
            &mut gx.write(),
            y.shape2(),
        )?;

        Ok(())
    }
}

impl Display for Softmax {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "Softmax")
    }
}

pub fn softmax<'t>(x: Var<'t>) -> Var<'t> {
    x.extend(NodeBuilder {
        op: Box::new(Softmax),
        deps: vec![x.index],
        shape: x.shape,
        skip: false,
        init: None,
        is_batched: x.is_batched,
    })
}