use std::collections::HashMap;
use std::path::PathBuf;
use crate::defaults::{seeds as seeds_defaults, tuner as tuner_defaults};
use crate::TreeBoostError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelFormat {
Rkyv,
Bincode,
}
impl ModelFormat {
pub fn extension(&self) -> &'static str {
match self {
Self::Rkyv => "rkyv",
Self::Bincode => "bin",
}
}
pub fn filename(&self) -> &'static str {
match self {
Self::Rkyv => "best_model.rkyv",
Self::Bincode => "best_model.bin",
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ParamBounds {
Continuous { min: f32, max: f32, log_scale: bool },
Discrete { min: usize, max: usize, step: usize },
Categorical { values: Vec<String> },
}
impl ParamBounds {
pub fn continuous(min: f32, max: f32) -> Self {
Self::Continuous {
min,
max,
log_scale: false,
}
}
pub fn log_continuous(min: f32, max: f32) -> Self {
Self::Continuous {
min,
max,
log_scale: true,
}
}
pub fn discrete(min: usize, max: usize) -> Self {
Self::Discrete { min, max, step: 1 }
}
pub fn discrete_step(min: usize, max: usize, step: usize) -> Self {
Self::Discrete { min, max, step }
}
pub fn categorical(values: Vec<String>) -> Self {
Self::Categorical { values }
}
pub fn categorical_from_strs(values: &[&str]) -> Self {
Self::Categorical {
values: values.iter().map(|s| s.to_string()).collect(),
}
}
pub fn clamp(&self, value: f32) -> f32 {
match self {
Self::Continuous { min, max, .. } => value.clamp(*min, *max),
Self::Discrete { min, max, step } => {
let clamped = (value as usize).clamp(*min, *max);
let steps = (clamped - min) / step;
(min + steps * step) as f32
}
Self::Categorical { values } => {
let max_idx = values.len().saturating_sub(1);
(value as usize).clamp(0, max_idx) as f32
}
}
}
pub fn contains(&self, value: f32) -> bool {
match self {
Self::Continuous { min, max, .. } => value >= *min && value <= *max,
Self::Discrete { min, max, .. } => {
let v = value as usize;
v >= *min && v <= *max
}
Self::Categorical { values } => {
let idx = value as usize;
idx < values.len()
}
}
}
pub fn min_value(&self) -> f32 {
match self {
Self::Continuous { min, .. } => *min,
Self::Discrete { min, .. } => *min as f32,
Self::Categorical { .. } => 0.0,
}
}
pub fn max_value(&self) -> f32 {
match self {
Self::Continuous { max, .. } => *max,
Self::Discrete { max, .. } => *max as f32,
Self::Categorical { values } => values.len().saturating_sub(1) as f32,
}
}
pub fn is_log_scale(&self) -> bool {
matches!(
self,
Self::Continuous {
log_scale: true,
..
}
)
}
pub fn is_categorical(&self) -> bool {
matches!(self, Self::Categorical { .. })
}
pub fn categorical_values(&self) -> Option<&[String]> {
match self {
Self::Categorical { values } => Some(values),
_ => None,
}
}
pub fn get_categorical_value(&self, index: usize) -> Option<&str> {
match self {
Self::Categorical { values } => values.get(index).map(|s| s.as_str()),
_ => None,
}
}
pub fn categorical_index(&self, value: &str) -> Option<usize> {
match self {
Self::Categorical { values } => values.iter().position(|v| v == value),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct ParamDef {
pub name: String,
pub bounds: ParamBounds,
pub center: f32,
}
impl ParamDef {
pub fn new(name: impl Into<String>, bounds: ParamBounds, center: f32) -> Self {
let name = name.into();
let center = bounds.clamp(center);
Self {
name,
bounds,
center,
}
}
pub fn set_center(&mut self, center: f32) {
self.center = self.bounds.clamp(center);
}
}
#[derive(Debug, Clone)]
pub struct ParameterSpace {
params: Vec<ParamDef>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SpacePreset {
Minimal,
Regression,
Classification,
Exhaustive,
Universal,
}
impl Default for ParameterSpace {
fn default() -> Self {
Self::with_preset(SpacePreset::Regression)
}
}
impl ParameterSpace {
pub fn new() -> Self {
Self { params: Vec::new() }
}
pub fn with_preset(preset: SpacePreset) -> Self {
match preset {
SpacePreset::Minimal => Self::minimal_space(),
SpacePreset::Regression => Self::regression_space(),
SpacePreset::Classification => Self::classification_space(),
SpacePreset::Exhaustive => Self::exhaustive(),
SpacePreset::Universal => Self::universal_space(),
}
}
pub fn exhaustive() -> Self {
Self {
params: vec![
ParamDef::new("max_depth", ParamBounds::discrete(2, 12), 6.0),
ParamDef::new("learning_rate", ParamBounds::log_continuous(0.01, 0.5), 0.1),
ParamDef::new("subsample", ParamBounds::continuous(0.5, 1.0), 0.8),
ParamDef::new("colsample", ParamBounds::continuous(0.5, 1.0), 1.0),
ParamDef::new("lambda", ParamBounds::continuous(0.0, 10.0), 1.0),
ParamDef::new("entropy_weight", ParamBounds::continuous(0.0, 0.5), 0.0),
ParamDef::new("goss_top_rate", ParamBounds::continuous(0.1, 0.4), 0.2),
ParamDef::new("goss_other_rate", ParamBounds::continuous(0.05, 0.2), 0.1),
],
}
}
pub fn universal_mode_only() -> Self {
Self {
params: vec![ParamDef::new(
"mode",
ParamBounds::categorical_from_strs(&["PureTree", "LinearThenTree", "RandomForest"]),
0.0,
)],
}
}
fn regression_space() -> Self {
Self {
params: vec![
ParamDef::new("max_depth", ParamBounds::discrete(2, 12), 6.0),
ParamDef::new("learning_rate", ParamBounds::log_continuous(0.01, 0.5), 0.1),
ParamDef::new("subsample", ParamBounds::continuous(0.5, 1.0), 0.8),
ParamDef::new("lambda", ParamBounds::continuous(0.0, 10.0), 1.0),
ParamDef::new("entropy_weight", ParamBounds::continuous(0.0, 0.5), 0.0),
],
}
}
fn classification_space() -> Self {
Self {
params: vec![
ParamDef::new("max_depth", ParamBounds::discrete(2, 10), 5.0),
ParamDef::new("learning_rate", ParamBounds::log_continuous(0.01, 0.3), 0.1),
ParamDef::new("subsample", ParamBounds::continuous(0.6, 1.0), 0.8),
ParamDef::new("lambda", ParamBounds::continuous(0.0, 5.0), 1.0),
ParamDef::new("entropy_weight", ParamBounds::continuous(0.0, 0.3), 0.0),
],
}
}
fn minimal_space() -> Self {
Self {
params: vec![
ParamDef::new("max_depth", ParamBounds::discrete(3, 10), 6.0),
ParamDef::new("learning_rate", ParamBounds::log_continuous(0.01, 0.3), 0.1),
],
}
}
fn universal_space() -> Self {
Self {
params: vec![
ParamDef::new(
"mode",
ParamBounds::categorical_from_strs(&[
"PureTree",
"LinearThenTree",
"RandomForest",
]),
0.0, ),
ParamDef::new("num_rounds", ParamBounds::discrete(50, 200), 100.0),
ParamDef::new("learning_rate", ParamBounds::log_continuous(0.01, 0.3), 0.1),
ParamDef::new("subsample", ParamBounds::continuous(0.6, 1.0), 0.8),
ParamDef::new("tree_max_depth", ParamBounds::discrete(3, 10), 6.0),
ParamDef::new("tree_lambda", ParamBounds::continuous(0.0, 10.0), 1.0),
],
}
}
pub fn universal_linear_then_tree() -> Self {
Self {
params: vec![
ParamDef::new(
"num_rounds",
ParamBounds::discrete(30, 150),
50.0, ),
ParamDef::new("learning_rate", ParamBounds::log_continuous(0.01, 0.3), 0.1),
ParamDef::new("linear_rounds", ParamBounds::discrete(5, 30), 10.0),
ParamDef::new(
"linear_lambda",
ParamBounds::log_continuous(0.01, 10.0),
1.0,
),
ParamDef::new("tree_max_depth", ParamBounds::discrete(3, 8), 5.0),
],
}
}
pub fn with_param(mut self, name: &str, bounds: ParamBounds, center: f32) -> Self {
self.params.retain(|p| p.name != name);
self.params.push(ParamDef::new(name, bounds, center));
self
}
pub fn without_param(mut self, name: &str) -> Self {
self.params.retain(|p| p.name != name);
self
}
pub fn get(&self, name: &str) -> Option<&ParamDef> {
self.params.iter().find(|p| p.name == name)
}
pub fn get_mut(&mut self, name: &str) -> Option<&mut ParamDef> {
self.params.iter_mut().find(|p| p.name == name)
}
pub fn params(&self) -> &[ParamDef] {
&self.params
}
pub fn params_mut(&mut self) -> &mut [ParamDef] {
&mut self.params
}
pub fn len(&self) -> usize {
self.params.len()
}
pub fn is_empty(&self) -> bool {
self.params.is_empty()
}
pub fn param_names(&self) -> Vec<String> {
self.params.iter().map(|p| p.name.clone()).collect()
}
pub fn centers(&self) -> HashMap<String, f32> {
self.params
.iter()
.map(|p| (p.name.clone(), p.center))
.collect()
}
pub fn set_centers(&mut self, centers: &HashMap<String, f32>) {
for param in &mut self.params {
if let Some(¢er) = centers.get(¶m.name) {
param.set_center(center);
}
}
}
pub fn validate(&self) -> Result<(), String> {
const VALID_PARAMS: &[&str] = &[
"max_depth",
"learning_rate",
"subsample",
"colsample",
"lambda",
"entropy_weight",
"min_samples_leaf",
"min_hessian_leaf",
"min_gain",
"num_rounds",
"goss_top_rate",
"goss_other_rate",
];
for param in &self.params {
if !VALID_PARAMS.contains(¶m.name.as_str()) {
return Err(format!(
"Unknown parameter '{}'. Valid parameters: {:?}",
param.name, VALID_PARAMS
));
}
}
Ok(())
}
pub fn constrain_from_history<P: AsRef<std::path::Path>>(
mut self,
history_dir: P,
top_percentile: f32,
metric_column: &str,
higher_is_better: bool,
) -> crate::Result<Self> {
use std::fs;
let dir = history_dir.as_ref();
if !dir.exists() {
return Err(TreeBoostError::Data(format!(
"History directory not found: {}",
dir.display()
)));
}
let mut csv_files: Vec<_> = fs::read_dir(dir)
.map_err(|e| TreeBoostError::Data(format!("Failed to read directory: {}", e)))?
.filter_map(|entry| {
let entry = entry.ok()?;
let path = entry.path();
let name = path.file_name()?.to_str()?;
if name.starts_with("iteration_") && name.ends_with(".csv") {
Some(path)
} else {
None
}
})
.collect();
if csv_files.is_empty() {
return Err(TreeBoostError::Data(
"No iteration_*.csv files found in history directory".into(),
));
}
csv_files.sort();
let mut all_trials: Vec<HashMap<String, f32>> = Vec::new();
for csv_path in &csv_files {
let mut reader = csv::Reader::from_path(csv_path)
.map_err(|e| TreeBoostError::Data(format!("Failed to open CSV: {}", e)))?;
let headers: Vec<String> = reader
.headers()
.map_err(|e| TreeBoostError::Data(format!("Failed to read headers: {}", e)))?
.iter()
.map(|s| s.to_string())
.collect();
for result in reader.records() {
let record = result
.map_err(|e| TreeBoostError::Data(format!("Failed to read record: {}", e)))?;
let mut trial: HashMap<String, f32> = HashMap::new();
for (i, value) in record.iter().enumerate() {
if let Some(header) = headers.get(i) {
if let Ok(v) = value.parse::<f32>() {
trial.insert(header.clone(), v);
}
}
}
if trial.contains_key(metric_column) {
all_trials.push(trial);
}
}
}
if all_trials.is_empty() {
return Err(TreeBoostError::Data(
"No valid trials found in CSV files".into(),
));
}
all_trials.sort_by(|a, b| {
let a_val = a.get(metric_column).copied().unwrap_or(f32::NAN);
let b_val = b.get(metric_column).copied().unwrap_or(f32::NAN);
if higher_is_better {
b_val
.partial_cmp(&a_val)
.unwrap_or(std::cmp::Ordering::Equal)
} else {
a_val
.partial_cmp(&b_val)
.unwrap_or(std::cmp::Ordering::Equal)
}
});
let n_top = ((all_trials.len() as f32) * top_percentile).ceil() as usize;
let n_top = n_top.max(1); let top_trials = &all_trials[..n_top.min(all_trials.len())];
for param in &mut self.params {
let values: Vec<f32> = top_trials
.iter()
.filter_map(|t| t.get(¶m.name).copied())
.filter(|v| !v.is_nan())
.collect();
if values.is_empty() {
continue;
}
let min_val = values.iter().copied().fold(f32::INFINITY, f32::min);
let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
match &mut param.bounds {
ParamBounds::Continuous { min, max, .. } => {
let range = max_val - min_val;
let margin = range * 0.1;
*min = (min_val - margin).max(*min);
*max = (max_val + margin).min(*max);
}
ParamBounds::Discrete { min, max, step } => {
let new_min = (min_val as usize).max(*min);
let new_max = (max_val as usize).min(*max);
*min = ((new_min - *min) / *step) * *step + *min;
*max = ((new_max - *min) / *step) * *step + *min;
}
ParamBounds::Categorical { .. } => {
}
}
let mid = match ¶m.bounds {
ParamBounds::Continuous { min, max, .. } => (*min + *max) / 2.0,
ParamBounds::Discrete { min, max, .. } => ((*min + *max) / 2) as f32,
ParamBounds::Categorical { values } => {
let current = param.center as usize;
if current < values.len() {
current as f32
} else {
0.0
}
}
};
param.set_center(mid);
}
Ok(self)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum EvalStrategy {
Holdout {
validation_ratio: f32,
folds: usize,
},
Conformal {
calibration_ratio: f32,
quantile: f32,
folds: usize,
},
}
impl Default for EvalStrategy {
fn default() -> Self {
Self::Holdout {
validation_ratio: 0.2,
folds: 1,
}
}
}
impl EvalStrategy {
pub fn holdout(validation_ratio: f32) -> Self {
Self::Holdout {
validation_ratio,
folds: 1,
}
}
pub fn conformal(calibration_ratio: f32, quantile: f32) -> Self {
Self::Conformal {
calibration_ratio,
quantile,
folds: 1,
}
}
pub fn conformal_90(calibration_ratio: f32) -> Self {
Self::conformal(calibration_ratio, 0.9)
}
pub fn with_folds(mut self, folds: usize) -> Self {
match &mut self {
Self::Holdout { folds: f, .. } => *f = folds,
Self::Conformal { folds: f, .. } => *f = folds,
}
self
}
pub fn folds(&self) -> usize {
match self {
Self::Holdout { folds, .. } => *folds,
Self::Conformal { folds, .. } => *folds,
}
}
pub fn auto(num_samples: usize) -> Self {
if num_samples < 1_000 {
Self::holdout(0.2).with_folds(5)
} else if num_samples < 5_000 {
Self::holdout(0.2).with_folds(3)
} else {
Self::holdout(0.2)
}
}
pub fn validate(&self) -> Result<(), String> {
match self {
Self::Holdout {
validation_ratio,
folds,
} => {
if *validation_ratio <= 0.0 || *validation_ratio >= 1.0 {
return Err(format!(
"validation_ratio must be in (0, 1), got {}",
validation_ratio
));
}
if *folds == 0 {
return Err("folds must be >= 1".into());
}
}
Self::Conformal {
calibration_ratio,
quantile,
folds,
} => {
if *calibration_ratio <= 0.0 || *calibration_ratio >= 1.0 {
return Err(format!(
"calibration_ratio must be in (0, 1), got {}",
calibration_ratio
));
}
if *quantile <= 0.0 || *quantile >= 1.0 {
return Err(format!("quantile must be in (0, 1), got {}", quantile));
}
if *folds == 0 {
return Err("folds must be >= 1".into());
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum GridStrategy {
Cartesian {
points_per_dim: usize,
},
LatinHypercube {
n_samples: usize,
},
Random {
n_samples: usize,
},
}
impl Default for GridStrategy {
fn default() -> Self {
Self::Cartesian { points_per_dim: 3 }
}
}
impl GridStrategy {
pub fn cartesian(points_per_dim: usize) -> Self {
Self::Cartesian { points_per_dim }
}
pub fn lhs(n_samples: usize) -> Self {
Self::LatinHypercube { n_samples }
}
pub fn random(n_samples: usize) -> Self {
Self::Random { n_samples }
}
pub fn num_candidates(&self, num_params: usize) -> usize {
match self {
Self::Cartesian { points_per_dim } => points_per_dim.pow(num_params as u32),
Self::LatinHypercube { n_samples } => *n_samples,
Self::Random { n_samples } => *n_samples,
}
}
pub fn validate(&self) -> Result<(), String> {
match self {
Self::Cartesian { points_per_dim } => {
if *points_per_dim < 2 {
return Err(format!(
"points_per_dim must be >= 2, got {}",
points_per_dim
));
}
}
Self::LatinHypercube { n_samples } | Self::Random { n_samples } => {
if *n_samples < 1 {
return Err(format!("n_samples must be >= 1, got {}", n_samples));
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum TuningMode {
#[default]
Optimistic,
Realistic,
}
impl TuningMode {
pub fn optimistic() -> Self {
Self::Optimistic
}
pub fn realistic() -> Self {
Self::Realistic
}
pub fn is_optimistic(&self) -> bool {
matches!(self, Self::Optimistic)
}
pub fn is_realistic(&self) -> bool {
matches!(self, Self::Realistic)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum OptimizationMetric {
#[default]
ValidationLoss,
F1Score,
RocAuc,
}
impl OptimizationMetric {
pub fn higher_is_better(&self) -> bool {
match self {
Self::ValidationLoss => false,
Self::F1Score => true,
Self::RocAuc => true,
}
}
pub fn name(&self) -> &'static str {
match self {
Self::ValidationLoss => "validation_loss",
Self::F1Score => "f1_score",
Self::RocAuc => "roc_auc",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum TaskType {
Regression,
#[default]
BinaryClassification,
MultiClassClassification,
}
impl TaskType {
pub fn is_classification(&self) -> bool {
matches!(
self,
Self::BinaryClassification | Self::MultiClassClassification
)
}
pub fn is_binary(&self) -> bool {
matches!(self, Self::BinaryClassification)
}
pub fn is_regression(&self) -> bool {
matches!(self, Self::Regression)
}
}
#[derive(Debug, Clone)]
pub struct TunerConfig {
pub space: ParameterSpace,
pub n_iterations: usize,
pub initial_spread: f32,
pub zoom_factor: f32,
pub grid_strategy: GridStrategy,
pub eval_strategy: EvalStrategy,
pub tuning_mode: TuningMode,
pub parallel_trials: bool,
pub n_parallel: usize,
pub num_rounds: usize,
pub early_stopping_rounds: usize,
pub validation_ratio: f32,
pub improvement_threshold: f32,
pub min_f1_score: f32,
pub seed: u64,
pub verbose: bool,
pub optimization_metric: OptimizationMetric,
pub task_type: TaskType,
pub output_dir: Option<PathBuf>,
pub save_model_formats: Vec<ModelFormat>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TunerPreset {
SmokeTest,
Quick,
Balanced,
Thorough,
}
impl Default for TunerConfig {
fn default() -> Self {
Self {
space: ParameterSpace::with_preset(SpacePreset::Regression),
n_iterations: tuner_defaults::DEFAULT_N_ITERATIONS, initial_spread: tuner_defaults::DEFAULT_INITIAL_SPREAD, zoom_factor: tuner_defaults::DEFAULT_ZOOM_FACTOR, grid_strategy: GridStrategy::Cartesian { points_per_dim: 3 },
eval_strategy: EvalStrategy::Holdout {
validation_ratio: tuner_defaults::DEFAULT_TUNER_VAL_RATIO,
folds: 1,
},
tuning_mode: TuningMode::Optimistic, parallel_trials: false, n_parallel: 0, num_rounds: tuner_defaults::DEFAULT_TUNER_ROUNDS,
early_stopping_rounds: tuner_defaults::DEFAULT_TUNER_EARLY_STOP, validation_ratio: tuner_defaults::DEFAULT_TUNER_VAL_RATIO,
improvement_threshold: tuner_defaults::DEFAULT_IMPROVEMENT_THRESHOLD,
min_f1_score: tuner_defaults::DEFAULT_MIN_F1_SCORE,
seed: seeds_defaults::DEFAULT_SEED,
verbose: true,
optimization_metric: OptimizationMetric::ValidationLoss,
task_type: TaskType::Regression,
output_dir: None,
save_model_formats: Vec::new(), }
}
}
impl TunerConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_preset(mut self, preset: TunerPreset) -> Self {
match preset {
TunerPreset::SmokeTest => {
self.n_iterations = tuner_defaults::SMOKE_TEST_N_ITERATIONS;
self.num_rounds = tuner_defaults::QUICK_TUNER_ROUNDS;
self.early_stopping_rounds = tuner_defaults::QUICK_TUNER_EARLY_STOP;
self.improvement_threshold = tuner_defaults::QUICK_IMPROVEMENT_THRESHOLD;
}
TunerPreset::Quick => {
self.n_iterations = tuner_defaults::QUICK_N_ITERATIONS;
self.num_rounds = tuner_defaults::QUICK_TUNER_ROUNDS;
self.early_stopping_rounds = tuner_defaults::QUICK_TUNER_EARLY_STOP;
self.improvement_threshold = tuner_defaults::QUICK_IMPROVEMENT_THRESHOLD;
}
TunerPreset::Balanced => {
self.n_iterations = tuner_defaults::DEFAULT_N_ITERATIONS;
self.num_rounds = tuner_defaults::DEFAULT_TUNER_ROUNDS;
self.early_stopping_rounds = tuner_defaults::DEFAULT_TUNER_EARLY_STOP;
self.improvement_threshold = tuner_defaults::DEFAULT_IMPROVEMENT_THRESHOLD;
}
TunerPreset::Thorough => {
self.n_iterations = tuner_defaults::THOROUGH_N_ITERATIONS;
self.num_rounds = tuner_defaults::THOROUGH_TUNER_ROUNDS;
self.early_stopping_rounds = tuner_defaults::THOROUGH_TUNER_EARLY_STOP;
self.improvement_threshold = tuner_defaults::THOROUGH_IMPROVEMENT_THRESHOLD;
}
}
self
}
pub fn with_space(mut self, space: ParameterSpace) -> Self {
self.space = space;
self
}
pub fn with_iterations(mut self, n: usize) -> Self {
self.n_iterations = n;
self
}
pub fn with_initial_spread(mut self, spread: f32) -> Self {
self.initial_spread = spread;
self
}
pub fn with_zoom_factor(mut self, factor: f32) -> Self {
self.zoom_factor = factor;
self
}
pub fn with_grid_strategy(mut self, strategy: GridStrategy) -> Self {
self.grid_strategy = strategy;
self
}
pub fn with_eval_strategy(mut self, strategy: EvalStrategy) -> Self {
self.eval_strategy = strategy;
self
}
pub fn with_parallel(mut self, enabled: bool) -> Self {
self.parallel_trials = enabled;
self
}
pub fn with_n_parallel(mut self, n: usize) -> Self {
self.n_parallel = n;
self
}
pub fn with_num_rounds(mut self, rounds: usize) -> Self {
self.num_rounds = rounds;
self
}
pub fn with_early_stopping(mut self, rounds: usize, validation_ratio: f32) -> Self {
self.early_stopping_rounds = rounds;
self.validation_ratio = validation_ratio;
self
}
pub fn without_early_stopping(mut self) -> Self {
self.early_stopping_rounds = 0;
self
}
pub fn with_improvement_threshold(mut self, threshold: f32) -> Self {
self.improvement_threshold = threshold;
self
}
pub fn with_min_f1_score(mut self, min_f1: f32) -> Self {
self.min_f1_score = min_f1;
self
}
pub fn with_tuning_mode(mut self, mode: TuningMode) -> Self {
self.tuning_mode = mode;
self
}
pub fn optimistic(mut self) -> Self {
self.tuning_mode = TuningMode::Optimistic;
self
}
pub fn realistic(mut self) -> Self {
self.tuning_mode = TuningMode::Realistic;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn with_verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
pub fn with_optimization_metric(mut self, metric: OptimizationMetric) -> Self {
self.optimization_metric = metric;
self
}
pub fn with_task_type(mut self, task_type: TaskType) -> Self {
self.task_type = task_type;
self
}
pub fn with_output_dir<P: AsRef<std::path::Path>>(mut self, path: P) -> Self {
self.output_dir = Some(path.as_ref().to_path_buf());
self
}
pub fn with_save_model_formats(mut self, formats: Vec<ModelFormat>) -> Self {
self.save_model_formats = formats;
self
}
pub fn validate(&self) -> Result<(), String> {
if self.n_iterations == 0 {
return Err("n_iterations must be > 0".into());
}
if self.initial_spread <= 0.0 || self.initial_spread > 1.0 {
return Err(format!(
"initial_spread must be in (0, 1], got {}",
self.initial_spread
));
}
if self.zoom_factor <= 0.0 || self.zoom_factor >= 1.0 {
return Err(format!(
"zoom_factor must be in (0, 1), got {}",
self.zoom_factor
));
}
if self.num_rounds == 0 {
return Err("num_rounds must be > 0".into());
}
self.space.validate()?;
self.grid_strategy.validate()?;
self.eval_strategy.validate()?;
Ok(())
}
pub fn spread_for_iteration(&self, iteration: usize) -> f32 {
self.initial_spread * self.zoom_factor.powi(iteration as i32)
}
pub fn estimated_trials(&self) -> usize {
let candidates_per_iter = self.grid_strategy.num_candidates(self.space.len());
candidates_per_iter * self.n_iterations
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_param_bounds_continuous() {
let bounds = ParamBounds::continuous(0.0, 1.0);
assert_eq!(bounds.clamp(-0.5), 0.0);
assert_eq!(bounds.clamp(0.5), 0.5);
assert_eq!(bounds.clamp(1.5), 1.0);
assert!(bounds.contains(0.5));
assert!(!bounds.contains(-0.1));
assert!(!bounds.is_log_scale());
}
#[test]
fn test_param_bounds_log_continuous() {
let bounds = ParamBounds::log_continuous(0.01, 1.0);
assert!(bounds.is_log_scale());
assert_eq!(bounds.min_value(), 0.01);
assert_eq!(bounds.max_value(), 1.0);
}
#[test]
fn test_param_bounds_discrete() {
let bounds = ParamBounds::discrete(2, 10);
assert_eq!(bounds.clamp(1.0), 2.0);
assert_eq!(bounds.clamp(5.0), 5.0);
assert_eq!(bounds.clamp(15.0), 10.0);
assert!(!bounds.is_log_scale());
}
#[test]
fn test_param_bounds_discrete_step() {
let bounds = ParamBounds::discrete_step(2, 10, 2);
assert_eq!(bounds.clamp(5.0), 4.0);
assert_eq!(bounds.clamp(6.0), 6.0);
}
#[test]
fn test_param_def() {
let mut param = ParamDef::new("test", ParamBounds::continuous(0.0, 1.0), 0.5);
assert_eq!(param.name, "test");
assert_eq!(param.center, 0.5);
param.set_center(2.0); assert_eq!(param.center, 1.0); }
#[test]
fn test_parameter_space_default() {
let space = ParameterSpace::with_preset(SpacePreset::Regression);
assert_eq!(space.len(), 5);
assert!(space.get("max_depth").is_some());
assert!(space.get("learning_rate").is_some());
assert!(space.get("subsample").is_some());
assert!(space.get("lambda").is_some());
assert!(space.get("entropy_weight").is_some());
}
#[test]
fn test_parameter_space_with_param() {
let space = ParameterSpace::with_preset(SpacePreset::Minimal).with_param(
"colsample",
ParamBounds::continuous(0.5, 1.0),
0.8,
);
assert_eq!(space.len(), 3);
assert!(space.get("colsample").is_some());
}
#[test]
fn test_parameter_space_without_param() {
let space =
ParameterSpace::with_preset(SpacePreset::Regression).without_param("entropy_weight");
assert_eq!(space.len(), 4);
assert!(space.get("entropy_weight").is_none());
}
#[test]
fn test_parameter_space_centers() {
let mut space = ParameterSpace::with_preset(SpacePreset::Minimal);
let centers = space.centers();
assert_eq!(centers.get("max_depth"), Some(&6.0));
assert_eq!(centers.get("learning_rate"), Some(&0.1));
let mut new_centers = HashMap::new();
new_centers.insert("max_depth".into(), 8.0);
space.set_centers(&new_centers);
assert_eq!(space.get("max_depth").unwrap().center, 8.0);
}
#[test]
fn test_parameter_space_validate() {
let valid = ParameterSpace::with_preset(SpacePreset::Regression);
assert!(valid.validate().is_ok());
let invalid = ParameterSpace::new().with_param(
"invalid_param",
ParamBounds::continuous(0.0, 1.0),
0.5,
);
assert!(invalid.validate().is_err());
}
#[test]
fn test_eval_strategy() {
let holdout = EvalStrategy::holdout(0.2);
assert!(holdout.validate().is_ok());
assert_eq!(holdout.folds(), 1);
let holdout_5fold = EvalStrategy::holdout(0.2).with_folds(5);
assert!(holdout_5fold.validate().is_ok());
assert_eq!(holdout_5fold.folds(), 5);
let conformal = EvalStrategy::conformal(0.2, 0.9).with_folds(3);
assert!(conformal.validate().is_ok());
assert_eq!(conformal.folds(), 3);
let invalid_holdout = EvalStrategy::holdout(1.5);
assert!(invalid_holdout.validate().is_err());
let invalid_folds = EvalStrategy::holdout(0.2).with_folds(0);
assert!(invalid_folds.validate().is_err()); }
#[test]
fn test_eval_strategy_auto() {
assert!(matches!(
EvalStrategy::auto(500),
EvalStrategy::Holdout { folds: 5, .. }
));
assert!(matches!(
EvalStrategy::auto(2000),
EvalStrategy::Holdout { folds: 3, .. }
));
assert!(matches!(
EvalStrategy::auto(10000),
EvalStrategy::Holdout { folds: 1, .. }
));
}
#[test]
fn test_grid_strategy() {
let cart = GridStrategy::cartesian(3);
assert_eq!(cart.num_candidates(5), 243);
let lhs = GridStrategy::lhs(50);
assert_eq!(lhs.num_candidates(5), 50);
let rand = GridStrategy::random(100);
assert_eq!(rand.num_candidates(5), 100);
}
#[test]
fn test_grid_strategy_validate() {
assert!(GridStrategy::cartesian(3).validate().is_ok());
assert!(GridStrategy::cartesian(1).validate().is_err());
assert!(GridStrategy::lhs(0).validate().is_err());
}
#[test]
fn test_tuner_config_default() {
let config = TunerConfig::default();
assert_eq!(config.n_iterations, 5);
assert_eq!(config.initial_spread, 1.0); assert_eq!(config.zoom_factor, 0.8); assert_eq!(config.early_stopping_rounds, 10);
assert_eq!(config.validation_ratio, 0.2);
assert_eq!(config.improvement_threshold, 0.001);
assert!(config.validate().is_ok());
}
#[test]
fn test_tuner_config_quick() {
let config = TunerConfig::default().with_preset(TunerPreset::Quick);
assert_eq!(config.n_iterations, 2);
assert_eq!(config.num_rounds, 50);
assert_eq!(config.early_stopping_rounds, 5);
assert_eq!(config.improvement_threshold, 0.01); }
#[test]
fn test_tuner_config_thorough() {
let config = TunerConfig::default().with_preset(TunerPreset::Thorough);
assert_eq!(config.n_iterations, 7);
assert_eq!(config.num_rounds, 200);
assert_eq!(config.early_stopping_rounds, 20);
assert_eq!(config.improvement_threshold, 0.0001); }
#[test]
fn test_tuner_config_builders() {
let config = TunerConfig::new()
.with_iterations(5)
.with_num_rounds(200)
.with_seed(123)
.with_verbose(false);
assert_eq!(config.n_iterations, 5);
assert_eq!(config.num_rounds, 200);
assert_eq!(config.seed, 123);
assert!(!config.verbose);
}
#[test]
fn test_tuner_config_spread_for_iteration() {
let config = TunerConfig::default();
assert_eq!(config.spread_for_iteration(0), 1.0);
assert_eq!(config.spread_for_iteration(1), 0.8);
assert!((config.spread_for_iteration(2) - 0.64).abs() < 0.001);
}
#[test]
fn test_tuner_config_estimated_trials() {
let config = TunerConfig::default();
assert_eq!(config.estimated_trials(), 1215);
}
#[test]
fn test_tuner_config_validate() {
assert!(TunerConfig::default().validate().is_ok());
let invalid = TunerConfig::default().with_iterations(0);
assert!(invalid.validate().is_err());
let invalid = TunerConfig::default().with_initial_spread(0.0);
assert!(invalid.validate().is_err());
let invalid = TunerConfig::default().with_zoom_factor(1.0);
assert!(invalid.validate().is_err());
let invalid = TunerConfig::default().with_num_rounds(0);
assert!(invalid.validate().is_err());
}
}