#![warn(missing_docs)]
use std::{error::Error, str::FromStr};
use crate::{Tensor, TensorElement};
pub trait Loss<'a, T>
where
T: TensorElement,
<T as FromStr>::Err: Error,
{
fn loss(y: &Tensor<'a, T>, y_hat: &Tensor<'a, T>) -> T;
fn backward();
}
pub fn BCEntroypyLoss<'a, T>(y: &Tensor<'a, T>, y_hat: &Tensor<'a, T>) -> T
where
T: TensorElement,
<T as FromStr>::Err: Error,
{
unimplemented!()
}
pub fn CEntroypyLoss<'a, T>(y: &Tensor<'a, T>, y_hat: &Tensor<'a, T>) -> T
where
T: TensorElement,
<T as FromStr>::Err: Error,
{
unimplemented!()
}
pub fn L1Loss<'a, T>(y: &Tensor<'a, T>, y_hat: &Tensor<'a, T>) -> T
where
T: TensorElement,
<T as FromStr>::Err: Error,
{
y.sub_abs(y_hat).unwrap().cumsum()
}
pub fn L2Loss<'a, T>(y: &Tensor<'a, T>, y_hat: &Tensor<'a, T>) -> T
where
T: TensorElement,
<T as FromStr>::Err: Error,
{
y.sub(y_hat).unwrap().pow(2).cumsum()
}