radiate_extensions/problems/
regression.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
use super::data_set::Row;
use super::{data_set::DataSet, error_functions::ErrorFunction};
use num_traits::cast::FromPrimitive;
use num_traits::float::Float;
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};

pub struct Regression<T> {
    pub sample_set: DataSet<T>,
    pub loss_function: ErrorFunction,
}

impl<T> Regression<T> {
    pub fn new(sample_set: DataSet<T>, loss_function: ErrorFunction) -> Self {
        Regression {
            sample_set,
            loss_function,
        }
    }

    pub fn from(loss_function: ErrorFunction, samples: Vec<(Vec<T>, Vec<T>)>) -> Self {
        let mut sample_set = DataSet::new();
        for (input, output) in samples {
            sample_set.add_sample(input, output);
        }
        Regression {
            sample_set,
            loss_function,
        }
    }

    pub fn error<F>(&self, mut error_fn: F) -> T
    where
        T: Clone
            + PartialEq
            + Default
            + Add<Output = T>
            + Div<Output = T>
            + Sub<Output = T>
            + Mul<Output = T>
            + Div<Output = T>
            + AddAssign
            + SubAssign
            + DivAssign
            + MulAssign
            + Float
            + FromPrimitive,
        F: FnMut(&Vec<T>) -> Vec<T>,
    {
        self.loss_function
            .calculate(&self.sample_set, &mut error_fn)
    }

    pub fn get_samples(&self) -> &[Row<T>] {
        self.sample_set.get_samples()
    }

    pub fn get_loss_function(&self) -> &ErrorFunction {
        &self.loss_function
    }
}