use crate::core::{error::BellandeError, tensor::Tensor};
use crate::loss::bce::Reduction;
pub trait CustomLossFunction {
fn compute(&self, prediction: &Tensor, target: &Tensor) -> Result<Tensor, BellandeError>;
}
pub struct CustomLoss {
loss_fn: Box<dyn CustomLossFunction>,
reduction: Reduction,
}
impl CustomLoss {
pub fn new(loss_fn: Box<dyn CustomLossFunction>, reduction: Reduction) -> Self {
CustomLoss { loss_fn, reduction }
}
pub fn forward(&self, prediction: &Tensor, target: &Tensor) -> Result<Tensor, BellandeError> {
let loss = self.loss_fn.compute(prediction, target)?;
match self.reduction {
Reduction::None => Ok(loss),
Reduction::Mean => {
let mean = loss.data.iter().sum::<f32>() / loss.data.len() as f32;
Ok(Tensor::new(
vec![mean],
vec![1],
true,
loss.device,
loss.dtype,
))
}
Reduction::Sum => {
let sum = loss.data.iter().sum::<f32>();
Ok(Tensor::new(
vec![sum],
vec![1],
true,
loss.device,
loss.dtype,
))
}
}
}
}