debot_ml/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
use debot_db::ModelParams;
use smartcore::ensemble::random_forest_classifier::RandomForestClassifier;
use smartcore::ensemble::random_forest_regressor::RandomForestRegressor;
use smartcore::linalg::basic::matrix::DenseMatrix;

mod classifier;
mod regression;

pub use classifier::grid_search_and_train_classifier;
pub use regression::grid_search_and_train_regressor;

pub struct RandomForest {
    model_0: RandomForestClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>,
    model_1: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>>,
}

impl RandomForest {
    pub async fn new(key: &str, model_params: &ModelParams) -> Self {
        let serializable_model_0 = model_params
            .load_model(&format!("{}_0", key))
            .await
            .expect("Failed to load model 0");
        let model_0: RandomForestClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>> =
            bincode::deserialize(&serializable_model_0.model).unwrap();

        let serializable_model_1 = model_params
            .load_model(&format!("{}_1", key))
            .await
            .expect("Failed to load model 1");
        let model_1: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
            bincode::deserialize(&serializable_model_1.model).unwrap();

        Self { model_0, model_1 }
    }

    pub fn classify_profitability(&self, x: &DenseMatrix<f64>) -> Vec<i32> {
        let prediction = self.model_0.predict(x).unwrap();
        log::trace!("predicted profitability: {:?}", prediction);
        prediction
    }

    pub fn regress_profit_ratio(&self, x: &DenseMatrix<f64>) -> Vec<f64> {
        let prediction = self.model_1.predict(x).unwrap();
        log::trace!("predicted profit_ratio: {:?}", prediction);
        prediction
    }
}