tensorrs 0.3.2

Tensors is a lightweight machine learning library in Rust
Documentation
use crate::linalg::Matrix;
use crate::loss::Loss;
use crate::Float;

pub struct BCE<T: Float>(T);

impl<T: Float> BCE<T> {
    pub fn new(_data_type: T) -> Self {
        Self(_data_type)
    }
}

impl<T: Float> Loss<T> for BCE<T> {
    fn call(&self, output: &Matrix<T>, target: &Matrix<T>) -> T {
        if output.shape() != target.shape() {
            panic!("!!!Size of output matrix and target must be equal!!!\nOutput size:{:?} Target size: {:?}", output.shape(), target.shape())
        }
        let n = output.data.len();
        if n == 0 {
            return T::default();
        }

        let epsilon = T::f32_f64(1e-7, 1e-15);
        let output_clamped = output.max(epsilon).min(T::one() - epsilon);

        let a = target & &output_clamped.ln();
        let b = target.map(|x| T::one() - x) & &output_clamped.map(|z| T::one() - z).ln();
        let loss = -(a + &b);
        loss.sum() / T::from_usize(n)
    }

    fn gradient(&self, output: &Matrix<T>, target: &Matrix<T>) -> Matrix<T> {
        let grads = output.zip_with(target, |y_pred, y_true| {
            y_true / y_pred - (T::one() - y_true) / (T::one() - y_pred)
        });
        let n = grads.data.len();
        grads * (T::one() / T::from_usize(n))
    }
}

#[cfg(test)]
mod tests {
    use crate::linalg::Matrix;
    use crate::loss::{Loss, BCE};
    use crate::{matrix, DataType};

    #[test]
    fn call() {
        let a = matrix![[1.0, 0.0]];
        let b = matrix![[0.5, 0.5]];

        let bce = BCE::new(DataType::f64());
        println!("{}", bce.call(&b, &a));
    }

    #[test]
    fn grad() {
        let a = matrix![[1.0, 0.0]];
        let b = matrix![[0.5, 0.5]];

        let bce = BCE::new(DataType::f64());
        println!("{}", bce.gradient(&b, &a));
    }

    #[test]
    fn help() {
        let tar = matrix![[1.0, 0.0, 1.0, 0.0]];
        let out = matrix![[0.9, 0.1, 0.8, 0.2]];
        let bce = BCE::new(DataType::f32());
        println!("{}", bce.call(&out, &tar));

        println!("{}", bce.gradient(&out, &tar));
    }
}