use debot_db::{ModelParams, SerializableModel};
use rand::seq::SliceRandom;
use rand::thread_rng;
use rayon::prelude::*;
use smartcore_proba::ensemble::random_forest_regressor::{
RandomForestRegressor, RandomForestRegressorParameters,
};
use smartcore_proba::linalg::basic::arrays::{Array, Array2};
use smartcore_proba::linalg::basic::matrix::DenseMatrix;
use smartcore_proba::metrics::mean_squared_error;
use std::cmp;
fn balance_for_regression(
x: &DenseMatrix<f64>,
y_class: &[i32],
y_reg: &[f64],
max_per_class: Option<usize>,
) -> (DenseMatrix<f64>, Vec<f64>) {
use rand::seq::SliceRandom;
use rand::thread_rng;
use std::collections::HashMap;
let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
let mut rng = thread_rng();
for (i, &label) in y_class.iter().enumerate() {
class_indices.entry(label).or_insert_with(Vec::new).push(i);
}
let inferred_max = class_indices.values().map(Vec::len).max().unwrap_or(0);
let max_class_size = max_per_class.unwrap_or(inferred_max).min(inferred_max);
let mut rows = Vec::new();
let mut y_bal = Vec::new();
for (&_label, indices) in &class_indices {
let mut idxs = indices.clone();
idxs.shuffle(&mut rng);
while idxs.len() < max_class_size {
idxs.push(*indices.choose(&mut rng).unwrap());
}
idxs.truncate(max_class_size);
for &i in &idxs {
rows.push(x.get_row(i).iterator(0).copied().collect::<Vec<f64>>());
y_bal.push(y_reg[i]);
}
}
(
DenseMatrix::from_2d_vec(&rows).expect("Matrix reconstruction failed"),
y_bal,
)
}
pub async fn grid_search_and_train_regressor(
key: &str,
model_params: &ModelParams,
x: DenseMatrix<f64>,
y_reg: Vec<f64>,
y_class: Vec<i32>,
k: usize,
max_per_class: Option<usize>,
model_suffix: usize,
) {
let (x_bal, y_bal) = balance_for_regression(&x, &y_class, &y_reg, max_per_class);
let (best_params, best_score) = staged_grid_search(&x_bal, &y_bal, k);
let model = RandomForestRegressor::fit(&x_bal, &y_bal, best_params.clone())
.expect("Failed to fit RandomForestRegressor");
let serialized = bincode::serialize(&model).unwrap();
let serializable = SerializableModel { model: serialized };
model_params
.save_model(&format!("{}_{}", key, model_suffix), &serializable)
.await
.expect("Failed to save regressor model");
let avg_cv_score = cross_validate(&x_bal, &y_bal, k, &best_params);
log::info!(
"Average CV score (MSE) on balanced data = {:.8}",
avg_cv_score
);
let delta = best_score - avg_cv_score;
let relative_delta = delta / best_score;
const OVERFITTING_RELATIVE_THRESHOLD: f64 = 0.01;
if relative_delta > OVERFITTING_RELATIVE_THRESHOLD {
log::warn!(
"Overfitting suspected: training score ({:.8}) > CV score ({:.8}); relative diff = {:.4}%",
best_score,
avg_cv_score,
relative_delta * 100.0
);
} else {
log::info!(
"No significant overfitting: training score ({:.8}), CV score ({:.8}), relative diff = {:.4}%",
best_score,
avg_cv_score,
relative_delta * 100.0
);
}
log::info!(
"Final best CV score (MSE) on balanced data: {:.8}",
best_score
);
}
fn staged_grid_search(
x: &DenseMatrix<f64>,
y: &Vec<f64>,
k: usize,
) -> (RandomForestRegressorParameters, f64) {
let n = x.shape().0;
let sample_sizes = vec![(n as f64 * 0.5) as usize, (n as f64 * 0.7) as usize, n];
let mut best_params: Option<RandomForestRegressorParameters> = None;
let mut best_score = f64::INFINITY;
for &size in &sample_sizes {
let (x_sub, y_sub) = sample_data(x, y, size);
log::info!("--- grid search with sample size = {} / {} ---", size, n);
let (params, score) = if size < n {
quick_grid_search(&x_sub, &y_sub, k)
} else {
let candidates = quick_grid_search_candidates(&x_sub, &y_sub, k);
let mut sorted = candidates.clone();
sorted.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); let top_candidates = sorted.into_iter().take(5);
let mut local_best: Option<RandomForestRegressorParameters> = None;
let mut local_score = f64::INFINITY;
for (p, _) in top_candidates {
let mse = cross_validate(&x_sub, &y_sub, k, &p);
if mse < local_score {
local_score = mse;
local_best = Some(p);
}
}
(local_best.expect("No top candidate found"), local_score)
};
log::info!("Sample size {} -> best MSE = {:.8}", size, score);
if score < best_score {
best_score = score;
best_params = Some(params);
}
}
(best_params.expect("No params found"), best_score)
}
fn quick_grid_search_candidates(
x: &DenseMatrix<f64>,
y: &Vec<f64>,
k: usize,
) -> Vec<(RandomForestRegressorParameters, f64)> {
let mut candidates = Vec::new();
let n_features = x.shape().1;
let m_values = compute_m(n_features);
let n_trees_values = compute_n_trees(n_features);
let mut best = f64::INFINITY;
let mut no_imp = 0;
let early_stop = 5;
'outer: for &m in &m_values {
for &n_trees in &n_trees_values {
let params = RandomForestRegressorParameters {
n_trees, m: Some(m),
max_depth: None,
min_samples_split: 2,
min_samples_leaf: 1,
seed: 42,
keep_samples: false,
};
let mse = cross_validate(x, y, k, ¶ms);
log::info!("Reg quick candidate: {:?}, MSE: {:.8}", params, mse);
candidates.push((params.clone(), mse));
if mse < best - 1e-6 {
best = mse;
no_imp = 0;
} else {
no_imp += 1;
}
if no_imp >= early_stop {
break 'outer;
}
}
}
candidates
}
fn quick_grid_search(
x: &DenseMatrix<f64>,
y: &Vec<f64>,
k: usize,
) -> (RandomForestRegressorParameters, f64) {
let mut cands = quick_grid_search_candidates(x, y, k);
cands.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
cands.into_iter().next().unwrap()
}
fn cross_validate(
x: &DenseMatrix<f64>,
y: &Vec<f64>,
k: usize,
params: &RandomForestRegressorParameters,
) -> f64 {
let n = x.shape().0;
let mut idx: Vec<usize> = (0..n).collect();
idx.shuffle(&mut thread_rng());
let fold = n / k;
let mses: Vec<f64> = (0..k)
.into_par_iter()
.map(|i| {
let start = i * fold;
let end = if i == k - 1 { n } else { (i + 1) * fold };
let test_idx = &idx[start..end];
let train_idx: Vec<usize> = idx
.iter()
.filter(|&&j| j < start || j >= end)
.copied()
.collect();
let x_tr = DenseMatrix::from_2d_vec(
&train_idx
.iter()
.map(|&r| x.get_row(r).iterator(0).copied().collect::<Vec<f64>>())
.collect::<Vec<_>>(),
)
.unwrap();
let y_tr: Vec<f64> = train_idx.iter().map(|&r| y[r]).collect();
let x_te = DenseMatrix::from_2d_vec(
&test_idx
.iter()
.map(|&r| x.get_row(r).iterator(0).copied().collect::<Vec<f64>>())
.collect::<Vec<_>>(),
)
.unwrap();
let y_te: Vec<f64> = test_idx.iter().map(|&r| y[r]).collect();
let model = RandomForestRegressor::fit(&x_tr, &y_tr, params.clone()).unwrap();
let y_pred = model.predict(&x_te).unwrap();
mean_squared_error(&y_te, &y_pred)
})
.collect();
mses.iter().copied().sum::<f64>() / (mses.len() as f64)
}
fn sample_data(x: &DenseMatrix<f64>, y: &Vec<f64>, size: usize) -> (DenseMatrix<f64>, Vec<f64>) {
let total = y.len();
let mut idx: Vec<usize> = (0..total).collect();
idx.shuffle(&mut thread_rng());
let sel = &idx[..cmp::min(size, total)];
let x_sub = DenseMatrix::from_2d_vec(
&sel.iter()
.map(|&r| x.get_row(r).iterator(0).copied().collect::<Vec<f64>>())
.collect::<Vec<_>>(),
)
.unwrap();
let y_sub = sel.iter().map(|&r| y[r]).collect();
(x_sub, y_sub)
}
fn compute_m(n_features: usize) -> Vec<usize> {
let sqrt_m = (n_features as f64).sqrt().round() as usize;
let log2_m = (n_features as f64).log2().round() as usize;
let half_m = cmp::max(1, n_features / 2);
vec![sqrt_m, log2_m, half_m]
}
fn compute_n_trees(n_features: usize) -> Vec<usize> {
let base = (10.0 * (n_features as f64).sqrt()).round() as usize;
vec![base, base + 50, base + 100]
}