reductionml_core/
loss_function.rs1mod squared_loss;
2
3use serde::{Deserialize, Serialize};
4
5use self::squared_loss::SquaredLoss;
6
7#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
8pub enum LossFunctionType {
9 Squared,
10}
11
12impl LossFunctionType {
13 pub fn create(&self) -> Box<dyn LossFunction> {
14 match self {
15 LossFunctionType::Squared => Box::new(SquaredLoss::new()),
16 }
17 }
18}
19
20pub trait LossFunction: Send {
21 fn get_type(&self) -> LossFunctionType;
22 fn get_loss(&self, min_label: f32, max_label: f32, prediction: f32, label: f32) -> f32;
23 fn get_update(
24 &self,
25 prediction: f32,
26 label: f32,
27 update_scale: f32,
28 pred_per_update: f32,
29 ) -> f32;
30 fn get_unsafe_update(&self, prediction: f32, label: f32, update_scale: f32) -> f32;
31 fn get_square_grad(&self, prediction: f32, label: f32) -> f32;
32 fn first_derivative(&self, min_label: f32, max_label: f32, prediction: f32, label: f32) -> f32;
33 fn second_derivative(&self, min_label: f32, max_label: f32, prediction: f32, label: f32)
34 -> f32;
35}