minidx_core/layers/
softmax.rs1use crate::Float;
2
3#[derive(Clone, Debug)]
5pub struct Softmax(pub f32);
6
7impl Default for Softmax {
8 fn default() -> Self {
9 Self(1.0)
10 }
11}
12
13impl Softmax {
14 #[inline]
15 fn forward<E: Float, const I: usize>(&self, input: &[E; I]) -> [E; I] {
16 let t = E::from_f32(self.0).unwrap();
17 let mut out: [E; I] = [E::default(); I];
18
19 let max_val = input.iter().fold(E::NEG_INFINITY, |l, r| E::max(l, *r));
20
21 out.iter_mut().zip(input.iter()).for_each(|(o, &x)| {
23 *o = ((x - max_val) / t).exp();
24 });
25
26 let sum_exp = out.iter().fold(E::default(), |a, x| a + *x);
28 out.iter_mut().for_each(|o| *o /= sum_exp);
29
30 out
31 }
32
33 #[inline]
34 fn backprop<E: Float, const I: usize>(
35 &self,
36 input: &[E; I],
37 grads_wrt_output: &[E; I],
38 ) -> [E; I] {
39 let t = E::from_f32(self.0).unwrap();
40 let output = self.forward(input);
41
42 let mut out: [E; I] = [E::default(); I];
43 out.iter_mut().enumerate().for_each(|(i, o)| {
44 let mut sum = E::default();
45 for j in 0..I {
46 let kronecker = if i == j { E::ONE } else { E::default() };
47 sum += (kronecker - output[j]) * grads_wrt_output[j];
48 }
49 *o = output[i] * sum / t;
50 });
51
52 out
53 }
54}
55
56impl crate::BaseModule for Softmax {}
57
58impl<E: Float, const I: usize> crate::Module<[E; I]> for Softmax {
59 type Output = [E; I];
60
61 fn forward(&self, x: &[E; I]) -> Result<Self::Output, crate::Error> {
62 Ok(Softmax::forward(self, x))
63 }
64}
65
66impl<E: Float, const I: usize> crate::RevModule<[E; I]> for Softmax {
67 type SelfGrads = ();
68
69 fn reverse(&self, inputs: &[E; I], grads_wrt_output: &[E; I]) -> ([E; I], Self::SelfGrads) {
70 (self.backprop(inputs, grads_wrt_output), ())
71 }
72
73 fn apply(
74 &mut self,
75 _applyer: &mut impl crate::optimizers::GradApplyer,
76 _updates: Self::SelfGrads,
77 ) -> Result<(), crate::Error> {
78 Ok(())
79 }
80}
81
82impl crate::LoadableModule for Softmax {
83 fn save(
84 &self,
85 _path: String,
86 _dict: &mut std::collections::HashMap<String, Vec<f64>>,
87 ) -> Result<(), crate::LoadSaveError> {
88 Ok(())
89 }
90
91 fn load(
92 &mut self,
93 _path: String,
94 _dict: &std::collections::HashMap<String, Vec<f64>>,
95 ) -> Result<(), crate::LoadSaveError> {
96 Ok(())
97 }
98}
99
100impl crate::ResetParams for Softmax {
101 fn rand_params<RNG: rand::Rng>(
102 &mut self,
103 _rng: &mut RNG,
104 _scale: f32,
105 ) -> Result<(), crate::Error> {
106 Ok(())
107 }
108}
109
110impl crate::VisualizableUnit for Softmax {
111 const KIND: &'static str = "softmax";
112 type Params = ();
113 fn params(&self) -> &Self::Params {
114 &()
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_softmax_zeros() {
124 assert_eq!([0.5, 0.5f32], (Softmax::default()).forward(&[0.0, 0.0f32]));
125 assert_eq!(
126 [0.25, 0.25, 0.25, 0.25f32],
127 (Softmax::default()).forward(&[0.0, 0.0, 0.0, 0.0f32])
128 );
129 }
130
131 #[test]
132 fn test_softmax() {
133 let [l, r] = (Softmax::default()).forward(&[0.01, 1.0f32]);
134 assert!(l < r);
135 assert!(l < r / 2.0);
136 assert!(l > r / 100.0);
137 }
138
139 #[test]
140 fn test_softmax_backprop() {
141 let [lg, rg] = (Softmax::default()).backprop(&[0.01, 1.0f32], &[1.0, -1.0f32]);
142 assert!(lg > rg);
143 }
145}