use irithyll::ensemble::multiclass::MulticlassSGBT;
use irithyll::ensemble::variants::SGBTVariant;
use irithyll::loss::logistic::LogisticLoss;
use irithyll::{SGBTConfig, Sample, SGBT};
fn xorshift64(state: &mut u64) -> f64 {
*state ^= *state << 13;
*state ^= *state >> 7;
*state ^= *state << 17;
(*state as f64) / (u64::MAX as f64)
}
fn default_config() -> SGBTConfig {
SGBTConfig::builder()
.n_steps(20)
.learning_rate(0.1)
.grace_period(10)
.max_depth(4)
.n_bins(16)
.build()
.unwrap()
}
fn rmse(errors: &[f64]) -> f64 {
(errors.iter().sum::<f64>() / errors.len() as f64).sqrt()
}
#[test]
fn sgbt_regression_converges() {
let config = default_config();
let mut model = SGBT::new(config);
let mut rng: u64 = 0xCAFE_BABE_1234_5678;
let mut early_errors = Vec::new();
let mut late_errors = Vec::new();
for i in 0..1000 {
let x1 = xorshift64(&mut rng) * 10.0 - 5.0;
let x2 = xorshift64(&mut rng) * 10.0 - 5.0;
let noise = (xorshift64(&mut rng) - 0.5) * 0.5;
let target = 2.0 * x1 + 3.0 * x2 + noise;
let pred = model.predict(&[x1, x2]);
let err_sq = (pred - target).powi(2);
if (50..150).contains(&i) {
early_errors.push(err_sq);
}
if i >= 900 {
late_errors.push(err_sq);
}
model.train_one(&Sample::new(vec![x1, x2], target));
}
let early_rmse = rmse(&early_errors);
let late_rmse = rmse(&late_errors);
assert!(
late_rmse < early_rmse,
"RMSE should decrease over time: early={:.4}, late={:.4}",
early_rmse,
late_rmse
);
}
#[test]
fn sgbt_classification_accuracy() {
let config = SGBTConfig::builder()
.n_steps(20)
.learning_rate(0.1)
.grace_period(10)
.max_depth(4)
.n_bins(16)
.build()
.unwrap();
let mut model = SGBT::with_loss(config, LogisticLoss);
let mut rng: u64 = 12345;
for _ in 0..2000 {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
let x = -5.0 + (rng as f64 / u64::MAX as f64) * 4.0; model.train_one(&Sample::new(vec![x], 0.0));
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
let x = 1.0 + (rng as f64 / u64::MAX as f64) * 4.0; model.train_one(&Sample::new(vec![x], 1.0));
}
let raw_class1 = model.predict(&[3.0]);
let raw_class0 = model.predict(&[-3.0]);
assert!(
raw_class1 > raw_class0,
"Raw prediction for class 1 region ({:.4}) should exceed class 0 region ({:.4})",
raw_class1,
raw_class0
);
let prob_class1 = model.predict_proba(&[3.0]);
let prob_class0 = model.predict_proba(&[-3.0]);
assert!(
prob_class1 > prob_class0,
"P(class1|x=3) = {:.4} should > P(class1|x=-3) = {:.4}",
prob_class1,
prob_class0
);
}
#[test]
fn sgbt_batch_equals_sequential() {
let config = default_config();
let mut model_seq = SGBT::new(config.clone());
let mut model_batch = SGBT::new(config);
let samples: Vec<Sample> = (0..100)
.map(|i| {
let x = i as f64 * 0.1;
Sample::new(vec![x, x * 2.0, x * 0.5], x * 3.0 + 1.0)
})
.collect();
for s in &samples {
model_seq.train_one(s);
}
model_batch.train_batch(&samples);
let test_points: Vec<Vec<f64>> = vec![
vec![0.5, 1.0, 0.25],
vec![3.0, 6.0, 1.5],
vec![7.0, 14.0, 3.5],
];
for features in &test_points {
let pred_seq = model_seq.predict(features);
let pred_batch = model_batch.predict(features);
assert!(
(pred_seq - pred_batch).abs() < 1e-10,
"Predictions should be identical: seq={}, batch={} for features {:?}",
pred_seq,
pred_batch,
features
);
}
assert_eq!(model_seq.n_samples_seen(), model_batch.n_samples_seen());
}
#[test]
fn sgbt_reset_restores_initial_state() {
let mut model = SGBT::new(default_config());
for i in 0..200 {
let x = i as f64 * 0.1;
model.train_one(&Sample::new(vec![x, x * 2.0], x * 3.0));
}
assert!(model.n_samples_seen() == 200);
assert!(model.is_initialized());
model.reset();
assert_eq!(
model.n_samples_seen(),
0,
"samples_seen should be 0 after reset"
);
assert!(
!model.is_initialized(),
"model should not be initialized after reset"
);
let pred = model.predict(&[1.0, 2.0]);
assert!(
pred.abs() < 1e-12,
"prediction should be ~0 after reset, got {}",
pred
);
}
#[test]
fn sgbt_with_skip_variant() {
let config = SGBTConfig::builder()
.n_steps(20)
.learning_rate(0.1)
.grace_period(10)
.max_depth(4)
.n_bins(16)
.variant(SGBTVariant::Skip { k: 5 })
.build()
.unwrap();
let mut model = SGBT::new(config);
let mut rng: u64 = 0xABCD_EF01_2345_6789;
let mut early_errors = Vec::new();
let mut late_errors = Vec::new();
for i in 0..500 {
let x1 = xorshift64(&mut rng) * 10.0 - 5.0;
let x2 = xorshift64(&mut rng) * 10.0 - 5.0;
let target = 2.0 * x1 + 3.0 * x2;
let pred = model.predict(&[x1, x2]);
let err_sq = (pred - target).powi(2);
if (50..150).contains(&i) {
early_errors.push(err_sq);
}
if i >= 400 {
late_errors.push(err_sq);
}
model.train_one(&Sample::new(vec![x1, x2], target));
}
let early_rmse = rmse(&early_errors);
let late_rmse = rmse(&late_errors);
assert!(
late_rmse < early_rmse,
"Skip variant should still converge: early={:.4}, late={:.4}",
early_rmse,
late_rmse
);
}
#[test]
fn sgbt_multiclass_basic() {
let config = SGBTConfig::builder()
.n_steps(10)
.learning_rate(0.1)
.grace_period(10)
.max_depth(3)
.n_bins(8)
.build()
.unwrap();
let mut model = MulticlassSGBT::new(config, 3).unwrap();
let mut rng: u64 = 0x1234_5678_ABCD_EF00;
for i in 0..300 {
let class = (i % 3) as f64;
let x1 = xorshift64(&mut rng) * 10.0 - 5.0;
let x2 = xorshift64(&mut rng) * 10.0 - 5.0;
model.train_one(&Sample::new(vec![x1, x2], class));
}
let proba = model.predict_proba(&[1.0, 2.0]);
assert_eq!(proba.len(), 3, "should return 3 probabilities");
let sum: f64 = proba.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-6,
"probabilities should sum to ~1.0, got {:.8}",
sum
);
for (i, &p) in proba.iter().enumerate() {
assert!(
p >= 0.0,
"probability for class {} should be non-negative, got {}",
i,
p
);
}
}
#[test]
fn sgbt_predictions_always_finite() {
let mut model = SGBT::new(default_config());
let mut rng: u64 = 0xFEED_FACE_DEAD_BEEF;
for _ in 0..500 {
let x1 = xorshift64(&mut rng) * 100.0 - 50.0;
let x2 = xorshift64(&mut rng) * 100.0 - 50.0;
let target = x1.sin() + x2.cos();
model.train_one(&Sample::new(vec![x1, x2], target));
let pred = model.predict(&[x1, x2]);
assert!(pred.is_finite(), "prediction must be finite, got {}", pred);
}
}
#[test]
fn sgbt_metadata_consistency() {
let config = SGBTConfig::builder()
.n_steps(15)
.learning_rate(0.1)
.grace_period(10)
.max_depth(3)
.n_bins(8)
.build()
.unwrap();
let mut model = SGBT::new(config);
assert_eq!(model.n_steps(), 15);
assert_eq!(model.total_leaves(), 15);
let mut rng: u64 = 0x9999_8888_7777_6666;
for _ in 0..300 {
let x = xorshift64(&mut rng) * 10.0;
model.train_one(&Sample::new(vec![x], x * 2.0));
}
assert!(
model.total_leaves() >= model.n_steps(),
"total_leaves ({}) should be >= n_steps ({})",
model.total_leaves(),
model.n_steps()
);
}