use treeboost::booster::{GBDTConfig, GBDTModel};
use treeboost::dataset::{DataPipeline, DatasetLoader, PipelineConfig};
#[test]
#[ignore] fn test_parquet_large_regression() {
use std::path::Path;
let parquet_path = Path::new("samples/synthetic/large_regression.parquet");
if !parquet_path.exists() {
eprintln!("Skipping test: {} not found", parquet_path.display());
eprintln!("Run: python scripts/generate_samples.py --small");
return;
}
let loader = DatasetLoader::new(64);
let dataset = loader
.load_parquet(parquet_path.to_str().unwrap(), "target", None)
.expect("Should load parquet");
assert!(dataset.num_rows() >= 10_000, "Expected at least 10K rows");
assert_eq!(dataset.num_features(), 10, "Expected 10 features");
let config = GBDTConfig::new()
.with_num_rounds(20)
.with_max_depth(5)
.with_learning_rate(0.1);
let model = GBDTModel::train_binned(&dataset, config).expect("Training should succeed");
let predictions = model.predict(&dataset);
assert_eq!(predictions.len(), dataset.num_rows());
let targets = dataset.targets();
let mean_target: f32 = targets.iter().sum::<f32>() / targets.len() as f32;
let ss_tot: f32 = targets.iter().map(|t| (t - mean_target).powi(2)).sum();
let ss_res: f32 = predictions
.iter()
.zip(targets.iter())
.map(|(p, t)| (p - t).powi(2))
.sum();
let r2 = 1.0 - ss_res / ss_tot;
assert!(r2 > 0.5, "R² should be > 0.5, got {}", r2);
}
#[test]
#[ignore] fn test_parquet_large_mixed() {
use std::path::Path;
let parquet_path = Path::new("samples/synthetic/large_mixed.parquet");
if !parquet_path.exists() {
eprintln!("Skipping test: {} not found", parquet_path.display());
eprintln!("Run: python scripts/generate_samples.py --small");
return;
}
let pipeline = DataPipeline::new(
PipelineConfig::new()
.with_num_bins(64)
.with_cms_params(0.01, 0.99, 10)
.with_smoothing(10.0),
);
let (dataset, state, _filtered_df) = pipeline
.load_parquet_for_training(
parquet_path.to_str().unwrap(),
"target",
Some(&[
"neighborhood",
"property_type",
"condition",
"has_pool",
"has_garage",
]),
)
.expect("Should load parquet with categoricals");
assert!(dataset.num_rows() >= 10_000, "Expected at least 10K rows");
assert_eq!(
state.categorical_encodings.len(),
5,
"Expected 5 categorical encodings"
);
let config = GBDTConfig::new()
.with_num_rounds(30)
.with_max_depth(6)
.with_learning_rate(0.1);
let model = GBDTModel::train_binned(&dataset, config).expect("Training should succeed");
let predictions = model.predict(&dataset);
assert_eq!(predictions.len(), dataset.num_rows());
}
#[test]
#[ignore] fn test_parquet_large_dirty() {
use std::path::Path;
let parquet_path = Path::new("samples/synthetic/large_dirty.parquet");
if !parquet_path.exists() {
eprintln!("Skipping test: {} not found", parquet_path.display());
eprintln!("Run: python scripts/generate_samples.py --small");
return;
}
let pipeline = DataPipeline::new(
PipelineConfig::new()
.with_num_bins(32)
.with_cms_params(0.001, 0.99, 50) .with_smoothing(20.0),
);
let (dataset, state, _filtered_df) = pipeline
.load_parquet_for_training(
parquet_path.to_str().unwrap(),
"target",
Some(&["category", "group"]),
)
.expect("Should load dirty parquet");
assert!(dataset.num_rows() > 0, "Should have rows after filtering");
let config = GBDTConfig::new()
.with_num_rounds(20)
.with_max_depth(4)
.with_pseudo_huber_loss(1.0);
let model = GBDTModel::train_binned(&dataset, config).expect("Training should succeed");
let predictions = model.predict(&dataset);
assert_eq!(predictions.len(), dataset.num_rows());
let cat_state = &state.categorical_encodings[0];
assert!(
cat_state.category_mapping.category_to_idx.len() <= 10,
"Should have filtered most rare categories, got {}",
cat_state.category_mapping.category_to_idx.len()
);
}
#[test]
#[ignore] fn test_parquet_high_cardinality() {
use std::path::Path;
let parquet_path = Path::new("samples/synthetic/large_high_cardinality.parquet");
if !parquet_path.exists() {
eprintln!("Skipping test: {} not found", parquet_path.display());
eprintln!("Run: python scripts/generate_samples.py --small");
return;
}
let pipeline = DataPipeline::new(
PipelineConfig::new()
.with_num_bins(64)
.with_cms_params(0.001, 0.99, 20) .with_smoothing(50.0),
);
let (dataset, state, _filtered_df) = pipeline
.load_parquet_for_training(
parquet_path.to_str().unwrap(),
"target",
Some(&["user_id", "product_id", "region", "merchant_id"]),
)
.expect("Should load high-cardinality parquet");
assert!(dataset.num_rows() >= 10_000, "Expected at least 10K rows");
assert_eq!(
state.categorical_encodings.len(),
4,
"Expected 4 categorical encodings"
);
let user_encoding = &state.categorical_encodings[0];
assert!(
user_encoding.category_mapping.category_to_idx.len() < 10000,
"Should have filtered some rare users"
);
let config = GBDTConfig::new()
.with_num_rounds(20)
.with_max_depth(5)
.with_learning_rate(0.1);
let model = GBDTModel::train_binned(&dataset, config).expect("Training should succeed");
let predictions = model.predict(&dataset);
assert_eq!(predictions.len(), dataset.num_rows());
}
#[test]
#[ignore] fn test_parquet_stress_test() {
use std::path::Path;
use std::time::Instant;
let parquet_path = Path::new("samples/synthetic/stress_test.parquet");
if !parquet_path.exists() {
eprintln!("Skipping test: {} not found", parquet_path.display());
eprintln!("Run: python scripts/generate_samples.py");
return;
}
let pipeline = DataPipeline::new(
PipelineConfig::new()
.with_num_bins(255)
.with_cms_params(0.001, 0.99, 100)
.with_smoothing(10.0),
);
let start = Instant::now();
let (dataset, _state, _filtered_df) = pipeline
.load_parquet_for_training(parquet_path.to_str().unwrap(), "target", Some(&["cat"]))
.expect("Should load stress test parquet");
let load_time = start.elapsed();
println!("Loaded {} rows in {:?}", dataset.num_rows(), load_time);
assert!(dataset.num_rows() >= 100_000, "Expected at least 100K rows");
let config = GBDTConfig::new()
.with_num_rounds(50)
.with_max_depth(6)
.with_learning_rate(0.1);
let start = Instant::now();
let model = GBDTModel::train_binned(&dataset, config).expect("Training should succeed");
let train_time = start.elapsed();
println!("Trained {} trees in {:?}", model.num_trees(), train_time);
let start = Instant::now();
let predictions = model.predict(&dataset);
let predict_time = start.elapsed();
println!("Predicted {} rows in {:?}", predictions.len(), predict_time);
assert_eq!(predictions.len(), dataset.num_rows());
}