minidx_core/layers/
softmax.rs

1use crate::Float;
2
3/// A softmax activation layer with no trainable parameters.
4#[derive(Clone, Debug)]
5pub struct Softmax(pub f32);
6
7impl Default for Softmax {
8    fn default() -> Self {
9        Self(1.0)
10    }
11}
12
13impl Softmax {
14    #[inline]
15    fn forward<E: Float, const I: usize>(&self, input: &[E; I]) -> [E; I] {
16        let t = E::from_f32(self.0).unwrap();
17        let mut out: [E; I] = [E::default(); I];
18
19        let max_val = input.iter().fold(E::NEG_INFINITY, |l, r| E::max(l, *r));
20
21        // Compute exponential of difference between x and max value.
22        out.iter_mut().zip(input.iter()).for_each(|(o, &x)| {
23            *o = ((x - max_val) / t).exp();
24        });
25
26        // Normalize
27        let sum_exp = out.iter().fold(E::default(), |a, x| a + *x);
28        out.iter_mut().for_each(|o| *o /= sum_exp);
29
30        out
31    }
32
33    #[inline]
34    fn backprop<E: Float, const I: usize>(
35        &self,
36        input: &[E; I],
37        grads_wrt_output: &[E; I],
38    ) -> [E; I] {
39        let t = E::from_f32(self.0).unwrap();
40        let output = self.forward(input);
41
42        let mut out: [E; I] = [E::default(); I];
43        out.iter_mut().enumerate().for_each(|(i, o)| {
44            let mut sum = E::default();
45            for j in 0..I {
46                let kronecker = if i == j { E::ONE } else { E::default() };
47                sum += (kronecker - output[j]) * grads_wrt_output[j];
48            }
49            *o = output[i] * sum / t;
50        });
51
52        out
53    }
54}
55
56impl crate::BaseModule for Softmax {}
57
58impl<E: Float, const I: usize> crate::Module<[E; I]> for Softmax {
59    type Output = [E; I];
60
61    fn forward(&self, x: &[E; I]) -> Result<Self::Output, crate::Error> {
62        Ok(Softmax::forward(self, x))
63    }
64}
65
66impl<E: Float, const I: usize> crate::RevModule<[E; I]> for Softmax {
67    type SelfGrads = ();
68
69    fn reverse(&self, inputs: &[E; I], grads_wrt_output: &[E; I]) -> ([E; I], Self::SelfGrads) {
70        (self.backprop(inputs, grads_wrt_output), ())
71    }
72
73    fn apply(
74        &mut self,
75        _applyer: &mut impl crate::optimizers::GradApplyer,
76        _updates: Self::SelfGrads,
77    ) -> Result<(), crate::Error> {
78        Ok(())
79    }
80}
81
82impl crate::LoadableModule for Softmax {
83    fn save(
84        &self,
85        _path: String,
86        _dict: &mut std::collections::HashMap<String, Vec<f64>>,
87    ) -> Result<(), crate::LoadSaveError> {
88        Ok(())
89    }
90
91    fn load(
92        &mut self,
93        _path: String,
94        _dict: &std::collections::HashMap<String, Vec<f64>>,
95    ) -> Result<(), crate::LoadSaveError> {
96        Ok(())
97    }
98}
99
100impl crate::ResetParams for Softmax {
101    fn rand_params<RNG: rand::Rng>(
102        &mut self,
103        _rng: &mut RNG,
104        _scale: f32,
105    ) -> Result<(), crate::Error> {
106        Ok(())
107    }
108}
109
110impl crate::VisualizableUnit for Softmax {
111    const KIND: &'static str = "softmax";
112    type Params = ();
113    fn params(&self) -> &Self::Params {
114        &()
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn test_softmax_zeros() {
124        assert_eq!([0.5, 0.5f32], (Softmax::default()).forward(&[0.0, 0.0f32]));
125        assert_eq!(
126            [0.25, 0.25, 0.25, 0.25f32],
127            (Softmax::default()).forward(&[0.0, 0.0, 0.0, 0.0f32])
128        );
129    }
130
131    #[test]
132    fn test_softmax() {
133        let [l, r] = (Softmax::default()).forward(&[0.01, 1.0f32]);
134        assert!(l < r);
135        assert!(l < r / 2.0);
136        assert!(l > r / 100.0);
137    }
138
139    #[test]
140    fn test_softmax_backprop() {
141        let [lg, rg] = (Softmax::default()).backprop(&[0.01, 1.0f32], &[1.0, -1.0f32]);
142        assert!(lg > rg);
143        // TODO: Needs moar checking
144    }
145}