use crate::analysis::{Confidence, DataFrameProfile, DatasetAnalysis};
use crate::dataset::{BinnedDataset, DataPipeline};
use crate::features::FeaturePlan;
use crate::loss::MseLoss;
use crate::model::{
AutoBuilder, AutoConfig, BoostingMode, BuildPhaseTimes, BuildResult, TreeTuningResult,
TuningLevel, UniversalModel,
};
use crate::preprocessing::PreprocessingPlan;
use crate::tuner::ltt::LttTuningResult;
use crate::{Result, TreeBoostError};
use polars::prelude::*;
use std::time::Duration;
pub struct AutoModel {
model: UniversalModel,
mode: BoostingMode,
target_column: String,
mode_confidence: Option<Confidence>,
preprocessing_plan: Option<PreprocessingPlan>,
feature_plan: Option<FeaturePlan>,
ltt_tuning: Option<LttTuningResult>,
tree_tuning: Option<TreeTuningResult>,
column_profile: Option<DataFrameProfile>,
analysis: Option<DatasetAnalysis>,
pipeline_state: Option<crate::dataset::PipelineState>,
build_time: Duration,
phase_times: BuildPhaseTimes,
}
impl AutoModel {
pub fn from_build_result(result: BuildResult) -> Self {
Self {
model: result.model,
mode: result.mode,
target_column: result.target_column,
mode_confidence: result.mode_confidence,
preprocessing_plan: result.preprocessing_plan,
feature_plan: result.feature_plan,
ltt_tuning: result.ltt_tuning,
tree_tuning: result.tree_tuning,
column_profile: result.column_profile,
pipeline_state: result.pipeline_state,
analysis: result.analysis,
build_time: result.build_time,
phase_times: result.phase_times,
}
}
pub fn train(df: &DataFrame, target_col: &str) -> Result<Self> {
let builder = AutoBuilder::new();
let result = builder.fit(df, target_col)?;
Ok(Self::from_build_result(result))
}
pub fn train_quick(df: &DataFrame, target_col: &str) -> Result<Self> {
let builder = AutoBuilder::new().with_tuning(TuningLevel::Quick);
let result = builder.fit(df, target_col)?;
Ok(Self::from_build_result(result))
}
pub fn train_thorough(df: &DataFrame, target_col: &str) -> Result<Self> {
let builder = AutoBuilder::new().with_tuning(TuningLevel::Thorough);
let result = builder.fit(df, target_col)?;
Ok(Self::from_build_result(result))
}
pub fn train_with_mode(df: &DataFrame, target_col: &str, mode: BoostingMode) -> Result<Self> {
let builder = AutoBuilder::new().with_mode(mode);
let result = builder.fit(df, target_col)?;
Ok(Self::from_build_result(result))
}
pub fn train_with_config(df: &DataFrame, target_col: &str, config: AutoConfig) -> Result<Self> {
let builder = AutoBuilder::with_config(config);
let result = builder.fit(df, target_col)?;
Ok(Self::from_build_result(result))
}
pub fn predict(&self, df: &DataFrame) -> Result<Vec<f32>> {
let (preprocessed_df, dataset) = self.prepare_dataset_for_prediction(df)?;
if matches!(self.mode, crate::model::BoostingMode::LinearThenTree) {
if let Some(ref extractor) = self.model.feature_extractor() {
let (raw_features, _num_features) =
extractor.extract(&preprocessed_df, &self.target_column)?;
return Ok(self
.model
.predict_with_raw_features(&dataset, &raw_features));
}
}
Ok(self.model.predict(&dataset))
}
pub fn predict_linear_only(&self, df: &DataFrame) -> Result<Vec<f32>> {
if !matches!(self.mode, crate::model::BoostingMode::LinearThenTree) {
return Err(TreeBoostError::Config(
"predict_linear_only() only available for LinearThenTree mode".to_string(),
));
}
let (preprocessed_df, dataset) = self.prepare_dataset_for_prediction(df)?;
if let Some(ref extractor) = self.model.feature_extractor() {
let (raw_features, _num_features) =
extractor.extract(&preprocessed_df, &self.target_column)?;
return Ok(self.model.predict_linear_only(&dataset, &raw_features)?);
}
Err(TreeBoostError::Config(
"LinearThenTree model missing FeatureExtractor - cannot predict".to_string(),
))
}
pub fn predict_binned(&self, dataset: &BinnedDataset) -> Vec<f32> {
self.model.predict(dataset)
}
pub fn mode(&self) -> BoostingMode {
self.mode
}
pub fn mode_confidence(&self) -> Option<Confidence> {
self.mode_confidence
}
pub fn build_time(&self) -> Duration {
self.build_time
}
pub fn phase_times(&self) -> &BuildPhaseTimes {
&self.phase_times
}
pub fn preprocessing_plan(&self) -> Option<&PreprocessingPlan> {
self.preprocessing_plan.as_ref()
}
pub fn feature_plan(&self) -> Option<&FeaturePlan> {
self.feature_plan.as_ref()
}
pub fn column_profile(&self) -> Option<&DataFrameProfile> {
self.column_profile.as_ref()
}
pub fn analysis(&self) -> Option<&DatasetAnalysis> {
self.analysis.as_ref()
}
pub fn ltt_tuning(&self) -> Option<&LttTuningResult> {
self.ltt_tuning.as_ref()
}
pub fn tree_tuning(&self) -> Option<&TreeTuningResult> {
self.tree_tuning.as_ref()
}
pub fn num_trees(&self) -> usize {
self.model.num_trees()
}
pub fn num_features(&self) -> usize {
self.model.num_features()
}
pub fn inner(&self) -> &UniversalModel {
&self.model
}
pub fn config(&self) -> &crate::model::UniversalConfig {
self.model.config()
}
pub fn summary(&self) -> String {
let mut lines = vec![
"┌─────────────────────────────────────────────────────────────────┐".to_string(),
"│ TreeBoost Pipeline Report │".to_string(),
"└─────────────────────────────────────────────────────────────────┘".to_string(),
"".to_string(),
];
if let Some(ref profile) = self.column_profile {
lines.push("═══ DATA PROFILE ═══".to_string());
lines.push(format!(" Rows: {}", profile.num_rows));
lines.push(format!(" Columns: {} total", profile.columns.len()));
lines.push(format!(" • Numeric: {}", profile.num_numeric));
lines.push(format!(" • Categorical: {}", profile.num_categorical));
lines.push(format!(
" Target: {} ({:?})",
self.target_column, profile.task_type
));
if !profile.drop_columns.is_empty() {
lines.push("".to_string());
lines.push(format!(" Dropped {} columns:", profile.drop_columns.len()));
for dropped in &profile.drop_columns {
lines.push(format!(" • '{}' - {}", dropped.name, dropped.reason));
}
}
lines.push("".to_string());
}
if let Some(ref plan) = self.preprocessing_plan {
lines.push("═══ PREPROCESSING DECISIONS ═══".to_string());
if !plan.reasoning.is_empty() {
for reason in &plan.reasoning {
lines.push(format!(" • {}", reason));
}
} else {
lines.push(" • No special preprocessing required".to_string());
}
lines.push("".to_string());
}
if let Some(ref plan) = self.feature_plan {
lines.push("═══ FEATURE ENGINEERING ═══".to_string());
if !plan.polynomial_features.is_empty() {
lines.push(format!(
" Polynomial features ({}): ",
plan.polynomial_features.len()
));
for feat in &plan.polynomial_features {
lines.push(format!(" • {}", feat));
}
}
if !plan.ratio_pairs.is_empty() {
lines.push(format!(" Ratio features ({}): ", plan.ratio_pairs.len()));
for (f1, f2) in &plan.ratio_pairs {
lines.push(format!(" • {}/{}", f1, f2));
}
}
if !plan.interaction_pairs.is_empty() {
lines.push(format!(
" Interaction features ({}): ",
plan.interaction_pairs.len()
));
for (f1, f2) in plan.interaction_pairs.iter().take(5) {
lines.push(format!(" • {} × {}", f1, f2));
}
if plan.interaction_pairs.len() > 5 {
lines.push(format!(
" ... and {} more",
plan.interaction_pairs.len() - 5
));
}
}
if !plan.reasoning.is_empty() {
lines.push("".to_string());
lines.push(" Reasoning:".to_string());
for reason in &plan.reasoning {
lines.push(format!(" • {}", reason));
}
}
lines.push("".to_string());
}
lines.push("═══ MODE SELECTION ═══".to_string());
lines.push(format!(" Selected: {:?}", self.mode));
lines.push(format!(
" Confidence: {:?}",
self.mode_confidence
.as_ref()
.map(|c| format!("{:?}", c))
.unwrap_or("N/A".to_string())
));
if let Some(ref analysis) = self.analysis {
lines.push("".to_string());
lines.push(" Analysis Results:".to_string());
lines.push(format!(
" • Linear R²: {:.4} ({})",
analysis.linear_r2,
if analysis.linear_r2 > 0.5 {
"Strong"
} else if analysis.linear_r2 > 0.3 {
"Moderate"
} else {
"Weak"
}
));
lines.push(format!(
" • Tree Gain: {:.4} ({})",
analysis.tree_gain,
if analysis.tree_gain > 0.3 {
"Strong"
} else if analysis.tree_gain > 0.1 {
"Moderate"
} else {
"Weak"
}
));
let recommended_mode = analysis.recommend_mode();
let reasoning = if analysis.linear_r2 > 0.5 && analysis.tree_gain > 0.1 {
"Strong linear trend + residual structure → Hybrid approach"
} else if analysis.linear_r2 > 0.5 {
"Strong linear relationship → Linear model dominates"
} else if analysis.tree_gain > 0.1 {
"Non-linear patterns → Tree-based approach"
} else {
"Moderate signals → Pure tree model"
};
lines.push("".to_string());
lines.push(format!(" Recommended: {:?}", recommended_mode));
lines.push(format!(" Reasoning: {}", reasoning));
}
lines.push("".to_string());
if let Some(ref tuning) = self.ltt_tuning {
lines.push("═══ LTT TUNING RESULTS ═══".to_string());
lines.push(" Linear Phase:".to_string());
lines.push(format!(" • R²: {:.4}", tuning.linear_r2));
lines.push(format!(" • Lambda: {:.4}", tuning.linear_params.lambda));
lines.push(format!(
" • L1 Ratio: {:.4} ({})",
tuning.linear_params.l1_ratio,
if tuning.linear_params.l1_ratio == 0.0 {
"Ridge"
} else if tuning.linear_params.l1_ratio == 1.0 {
"LASSO"
} else {
"ElasticNet"
}
));
lines.push("".to_string());
lines.push(" Tree Phase:".to_string());
lines.push(format!(" • Max Depth: {}", tuning.tree_params.max_depth));
lines.push(format!(
" • Learning Rate: {:.4}",
tuning.tree_params.learning_rate
));
lines.push(format!(
" • Num Rounds: {}",
tuning.tree_params.num_rounds
));
lines.push("".to_string());
lines.push(format!(" Final RMSE: {:.4}", tuning.final_rmse));
lines.push("".to_string());
}
lines.push("═══ TRAINING SUMMARY ═══".to_string());
lines.push(format!(
" Total Time: {:.3}s",
self.build_time.as_secs_f64()
));
lines.push("".to_string());
lines.push(" Phase Breakdown:".to_string());
lines.push(format!(" • Profiling: {:?}", self.phase_times.profiling));
lines.push(format!(
" • Preprocessing: {:?}",
self.phase_times.preprocessing
));
lines.push(format!(
" • Feature Engineering: {:?}",
self.phase_times.feature_engineering
));
lines.push(format!(" • Analysis: {:?}", self.phase_times.analysis));
lines.push(format!(" • Tuning: {:?}", self.phase_times.tuning));
lines.push(format!(" • Training: {:?}", self.phase_times.training));
lines.push("".to_string());
lines.push(
"┌─────────────────────────────────────────────────────────────────┐".to_string(),
);
lines.push(
"│ TreeBoost: The Smart Engineer That Explains Itself │".to_string(),
);
lines.push(
"└─────────────────────────────────────────────────────────────────┘".to_string(),
);
lines.join("\n")
}
fn prepare_dataset_for_prediction(&self, df: &DataFrame) -> Result<(DataFrame, BinnedDataset)> {
let pipeline_state = self.pipeline_state.as_ref().ok_or_else(|| {
TreeBoostError::Data(
"AutoModel missing fitted pipeline state - cannot make predictions".to_string(),
)
})?;
let pipeline = DataPipeline::with_defaults();
let (preprocessed_df, dataset) =
pipeline.process_for_inference(df.clone(), pipeline_state)?;
Ok((preprocessed_df, dataset))
}
pub fn save_config(&self, path: impl AsRef<std::path::Path>) -> Result<()> {
let config = self.config();
let json = serde_json::to_string_pretty(config).map_err(|e| {
TreeBoostError::Serialization(format!("Failed to serialize config to JSON: {}", e))
})?;
std::fs::write(path, json)?;
Ok(())
}
pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<()> {
self.model.save(path)
}
pub fn update(
&mut self,
df: &DataFrame,
additional_rounds: usize,
) -> Result<AutoModelUpdateReport> {
let rows_before = df.height();
let (_preprocessed_df, dataset) = self.prepare_dataset_for_prediction(df)?;
let target_series = df.column(&self.target_column).map_err(|e| {
TreeBoostError::Data(format!(
"Target column '{}' not found: {}",
self.target_column, e
))
})?;
let targets: Vec<f32> = target_series
.cast(&polars::datatypes::DataType::Float32)
.map_err(|e| TreeBoostError::Data(format!("Failed to cast target to f32: {}", e)))?
.f32()
.map_err(|e| TreeBoostError::Data(format!("Failed to get f32 values: {}", e)))?
.into_no_null_iter()
.collect();
let update_dataset = dataset.with_targets(targets);
let loss_fn = MseLoss::new();
let model_report = self
.model
.update(&update_dataset, &loss_fn, additional_rounds)?;
Ok(AutoModelUpdateReport {
rows_trained: rows_before,
trees_before: model_report.trees_before,
trees_after: model_report.trees_after,
trees_added: model_report.trees_added,
mode: self.mode,
target_column: self.target_column.clone(),
})
}
pub fn save_trb(&self, path: impl AsRef<std::path::Path>, description: &str) -> Result<()> {
self.model.save_trb(path, description)
}
pub fn save_trb_update(
&self,
path: impl AsRef<std::path::Path>,
rows_trained: usize,
description: &str,
) -> Result<()> {
self.model.save_trb_update(path, rows_trained, description)
}
pub fn load_trb(path: impl AsRef<std::path::Path>, target_column: &str) -> Result<Self> {
let model = crate::model::UniversalModel::load_trb(path)?;
let mode = model.mode();
Ok(Self {
model,
mode,
target_column: target_column.to_string(),
mode_confidence: None,
preprocessing_plan: None,
feature_plan: None,
ltt_tuning: None,
tree_tuning: None,
column_profile: None,
analysis: None,
pipeline_state: None, build_time: Duration::default(),
phase_times: BuildPhaseTimes::default(),
})
}
pub fn is_compatible_for_update(&self, df: &DataFrame) -> bool {
if df.column(&self.target_column).is_err() {
return false;
}
self.prepare_dataset_for_prediction(df).is_ok()
}
pub fn model_mut(&mut self) -> &mut UniversalModel {
&mut self.model
}
}
#[derive(Debug, Clone)]
pub struct AutoModelUpdateReport {
pub rows_trained: usize,
pub trees_before: usize,
pub trees_after: usize,
pub trees_added: usize,
pub mode: BoostingMode,
pub target_column: String,
}
impl std::fmt::Display for AutoModelUpdateReport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"AutoModel Update: {} rows on '{}', {} trees added ({} -> {}), mode={:?}",
self.rows_trained,
self.target_column,
self.trees_added,
self.trees_before,
self.trees_after,
self.mode
)
}
}
impl std::fmt::Debug for AutoModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AutoModel")
.field("mode", &self.mode)
.field("mode_confidence", &self.mode_confidence)
.field("build_time", &self.build_time)
.finish()
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_auto_model_from_build_result() {
}
#[test]
fn test_auto_model_summary_format() {
}
}