#[path = "common/mod.rs"]
mod common;
use treeboost::booster::{GBDTConfig, GBDTModel};
use treeboost::dataset::BinnedDataset;
fn create_classification_dataset(n_samples: usize, seed: u64) -> BinnedDataset {
let mut rng = common::SimpleRng::new(seed);
let n_features = 20;
let mut row_major_features = Vec::with_capacity(n_samples * n_features);
let targets: Vec<f32> = (0..n_samples)
.map(|i| {
let class = if i < n_samples / 2 { 0.0 } else { 1.0 };
let center = if class == 0.0 { 0.3 } else { 0.7 };
for _f in 0..n_features {
let noise = (rng.next_f32() - 0.5) * 0.3;
let feature_val = ((center + noise) * 255.0).clamp(0.0, 255.0) as u8;
row_major_features.push(feature_val);
}
class
})
.collect();
let mut col_major_features = vec![0u8; n_samples * n_features];
for f in 0..n_features {
for r in 0..n_samples {
col_major_features[f * n_samples + r] = row_major_features[r * n_features + f];
}
}
let feature_info = common::create_feature_info(n_features, "feature");
BinnedDataset::new(n_samples, col_major_features, targets, feature_info)
}
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
fn main() {
println!("{}", "=".repeat(60));
println!("TreeBoost: Binary Classification Example");
println!("{}", "=".repeat(60));
println!();
let n_samples = 5000;
let seed = 42;
println!("1. Generating synthetic classification dataset...");
println!(" Samples: {}", n_samples);
println!(" Classes: 2 (binary)");
println!(" Features: 20");
println!();
let dataset = create_classification_dataset(n_samples, seed);
println!("2. Configuring GBDT model for binary classification...");
let config = GBDTConfig::new()
.with_num_rounds(100)
.with_max_depth(6)
.with_learning_rate(0.1)
.with_binary_logloss()
.with_subsample(0.8)
.with_seed(42);
println!(" Rounds: 100");
println!(" Max depth: 6");
println!(" Loss function: Binary Logistic Loss");
println!();
println!("3. Training model...");
let start = std::time::Instant::now();
let model = GBDTModel::train_binned(&dataset, config).expect("Training failed");
let elapsed = start.elapsed();
println!(" Time: {:.2?}", elapsed);
println!(" Trees: {}", model.num_trees());
println!();
println!("4. Making predictions...");
let logits = model.predict(&dataset);
let probabilities: Vec<f32> = logits.iter().map(|&x| sigmoid(x)).collect();
let threshold = 0.5;
let predictions: Vec<f32> = probabilities
.iter()
.map(|&p| if p >= threshold { 1.0 } else { 0.0 })
.collect();
println!();
println!("5. Evaluating classification performance...");
let targets = dataset.targets();
let correct = predictions
.iter()
.zip(targets.iter())
.filter(|(pred, &target)| (**pred - target).abs() < 0.01)
.count();
let accuracy = correct as f32 / predictions.len() as f32;
let tp = predictions
.iter()
.zip(targets.iter())
.filter(|(pred, &target)| **pred >= 0.5 && target >= 0.5)
.count();
let fp = predictions
.iter()
.zip(targets.iter())
.filter(|(pred, &target)| **pred >= 0.5 && target < 0.5)
.count();
let fn_ = predictions
.iter()
.zip(targets.iter())
.filter(|(pred, &target)| **pred < 0.5 && target >= 0.5)
.count();
let precision = if tp + fp > 0 {
tp as f32 / (tp + fp) as f32
} else {
0.0
};
let recall = if tp + fn_ > 0 {
tp as f32 / (tp + fn_) as f32
} else {
0.0
};
let f1 = if precision + recall > 0.0 {
2.0 * (precision * recall) / (precision + recall)
} else {
0.0
};
println!(" Accuracy: {:.4}", accuracy);
println!(" Precision: {:.4}", precision);
println!(" Recall: {:.4}", recall);
println!(" F1-Score: {:.4}", f1);
println!();
println!("6. Feature Importance (top 5):");
let importances = model.feature_importance();
let mut indexed: Vec<(usize, f32)> = importances
.iter()
.enumerate()
.map(|(i, &imp)| (i, imp))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
for (feature_idx, importance) in indexed.iter().take(5) {
println!(" Feature {}: {:.6}", feature_idx, importance);
}
println!();
println!("7. Probability Distribution:");
let mut prob_bins = vec![0usize; 10];
for prob in &probabilities {
let bin = ((prob * 10.0).min(9.0)) as usize;
prob_bins[bin] += 1;
}
for (i, count) in prob_bins.iter().enumerate() {
let start = i as f32 / 10.0;
let end = (i + 1) as f32 / 10.0;
let bar = "*".repeat(*count / 50);
println!(" [{:.1}-{:.1}]: {} ({})", start, end, bar, count);
}
println!();
println!("8. Sample Predictions:");
for i in (0..predictions.len()).step_by(predictions.len() / 5) {
println!(
" Sample {}: Logit={:7.4}, Prob={:.4}, Pred={}, True={}",
i, logits[i], probabilities[i], predictions[i] as i32, targets[i] as i32
);
}
println!();
println!("{}", "=".repeat(60));
println!("Example completed successfully!");
println!("{}", "=".repeat(60));
}