1use ndarray::{Array, ArrayBase};
4
5pub trait NdArrayTensor<S, T, D> {
15 fn softmax(&self, axis: ndarray::Axis) -> Array<T, D>
29 where
30 D: ndarray::RemoveAxis,
31 S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
32 <S as ndarray::RawData>::Elem: std::clone::Clone,
33 T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign;
34}
35
36impl<S, T, D> NdArrayTensor<S, T, D> for ArrayBase<S, D>
37where
38 D: ndarray::RemoveAxis,
39 S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
40 <S as ndarray::RawData>::Elem: std::clone::Clone,
41 T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign,
42{
43 fn softmax(&self, axis: ndarray::Axis) -> Array<T, D> {
44 let mut new_array: Array<T, D> = self.to_owned();
45 new_array.map_inplace(|v| *v = v.exp());
49 let sum = new_array.sum_axis(axis).insert_axis(axis);
50 new_array /= ∑
51
52 new_array
53 }
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59 use ndarray::{arr1, arr2, arr3};
60 use test_env_log::test;
61
62 #[test]
63 fn softmax_1d() {
64 let array = arr1(&[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]);
65
66 let expected_softmax = arr1(&[
67 0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
68 ]);
69
70 let softmax = array.softmax(ndarray::Axis(0));
71
72 assert_eq!(softmax.shape(), expected_softmax.shape());
73
74 let diff = softmax - expected_softmax;
75
76 assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
77 }
78
79 #[test]
80 fn softmax_2d() {
81 let array = arr2(&[
82 [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
83 [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
84 ]);
85
86 let expected_softmax = arr2(&[
87 [
88 0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
89 ],
90 [
91 0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
92 ],
93 ]);
94
95 let softmax = array.softmax(ndarray::Axis(1));
96
97 assert_eq!(softmax.shape(), expected_softmax.shape());
98
99 let diff = softmax - expected_softmax;
100
101 assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
102 }
103
104 #[test]
105 fn softmax_3d() {
106 let array = arr3(&[
107 [
108 [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
109 [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
110 ],
111 [
112 [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
113 [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
114 ],
115 [
116 [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
117 [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
118 ],
119 ]);
120
121 let expected_softmax = arr3(&[
122 [
123 [
124 0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
125 ],
126 [
127 0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
128 ],
129 ],
130 [
131 [
132 0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
133 ],
134 [
135 0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
136 ],
137 ],
138 [
139 [
140 0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
141 ],
142 [
143 0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
144 ],
145 ],
146 ]);
147
148 let softmax = array.softmax(ndarray::Axis(2));
149
150 assert_eq!(softmax.shape(), expected_softmax.shape());
151
152 let diff = softmax - expected_softmax;
153
154 assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
155 }
156}