use crate::analysis::{Confidence, DataFrameProfile, DatasetAnalysis};
use crate::dataset::feature_extractor::LinearFeatureConfig;
use crate::dataset::{BinnedDataset, DataPipeline};
use crate::defaults::auto as auto_defaults;
use crate::features::{FeaturePlan, SmartFeatureEngine};
use crate::model::config::{
AutoConfig, AutoEnsembleConfig, AutoEnsembleMethod, BuildPhaseTimes, BuildResult, TuningLevel,
};
use crate::model::progress::{ProgressCallback, ProgressUpdate, TrainingPhase};
use crate::model::{BoostingMode, UniversalConfig, UniversalModel};
use crate::preprocessing::{ModelType, PreprocessingPlan, SmartPreprocessor};
use crate::Result;
use polars::prelude::*;
use std::sync::Arc;
use std::time::{Duration, Instant};
use super::tuning;
pub struct AutoBuilder {
config: AutoConfig,
}
impl AutoBuilder {
pub fn new() -> Self {
Self {
config: AutoConfig::default(),
}
}
pub fn with_config(config: AutoConfig) -> Self {
Self { config }
}
pub fn with_tuning(mut self, level: TuningLevel) -> Self {
self.config.tuning_level = level;
self
}
pub fn with_validation_split(mut self, ratio: f32) -> Self {
self.config.val_ratio = ratio;
self
}
pub fn with_auto_features(mut self, enabled: bool) -> Self {
self.config.auto_features = enabled;
self
}
pub fn with_mode(mut self, mode: BoostingMode) -> Self {
self.config.force_mode = Some(mode);
self.config.auto_mode = false;
self
}
pub fn with_ensemble(mut self) -> Self {
self.config.ensemble = Some(AutoEnsembleConfig::default());
self
}
pub fn with_ensemble_method(mut self, method: AutoEnsembleMethod) -> Self {
self.config.ensemble = Some(AutoEnsembleConfig::default().with_method(method));
self
}
pub fn with_ensemble_config(mut self, config: AutoEnsembleConfig) -> Self {
self.config.ensemble = Some(config);
self
}
pub fn with_verbose(mut self, verbose: bool) -> Self {
self.config.verbose = verbose;
self
}
pub fn with_time_budget(mut self, budget: Duration) -> Self {
self.config.time_budget = Some(budget);
self
}
pub fn with_progress_callback(mut self, callback: Arc<dyn ProgressCallback>) -> Self {
self.config.progress_callback = callback;
self
}
pub fn with_linear_feature_config(mut self, config: LinearFeatureConfig) -> Self {
self.config.linear_feature_config = config;
self
}
pub fn fit(&self, df: &DataFrame, target_col: &str) -> Result<BuildResult> {
let start = Instant::now();
let mut phase_times = BuildPhaseTimes::default();
let mut adapted_config = self.config.clone();
if let Some(budget) = self.config.time_budget {
if self.config.verbose {
println!("AutoBuilder: Time budget set to {:?}", budget);
}
}
if self.config.verbose {
println!("AutoBuilder: Starting build process...");
}
self.config.progress_callback.on_progress(&ProgressUpdate {
phase: TrainingPhase::Profiling,
progress_pct: TrainingPhase::Profiling.progress_pct(),
elapsed: start.elapsed(),
message: Some(format!("Analyzing {} columns", df.width())),
});
let phase_start = Instant::now();
let profile = self.profile_dataframe(df, target_col)?;
phase_times.profiling = phase_start.elapsed();
if self.config.verbose {
println!(
" [Profile] {} columns analyzed, {} dropped, task: {:?}",
profile.columns.len(),
profile.drop_columns.len(),
profile.task_type
);
}
self.config.progress_callback.on_progress(&ProgressUpdate {
phase: TrainingPhase::Preprocessing,
progress_pct: TrainingPhase::Preprocessing.progress_pct(),
elapsed: start.elapsed(),
message: Some(format!(
"{} columns retained",
profile
.columns
.len()
.saturating_sub(profile.drop_columns.len())
)),
});
let phase_start = Instant::now();
let (model_type, preprocessing_plan) = self.plan_preprocessing(&profile)?;
phase_times.preprocessing = phase_start.elapsed();
if self.config.verbose {
println!(
" [Preprocess] Model type: {:?}, {} steps planned",
model_type,
preprocessing_plan.steps.len()
);
}
let mut skip_features = !adapted_config.auto_features;
if let Some(budget) = self.config.time_budget {
let elapsed = start.elapsed();
let remaining = budget.saturating_sub(elapsed);
if remaining < Duration::from_secs(20) {
skip_features = true;
if self.config.verbose && adapted_config.auto_features {
println!(" [Budget] Low time remaining, skipping feature engineering");
}
}
}
self.config.progress_callback.on_progress(&ProgressUpdate {
phase: TrainingPhase::FeatureEngineering,
progress_pct: TrainingPhase::FeatureEngineering.progress_pct(),
elapsed: start.elapsed(),
message: if skip_features {
Some("Skipped".to_string())
} else {
Some("Planning features".to_string())
},
});
let phase_start = Instant::now();
let feature_plan = if !skip_features {
Some(self.plan_features(&profile)?)
} else {
None
};
phase_times.feature_engineering = phase_start.elapsed();
if self.config.verbose {
if let Some(ref plan) = feature_plan {
println!(
" [Features] {} polynomial, {} ratio, {} interaction features",
plan.polynomial_features.len(),
plan.ratio_pairs.len(),
plan.interaction_pairs.len()
);
}
}
self.config.progress_callback.on_progress(&ProgressUpdate {
phase: TrainingPhase::DatasetPreparation,
progress_pct: TrainingPhase::DatasetPreparation.progress_pct(),
elapsed: start.elapsed(),
message: Some(format!("{} rows", df.height())),
});
let (dataset, pipeline_state, filtered_df) =
self.prepare_dataset_and_state(df, target_col)?;
self.config.progress_callback.on_progress(&ProgressUpdate {
phase: TrainingPhase::Analysis,
progress_pct: TrainingPhase::Analysis.progress_pct(),
elapsed: start.elapsed(),
message: Some("Selecting optimal mode".to_string()),
});
let phase_start = Instant::now();
let (mode, analysis, mode_confidence) = self.select_mode(&dataset)?;
phase_times.analysis = phase_start.elapsed();
if self.config.verbose {
println!(
" [Analysis] Selected mode: {:?}, confidence: {:?}",
mode, mode_confidence
);
}
if let Some(budget) = self.config.time_budget {
let elapsed = start.elapsed();
let remaining = budget.saturating_sub(elapsed);
if remaining < Duration::from_secs(10) {
adapted_config.tuning_level = TuningLevel::None;
if self.config.verbose {
println!(" [Budget] Low time remaining, skipping tuning");
}
} else if remaining < Duration::from_secs(30) {
if adapted_config.tuning_level != TuningLevel::None {
adapted_config.tuning_level = TuningLevel::Quick;
if self.config.verbose {
println!(" [Budget] Limited time, using quick tuning");
}
}
}
}
self.config.progress_callback.on_progress(&ProgressUpdate {
phase: TrainingPhase::Tuning,
progress_pct: TrainingPhase::Tuning.progress_pct(),
elapsed: start.elapsed(),
message: Some(format!("{:?} - {:?}", adapted_config.tuning_level, mode)),
});
let phase_start = Instant::now();
let (universal_config, ltt_tuning, tree_tuning) =
if let Some(ref custom) = self.config.custom_config {
(custom.clone(), None, None)
} else if adapted_config.tuning_level == TuningLevel::None {
let config = tuning::create_config_for_mode(mode, self.config.tuning_level);
(config, None, None)
} else {
tuning::tune_hyperparameters(
&self.config,
&dataset,
mode,
&filtered_df,
target_col,
&profile,
)?
};
phase_times.tuning = phase_start.elapsed();
if self.config.verbose {
let num_trials = ltt_tuning
.as_ref()
.map(|t| t.history.linear_trials.len())
.or_else(|| tree_tuning.as_ref().map(|t| t.num_trials))
.unwrap_or(0);
println!(" [Tuning] {} trials completed", num_trials);
}
self.config.progress_callback.on_progress(&ProgressUpdate {
phase: TrainingPhase::Training,
progress_pct: TrainingPhase::Training.progress_pct(),
elapsed: start.elapsed(),
message: Some(format!("{:?} model", mode)),
});
let phase_start = Instant::now();
let (feature_extractor, raw_features, linear_indices) =
if matches!(mode, BoostingMode::LinearThenTree) {
let all_config = LinearFeatureConfig {
exclude_columns: std::collections::HashSet::new(),
exclude_categorical: false,
exclude_id: false,
exclude_constant: false,
exclude_boolean: false,
exclude_datetime: false,
exclude_text: false,
};
let all_extractor =
crate::dataset::feature_extractor::FeatureExtractor::with_config(all_config);
let (all_features, _num_all_features) =
all_extractor.extract(&filtered_df, target_col)?;
let filter_config = &self.config.linear_feature_config;
let filter_extractor =
crate::dataset::feature_extractor::FeatureExtractor::with_config(
filter_config.clone(),
);
let mut linear_feature_indices = Vec::new();
let col_names: Vec<String> = filtered_df
.get_column_names()
.iter()
.map(|s| s.to_string())
.collect();
let mut feature_idx = 0;
for col_name in &col_names {
if col_name == target_col {
continue; }
if !filter_extractor.should_exclude_column(&filtered_df, col_name, target_col) {
linear_feature_indices.push(feature_idx);
}
feature_idx += 1; }
(
Some(all_extractor),
Some(all_features),
Some(linear_feature_indices),
)
} else {
(None, None, None)
};
let final_config = if let Some(ref ensemble_config) = adapted_config.ensemble {
if matches!(mode, BoostingMode::PureTree | BoostingMode::LinearThenTree) {
let seeds: Vec<u64> = (0..ensemble_config.multi_seed.n_seeds)
.map(|i| ensemble_config.multi_seed.base_seed + i as u64)
.collect();
if self.config.verbose {
println!(" [Ensemble] Training with {} seeds", seeds.len());
}
let stacking_strategy = crate::model::universal::config::StackingStrategy::Ridge {
alpha: ensemble_config.stacking.alpha,
rank_transform: ensemble_config.stacking.rank_transform,
fit_intercept: ensemble_config.stacking.fit_intercept,
min_weight: ensemble_config.stacking.min_weight,
};
universal_config
.with_ensemble_seeds(seeds)
.with_stacking_strategy(stacking_strategy)
} else {
if self.config.verbose {
println!(
" [Ensemble] Skipped (mode {:?} does not support ensembles)",
mode
);
}
universal_config
}
} else {
universal_config
};
let model = self.train_model(
&dataset,
final_config,
feature_extractor,
raw_features,
linear_indices,
)?;
phase_times.training = phase_start.elapsed();
if self.config.verbose {
println!(" [Train] Model trained in {:?}", phase_times.training);
}
self.config.progress_callback.on_progress(&ProgressUpdate {
phase: TrainingPhase::Complete,
progress_pct: TrainingPhase::Complete.progress_pct(),
elapsed: start.elapsed(),
message: Some(format!("Total: {:?}", start.elapsed())),
});
Ok(BuildResult {
model,
mode,
target_column: target_col.to_string(),
mode_confidence: Some(mode_confidence),
preprocessing_plan: Some(preprocessing_plan),
feature_plan,
ltt_tuning,
tree_tuning,
column_profile: Some(profile),
analysis,
pipeline_state: Some(pipeline_state), build_time: start.elapsed(),
phase_times,
})
}
fn profile_dataframe(&self, df: &DataFrame, target_col: &str) -> Result<DataFrameProfile> {
DataFrameProfile::analyze(df, target_col)
}
fn plan_preprocessing(
&self,
profile: &DataFrameProfile,
) -> Result<(ModelType, PreprocessingPlan)> {
let has_linear_signal = profile.columns.iter().any(|c| {
c.target_correlation
.map(|r| r.abs() > auto_defaults::LINEAR_SIGNAL_THRESHOLD)
.unwrap_or(false)
});
let model_type = if has_linear_signal {
ModelType::LinearThenTree
} else {
ModelType::Tree
};
let plan = SmartPreprocessor::infer(profile, model_type);
Ok((model_type, plan))
}
fn plan_features(&self, profile: &DataFrameProfile) -> Result<FeaturePlan> {
let plan = SmartFeatureEngine::infer(profile, None);
Ok(plan)
}
fn prepare_dataset_and_state(
&self,
df: &DataFrame,
target_col: &str,
) -> Result<(BinnedDataset, crate::dataset::PipelineState, DataFrame)> {
let pipeline = DataPipeline::with_defaults();
let (dataset, pipeline_state, filtered_df) =
pipeline.process_for_training(df.clone(), target_col, None)?;
Ok((dataset, pipeline_state, filtered_df))
}
fn select_mode(
&self,
dataset: &BinnedDataset,
) -> Result<(BoostingMode, Option<DatasetAnalysis>, Confidence)> {
if let Some(ref custom) = self.config.custom_config {
return Ok((custom.mode, None, Confidence::High));
}
if let Some(mode) = self.config.force_mode {
return Ok((mode, None, Confidence::High));
}
if !self.config.auto_mode {
return Ok((BoostingMode::PureTree, None, Confidence::Medium));
}
let analysis = DatasetAnalysis::analyze(dataset)?;
let mode = analysis.recommend_mode();
let confidence = analysis.confidence();
Ok((mode, Some(analysis), confidence))
}
fn train_model(
&self,
dataset: &BinnedDataset,
mut config: UniversalConfig,
feature_extractor: Option<crate::dataset::feature_extractor::FeatureExtractor>,
raw_features: Option<Vec<f32>>,
linear_indices: Option<Vec<usize>>,
) -> Result<UniversalModel> {
let loss = crate::loss::MseLoss;
config.feature_extractor = feature_extractor;
match (config.mode, raw_features, linear_indices) {
(crate::model::BoostingMode::LinearThenTree, Some(features), Some(indices)) => {
UniversalModel::train_with_linear_feature_selection(
dataset, &features, &indices, config, &loss,
)
}
_ => UniversalModel::train(dataset, config, &loss),
}
}
}
impl Default for AutoBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auto_config_defaults() {
let config = AutoConfig::default();
assert_eq!(config.tuning_level, TuningLevel::Standard);
assert!((config.val_ratio - 0.2).abs() < 0.01);
assert!(config.auto_features);
assert!(config.auto_preprocessing);
assert!(config.auto_mode);
}
#[test]
fn test_auto_config_builder() {
let config = AutoConfig::new()
.with_tuning(TuningLevel::Thorough)
.with_validation_split(0.3)
.with_auto_features(false)
.with_mode(BoostingMode::LinearThenTree);
assert_eq!(config.tuning_level, TuningLevel::Thorough);
assert!((config.val_ratio - 0.3).abs() < 0.01);
assert!(!config.auto_features);
assert_eq!(config.force_mode, Some(BoostingMode::LinearThenTree));
assert!(!config.auto_mode);
}
#[test]
fn test_auto_builder_creation() {
let builder = AutoBuilder::new()
.with_tuning(TuningLevel::Quick)
.with_verbose(true);
assert_eq!(builder.config.tuning_level, TuningLevel::Quick);
assert!(builder.config.verbose);
}
#[test]
fn test_tuning_level_variants() {
assert_eq!(TuningLevel::default(), TuningLevel::Standard);
}
}