1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
mod squared_loss;

use serde::{Deserialize, Serialize};

use self::squared_loss::SquaredLoss;

#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub enum LossFunctionType {
    Squared,
}

impl LossFunctionType {
    pub fn create(&self) -> Box<dyn LossFunction> {
        match self {
            LossFunctionType::Squared => Box::new(SquaredLoss::new()),
        }
    }
}

pub trait LossFunction {
    fn get_type(&self) -> LossFunctionType;
    fn get_loss(&self, min_label: f32, max_label: f32, prediction: f32, label: f32) -> f32;
    fn get_update(
        &self,
        prediction: f32,
        label: f32,
        update_scale: f32,
        pred_per_update: f32,
    ) -> f32;
    fn get_unsafe_update(&self, prediction: f32, label: f32, update_scale: f32) -> f32;
    fn get_square_grad(&self, prediction: f32, label: f32) -> f32;
    fn first_derivative(&self, min_label: f32, max_label: f32, prediction: f32, label: f32) -> f32;
    fn second_derivative(&self, min_label: f32, max_label: f32, prediction: f32, label: f32)
        -> f32;
}