use crate::analysis::{Confidence, DataFrameProfile, DatasetAnalysis};
use crate::dataset::feature_extractor::LinearFeatureConfig;
use crate::defaults::{auto as auto_defaults, seeds as seeds_defaults};
use crate::ensemble::{MultiSeedConfig, SelectionConfig, StackingConfig};
use crate::features::FeaturePlan;
use crate::model::progress::{ProgressCallback, QuietProgress};
use crate::model::{BoostingMode, UniversalModel};
use crate::preprocessing::PreprocessingPlan;
use crate::tuner::ltt::LttTuningResult;
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum TuningLevel {
Quick,
#[default]
Standard,
Thorough,
None,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AutoEnsembleMethod {
SimpleAverage,
RidgeStacking,
}
#[derive(Debug, Clone)]
pub struct AutoEnsembleConfig {
pub method: AutoEnsembleMethod,
pub multi_seed: MultiSeedConfig,
pub selection: SelectionConfig,
pub stacking: StackingConfig,
}
impl Default for AutoEnsembleConfig {
fn default() -> Self {
Self {
method: AutoEnsembleMethod::RidgeStacking,
multi_seed: MultiSeedConfig::default(),
selection: SelectionConfig::default(),
stacking: StackingConfig::default(),
}
}
}
impl AutoEnsembleConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_method(mut self, method: AutoEnsembleMethod) -> Self {
self.method = method;
self
}
pub fn with_multi_seed_config(mut self, config: MultiSeedConfig) -> Self {
self.multi_seed = config;
self
}
pub fn with_selection_config(mut self, config: SelectionConfig) -> Self {
self.selection = config;
self
}
pub fn with_stacking_config(mut self, config: StackingConfig) -> Self {
self.stacking = config;
self
}
}
pub struct AutoConfig {
pub tuning_level: TuningLevel,
pub val_ratio: f32,
pub auto_features: bool,
pub auto_preprocessing: bool,
pub auto_mode: bool,
pub force_mode: Option<BoostingMode>,
pub max_generated_features: usize,
pub seed: u64,
pub verbose: bool,
pub time_budget: Option<Duration>,
pub progress_callback: Arc<dyn ProgressCallback>,
pub linear_feature_config: LinearFeatureConfig,
pub custom_config: Option<crate::model::UniversalConfig>,
pub ensemble: Option<AutoEnsembleConfig>,
}
impl Clone for AutoConfig {
fn clone(&self) -> Self {
Self {
tuning_level: self.tuning_level,
val_ratio: self.val_ratio,
auto_features: self.auto_features,
auto_preprocessing: self.auto_preprocessing,
auto_mode: self.auto_mode,
force_mode: self.force_mode,
max_generated_features: self.max_generated_features,
seed: self.seed,
verbose: self.verbose,
time_budget: self.time_budget,
progress_callback: Arc::clone(&self.progress_callback),
linear_feature_config: self.linear_feature_config.clone(),
custom_config: self.custom_config.clone(),
ensemble: self.ensemble.clone(),
}
}
}
impl std::fmt::Debug for AutoConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AutoConfig")
.field("tuning_level", &self.tuning_level)
.field("val_ratio", &self.val_ratio)
.field("auto_features", &self.auto_features)
.field("auto_preprocessing", &self.auto_preprocessing)
.field("auto_mode", &self.auto_mode)
.field("force_mode", &self.force_mode)
.field("max_generated_features", &self.max_generated_features)
.field("seed", &self.seed)
.field("verbose", &self.verbose)
.field("time_budget", &self.time_budget)
.field("progress_callback", &"<callback>")
.field("linear_feature_config", &self.linear_feature_config)
.field("ensemble", &self.ensemble)
.finish()
}
}
impl Default for AutoConfig {
fn default() -> Self {
Self {
tuning_level: TuningLevel::Standard,
val_ratio: auto_defaults::DEFAULT_VALIDATION_RATIO,
auto_features: true,
auto_preprocessing: true,
auto_mode: true,
force_mode: None,
max_generated_features: auto_defaults::AUTO_FEATURES_DEFAULT_COUNT,
seed: seeds_defaults::DEFAULT_SEED,
verbose: false,
time_budget: None,
progress_callback: Arc::new(QuietProgress),
linear_feature_config: LinearFeatureConfig::default(),
custom_config: None,
ensemble: None,
}
}
}
impl AutoConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_tuning(mut self, level: TuningLevel) -> Self {
self.tuning_level = level;
self
}
pub fn with_validation_split(mut self, ratio: f32) -> Self {
self.val_ratio = ratio.clamp(0.1, 0.4);
self
}
pub fn with_auto_features(mut self, enabled: bool) -> Self {
self.auto_features = enabled;
self
}
pub fn with_auto_preprocessing(mut self, enabled: bool) -> Self {
self.auto_preprocessing = enabled;
self
}
pub fn with_auto_mode(mut self, enabled: bool) -> Self {
self.auto_mode = enabled;
self
}
pub fn with_mode(mut self, mode: BoostingMode) -> Self {
self.force_mode = Some(mode);
self.auto_mode = false;
self
}
pub fn with_ensemble(mut self) -> Self {
self.ensemble = Some(AutoEnsembleConfig::default());
self
}
pub fn with_ensemble_method(mut self, method: AutoEnsembleMethod) -> Self {
self.ensemble = Some(AutoEnsembleConfig::default().with_method(method));
self
}
pub fn with_ensemble_config(mut self, config: AutoEnsembleConfig) -> Self {
self.ensemble = Some(config);
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn with_verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
pub fn with_time_budget(mut self, budget: Duration) -> Self {
self.time_budget = Some(budget);
self
}
pub fn with_progress_callback(mut self, callback: Arc<dyn ProgressCallback>) -> Self {
self.progress_callback = callback;
self
}
pub fn with_linear_feature_config(mut self, config: LinearFeatureConfig) -> Self {
self.linear_feature_config = config;
self
}
pub fn with_custom_config(mut self, config: crate::model::UniversalConfig) -> Self {
self.custom_config = Some(config);
self
}
}
#[derive(Debug, Clone)]
pub struct TreeTuningResult {
pub num_trials: usize,
pub best_metric: f32,
pub best_params: std::collections::HashMap<String, f32>,
}
#[derive(Debug, Clone)]
pub struct TreeTunerConfig {
pub max_depth_range: (usize, usize),
pub learning_rate_range: (f32, f32),
pub n_samples: usize,
pub n_iterations: usize,
pub max_rounds: usize,
pub early_stopping_rounds: usize,
pub validation_ratio: f32,
pub improvement_threshold: f32,
pub min_f1_score: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TreeTunerPreset {
Quick,
Standard,
Thorough,
}
impl TreeTunerConfig {
fn preset_quick() -> Self {
Self {
max_depth_range: auto_defaults::QUICK_DEPTH_RANGE,
learning_rate_range: auto_defaults::QUICK_LR_RANGE,
n_samples: 30,
n_iterations: 1,
max_rounds: 100,
early_stopping_rounds: 10,
validation_ratio: auto_defaults::DEFAULT_VALIDATION_RATIO,
improvement_threshold: 0.001,
min_f1_score: 0.80,
}
}
fn preset_standard() -> Self {
Self {
max_depth_range: auto_defaults::STANDARD_DEPTH_RANGE,
learning_rate_range: auto_defaults::STANDARD_LR_RANGE,
n_samples: 100,
n_iterations: 3,
max_rounds: 200,
early_stopping_rounds: 10,
validation_ratio: auto_defaults::DEFAULT_VALIDATION_RATIO,
improvement_threshold: 0.001,
min_f1_score: 0.85,
}
}
fn preset_thorough() -> Self {
Self {
max_depth_range: auto_defaults::THOROUGH_DEPTH_RANGE,
learning_rate_range: auto_defaults::STANDARD_LR_RANGE,
n_samples: 150,
n_iterations: 15,
max_rounds: 200,
early_stopping_rounds: 10,
validation_ratio: auto_defaults::DEFAULT_VALIDATION_RATIO,
improvement_threshold: 0.001,
min_f1_score: 0.85,
}
}
pub fn with_preset(preset: TreeTunerPreset) -> Self {
match preset {
TreeTunerPreset::Quick => Self::preset_quick(),
TreeTunerPreset::Standard => Self::preset_standard(),
TreeTunerPreset::Thorough => Self::preset_thorough(),
}
}
}
#[derive(Debug)]
pub struct BuildResult {
pub model: UniversalModel,
pub mode: BoostingMode,
pub target_column: String,
pub mode_confidence: Option<Confidence>,
pub preprocessing_plan: Option<PreprocessingPlan>,
pub feature_plan: Option<FeaturePlan>,
pub ltt_tuning: Option<LttTuningResult>,
pub tree_tuning: Option<TreeTuningResult>,
pub column_profile: Option<DataFrameProfile>,
pub analysis: Option<DatasetAnalysis>,
pub pipeline_state: Option<crate::dataset::PipelineState>,
pub build_time: Duration,
pub phase_times: BuildPhaseTimes,
}
#[derive(Debug, Clone, Default)]
pub struct BuildPhaseTimes {
pub profiling: Duration,
pub preprocessing: Duration,
pub feature_engineering: Duration,
pub analysis: Duration,
pub tuning: Duration,
pub training: Duration,
}