mod common;
use treeboost::booster::{GBDTConfig, GBDTModel};
use treeboost::features::{
FeatureGenerator, FeatureSelector, InteractionGenerator, InteractionType, PolynomialGenerator,
RatioGenerator, SelectionConfig,
};
use common::create_synthetic_dataset;
#[test]
fn test_features_selection_workflow() {
let dataset = create_synthetic_dataset(500, 42);
let config = GBDTConfig::new().with_num_rounds(30).with_max_depth(4);
let model = GBDTModel::train_binned(&dataset, config).expect("Training should succeed");
let importances = model.feature_importance();
assert_eq!(importances.len(), 5, "Should have 5 feature importances");
let total: f32 = importances.iter().sum();
assert!(
(total - 1.0).abs() < 0.01,
"Importances should sum to 1, got {}",
total
);
for &imp in &importances {
assert!(imp >= 0.0, "Importances should be non-negative");
}
}
#[test]
fn test_features_drop_collinear() {
let num_rows = 100;
let num_features = 3;
let mut data: Vec<f32> = Vec::with_capacity(num_rows * num_features);
for i in 0..num_rows {
let f0 = i as f32;
let f1 = f0 + (i % 5) as f32 * 0.01; let f2 = (i * i) as f32 % 100.0; data.push(f0);
data.push(f1);
data.push(f2);
}
let feature_names: Vec<String> = vec!["f0".to_string(), "f1".to_string(), "f2".to_string()];
let targets: Vec<f32> = (0..num_rows).map(|i| i as f32 * 2.0).collect();
let selection_config = SelectionConfig::default()
.with_drop_collinear(true)
.with_collinearity_threshold(0.95);
let selector = FeatureSelector::new(selection_config);
let (filtered_data, filtered_names, kept_indices) =
selector.drop_collinear_features(&data, num_features, &feature_names, Some(&targets));
assert!(
kept_indices.len() <= num_features,
"Should keep at most original features"
);
assert!(
kept_indices.len() >= 2,
"Should keep independent features: kept {}",
kept_indices.len()
);
assert_eq!(filtered_names.len(), kept_indices.len());
assert_eq!(filtered_data.len(), num_rows * kept_indices.len());
}
#[test]
fn test_interaction_generator_explicit_pairs() {
let num_rows = 100;
let num_features = 4;
let mut data: Vec<f32> = Vec::with_capacity(num_rows * num_features);
for i in 0..num_rows {
data.push(i as f32 * 1.5); data.push((i as f32).sqrt()); data.push((i % 10) as f32); data.push(100.0 - i as f32); }
let names: Vec<String> = (0..num_features).map(|i| format!("feat_{}", i)).collect();
let gen = InteractionGenerator::from_pairs(vec![(0, 1), (2, 3)])
.with_types(vec![InteractionType::Multiply, InteractionType::Subtract]);
let (int_data, int_names) = gen.generate(&data, num_features, &names);
assert_eq!(int_names.len(), 4);
assert_eq!(int_data.len(), num_rows * 4);
assert!(int_names.contains(&"feat_0_mul_feat_1".to_string()));
assert!(int_names.contains(&"feat_0_sub_feat_1".to_string()));
assert!(int_names.contains(&"feat_2_mul_feat_3".to_string()));
assert!(int_names.contains(&"feat_2_sub_feat_3".to_string()));
assert!((int_data[0] - 0.0).abs() < 1e-6); assert!((int_data[1] - 0.0).abs() < 1e-6);
assert!((int_data[2] - 0.0).abs() < 1e-6); assert!((int_data[3] - 100.0).abs() < 1e-6); }
#[test]
fn test_interaction_generator_all_pairs() {
let num_rows = 50;
let num_features = 3;
let data: Vec<f32> = (0..num_rows * num_features)
.map(|i| (i % 20) as f32)
.collect();
let names: Vec<String> = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let mut gen = InteractionGenerator::all_pairs();
gen.fit(&data, num_features);
assert_eq!(gen.pairs().unwrap().len(), 3);
let (int_data, int_names) = gen.generate(&data, num_features, &names);
assert_eq!(int_names.len(), 3);
assert_eq!(int_data.len(), num_rows * 3);
}
#[test]
fn test_interaction_generator_auto_select() {
let num_rows = 200;
let num_features = 5;
let mut data: Vec<f32> = Vec::with_capacity(num_rows * num_features);
for i in 0..num_rows {
let base = i as f32;
data.push(base); data.push(base * 2.0); data.push(base + (i % 7) as f32); data.push((i * 17 % 100) as f32); data.push(1000.0 - base); }
let names: Vec<String> = (0..num_features).map(|i| format!("f{}", i)).collect();
let mut gen = InteractionGenerator::top_correlated(5).with_min_correlation(0.5);
gen.fit(&data, num_features);
let pairs = gen.pairs().unwrap();
assert!(!pairs.is_empty());
assert!(pairs.len() <= 5);
let (int_data, int_names) = gen.generate(&data, num_features, &names);
assert_eq!(int_names.len(), pairs.len());
assert_eq!(int_data.len(), num_rows * pairs.len());
for val in &int_data {
assert!(val.is_finite(), "Interaction values should be finite");
}
}
#[test]
fn test_interaction_generator_target_based() {
let num_rows = 100;
let num_features = 4;
let mut data: Vec<f32> = Vec::with_capacity(num_rows * num_features);
for i in 0..num_rows {
let f0 = (i % 10 + 1) as f32;
let f1 = ((i / 10) % 10 + 1) as f32;
let f2 = (i * 3 % 20) as f32;
let f3 = (i * 7 % 15) as f32;
data.push(f0);
data.push(f1);
data.push(f2);
data.push(f3);
}
let targets: Vec<f32> = (0..num_rows)
.map(|i| {
let f0 = (i % 10 + 1) as f32;
let f1 = ((i / 10) % 10 + 1) as f32;
f0 * f1
})
.collect();
let names: Vec<String> = (0..num_features).map(|i| format!("f{}", i)).collect();
let mut gen = InteractionGenerator::target_based(3, targets);
gen.fit(&data, num_features);
let (int_data, int_names) = gen.generate(&data, num_features, &names);
assert!(!int_names.is_empty());
assert_eq!(int_data.len(), num_rows * int_names.len());
}
#[test]
fn test_interaction_all_types() {
let data = vec![
3.0, 5.0, 8.0, 2.0, 4.0, 6.0, 1.0, 9.0, ];
let names: Vec<String> = vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
];
let gen = InteractionGenerator::from_pairs(vec![(0, 1)]).with_types(InteractionType::all());
let (int_data, int_names) = gen.generate(&data, 4, &names);
assert_eq!(int_names.len(), 5);
assert_eq!(int_data.len(), 2 * 5);
assert!((int_data[0] - 15.0).abs() < 1e-6); assert!((int_data[1] - 8.0).abs() < 1e-6); assert!((int_data[2] - 2.0).abs() < 1e-6); assert!((int_data[3] - 3.0).abs() < 1e-6); assert!((int_data[4] - 5.0).abs() < 1e-6); }
#[test]
fn test_combined_feature_generators() {
let num_rows = 50;
let num_features = 3;
let data: Vec<f32> = (0..num_rows * num_features)
.map(|i| (i % 10 + 1) as f32)
.collect();
let names: Vec<String> = vec!["x".to_string(), "y".to_string(), "z".to_string()];
let poly = PolynomialGenerator::new(); let (poly_data, poly_names) = poly.generate(&data, num_features, &names);
let ratio = RatioGenerator::from_pairs(vec![(0, 1), (1, 2)]);
let (ratio_data, ratio_names) = ratio.generate(&data, num_features, &names);
let interaction = InteractionGenerator::from_pairs(vec![(0, 2)]);
let (int_data, int_names) = interaction.generate(&data, num_features, &names);
assert!(!poly_names.is_empty());
assert!(!ratio_names.is_empty());
assert!(!int_names.is_empty());
let total_features = num_features + poly_names.len() + ratio_names.len() + int_names.len();
assert!(total_features > num_features);
assert_eq!(poly_data.len(), num_rows * poly_names.len());
assert_eq!(ratio_data.len(), num_rows * ratio_names.len());
assert_eq!(int_data.len(), num_rows * int_names.len());
}
#[test]
fn test_self_interactions() {
let data = vec![2.0, 3.0, 4.0]; let names: Vec<String> = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let mut gen = InteractionGenerator::all_pairs().with_self_interactions(true);
gen.fit(&data, 3);
assert_eq!(gen.pairs().unwrap().len(), 6);
let (int_data, int_names) = gen.generate(&data, 3, &names);
assert_eq!(int_names.len(), 6);
assert!(int_names.contains(&"a_mul_a".to_string()));
assert!((int_data[0] - 4.0).abs() < 1e-6); }