mcai_onnxruntime/tensor/
ndarray_tensor.rs

1//! Module containing a tensor trait extending [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)
2
3use ndarray::{Array, ArrayBase};
4
5/// Trait extending [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)
6/// with useful tensor operations.
7///
8/// # Generic
9///
10/// The trait is generic over:
11/// * `S`: [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)'s data container
12/// * `T`: Type contained inside the tensor (for example `f32`)
13/// * `D`: Tensor's dimension ([`ndarray::Dimension`](https://docs.rs/ndarray/latest/ndarray/trait.Dimension.html))
14pub trait NdArrayTensor<S, T, D> {
15    /// Calculate the [softmax](https://en.wikipedia.org/wiki/Softmax_function) of the tensor along a given axis
16    ///
17    /// # Trait Bounds
18    ///
19    /// The function is generic and thus has some trait bounds:
20    /// * `D: ndarray::RemoveAxis`: The summation over an axis reduces the dimension of the tensor. A 0-D tensor thus
21    ///   cannot have a softmax calculated.
22    /// * `S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>`: The storage of the tensor can be an owned
23    ///   array ([`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)) or an array view
24    ///   ([`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html)).
25    /// * `<S as ndarray::RawData>::Elem: std::clone::Clone`: The elements of the tensor must be `Clone`.
26    /// * `T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign`: The elements of the tensor must be workable
27    ///   as floats and must support `-=` and `/=` operations.
28    fn softmax(&self, axis: ndarray::Axis) -> Array<T, D>
29    where
30        D: ndarray::RemoveAxis,
31        S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
32        <S as ndarray::RawData>::Elem: std::clone::Clone,
33        T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign;
34}
35
36impl<S, T, D> NdArrayTensor<S, T, D> for ArrayBase<S, D>
37where
38    D: ndarray::RemoveAxis,
39    S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
40    <S as ndarray::RawData>::Elem: std::clone::Clone,
41    T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign,
42{
43    fn softmax(&self, axis: ndarray::Axis) -> Array<T, D> {
44        let mut new_array: Array<T, D> = self.to_owned();
45        // FIXME: Change to non-overflowing formula
46        // e = np.exp(A - np.sum(A, axis=1, keepdims=True))
47        // np.exp(a) / np.sum(np.exp(a))
48        new_array.map_inplace(|v| *v = v.exp());
49        let sum = new_array.sum_axis(axis).insert_axis(axis);
50        new_array /= &sum;
51
52        new_array
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use ndarray::{arr1, arr2, arr3};
60    use test_env_log::test;
61
62    #[test]
63    fn softmax_1d() {
64        let array = arr1(&[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]);
65
66        let expected_softmax = arr1(&[
67            0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
68        ]);
69
70        let softmax = array.softmax(ndarray::Axis(0));
71
72        assert_eq!(softmax.shape(), expected_softmax.shape());
73
74        let diff = softmax - expected_softmax;
75
76        assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
77    }
78
79    #[test]
80    fn softmax_2d() {
81        let array = arr2(&[
82            [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
83            [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
84        ]);
85
86        let expected_softmax = arr2(&[
87            [
88                0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
89            ],
90            [
91                0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
92            ],
93        ]);
94
95        let softmax = array.softmax(ndarray::Axis(1));
96
97        assert_eq!(softmax.shape(), expected_softmax.shape());
98
99        let diff = softmax - expected_softmax;
100
101        assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
102    }
103
104    #[test]
105    fn softmax_3d() {
106        let array = arr3(&[
107            [
108                [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
109                [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
110            ],
111            [
112                [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
113                [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
114            ],
115            [
116                [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
117                [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
118            ],
119        ]);
120
121        let expected_softmax = arr3(&[
122            [
123                [
124                    0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
125                ],
126                [
127                    0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
128                ],
129            ],
130            [
131                [
132                    0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
133                ],
134                [
135                    0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
136                ],
137            ],
138            [
139                [
140                    0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
141                ],
142                [
143                    0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
144                ],
145            ],
146        ]);
147
148        let softmax = array.softmax(ndarray::Axis(2));
149
150        assert_eq!(softmax.shape(), expected_softmax.shape());
151
152        let diff = softmax - expected_softmax;
153
154        assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
155    }
156}