ferrite-dl 0.2.0

Deep learning library written in pure Rust
Documentation
use super::loss::*;
use crate::tensor::*;

pub struct MSELoss {
  is_mean_reduction: bool
}

impl MSELoss {
  pub fn new(reduction: &str) -> Self {
    let is_mean_reduction = match reduction {
      "mean" => true,
      "sum" => false,
      _ => panic!("Reduction must be either 'mean' or 'sum'"),
    };

    Self{ is_mean_reduction }
  }
}

impl LossTrait for MSELoss {
  fn loss(&self, x: &Tensor, y: &Tensor) -> Tensor {
    let z_1 = x.sub_tensor(y);
    let z_2 = z_1.pow_f32(2.); 

    if self.is_mean_reduction {
      z_2.mean()
    } else {
      z_2.sum()
    }
  }
}


pub struct MAELoss {
  is_mean_reduction: bool
}

impl MAELoss {
  pub fn new(reduction: &str) -> Self {
    let is_mean_reduction = match reduction {
      "mean" => true,
      "sum" => false,
      _ => panic!("Reduction must be either 'mean' or 'sum'"),
    };

    Self{ is_mean_reduction }
  }
}

impl LossTrait for MAELoss {
  fn loss(&self, x: &Tensor, y: &Tensor) -> Tensor {
    let z_1 = x.sub_tensor(y);
    let z_2 = z_1.abs(); 

    if self.is_mean_reduction {
      z_2.mean()
    } else {
      z_2.sum()
    }
  }
}