yarnn 0.1.0

Yet Another rust Neural Network framework
Documentation
use crate::loss::Loss;
use crate::backend::{Backend, BackendMse};
use crate::tensor::Tensor;

use core::marker::PhantomData;


pub struct MeanSquareErrorLoss<N, B> {
    _m: PhantomData<fn(N, B)>   
}

impl<N, B> MeanSquareErrorLoss<N, B> {
    pub fn new() -> Self {
        Self {
            _m: Default::default()
        }
    }
}

impl<N, B> Loss<N, B> for MeanSquareErrorLoss<N, B> 
    where B: Backend<N> + BackendMse<N>
{
    fn compute(&self, backend: &B, dst: &mut B::Tensor, pred: &B::Tensor, target: &B::Tensor) {
        let batch_size = pred.shape().get(0) as f32;

        backend.scaled_square_diff(dst, target, pred, backend.scalar_f32(0.5 * batch_size));
    }

    fn derivative(&self, backend: &B, dst: &mut B::Tensor, pred: &B::Tensor, target: &B::Tensor) {
        let batch_size = pred.shape().get(0) as f32;

        backend.scaled_diff(dst, pred, target, backend.scalar_f32(1.0 / batch_size));
    }
}