1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
//! Trial result types for hyperparameter tuning
use std::collections::HashMap;
/// Result of a single trial (candidate evaluation)
#[derive(Debug, Clone)]
pub struct TrialResult {
/// Unique trial identifier
pub trial_id: usize,
/// Iteration (zoom level) when this trial was run
pub iteration: usize,
/// Hyperparameter values used
pub params: HashMap<String, f32>,
/// Validation metric (lower is better for MSE/LogLoss)
pub val_metric: f32,
/// Training metric
pub train_metric: f32,
/// Number of trees actually trained (may be < num_rounds if early stopped)
pub num_trees: usize,
/// Training time in milliseconds
pub train_time_ms: u64,
/// F1 score for classification (None for regression)
///
/// F1 is the harmonic mean of precision and recall.
/// A low F1 score indicates an unbalanced model (e.g., predicting
/// all negatives gives F1 = 0).
pub f1_score: Option<f32>,
/// ROC-AUC score for binary classification (None for regression/multi-class)
///
/// Area Under the ROC Curve measures ranking quality.
pub roc_auc: Option<f64>,
}
impl TrialResult {
/// CSV column headers (excluding dynamic param columns)
pub fn csv_headers() -> &'static [&'static str] {
&[
"trial_id",
"iteration",
"val_metric",
"train_metric",
"f1_score",
"roc_auc",
"num_trees",
"train_time_ms",
]
}
/// Convert trial result to CSV row values (excluding dynamic param columns)
pub fn to_csv_row(&self) -> Vec<String> {
vec![
self.trial_id.to_string(),
self.iteration.to_string(),
format!("{:.6}", self.val_metric),
format!("{:.6}", self.train_metric),
self.f1_score
.map(|f| format!("{:.4}", f))
.unwrap_or_default(),
self.roc_auc
.map(|a| format!("{:.6}", a))
.unwrap_or_default(),
self.num_trees.to_string(),
self.train_time_ms.to_string(),
]
}
/// Get param value formatted for CSV
pub fn param_to_csv(&self, name: &str) -> String {
self.params
.get(name)
.map(|v| format!("{:.6}", v))
.unwrap_or_default()
}
}