use crate::core::{error::BellandeError, tensor::Tensor};
use crate::loss::{bce::Reduction, Loss};
pub struct MSELoss {
reduction: Reduction,
}
impl MSELoss {
pub fn new(reduction: Reduction) -> Self {
MSELoss { reduction }
}
}
impl Loss for MSELoss {
fn forward(&self, prediction: &Tensor, target: &Tensor) -> Result<Tensor, BellandeError> {
if prediction.shape != target.shape {
return Err(BellandeError::DimensionMismatch);
}
let mut loss = Vec::with_capacity(prediction.data.len());
for (pred, tgt) in prediction.data.iter().zip(target.data.iter()) {
loss.push((pred - tgt).powi(2));
}
match self.reduction {
Reduction::None => Ok(Tensor::new(
loss,
prediction.shape.clone(),
true,
prediction.device.clone(),
prediction.dtype,
)),
Reduction::Mean => Ok(Tensor::new(
vec![loss.iter().sum::<f32>() / loss.len() as f32],
vec![1],
true,
prediction.device.clone(),
prediction.dtype,
)),
Reduction::Sum => Ok(Tensor::new(
vec![loss.iter().sum()],
vec![1],
true,
prediction.device.clone(),
prediction.dtype,
)),
}
}
fn backward(&self, prediction: &Tensor, target: &Tensor) -> Result<Tensor, BellandeError> {
if prediction.shape != target.shape {
return Err(BellandeError::DimensionMismatch);
}
let mut grad = Vec::with_capacity(prediction.data.len());
for (pred, tgt) in prediction.data.iter().zip(target.data.iter()) {
grad.push(2.0 * (pred - tgt));
}
let grad = match self.reduction {
Reduction::None => grad,
Reduction::Mean => {
let scale = 1.0 / prediction.data.len() as f32;
grad.iter().map(|&g| g * scale).collect()
}
Reduction::Sum => grad,
};
Ok(Tensor::new(
grad,
prediction.shape.clone(),
true,
prediction.device.clone(),
prediction.dtype,
))
}
}
unsafe impl Send for MSELoss {}
unsafe impl Sync for MSELoss {}