1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
//! Helper traits to extend [`ndarray`] functionality.

use ndarray::{Array, ArrayBase};

/// Trait extending [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)
/// with useful tensor operations.
///
/// # Generic
///
/// The trait is generic over:
/// * `S`: [`ndarray::ArrayBase`](https://docs.rs/ndarray/latest/ndarray/struct.ArrayBase.html)'s data container
/// * `T`: Type contained inside the tensor (for example `f32`)
/// * `D`: Tensor's dimension ([`ndarray::Dimension`](https://docs.rs/ndarray/latest/ndarray/trait.Dimension.html))
pub trait ArrayExtensions<S, T, D> {
	/// Calculate the [softmax](https://en.wikipedia.org/wiki/Softmax_function) of the tensor along a given axis
	///
	/// # Trait Bounds
	///
	/// The function is generic and thus has some trait bounds:
	/// * `D: ndarray::RemoveAxis`: The summation over an axis reduces the dimension of the tensor. A 0-D tensor thus
	///   cannot have a softmax calculated.
	/// * `S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>`: The storage of the tensor can be an owned
	///   array ([`ndarray::Array`](https://docs.rs/ndarray/latest/ndarray/type.Array.html)) or an array view
	///   ([`ndarray::ArrayView`](https://docs.rs/ndarray/latest/ndarray/type.ArrayView.html)).
	/// * `<S as ndarray::RawData>::Elem: std::clone::Clone`: The elements of the tensor must be `Clone`.
	/// * `T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign`: The elements of the tensor must be workable
	///   as floats and must support `-=` and `/=` operations.
	fn softmax(&self, axis: ndarray::Axis) -> Array<T, D>
	where
		D: ndarray::RemoveAxis,
		S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
		<S as ndarray::RawData>::Elem: std::clone::Clone,
		T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign;
}

impl<S, T, D> ArrayExtensions<S, T, D> for ArrayBase<S, D>
where
	D: ndarray::RemoveAxis,
	S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
	<S as ndarray::RawData>::Elem: std::clone::Clone,
	T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign
{
	fn softmax(&self, axis: ndarray::Axis) -> Array<T, D> {
		let mut new_array: Array<T, D> = self.to_owned();
		// FIXME: Change to non-overflowing formula
		// e = np.exp(A - np.sum(A, axis=1, keepdims=True))
		// np.exp(a) / np.sum(np.exp(a))
		new_array.map_inplace(|v| *v = v.exp());
		let sum = new_array.sum_axis(axis).insert_axis(axis);
		new_array /= &sum;

		new_array
	}
}

#[cfg(test)]
mod tests {
	use ndarray::{arr1, arr2, arr3};
	use test_log::test;

	use super::*;

	#[test]
	fn softmax_1d() {
		let array = arr1(&[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]);

		let expected_softmax = arr1(&[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813]);

		let softmax = array.softmax(ndarray::Axis(0));

		assert_eq!(softmax.shape(), expected_softmax.shape());

		let diff = softmax - expected_softmax;

		assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
	}

	#[test]
	fn softmax_2d() {
		let array = arr2(&[[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]]);

		let expected_softmax = arr2(&[
			[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813],
			[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813]
		]);

		let softmax = array.softmax(ndarray::Axis(1));

		assert_eq!(softmax.shape(), expected_softmax.shape());

		let diff = softmax - expected_softmax;

		assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
	}

	#[test]
	fn softmax_3d() {
		let array = arr3(&[
			[[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]],
			[[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]],
			[[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0], [1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]]
		]);

		let expected_softmax = arr3(&[
			[
				[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813],
				[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813]
			],
			[
				[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813],
				[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813]
			],
			[
				[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813],
				[0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813]
			]
		]);

		let softmax = array.softmax(ndarray::Axis(2));

		assert_eq!(softmax.shape(), expected_softmax.shape());

		let diff = softmax - expected_softmax;

		assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
	}
}