debot_ml/
lib.rs

1// src/lib.rs
2
3#[cfg(feature = "classification")]
4mod classifier;
5
6mod regressor;
7
8use debot_db::{ModelParams, SerializableModel};
9use smartcore_proba::ensemble::random_forest_regressor::RandomForestRegressor;
10use smartcore_proba::linalg::basic::arrays::Array2;
11use smartcore_proba::linalg::basic::matrix::DenseMatrix;
12
13#[cfg(feature = "classification")]
14use classifier::ModelWithWeights;
15#[cfg(feature = "classification")]
16use smartcore_proba::ensemble::random_forest_classifier::RandomForestClassifier;
17
18/// Unified RandomForest wrapper supporting classification and regression
19pub struct RandomForest {
20    model: RandomForestModel,
21    #[cfg(feature = "classification")]
22    /// Weights for expected-score calculation (loss, expired-profit, take-profit)
23    weights: (f64, f64, f64),
24
25    extra_regressors: [RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>>; 2],
26}
27
28impl RandomForest {
29    /// Access classification weights
30    #[cfg(feature = "classification")]
31    pub fn weights(&self) -> (f64, f64, f64) {
32        self.weights
33    }
34
35    /// Load model from storage (classification or regression)
36    pub async fn new(key: &str, model_params: &ModelParams) -> Self {
37        #[cfg(feature = "classification")]
38        {
39            let serial: SerializableModel = model_params
40                .load_model(&format!("{}_0", key))
41                .await
42                .expect("Failed to load classifier model");
43
44            let mw: ModelWithWeights = bincode::deserialize(&serial.model)
45                .expect("Failed to deserialize ModelWithWeights");
46
47            let extra_1: SerializableModel = model_params
48                .load_model(&format!("{}_1", key))
49                .await
50                .expect("Failed to load regression model 1");
51            let extra_2: SerializableModel = model_params
52                .load_model(&format!("{}_2", key))
53                .await
54                .expect("Failed to load regression model 2");
55
56            let reg1: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
57                bincode::deserialize(&extra_1.model).unwrap();
58            let reg2: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
59                bincode::deserialize(&extra_2.model).unwrap();
60
61            return RandomForest {
62                model: RandomForestModel::Classifier(mw.model),
63                weights: mw.weights,
64                extra_regressors: [reg1, reg2],
65            };
66        }
67
68        #[cfg(all(feature = "regression", not(feature = "classification")))]
69        {
70            let reg0: SerializableModel = model_params
71                .load_model(&format!("{}_0", key))
72                .await
73                .expect("Failed to load regression model 0");
74            let reg1: SerializableModel = model_params
75                .load_model(&format!("{}_1", key))
76                .await
77                .expect("Failed to load regression model 1");
78            let reg2: SerializableModel = model_params
79                .load_model(&format!("{}_2", key))
80                .await
81                .expect("Failed to load regression model 2");
82
83            let model: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
84                bincode::deserialize(&reg0.model).unwrap();
85            let extra_1: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
86                bincode::deserialize(&reg1.model).unwrap();
87            let extra_2: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
88                bincode::deserialize(&reg2.model).unwrap();
89
90            return RandomForest {
91                model: RandomForestModel::Regressor(model),
92                extra_regressors: [extra_1, extra_2],
93            };
94        }
95
96        #[cfg(not(any(feature = "classification", feature = "regression")))]
97        panic!("Either 'classification' or 'regression' feature must be enabled");
98    }
99
100    /// Predict raw outputs: class labels (as f64) or continuous values
101    pub fn predict(&self, x: &DenseMatrix<f64>) -> Vec<f64> {
102        match &self.model {
103            #[cfg(feature = "classification")]
104            RandomForestModel::Classifier(clf) => clf
105                .predict(x)
106                .unwrap()
107                .into_iter()
108                .map(|v| v as f64)
109                .collect(),
110
111            #[cfg(feature = "regression")]
112            RandomForestModel::Regressor(reg) => reg.predict(x).unwrap(),
113
114            #[allow(unreachable_patterns)]
115            _ => unreachable!(),
116        }
117    }
118
119    /// For classification: predict probabilities and compute expected score
120    #[cfg(feature = "classification")]
121    pub fn predict_proba(&self, x: &DenseMatrix<f64>) -> (DenseMatrix<f64>, f64) {
122        let clf = match &self.model {
123            RandomForestModel::Classifier(clf) => clf,
124            #[cfg(feature = "regression")]
125            RandomForestModel::Regressor(_) => unreachable!("predict_proba on regression model"),
126        };
127        let probs: DenseMatrix<f64> = clf.predict_proba(x).unwrap();
128        log::trace!("Predicted probabilities: {:?}", probs);
129        let row = probs.get_row(0).iterator(0).copied().collect::<Vec<f64>>();
130        let (w0, w1, w2) = self.weights;
131        let exp_score = w0 * row[0] + w1 * row[1] + w2 * row[2];
132        (probs, exp_score)
133    }
134
135    /// Predict timeout durations from extra regressors
136    pub fn predict_timeout(&self, x: &DenseMatrix<f64>) -> (f64, f64) {
137        let pred1 = self.extra_regressors[0].predict(x).unwrap();
138        let pred2 = self.extra_regressors[1].predict(x).unwrap();
139        (pred1[0], pred2[0])
140    }
141}
142
143/// Enum of underlying model variants
144pub enum RandomForestModel {
145    #[cfg(feature = "classification")]
146    Classifier(RandomForestClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>),
147
148    #[cfg(feature = "regression")]
149    Regressor(RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>>),
150}
151
152// Re-export training functions
153#[cfg(feature = "classification")]
154pub use classifier::grid_search_and_train_classifier;
155#[cfg(feature = "classification")]
156pub use classifier::Metric::*;
157
158pub use regressor::grid_search_and_train_regressor;