use num_traits::Float;
pub trait LossFunction<T> {
type Error;
type Input;
type Output;
fn mse_loss(
&self,
predictions: Self::Input,
targets: Self::Input,
) -> Result<Self::Output, Self::Error>;
fn mae_loss(
&self,
predictions: Self::Input,
targets: Self::Input,
) -> Result<Self::Output, Self::Error>;
fn bce_loss(
&self,
predictions: Self::Input,
targets: Self::Input,
) -> Result<Self::Output, Self::Error>;
fn cross_entropy_loss(
&self,
predictions: Self::Input,
targets: Self::Input,
) -> Result<Self::Output, Self::Error>;
}
pub mod shared_losses {
use super::*;
pub fn mse_loss_vec<T: Float>(predictions: &[T], targets: &[T]) -> Result<T, &'static str> {
if predictions.len() != targets.len() {
return Err("Predictions and targets must have the same length");
}
if predictions.is_empty() {
return Ok(T::zero());
}
let sum_squared_error = predictions
.iter()
.zip(targets.iter())
.map(|(&pred, &target)| {
let diff = pred - target;
diff * diff
})
.fold(T::zero(), |acc, x| acc + x);
Ok(sum_squared_error / T::from(predictions.len()).unwrap_or(T::one()))
}
pub fn mae_loss_vec<T: Float>(predictions: &[T], targets: &[T]) -> Result<T, &'static str> {
if predictions.len() != targets.len() {
return Err("Predictions and targets must have the same length");
}
if predictions.is_empty() {
return Ok(T::zero());
}
let sum_abs_error = predictions
.iter()
.zip(targets.iter())
.map(|(&pred, &target)| (pred - target).abs())
.fold(T::zero(), |acc, x| acc + x);
Ok(sum_abs_error / T::from(predictions.len()).unwrap_or(T::one()))
}
pub fn bce_loss_vec<T: Float>(predictions: &[T], targets: &[T]) -> Result<T, &'static str> {
if predictions.len() != targets.len() {
return Err("Predictions and targets must have the same length");
}
if predictions.is_empty() {
return Ok(T::zero());
}
let eps = T::from(1e-7).unwrap_or(T::zero());
let sum_loss = predictions
.iter()
.zip(targets.iter())
.map(|(&pred, &target)| {
let clamped_pred = pred.max(eps).min(T::one() - eps);
-(target * clamped_pred.ln() + (T::one() - target) * (T::one() - clamped_pred).ln())
})
.fold(T::zero(), |acc, x| acc + x);
Ok(sum_loss / T::from(predictions.len()).unwrap_or(T::one()))
}
pub fn cross_entropy_loss_vec<T: Float>(
log_probs: &[T],
targets: &[T],
) -> Result<T, &'static str> {
if log_probs.len() != targets.len() {
return Err("Log probabilities and targets must have the same length");
}
if log_probs.is_empty() {
return Ok(T::zero());
}
let sum_loss = log_probs
.iter()
.zip(targets.iter())
.map(|(&log_prob, &target)| -target * log_prob)
.fold(T::zero(), |acc, x| acc + x);
Ok(sum_loss / T::from(log_probs.len()).unwrap_or(T::one()))
}
}