reductionml_core/
loss_function.rs

1mod squared_loss;
2
3use serde::{Deserialize, Serialize};
4
5use self::squared_loss::SquaredLoss;
6
7#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
8pub enum LossFunctionType {
9    Squared,
10}
11
12impl LossFunctionType {
13    pub fn create(&self) -> Box<dyn LossFunction> {
14        match self {
15            LossFunctionType::Squared => Box::new(SquaredLoss::new()),
16        }
17    }
18}
19
20pub trait LossFunction: Send {
21    fn get_type(&self) -> LossFunctionType;
22    fn get_loss(&self, min_label: f32, max_label: f32, prediction: f32, label: f32) -> f32;
23    fn get_update(
24        &self,
25        prediction: f32,
26        label: f32,
27        update_scale: f32,
28        pred_per_update: f32,
29    ) -> f32;
30    fn get_unsafe_update(&self, prediction: f32, label: f32, update_scale: f32) -> f32;
31    fn get_square_grad(&self, prediction: f32, label: f32) -> f32;
32    fn first_derivative(&self, min_label: f32, max_label: f32, prediction: f32, label: f32) -> f32;
33    fn second_derivative(&self, min_label: f32, max_label: f32, prediction: f32, label: f32)
34        -> f32;
35}