debot_ml 3.0.7

ML prediction
Documentation
// src/lib.rs

#[cfg(feature = "classification")]
mod classifier;

mod regressor;

use debot_db::{ModelParams, SerializableModel};
use smartcore_proba::ensemble::random_forest_regressor::RandomForestRegressor;
use smartcore_proba::linalg::basic::arrays::Array2;
use smartcore_proba::linalg::basic::matrix::DenseMatrix;

#[cfg(feature = "classification")]
use classifier::ModelWithWeights;
#[cfg(feature = "classification")]
use smartcore_proba::ensemble::random_forest_classifier::RandomForestClassifier;

/// Unified RandomForest wrapper supporting classification and regression
pub struct RandomForest {
    model: RandomForestModel,
    #[cfg(feature = "classification")]
    /// Weights for expected-score calculation (loss, expired-profit, take-profit)
    weights: (f64, f64, f64),

    extra_regressors: [RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>>; 2],
}

impl RandomForest {
    /// Access classification weights
    #[cfg(feature = "classification")]
    pub fn weights(&self) -> (f64, f64, f64) {
        self.weights
    }

    /// Load model from storage (classification or regression)
    pub async fn new(key: &str, model_params: &ModelParams) -> Self {
        #[cfg(feature = "classification")]
        {
            let serial: SerializableModel = model_params
                .load_model(&format!("{}_0", key))
                .await
                .expect("Failed to load classifier model");

            let mw: ModelWithWeights = bincode::deserialize(&serial.model)
                .expect("Failed to deserialize ModelWithWeights");

            let extra_1: SerializableModel = model_params
                .load_model(&format!("{}_1", key))
                .await
                .expect("Failed to load regression model 1");
            let extra_2: SerializableModel = model_params
                .load_model(&format!("{}_2", key))
                .await
                .expect("Failed to load regression model 2");

            let reg1: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
                bincode::deserialize(&extra_1.model).unwrap();
            let reg2: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
                bincode::deserialize(&extra_2.model).unwrap();

            return RandomForest {
                model: RandomForestModel::Classifier(mw.model),
                weights: mw.weights,
                extra_regressors: [reg1, reg2],
            };
        }

        #[cfg(all(feature = "regression", not(feature = "classification")))]
        {
            let reg0: SerializableModel = model_params
                .load_model(&format!("{}_0", key))
                .await
                .expect("Failed to load regression model 0");
            let reg1: SerializableModel = model_params
                .load_model(&format!("{}_1", key))
                .await
                .expect("Failed to load regression model 1");
            let reg2: SerializableModel = model_params
                .load_model(&format!("{}_2", key))
                .await
                .expect("Failed to load regression model 2");

            let model: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
                bincode::deserialize(&reg0.model).unwrap();
            let extra_1: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
                bincode::deserialize(&reg1.model).unwrap();
            let extra_2: RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>> =
                bincode::deserialize(&reg2.model).unwrap();

            return RandomForest {
                model: RandomForestModel::Regressor(model),
                extra_regressors: [extra_1, extra_2],
            };
        }

        #[cfg(not(any(feature = "classification", feature = "regression")))]
        panic!("Either 'classification' or 'regression' feature must be enabled");
    }

    /// Predict raw outputs: class labels (as f64) or continuous values
    pub fn predict(&self, x: &DenseMatrix<f64>) -> Vec<f64> {
        match &self.model {
            #[cfg(feature = "classification")]
            RandomForestModel::Classifier(clf) => clf
                .predict(x)
                .unwrap()
                .into_iter()
                .map(|v| v as f64)
                .collect(),

            #[cfg(feature = "regression")]
            RandomForestModel::Regressor(reg) => reg.predict(x).unwrap(),

            #[allow(unreachable_patterns)]
            _ => unreachable!(),
        }
    }

    /// For classification: predict probabilities and compute expected score
    #[cfg(feature = "classification")]
    pub fn predict_proba(&self, x: &DenseMatrix<f64>) -> (DenseMatrix<f64>, f64) {
        let clf = match &self.model {
            RandomForestModel::Classifier(clf) => clf,
            #[cfg(feature = "regression")]
            RandomForestModel::Regressor(_) => unreachable!("predict_proba on regression model"),
        };
        let probs: DenseMatrix<f64> = clf.predict_proba(x).unwrap();
        log::trace!("Predicted probabilities: {:?}", probs);
        let row = probs.get_row(0).iterator(0).copied().collect::<Vec<f64>>();
        let (w0, w1, w2) = self.weights;
        let exp_score = w0 * row[0] + w1 * row[1] + w2 * row[2];
        (probs, exp_score)
    }

    /// Predict timeout durations from extra regressors
    pub fn predict_timeout(&self, x: &DenseMatrix<f64>) -> (f64, f64) {
        let pred1 = self.extra_regressors[0].predict(x).unwrap();
        let pred2 = self.extra_regressors[1].predict(x).unwrap();
        (pred1[0], pred2[0])
    }
}

/// Enum of underlying model variants
pub enum RandomForestModel {
    #[cfg(feature = "classification")]
    Classifier(RandomForestClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>),

    #[cfg(feature = "regression")]
    Regressor(RandomForestRegressor<f64, f64, DenseMatrix<f64>, Vec<f64>>),
}

// Re-export training functions
#[cfg(feature = "classification")]
pub use classifier::grid_search_and_train_classifier;
#[cfg(feature = "classification")]
pub use classifier::Metric::*;

pub use regressor::grid_search_and_train_regressor;