use debot_db::ModelParams;
use smartcore::ensemble::random_forest_classifier::RandomForestClassifier;
use smartcore::ensemble::random_forest_regressor::RandomForestRegressor;
use smartcore::linalg::basic::matrix::DenseMatrix;
mod classifier;
mod regression;
pub use classifier::grid_search_and_train_classifier;
pub use regression::grid_search_and_train_regressor;
pub struct RandomForest {
model_0: RandomForestClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>,
model_1: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>>,
}
impl RandomForest {
pub async fn new(key: &str, model_params: &ModelParams) -> Self {
let serializable_model_0 = model_params
.load_model(&format!("{}_0", key))
.await
.expect("Failed to load model 0");
let model_0: RandomForestClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>> =
bincode::deserialize(&serializable_model_0.model).unwrap();
let serializable_model_1 = model_params
.load_model(&format!("{}_1", key))
.await
.expect("Failed to load model 1");
let model_1: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
bincode::deserialize(&serializable_model_1.model).unwrap();
Self { model_0, model_1 }
}
pub fn classify_profitability(&self, x: &DenseMatrix<f64>) -> Vec<i32> {
let prediction = self.model_0.predict(x).unwrap();
log::trace!("predicted profitability: {:?}", prediction);
prediction
}
pub fn regress_profit_ratio(&self, x: &DenseMatrix<f64>) -> Vec<f64> {
let prediction = self.model_1.predict(x).unwrap();
log::trace!("predicted profit_ratio: {:?}", prediction);
prediction
}
}