pub trait LossFunction {
fn forward(&self, predicted: f32, target: f32) -> f32;
fn backward(&self, predicted: f32, target: f32) -> f32;
}
#[derive(Clone)]
pub struct MSELoss;
impl LossFunction for MSELoss {
fn forward(&self, predicted: f32, target: f32) -> f32 {
0.5 * (predicted - target).powi(2)
}
fn backward(&self, predicted: f32, target: f32) -> f32 {
predicted - target
}
}
#[derive(Clone)]
pub struct MAELoss;
impl LossFunction for MAELoss {
fn forward(&self, predicted: f32, target: f32) -> f32 {
(predicted - target).abs()
}
fn backward(&self, predicted: f32, target: f32) -> f32 {
if predicted > target {
1.0
} else if predicted < target {
-1.0
} else {
0.0
}
}
}
#[derive(Clone)]
pub struct HuberLoss {
delta: f32,
}
impl HuberLoss {
pub fn new(delta: f32) -> Self {
Self { delta }
}
}
impl LossFunction for HuberLoss {
fn forward(&self, predicted: f32, target: f32) -> f32 {
let diff = predicted - target;
let abs_diff = diff.abs();
if abs_diff <= self.delta {
0.5 * diff.powi(2)
} else {
self.delta * abs_diff - 0.5 * self.delta.powi(2)
}
}
fn backward(&self, predicted: f32, target: f32) -> f32 {
let diff = predicted - target;
let abs_diff = diff.abs();
if abs_diff <= self.delta {
diff
} else if diff > 0.0 {
self.delta
} else {
-self.delta
}
}
}
#[derive(Clone)]
pub struct BCELoss;
impl LossFunction for BCELoss {
fn forward(&self, predicted: f32, target: f32) -> f32 {
let eps = 1e-7; let p = predicted.clamp(eps, 1.0 - eps);
-(target * p.ln() + (1.0 - target) * (1.0 - p).ln())
}
fn backward(&self, predicted: f32, target: f32) -> f32 {
let eps = 1e-7;
let p = predicted.clamp(eps, 1.0 - eps);
(p - target) / (p * (1.0 - p))
}
}