use polars::prelude::*;
use treeboost::{
dataset::feature_extractor::LinearFeatureConfig,
learner::{LinearConfig, TreeConfig},
model::{AutoConfig, AutoModel, BoostingMode, TuningLevel, UniversalConfig},
};
#[test]
fn test_shrinkage_factor_applied() {
let x: Vec<f64> = (0..50).map(|i| i as f64 * 0.1).collect();
let y: Vec<f64> = x.iter().map(|v| 2.0 * v + 5.0).collect();
let df = DataFrame::new(vec![
Series::new("x".into(), x).into(),
Series::new("target".into(), y.clone()).into(),
])
.unwrap();
for &shrinkage in &[0.1f32, 0.5, 1.0] {
let linear_config = LinearConfig::default()
.with_preset(treeboost::LinearPreset::Ridge)
.with_lambda(0.01)
.with_shrinkage_factor(shrinkage)
.with_max_iter(500);
let tree_config = TreeConfig::default().with_max_depth(3);
let univ_config = UniversalConfig::default()
.with_mode(BoostingMode::LinearThenTree)
.with_linear_config(linear_config)
.with_tree_config(tree_config)
.with_num_rounds(20)
.with_learning_rate(0.1);
let config = AutoConfig::new()
.with_auto_features(false)
.with_tuning(TuningLevel::None)
.with_custom_config(univ_config);
let model = AutoModel::train_with_config(&df, "target", config).unwrap();
let preds = model.predict(&df).unwrap();
let stored_shrinkage = model.inner().config().linear_config.shrinkage_factor;
assert!(
(stored_shrinkage - shrinkage).abs() < 1e-6,
"shrinkage_factor not stored correctly: expected {}, got {}",
shrinkage,
stored_shrinkage
);
assert_eq!(preds.len(), df.height());
assert!(preds.iter().all(|p| p.is_finite()));
}
}
#[test]
fn test_ltt_pure_linear_data() {
let x_values: Vec<f64> = (0..100).map(|i| i as f64).collect();
let y_values: Vec<f64> = x_values.iter().map(|x| x * 2.0 + 3.0).collect();
let df = DataFrame::new(vec![
Series::new("x".into(), x_values.clone()).into(),
Series::new("target".into(), y_values.clone()).into(),
])
.unwrap();
println!("\n=== Test: Pure Linear Data (y = 2x + 3) ===");
println!(
"Input DataFrame: {} rows × {} cols",
df.height(),
df.width()
);
let linear_config = LinearConfig::default()
.with_preset(treeboost::LinearPreset::Ridge)
.with_lambda(0.01) .with_shrinkage_factor(1.0) .with_max_iter(500);
let tree_config = TreeConfig::default()
.with_max_depth(3)
.with_min_samples_leaf(5);
let univ_config = UniversalConfig::default()
.with_mode(BoostingMode::LinearThenTree)
.with_linear_config(linear_config)
.with_tree_config(tree_config)
.with_num_rounds(50)
.with_learning_rate(0.1);
let config = AutoConfig::new()
.with_auto_features(false)
.with_tuning(TuningLevel::None) .with_custom_config(univ_config);
let model = AutoModel::train_with_config(&df, "target", config).unwrap();
let predictions = model.predict(&df).unwrap();
assert_eq!(
predictions.len(),
df.height(),
"Predictions length must match input DataFrame rows"
);
let rmse: f32 = predictions
.iter()
.zip(y_values.iter())
.map(|(pred, actual)| (pred - *actual as f32).powi(2))
.sum::<f32>()
.sqrt()
/ predictions.len() as f32;
println!("Predictions: {} rows", predictions.len());
println!("Pure linear data RMSE: {:.4}", rmse);
assert!(
rmse < 5.0,
"RMSE should be low for pure linear data, got {:.4}",
rmse
);
}
#[test]
fn test_ltt_linear_plus_residual() {
let x_values: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
let y_values: Vec<f64> = x_values.iter().map(|x| x * 2.0 + x.sin()).collect();
let df = DataFrame::new(vec![
Series::new("x".into(), x_values.clone()).into(),
Series::new("target".into(), y_values.clone()).into(),
])
.unwrap();
println!("\n=== Test: Linear + Nonlinear (y = 2x + sin(x)) ===");
println!(
"Input DataFrame: {} rows × {} cols",
df.height(),
df.width()
);
let linear_config = LinearConfig::default()
.with_preset(treeboost::LinearPreset::Ridge)
.with_lambda(0.01)
.with_shrinkage_factor(1.0)
.with_max_iter(500);
let tree_config = TreeConfig::default().with_max_depth(6);
let univ_config = UniversalConfig::default()
.with_mode(BoostingMode::LinearThenTree)
.with_linear_config(linear_config)
.with_tree_config(tree_config)
.with_num_rounds(200) .with_learning_rate(0.1);
let config = AutoConfig::new()
.with_auto_features(false)
.with_tuning(TuningLevel::None)
.with_custom_config(univ_config);
let model = AutoModel::train_with_config(&df, "target", config).unwrap();
let predictions = model.predict(&df).unwrap();
assert_eq!(
predictions.len(),
df.height(),
"Predictions length must match input DataFrame rows"
);
let rmse: f32 = predictions
.iter()
.zip(y_values.iter())
.map(|(pred, actual)| (pred - *actual as f32).powi(2))
.sum::<f32>()
.sqrt()
/ predictions.len() as f32;
println!("Predictions: {} rows", predictions.len());
println!("Linear + residual RMSE: {:.4}", rmse);
assert!(
rmse < 0.5,
"RMSE should be low (LTT captures both linear trend and sin(x) residual), got {:.4}",
rmse
);
}
#[test]
fn test_ltt_with_categoricals() {
let x_values: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
let cat_values: Vec<&str> = (0..100)
.map(|i| match i % 4 {
0 => "A",
1 => "B",
2 => "C",
_ => "D",
})
.collect();
let y_values: Vec<f64> = x_values.iter().map(|x| x * 2.0 + 3.0).collect();
let df = DataFrame::new(vec![
Series::new("x".into(), x_values.clone()).into(),
Series::new("category".into(), cat_values).into(),
Series::new("target".into(), y_values.clone()).into(),
])
.unwrap();
println!("\n=== Test: LTT with Categoricals ===");
println!(
"Input DataFrame: {} rows × {} cols (1 numeric, 1 categorical)",
df.height(),
df.width()
);
let config = AutoConfig::new()
.with_mode(BoostingMode::LinearThenTree)
.with_auto_features(false)
.with_tuning(TuningLevel::None);
let model = AutoModel::train_with_config(&df, "target", config).unwrap();
let predictions = model.predict(&df).unwrap();
assert_eq!(
predictions.len(),
df.height(),
"Predictions length must match input DataFrame rows"
);
let rmse: f32 = predictions
.iter()
.zip(y_values.iter())
.map(|(pred, actual)| (pred - *actual as f32).powi(2))
.sum::<f32>()
.sqrt()
/ predictions.len() as f32;
println!("Predictions: {} rows", predictions.len());
println!("LTT with categoricals RMSE: {:.4}", rmse);
assert!(
rmse < 2.0,
"RMSE should be low even with categorical features, got {:.4}",
rmse
);
assert!(
model.inner().feature_extractor().is_some(),
"FeatureExtractor must be stored for LTT mode"
);
}
#[test]
fn test_ltt_with_id_like_columns() {
let x_values: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
let id_values: Vec<String> = (0..100).map(|i| format!("ID_{:04}", i)).collect();
let y_values: Vec<f64> = x_values.iter().map(|x| x * 2.0 + 3.0).collect();
let df = DataFrame::new(vec![
Series::new("x".into(), x_values.clone()).into(),
Series::new("id".into(), id_values).into(),
Series::new("target".into(), y_values.clone()).into(),
])
.unwrap();
let linear_config = LinearConfig::default()
.with_preset(treeboost::LinearPreset::Ridge)
.with_lambda(0.01)
.with_shrinkage_factor(1.0)
.with_max_iter(500);
let tree_config = TreeConfig::default().with_max_depth(3);
let univ_config = UniversalConfig::default()
.with_mode(BoostingMode::LinearThenTree)
.with_linear_config(linear_config)
.with_tree_config(tree_config)
.with_num_rounds(50)
.with_learning_rate(0.1);
let config = AutoConfig::new()
.with_auto_features(false)
.with_custom_config(univ_config);
let model = AutoModel::train_with_config(&df, "target", config).unwrap();
let predictions = model.predict(&df).unwrap();
let rmse: f32 = predictions
.iter()
.zip(y_values.iter())
.map(|(pred, actual)| (pred - *actual as f32).powi(2))
.sum::<f32>()
.sqrt()
/ predictions.len() as f32;
println!("LTT with ID-like columns RMSE: {:.4}", rmse);
assert!(
rmse < 5.0,
"RMSE should be low, ID columns should be auto-excluded"
);
}
#[test]
fn test_ltt_with_user_exclusions() {
let x_values: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
let corr_values: Vec<f64> = x_values.iter().map(|x| x * 3.0).collect();
let y_values: Vec<f64> = x_values.iter().map(|x| x * 2.0 + 3.0).collect();
let df = DataFrame::new(vec![
Series::new("x".into(), x_values.clone()).into(),
Series::new("correlated".into(), corr_values).into(),
Series::new("target".into(), y_values.clone()).into(),
])
.unwrap();
let linear_config = LinearFeatureConfig::default().with_exclude_columns(&["correlated"]);
let config = AutoConfig::new()
.with_mode(BoostingMode::LinearThenTree)
.with_linear_feature_config(linear_config)
.with_tuning(TuningLevel::Quick);
let model = AutoModel::train_with_config(&df, "target", config).unwrap();
let predictions = model.predict(&df).unwrap();
let rmse: f32 = predictions
.iter()
.zip(y_values.iter())
.map(|(pred, actual)| (pred - *actual as f32).powi(2))
.sum::<f32>()
.sqrt()
/ predictions.len() as f32;
println!("LTT with user exclusions RMSE: {:.4}", rmse);
assert!(
rmse < 2.0,
"RMSE should be low with user-specified exclusions"
);
}
#[test]
fn test_ltt_feature_extractor_storage() {
let x_values: Vec<f64> = (0..50).map(|i| i as f64 * 0.1).collect();
let cat_values: Vec<&str> = (0..50)
.map(|i| match i % 3 {
0 => "A",
1 => "B",
_ => "C",
})
.collect();
let y_values: Vec<f64> = x_values.iter().map(|x| x * 2.0 + 3.0).collect();
let df = DataFrame::new(vec![
Series::new("x".into(), x_values.clone()).into(),
Series::new("category".into(), cat_values).into(),
Series::new("target".into(), y_values.clone()).into(),
])
.unwrap();
let config = AutoConfig::new()
.with_mode(BoostingMode::LinearThenTree)
.with_auto_features(false)
.with_tuning(TuningLevel::None);
let model = AutoModel::train_with_config(&df, "target", config).unwrap();
assert!(
model.inner().feature_extractor().is_some(),
"FeatureExtractor should be stored in model"
);
let extractor = model.inner().feature_extractor().unwrap();
println!("Feature extractor config: {:?}", extractor.config());
println!(
"Exclude categorical: {}",
extractor.config().exclude_categorical
);
println!("Exclude ID: {}", extractor.config().exclude_id);
println!("FeatureExtractor storage successful");
}
#[test]
fn test_ltt_with_pipeline_encoded_categoricals() {
let x_values: Vec<f64> = (0..200).map(|i| i as f64 * 0.1).collect();
let cat1_values: Vec<&str> = (0..200)
.map(|i| match i % 4 {
0 => "A",
1 => "B",
2 => "C",
_ => "D",
})
.collect();
let cat2_values: Vec<&str> = (0..200)
.map(|i| match i % 3 {
0 => "X",
1 => "Y",
_ => "Z",
})
.collect();
let y_values: Vec<f64> = x_values
.iter()
.zip(cat1_values.iter())
.map(|(x, &cat)| {
let cat_effect = match cat {
"A" => 10.0,
"B" => 20.0,
"C" => 30.0,
_ => 40.0,
};
x * 2.0 + cat_effect
})
.collect();
let df = DataFrame::new(vec![
Series::new("x".into(), x_values.clone()).into(),
Series::new("cat1".into(), cat1_values).into(),
Series::new("cat2".into(), cat2_values).into(),
Series::new("target".into(), y_values.clone()).into(),
])
.unwrap();
println!("\n=== Testing LTT with Pipeline-Encoded Categoricals ===");
println!(
"Original DataFrame: {} rows × {} cols",
df.height(),
df.width()
);
println!("Original dtypes:");
for col in df.get_columns() {
println!(" {} : {:?}", col.name(), col.dtype());
}
let linear_config = LinearConfig::default()
.with_preset(treeboost::LinearPreset::Ridge)
.with_lambda(0.01)
.with_shrinkage_factor(1.0)
.with_max_iter(100);
let tree_config = TreeConfig::default().with_max_depth(6);
let univ_config = UniversalConfig::default()
.with_mode(BoostingMode::LinearThenTree)
.with_linear_config(linear_config)
.with_tree_config(tree_config)
.with_num_rounds(100)
.with_learning_rate(0.1);
let config = AutoConfig::new()
.with_auto_features(false)
.with_tuning(TuningLevel::None) .with_verbose(false)
.with_custom_config(univ_config);
let model = AutoModel::train_with_config(&df, "target", config).unwrap();
let predictions = model.predict(&df).unwrap();
assert_eq!(
predictions.len(),
df.height(),
"Predictions length must match input DataFrame rows. \
Mismatch indicates preprocessing inconsistency between train and predict."
);
let rmse: f32 = predictions
.iter()
.zip(y_values.iter())
.map(|(pred, actual)| (pred - *actual as f32).powi(2))
.sum::<f32>()
.sqrt()
/ predictions.len() as f32;
println!("Predictions: {} rows", predictions.len());
println!("Pipeline-encoded categoricals RMSE: {:.4}", rmse);
assert!(
model.inner().feature_extractor().is_some(),
"FeatureExtractor should be stored"
);
assert!(
rmse < 5.0,
"RMSE too high - feature count mismatch or preprocessing issue. RMSE: {:.4}",
rmse
);
println!("✓ LTT correctly handles pipeline-encoded categoricals");
}