use ndarray::{Array, ArrayBase};
pub trait NdArrayTensor<S, T, D> {
fn softmax(&self, axis: ndarray::Axis) -> Array<T, D>
where
D: ndarray::RemoveAxis,
S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
<S as ndarray::RawData>::Elem: std::clone::Clone,
T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign;
}
impl<S, T, D> NdArrayTensor<S, T, D> for ArrayBase<S, D>
where
D: ndarray::RemoveAxis,
S: ndarray::RawData + ndarray::Data + ndarray::RawData<Elem = T>,
<S as ndarray::RawData>::Elem: std::clone::Clone,
T: ndarray::NdFloat + std::ops::SubAssign + std::ops::DivAssign,
{
fn softmax(&self, axis: ndarray::Axis) -> Array<T, D> {
let mut new_array: Array<T, D> = self.to_owned();
new_array.map_inplace(|v| *v = v.exp());
let sum = new_array.sum_axis(axis).insert_axis(axis);
new_array /= ∑
new_array
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{arr1, arr2, arr3};
use test_env_log::test;
#[test]
fn softmax_1d() {
let array = arr1(&[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]);
let expected_softmax = arr1(&[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
]);
let softmax = array.softmax(ndarray::Axis(0));
assert_eq!(softmax.shape(), expected_softmax.shape());
let diff = softmax - expected_softmax;
assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
}
#[test]
fn softmax_2d() {
let array = arr2(&[
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
]);
let expected_softmax = arr2(&[
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
]);
let softmax = array.softmax(ndarray::Axis(1));
assert_eq!(softmax.shape(), expected_softmax.shape());
let diff = softmax - expected_softmax;
assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
}
#[test]
fn softmax_3d() {
let array = arr3(&[
[
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
],
[
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
],
[
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
[1.0_f32, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0],
],
]);
let expected_softmax = arr3(&[
[
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
],
[
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
],
[
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
[
0.02364054, 0.06426166, 0.1746813, 0.474833, 0.02364054, 0.06426166, 0.1746813,
],
],
]);
let softmax = array.softmax(ndarray::Axis(2));
assert_eq!(softmax.shape(), expected_softmax.shape());
let diff = softmax - expected_softmax;
assert!(diff.iter().all(|d| d.abs() < 1.0e-7));
}
}