use crate::core::{error::BellandeError, tensor::Tensor};
pub mod bce;
pub mod cross_entropy;
pub mod custom;
pub mod mse;
pub trait Loss: Send + Sync {
fn forward(&self, output: &Tensor, target: &Tensor) -> Result<Tensor, BellandeError>;
fn backward(&self, output: &Tensor, target: &Tensor) -> Result<Tensor, BellandeError>;
fn name(&self) -> &str {
"GenericLoss"
}
fn reduction(&self) -> Reduction {
Reduction::Mean
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Reduction {
None,
Mean,
Sum,
}
pub trait StaticLoss: Loss + 'static {}
impl<T: Loss + 'static> StaticLoss for T {}
pub trait LossInit: Loss {
fn new() -> Self;
fn new_with_reduction(reduction: Reduction) -> Self;
}
pub trait WeightedLoss: Loss {
fn forward_weighted(
&self,
output: &Tensor,
target: &Tensor,
weights: &Tensor,
) -> Result<Tensor, BellandeError>;
fn backward_weighted(
&self,
output: &Tensor,
target: &Tensor,
weights: &Tensor,
) -> Result<Tensor, BellandeError>;
}
pub trait ClassWeightedLoss: Loss {
fn set_class_weights(&mut self, weights: Tensor) -> Result<(), BellandeError>;
fn get_class_weights(&self) -> Option<&Tensor>;
}
pub mod utils {
use super::*;
pub fn validate_shapes(output: &Tensor, target: &Tensor) -> Result<(), BellandeError> {
if output.shape != target.shape {
return Err(BellandeError::ShapeMismatch(format!(
"Output shape {:?} doesn't match target shape {:?}",
output.shape, target.shape
)));
}
Ok(())
}
pub fn apply_reduction(loss: &Tensor, reduction: Reduction) -> Result<Tensor, BellandeError> {
let result = match reduction {
Reduction::None => Ok(loss.clone()),
Reduction::Mean => {
let sum: f32 = loss.data.iter().sum();
let mean = sum / (loss.data.len() as f32);
Ok(Tensor::new(
vec![mean],
vec![1],
loss.requires_grad,
loss.device.clone(),
loss.dtype,
))
}
Reduction::Sum => {
let sum: f32 = loss.data.iter().sum();
Ok(Tensor::new(
vec![sum],
vec![1],
loss.requires_grad,
loss.device.clone(),
loss.dtype,
))
}
};
result
}
pub fn compute_elementwise_loss(
output: &Tensor,
target: &Tensor,
op: impl Fn(f32, f32) -> f32,
) -> Result<Tensor, BellandeError> {
validate_shapes(output, target)?;
let loss_data: Vec<f32> = output
.data
.iter()
.zip(target.data.iter())
.map(|(&o, &t)| op(o, t))
.collect();
Ok(Tensor::new(
loss_data,
output.shape.clone(),
output.requires_grad,
output.device.clone(),
output.dtype,
))
}
pub fn apply_weights(loss: &mut Tensor, weights: &Tensor) -> Result<(), BellandeError> {
if loss.shape != weights.shape {
return Err(BellandeError::ShapeMismatch(
"Weights shape doesn't match loss shape".into(),
));
}
for (l, &w) in loss.data.iter_mut().zip(weights.data.iter()) {
*l *= w;
}
Ok(())
}
}