use crate::activation::Function;
use crate::linalg::{Matrix, Vector};
use crate::Float;
use rayon::prelude::*;
pub struct SoftMax;
impl SoftMax {
pub fn new() -> Self {
Self
}
fn vec_fun<T: Float>(&self, vector: Vector<T>) -> Vector<T> {
let max = vector.max_val().unwrap();
let exps = vector.map_vec(|x| (x - max).exp());
let sum: T = exps.sum_all();
exps.map_vec(|x| x / sum)
}
}
impl<T: Float> Function<T> for SoftMax {
fn name(&self) -> String {
String::from("SoftMax")
}
fn call(&self, matrix: Matrix<T>) -> Matrix<T> {
let data: Vec<Vector<T>> = (0..matrix.rows)
.into_par_iter()
.map(|i| self.vec_fun(matrix.get_row(i)))
.collect();
Matrix::from(data)
}
fn derivative(&self, matrix: Matrix<T>) -> Matrix<T> {
let s = self.call(matrix.clone());
let grad_rows: Vec<Vector<T>> = (0..s.rows)
.into_par_iter()
.map(|i| {
let row = s.get_row(i);
row.map_vec(|x| x * (T::one() - x))
})
.collect();
Matrix::from(grad_rows)
}
}
#[cfg(test)]
mod tests {
use crate::activation::{Function, SoftMax};
use crate::linalg::Matrix;
use crate::matrix;
#[test]
fn softmax_call() {
let matrix = matrix![[2.0, 4.0], [1.0, 3.0]];
let a = SoftMax::new();
let matrix = a.call(matrix);
println!("{}", matrix);
}
#[test]
fn der_softmax() {
let matrix = matrix![[2.0, 4.0], [1.0, 3.0]];
let a = SoftMax::new();
let matrix = a.derivative(matrix);
println!("{}", matrix);
}
#[test]
fn softmax() {
let matrix: Matrix<f32> = matrix![[0.9, 0.1, 0.8, 0.2]];
let softmax = SoftMax::new();
println!("{}", softmax.call(matrix.clone()));
println!("{}", softmax.derivative(matrix));
}
}