#[derive(Clone, Copy, Debug, PartialEq)]
pub enum LossKind {
Mse,
Mae,
Huber { delta: f32 },
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum LossError {
Empty,
ShapeMismatch,
NonFinite,
}
pub fn loss_and_gradient(
kind: LossKind,
prediction: &[f32],
target: &[f32],
grad_out: &mut [f32],
) -> Result<f32, LossError> {
if prediction.is_empty() {
return Err(LossError::Empty);
}
if prediction.len() != target.len() || grad_out.len() < prediction.len() {
return Err(LossError::ShapeMismatch);
}
let n = prediction.len() as f32;
let mut loss = 0.0f32;
match kind {
LossKind::Mse => {
let scale = 2.0 / n;
for i in 0..prediction.len() {
let diff = prediction[i] - target[i];
loss += diff * diff;
grad_out[i] = scale * diff;
}
loss /= n;
}
LossKind::Mae => {
let inv_n = 1.0 / n;
for i in 0..prediction.len() {
let diff = prediction[i] - target[i];
loss += diff.abs();
grad_out[i] = if diff > 0.0 {
inv_n
} else if diff < 0.0 {
-inv_n
} else {
0.0
};
}
loss *= inv_n;
}
LossKind::Huber { delta } => {
if !delta.is_finite() || delta <= 0.0 {
return Err(LossError::NonFinite);
}
let inv_n = 1.0 / n;
for i in 0..prediction.len() {
let diff = prediction[i] - target[i];
let ad = diff.abs();
if ad <= delta {
loss += 0.5 * diff * diff;
grad_out[i] = diff * inv_n;
} else {
loss += delta * (ad - 0.5 * delta);
grad_out[i] = if diff > 0.0 {
delta * inv_n
} else {
-delta * inv_n
};
}
}
loss *= inv_n;
}
}
if !loss.is_finite() {
return Err(LossError::NonFinite);
}
Ok(loss)
}