#[cfg(feature = "classification")]
mod classifier;
mod regressor;
use debot_db::{ModelParams, SerializableModel};
use smartcore_proba::ensemble::random_forest_regressor::RandomForestRegressor;
use smartcore_proba::linalg::basic::arrays::Array2;
use smartcore_proba::linalg::basic::matrix::DenseMatrix;
#[cfg(feature = "classification")]
use classifier::ModelWithWeights;
#[cfg(feature = "classification")]
use smartcore_proba::ensemble::random_forest_classifier::RandomForestClassifier;
pub struct RandomForest {
model: RandomForestModel,
#[cfg(feature = "classification")]
weights: (f64, f64, f64),
extra_regressors: [RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>>; 2],
}
impl RandomForest {
#[cfg(feature = "classification")]
pub fn weights(&self) -> (f64, f64, f64) {
self.weights
}
pub async fn new(key: &str, model_params: &ModelParams) -> Self {
#[cfg(feature = "classification")]
{
let serial: SerializableModel = model_params
.load_model(&format!("{}_0", key))
.await
.expect("Failed to load classifier model");
let mw: ModelWithWeights = bincode::deserialize(&serial.model)
.expect("Failed to deserialize ModelWithWeights");
let extra_1: SerializableModel = model_params
.load_model(&format!("{}_1", key))
.await
.expect("Failed to load regression model 1");
let extra_2: SerializableModel = model_params
.load_model(&format!("{}_2", key))
.await
.expect("Failed to load regression model 2");
let reg1: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
bincode::deserialize(&extra_1.model).unwrap();
let reg2: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
bincode::deserialize(&extra_2.model).unwrap();
return RandomForest {
model: RandomForestModel::Classifier(mw.model),
weights: mw.weights,
extra_regressors: [reg1, reg2],
};
}
#[cfg(all(feature = "regression", not(feature = "classification")))]
{
let reg0: SerializableModel = model_params
.load_model(&format!("{}_0", key))
.await
.expect("Failed to load regression model 0");
let reg1: SerializableModel = model_params
.load_model(&format!("{}_1", key))
.await
.expect("Failed to load regression model 1");
let reg2: SerializableModel = model_params
.load_model(&format!("{}_2", key))
.await
.expect("Failed to load regression model 2");
let model: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
bincode::deserialize(®0.model).unwrap();
let extra_1: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
bincode::deserialize(®1.model).unwrap();
let extra_2: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
bincode::deserialize(®2.model).unwrap();
return RandomForest {
model: RandomForestModel::Regressor(model),
extra_regressors: [extra_1, extra_2],
};
}
#[cfg(not(any(feature = "classification", feature = "regression")))]
panic!("Either 'classification' or 'regression' feature must be enabled");
}
pub fn predict(&self, x: &DenseMatrix<f64>) -> Vec<f64> {
match &self.model {
#[cfg(feature = "classification")]
RandomForestModel::Classifier(clf) => clf
.predict(x)
.unwrap()
.into_iter()
.map(|v| v as f64)
.collect(),
#[cfg(feature = "regression")]
RandomForestModel::Regressor(reg) => reg.predict(x).unwrap(),
#[allow(unreachable_patterns)]
_ => unreachable!(),
}
}
#[cfg(feature = "classification")]
pub fn predict_proba(&self, x: &DenseMatrix<f64>) -> (DenseMatrix<f64>, f64) {
let clf = match &self.model {
RandomForestModel::Classifier(clf) => clf,
#[cfg(feature = "regression")]
RandomForestModel::Regressor(_) => unreachable!("predict_proba on regression model"),
};
let probs: DenseMatrix<f64> = clf.predict_proba(x).unwrap();
log::trace!("Predicted probabilities: {:?}", probs);
let row = probs.get_row(0).iterator(0).copied().collect::<Vec<f64>>();
let (w0, w1, w2) = self.weights;
let exp_score = w0 * row[0] + w1 * row[1] + w2 * row[2];
(probs, exp_score)
}
pub fn predict_timeout(&self, x: &DenseMatrix<f64>) -> (f64, f64) {
let pred1 = self.extra_regressors[0].predict(x).unwrap();
let pred2 = self.extra_regressors[1].predict(x).unwrap();
(pred1[0], pred2[0])
}
}
pub enum RandomForestModel {
#[cfg(feature = "classification")]
Classifier(RandomForestClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>),
#[cfg(feature = "regression")]
Regressor(RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>>),
}
#[cfg(feature = "classification")]
pub use classifier::grid_search_and_train_classifier;
#[cfg(feature = "classification")]
pub use classifier::Metric::*;
pub use regressor::grid_search_and_train_regressor;