pub trait LossFunction: Send + Sync {
fn loss(&self, target: f32, prediction: f32) -> f32;
fn gradient(&self, target: f32, prediction: f32) -> f32;
fn hessian(&self, target: f32, prediction: f32) -> f32;
#[inline]
fn gradient_hessian(&self, target: f32, prediction: f32) -> (f32, f32) {
(
self.gradient(target, prediction),
self.hessian(target, prediction),
)
}
fn compute_gradients(
&self,
targets: &[f32],
predictions: &[f32],
gradients: &mut [f32],
hessians: &mut [f32],
) {
debug_assert_eq!(targets.len(), predictions.len());
debug_assert_eq!(targets.len(), gradients.len());
debug_assert_eq!(targets.len(), hessians.len());
for i in 0..targets.len() {
let (g, h) = self.gradient_hessian(targets[i], predictions[i]);
gradients[i] = g;
hessians[i] = h;
}
}
fn initial_prediction(&self, targets: &[f32]) -> f32 {
if targets.is_empty() {
return 0.0;
}
targets.iter().sum::<f32>() / targets.len() as f32
}
fn name(&self) -> &'static str;
}