1#[cfg(feature = "classification")]
4mod classifier;
5
6mod regressor;
7
8use debot_db::{ModelParams, SerializableModel};
9use smartcore_proba::ensemble::random_forest_regressor::RandomForestRegressor;
10use smartcore_proba::linalg::basic::arrays::Array2;
11use smartcore_proba::linalg::basic::matrix::DenseMatrix;
12
13#[cfg(feature = "classification")]
14use classifier::ModelWithWeights;
15#[cfg(feature = "classification")]
16use smartcore_proba::ensemble::random_forest_classifier::RandomForestClassifier;
17
18pub struct RandomForest {
20 model: RandomForestModel,
21 #[cfg(feature = "classification")]
22 weights: (f64, f64, f64),
24
25 extra_regressors: [RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>>; 2],
26}
27
28impl RandomForest {
29 #[cfg(feature = "classification")]
31 pub fn weights(&self) -> (f64, f64, f64) {
32 self.weights
33 }
34
35 pub async fn new(key: &str, model_params: &ModelParams) -> Self {
37 #[cfg(feature = "classification")]
38 {
39 let serial: SerializableModel = model_params
40 .load_model(&format!("{}_0", key))
41 .await
42 .expect("Failed to load classifier model");
43
44 let mw: ModelWithWeights = bincode::deserialize(&serial.model)
45 .expect("Failed to deserialize ModelWithWeights");
46
47 let extra_1: SerializableModel = model_params
48 .load_model(&format!("{}_1", key))
49 .await
50 .expect("Failed to load regression model 1");
51 let extra_2: SerializableModel = model_params
52 .load_model(&format!("{}_2", key))
53 .await
54 .expect("Failed to load regression model 2");
55
56 let reg1: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
57 bincode::deserialize(&extra_1.model).unwrap();
58 let reg2: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
59 bincode::deserialize(&extra_2.model).unwrap();
60
61 return RandomForest {
62 model: RandomForestModel::Classifier(mw.model),
63 weights: mw.weights,
64 extra_regressors: [reg1, reg2],
65 };
66 }
67
68 #[cfg(all(feature = "regression", not(feature = "classification")))]
69 {
70 let reg0: SerializableModel = model_params
71 .load_model(&format!("{}_0", key))
72 .await
73 .expect("Failed to load regression model 0");
74 let reg1: SerializableModel = model_params
75 .load_model(&format!("{}_1", key))
76 .await
77 .expect("Failed to load regression model 1");
78 let reg2: SerializableModel = model_params
79 .load_model(&format!("{}_2", key))
80 .await
81 .expect("Failed to load regression model 2");
82
83 let model: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
84 bincode::deserialize(®0.model).unwrap();
85 let extra_1: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
86 bincode::deserialize(®1.model).unwrap();
87 let extra_2: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
88 bincode::deserialize(®2.model).unwrap();
89
90 return RandomForest {
91 model: RandomForestModel::Regressor(model),
92 extra_regressors: [extra_1, extra_2],
93 };
94 }
95
96 #[cfg(not(any(feature = "classification", feature = "regression")))]
97 panic!("Either 'classification' or 'regression' feature must be enabled");
98 }
99
100 pub fn predict(&self, x: &DenseMatrix<f64>) -> Vec<f64> {
102 match &self.model {
103 #[cfg(feature = "classification")]
104 RandomForestModel::Classifier(clf) => clf
105 .predict(x)
106 .unwrap()
107 .into_iter()
108 .map(|v| v as f64)
109 .collect(),
110
111 #[cfg(feature = "regression")]
112 RandomForestModel::Regressor(reg) => reg.predict(x).unwrap(),
113
114 #[allow(unreachable_patterns)]
115 _ => unreachable!(),
116 }
117 }
118
119 #[cfg(feature = "classification")]
121 pub fn predict_proba(&self, x: &DenseMatrix<f64>) -> (DenseMatrix<f64>, f64) {
122 let clf = match &self.model {
123 RandomForestModel::Classifier(clf) => clf,
124 #[cfg(feature = "regression")]
125 RandomForestModel::Regressor(_) => unreachable!("predict_proba on regression model"),
126 };
127 let probs: DenseMatrix<f64> = clf.predict_proba(x).unwrap();
128 log::trace!("Predicted probabilities: {:?}", probs);
129 let row = probs.get_row(0).iterator(0).copied().collect::<Vec<f64>>();
130 let (w0, w1, w2) = self.weights;
131 let exp_score = w0 * row[0] + w1 * row[1] + w2 * row[2];
132 (probs, exp_score)
133 }
134
135 pub fn predict_timeout(&self, x: &DenseMatrix<f64>) -> (f64, f64) {
137 let pred1 = self.extra_regressors[0].predict(x).unwrap();
138 let pred2 = self.extra_regressors[1].predict(x).unwrap();
139 (pred1[0], pred2[0])
140 }
141}
142
143pub enum RandomForestModel {
145 #[cfg(feature = "classification")]
146 Classifier(RandomForestClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>),
147
148 #[cfg(feature = "regression")]
149 Regressor(RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>>),
150}
151
152#[cfg(feature = "classification")]
154pub use classifier::grid_search_and_train_classifier;
155#[cfg(feature = "classification")]
156pub use classifier::Metric::*;
157
158pub use regressor::grid_search_and_train_regressor;