use std::fmt;
use crate::ensemble::config::SGBTConfig;
use crate::ensemble::SGBT;
use crate::loss::squared::SquaredLoss;
use crate::loss::Loss;
use crate::sample::SampleRef;
pub use irithyll_core::learner::{HasReadout, StreamingLearner, Structural, Tunable};
pub struct SGBTLearner<L: Loss = SquaredLoss> {
inner: SGBT<L>,
}
impl<L: Loss> SGBTLearner<L> {
#[inline]
pub fn new(model: SGBT<L>) -> Self {
Self { inner: model }
}
#[inline]
pub fn inner(&self) -> &SGBT<L> {
&self.inner
}
#[inline]
pub fn inner_mut(&mut self) -> &mut SGBT<L> {
&mut self.inner
}
#[inline]
pub fn into_inner(self) -> SGBT<L> {
self.inner
}
}
impl SGBTLearner<SquaredLoss> {
#[inline]
pub fn from_config(config: SGBTConfig) -> Self {
Self {
inner: SGBT::new(config),
}
}
}
impl<L: Loss> StreamingLearner for SGBTLearner<L> {
#[inline]
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
let sample = SampleRef::weighted(features, target, weight);
self.inner.train_one(&sample);
}
#[inline]
fn predict(&self, features: &[f64]) -> f64 {
self.inner.predict(features)
}
#[inline]
fn n_samples_seen(&self) -> u64 {
self.inner.n_samples_seen()
}
#[inline]
fn reset(&mut self) {
self.inner.reset();
}
#[allow(deprecated)]
fn diagnostics_array(&self) -> [f64; 5] {
<Self as Tunable>::diagnostics_array(self)
}
#[allow(deprecated)]
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
<Self as Tunable>::adjust_config(self, lr_multiplier, lambda_delta);
}
#[allow(deprecated)]
fn apply_structural_change(&mut self, depth_delta: i32, steps_delta: i32) {
<Self as Structural>::apply_structural_change(self, depth_delta, steps_delta);
}
#[allow(deprecated)]
fn replacement_count(&self) -> u64 {
<Self as Structural>::replacement_count(self)
}
#[allow(deprecated)]
fn check_proactive_prune(&mut self) -> bool {
<Self as Structural>::check_proactive_prune(self)
}
#[allow(deprecated)]
fn set_prune_half_life(&mut self, hl: usize) {
<Self as Structural>::set_prune_half_life(self, hl);
}
#[allow(deprecated)]
fn tree_structure(&self) -> Vec<(usize, usize, f64, f64, u64)> {
<Self as Structural>::tree_structure(self)
}
}
impl<L: Loss> Tunable for SGBTLearner<L> {
fn diagnostics_array(&self) -> [f64; 5] {
use crate::automl::DiagnosticSource;
match self.inner.config_diagnostics() {
Some(d) => [
d.residual_alignment,
d.regularization_sensitivity,
d.depth_sufficiency,
d.effective_dof,
d.uncertainty,
],
None => [0.0; 5],
}
}
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
let current_lr = self.inner.config().learning_rate;
self.inner.set_learning_rate(current_lr * lr_multiplier);
let current_lambda = self.inner.config().lambda;
self.inner.set_lambda(current_lambda + lambda_delta);
}
}
impl<L: Loss> Structural for SGBTLearner<L> {
fn apply_structural_change(&mut self, depth_delta: i32, steps_delta: i32) {
if depth_delta != 0 {
let current = self.inner.config().max_depth as i32;
self.inner
.set_max_depth((current + depth_delta).max(1) as usize);
}
if steps_delta != 0 {
let current = self.inner.config().n_steps as i32;
self.inner
.set_n_steps((current + steps_delta).max(3) as usize);
}
}
fn replacement_count(&self) -> u64 {
self.inner.total_replacements()
}
fn check_proactive_prune(&mut self) -> bool {
self.inner.check_proactive_prune()
}
fn set_prune_half_life(&mut self, hl: usize) {
self.inner.set_prune_half_life(hl);
}
fn tree_structure(&self) -> Vec<(usize, usize, f64, f64, u64)> {
let diag = self.inner.diagnostics_overview();
diag.trees
.iter()
.map(|t| (t.max_depth, t.n_leaves, 0.0, 0.0, t.n_samples))
.collect()
}
}
impl<L: Loss + Clone> Clone for SGBTLearner<L> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<L: Loss> fmt::Debug for SGBTLearner<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SGBTLearner")
.field("inner", &self.inner)
.finish()
}
}
impl<L: Loss> crate::automl::DiagnosticSource for SGBTLearner<L> {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
self.inner.config_diagnostics()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::SGBTConfig;
fn default_config() -> SGBTConfig {
SGBTConfig::builder()
.n_steps(10)
.learning_rate(0.1)
.grace_period(20)
.max_depth(4)
.n_bins(16)
.build()
.unwrap()
}
#[test]
fn test_sgbt_learner_creation() {
let learner = SGBTLearner::from_config(default_config());
assert_eq!(learner.n_samples_seen(), 0);
let pred = learner.predict(&[1.0, 2.0, 3.0]);
assert!(pred.abs() < 1e-12);
}
#[test]
fn test_trait_object_safety() {
let learner = SGBTLearner::from_config(default_config());
let mut boxed: Box<dyn StreamingLearner> = Box::new(learner);
boxed.train(&[1.0, 2.0, 3.0], 5.0);
assert_eq!(boxed.n_samples_seen(), 1);
let pred = boxed.predict(&[1.0, 2.0, 3.0]);
assert!(pred.is_finite());
boxed.reset();
assert_eq!(boxed.n_samples_seen(), 0);
}
#[test]
fn test_train_and_predict() {
let mut learner = SGBTLearner::from_config(default_config());
let features = [1.0, 2.0, 3.0];
let pred_before = learner.predict(&features);
for i in 0..100 {
learner.train(&features, (i as f64) * 0.1);
}
assert_eq!(learner.n_samples_seen(), 100);
let pred_after = learner.predict(&features);
assert!(
(pred_after - pred_before).abs() > 1e-6,
"prediction should change after training: before={}, after={}",
pred_before,
pred_after,
);
assert!(pred_after.is_finite());
}
#[test]
fn test_reset() {
let mut learner = SGBTLearner::from_config(default_config());
for i in 0..50 {
learner.train(&[1.0, 2.0], i as f64);
}
assert_eq!(learner.n_samples_seen(), 50);
learner.reset();
assert_eq!(learner.n_samples_seen(), 0);
let pred = learner.predict(&[1.0, 2.0]);
assert!(
pred.abs() < 1e-12,
"prediction after reset should be zero, got {}",
pred,
);
}
#[test]
fn test_predict_batch() {
let mut learner = SGBTLearner::from_config(default_config());
for i in 0..60 {
learner.train(&[i as f64, (i as f64) * 0.5], i as f64);
}
let rows: Vec<&[f64]> = vec![&[1.0, 0.5], &[10.0, 5.0], &[30.0, 15.0]];
let batch = learner.predict_batch(&rows);
assert_eq!(batch.len(), rows.len());
for (i, row) in rows.iter().enumerate() {
let individual = learner.predict(row);
assert!(
(batch[i] - individual).abs() < 1e-12,
"batch[{}]={} != individual={}",
i,
batch[i],
individual,
);
}
}
#[test]
fn test_inner_access() {
let config = default_config();
let learner = SGBTLearner::from_config(config.clone());
let sgbt = learner.inner();
assert_eq!(sgbt.n_samples_seen(), 0);
assert_eq!(sgbt.config().n_steps, config.n_steps);
}
#[test]
fn test_clone() {
let mut learner = SGBTLearner::from_config(default_config());
for i in 0..80 {
learner.train(&[i as f64, (i as f64) * 2.0], i as f64);
}
let mut cloned = learner.clone();
assert_eq!(cloned.n_samples_seen(), learner.n_samples_seen());
let features = [5.0, 10.0];
let pred_orig = learner.predict(&features);
let pred_clone = cloned.predict(&features);
assert!(
(pred_orig - pred_clone).abs() < 1e-12,
"clone prediction should match original: {} vs {}",
pred_orig,
pred_clone,
);
for i in 0..50 {
cloned.train(&[i as f64, (i as f64) * 2.0], 999.0);
}
assert_eq!(learner.n_samples_seen(), 80);
assert_eq!(cloned.n_samples_seen(), 130);
let pred_orig_after = learner.predict(&features);
let pred_clone_after = cloned.predict(&features);
assert!(
(pred_orig - pred_orig_after).abs() < 1e-12,
"original should be unchanged after training clone",
);
assert!(
(pred_clone_after - pred_clone).abs() > 1e-6,
"clone prediction should change after further training",
);
}
}