use debot_db::{ModelParams, SerializableModel};
use rand::seq::SliceRandom;
use rand::thread_rng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use smartcore_proba::ensemble::random_forest_classifier::{
RandomForestClassifier, RandomForestClassifierParameters,
};
use smartcore_proba::linalg::basic::arrays::Array;
use smartcore_proba::linalg::basic::arrays::Array2;
use smartcore_proba::linalg::basic::matrix::DenseMatrix;
use smartcore_proba::tree::decision_tree_classifier::SplitCriterion;
use std::collections::HashMap;
#[derive(Clone, Copy)]
pub enum Metric {
ExpectedScore,
MSE,
}
#[cfg(feature = "classification")]
#[derive(Serialize, Deserialize)]
pub struct ModelWithWeights {
pub model: RandomForestClassifier<f64, i32, DenseMatrix<f64>, Vec<i32>>,
pub weights: (f64, f64, f64),
}
fn expected_score_from_proba(proba_row: &[f64], w0: f64, w1: f64, w2: f64) -> f64 {
w0 * proba_row[0] + w1 * proba_row[1] + w2 * proba_row[2]
}
fn balance_classes(
x: &DenseMatrix<f64>,
y: &Vec<i32>,
max_per_class: Option<usize>,
) -> (DenseMatrix<f64>, Vec<i32>) {
let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
let mut rng = thread_rng();
for (i, &label) in y.iter().enumerate() {
class_indices.entry(label).or_insert_with(Vec::new).push(i);
}
let inferred_max = class_indices.values().map(Vec::len).max().unwrap_or(0);
let max_class_size = max_per_class.unwrap_or(inferred_max).min(inferred_max);
let mut rows = Vec::new();
let mut labels = Vec::new();
for (&label, indices) in &class_indices {
let mut idxs = indices.clone();
idxs.shuffle(&mut rng);
while idxs.len() < max_class_size {
idxs.push(*indices.choose(&mut rng).unwrap());
}
idxs.truncate(max_class_size);
for &i in &idxs {
let row = x.get_row(i).iterator(0).copied().collect::<Vec<f64>>();
rows.push(row);
labels.push(label);
}
}
(DenseMatrix::from_2d_vec(&rows).unwrap(), labels)
}
fn cross_validate_expected_score(
x: &DenseMatrix<f64>,
y: &Vec<i32>,
k: usize,
params: &RandomForestClassifierParameters,
weights: (f64, f64, f64),
) -> f64 {
let (w0, w1, w2) = weights;
let n = x.shape().0;
let mut idx: Vec<usize> = (0..n).collect();
idx.shuffle(&mut thread_rng());
let fold = n / k;
let scores: Vec<f64> = (0..k)
.into_par_iter()
.map(|i| {
let start = i * fold;
let end = if i == k - 1 { n } else { (i + 1) * fold };
let valid = &idx[start..end];
let train: Vec<usize> = idx
.iter()
.filter(|&&j| j < start || j >= end)
.copied()
.collect();
let x_tr = DenseMatrix::from_2d_vec(
&train
.iter()
.map(|&r| x.get_row(r).iterator(0).copied().collect())
.collect::<Vec<_>>(),
)
.unwrap();
let y_tr = train.iter().map(|&r| y[r]).collect::<Vec<_>>();
let x_va = DenseMatrix::from_2d_vec(
&valid
.iter()
.map(|&r| x.get_row(r).iterator(0).copied().collect())
.collect::<Vec<_>>(),
)
.unwrap();
let clf = RandomForestClassifier::fit(&x_tr, &y_tr, params.clone()).unwrap();
let proba: DenseMatrix<f64> = clf.predict_proba(&x_va).unwrap();
let mut sum = 0.0;
for r in 0..proba.shape().0 {
let row = proba.get_row(r).iterator(0).copied().collect::<Vec<f64>>();
sum += expected_score_from_proba(&row, w0, w1, w2);
}
sum / (proba.shape().0 as f64)
})
.collect();
scores.iter().sum::<f64>() / (k as f64)
}
fn cross_validate_mse(
x: &DenseMatrix<f64>,
y: &Vec<i32>,
k: usize,
params: &RandomForestClassifierParameters,
weights: (f64, f64, f64),
) -> f64 {
let (w0, w1, w2) = weights;
let n = x.shape().0;
let mut idx: Vec<usize> = (0..n).collect();
idx.shuffle(&mut thread_rng());
let fold = n / k;
let mses: Vec<f64> = (0..k)
.into_par_iter()
.map(|i| {
let start = i * fold;
let end = if i == k - 1 { n } else { (i + 1) * fold };
let valid = &idx[start..end];
let train: Vec<usize> = idx
.iter()
.filter(|&&j| j < start || j >= end)
.copied()
.collect();
let x_tr = DenseMatrix::from_2d_vec(
&train
.iter()
.map(|&r| x.get_row(r).iterator(0).copied().collect())
.collect::<Vec<_>>(),
)
.unwrap();
let y_tr = train.iter().map(|&r| y[r]).collect::<Vec<_>>();
let x_va = DenseMatrix::from_2d_vec(
&valid
.iter()
.map(|&r| x.get_row(r).iterator(0).copied().collect())
.collect::<Vec<_>>(),
)
.unwrap();
let y_va: Vec<i32> = valid.iter().map(|&r| y[r]).collect();
let clf = RandomForestClassifier::fit(&x_tr, &y_tr, params.clone()).unwrap();
let proba: DenseMatrix<f64> = clf.predict_proba(&x_va).unwrap();
let mut sum_sq = 0.0;
for (r, &true_label) in y_va.iter().enumerate() {
let row = proba.get_row(r).iterator(0).copied().collect::<Vec<f64>>();
let pred_score = expected_score_from_proba(&row, w0, w1, w2);
let real_score = match true_label {
0 => w0,
1 => w1,
2 => w2,
_ => unreachable!(),
};
sum_sq += (pred_score - real_score).powi(2);
}
sum_sq / (y_va.len() as f64)
})
.collect();
mses.iter().sum::<f64>() / (k as f64)
}
pub async fn grid_search_and_train_classifier(
key: &str,
model_params: &ModelParams,
x: DenseMatrix<f64>,
y: Vec<i32>,
k: usize,
max_per_class: Option<usize>,
w_loss: f64,
w_expired: f64,
w_take: f64,
metric: Metric,
model_suffix: usize,
) {
if y.iter().all(|&l| l == y[0]) {
log::error!("Training data contains only one class");
return;
}
let (x_bal, y_bal) = balance_classes(&x, &y, max_per_class);
let (best_params, best_score) =
staged_grid_search(&x_bal, &y_bal, k, (w_loss, w_expired, w_take), metric);
let avg_exp_score =
cross_validate_expected_score(&x_bal, &y_bal, k, &best_params, (w_loss, w_expired, w_take));
let model = RandomForestClassifier::fit(&x_bal, &y_bal, best_params.clone()).unwrap();
let mw = ModelWithWeights {
model,
weights: (w_loss, w_expired, w_take),
};
let serial = bincode::serialize(&mw).unwrap();
let model_size_mb = serial.len() as f64;
model_params
.save_model(
&format!("{}_{}", key, model_suffix),
&SerializableModel { model: serial },
)
.await
.unwrap();
match metric {
Metric::ExpectedScore => {
log::info!("Final best expected_score = {:.8}", best_score);
}
Metric::MSE => {
log::info!("Final best MSE = {:.8}", -best_score);
}
}
log::info!("Average expected_score (CV) = {:.8}", avg_exp_score);
log::info!(
"Model size = {:.2} MB, n_trees = {}, m = {:?}, min_samples_split = {}, min_samples_leaf = {}",
model_size_mb / 1_048_576.0, best_params.n_trees,
best_params.m,
best_params.min_samples_split,
best_params.min_samples_leaf
);
let delta = (avg_exp_score - best_score).abs();
if delta > 0.0005 {
log::warn!(
"Possible overfitting detected: avg_cv_score ({:.8}) and best_train_score ({:.8}) differ by {:.8}",
avg_exp_score,
best_score,
delta
);
} else {
log::info!(
"No significant overfitting detected: avg_cv_score ({:.8}), best_train_score ({:.8})",
avg_exp_score,
best_score
);
}
}
fn staged_grid_search(
x: &DenseMatrix<f64>,
y: &Vec<i32>,
k: usize,
weights: (f64, f64, f64),
metric: Metric,
) -> (RandomForestClassifierParameters, f64) {
let n = x.shape().0;
let sizes = vec![(n as f64 * 0.5) as usize, (n as f64 * 0.7) as usize, n];
let mut best_p = params_default();
let mut best_s = f64::NEG_INFINITY;
let mut no_improve_rounds = 0;
let early_stopping_rounds = 2;
for &sz in &sizes {
let (x_sub, y_sub) = sample_data(x, y, sz);
log::info!("Grid search on size {}/{}", sz, n);
let (p, s) = if sz < n {
quick_grid_search(&x_sub, &y_sub, k, weights, metric)
} else {
let mut cands = quick_grid_search_candidates(&x_sub, &y_sub, k, weights, metric);
cands.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let mut lp = params_default();
let mut ls = f64::NEG_INFINITY;
for (cp, _) in cands.into_iter().take(5) {
let sc = match metric {
Metric::ExpectedScore => {
cross_validate_expected_score(&x_sub, &y_sub, k, &cp, weights)
}
Metric::MSE => -cross_validate_mse(&x_sub, &y_sub, k, &cp, weights),
};
if sc > ls {
ls = sc;
lp = cp;
}
}
(lp, ls)
};
log::info!("Size {} -> score {:.8}", sz, s);
if s > best_s {
best_s = s;
best_p = p;
no_improve_rounds = 0;
} else {
no_improve_rounds += 1;
if no_improve_rounds >= early_stopping_rounds {
log::warn!(
"Early stopping triggered after {} no-improve rounds.",
early_stopping_rounds
);
break;
}
}
}
(best_p, best_s)
}
fn quick_grid_search_candidates(
x: &DenseMatrix<f64>,
y: &Vec<i32>,
k: usize,
weights: (f64, f64, f64),
metric: Metric,
) -> Vec<(RandomForestClassifierParameters, f64)> {
let crits = vec![SplitCriterion::Gini, SplitCriterion::Entropy];
let leafs = vec![1, 2];
let splits = vec![2, 5];
let mvals = compute_m(x.shape().1);
let ntree = compute_n_trees(x.shape().1);
let mut c = Vec::new();
let mut best = f64::NEG_INFINITY;
let mut noimp = 0;
'o: for &m in &mvals {
for &n in &ntree {
for cr in &crits {
for &lf in &leafs {
for &sp in &splits {
let param = RandomForestClassifierParameters {
criterion: cr.clone(),
max_depth: None,
min_samples_leaf: lf,
min_samples_split: sp,
n_trees: n as u16,
m: Some(m),
keep_samples: false,
seed: 42,
};
let sc = match metric {
Metric::ExpectedScore => {
cross_validate_expected_score(x, y, k, ¶m, weights)
}
Metric::MSE => -cross_validate_mse(x, y, k, ¶m, weights),
};
c.push((param.clone(), sc));
if sc > best {
best = sc;
noimp = 0;
} else {
noimp += 1;
}
if noimp >= 5 {
break 'o;
}
}
}
}
}
}
c
}
fn quick_grid_search(
x: &DenseMatrix<f64>,
y: &Vec<i32>,
k: usize,
weights: (f64, f64, f64),
metric: Metric,
) -> (RandomForestClassifierParameters, f64) {
let mut c = quick_grid_search_candidates(x, y, k, weights, metric);
c.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
c.remove(0)
}
fn params_default() -> RandomForestClassifierParameters {
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 10,
m: Some(2),
keep_samples: false,
seed: 42,
}
}
fn sample_data(x: &DenseMatrix<f64>, y: &Vec<i32>, size: usize) -> (DenseMatrix<f64>, Vec<i32>) {
let n = y.len();
let mut idx: Vec<usize> = (0..n).collect();
idx.shuffle(&mut thread_rng());
let sel = &idx[..size.min(n)];
let xsub = DenseMatrix::from_2d_vec(
&sel.iter()
.map(|&i| x.get_row(i).iterator(0).copied().collect())
.collect::<Vec<_>>(),
)
.unwrap();
let ysub = sel.iter().map(|&i| y[i]).collect();
(xsub, ysub)
}
fn compute_m(f: usize) -> Vec<usize> {
let a = (f as f64).sqrt().round() as usize;
let b = (f as f64).log2().round() as usize;
let c = (f / 2).max(1);
vec![a, b, c]
}
fn compute_n_trees(f: usize) -> Vec<usize> {
let base = (10.0 * (f as f64).sqrt()).round() as usize;
vec![base, base + 50, base + 100]
}