use std::fmt;
use crate::ensemble::config::SGBTConfig;
use crate::ensemble::lr_schedule::LRScheduler;
use crate::ensemble::SGBT;
use crate::learner::StreamingLearner;
use crate::loss::squared::SquaredLoss;
use crate::loss::Loss;
use crate::sample::{Observation, SampleRef};
pub struct AdaptiveSGBT<L: Loss = SquaredLoss> {
inner: SGBT<L>,
scheduler: Box<dyn LRScheduler>,
step_count: u64,
last_loss: f64,
base_lr: f64,
}
impl AdaptiveSGBT<SquaredLoss> {
pub fn new(config: SGBTConfig, scheduler: impl LRScheduler + 'static) -> Self {
let base_lr = config.learning_rate;
Self {
inner: SGBT::new(config),
scheduler: Box::new(scheduler),
step_count: 0,
last_loss: 0.0,
base_lr,
}
}
}
impl<L: Loss> AdaptiveSGBT<L> {
pub fn with_loss(config: SGBTConfig, loss: L, scheduler: impl LRScheduler + 'static) -> Self {
let base_lr = config.learning_rate;
Self {
inner: SGBT::with_loss(config, loss),
scheduler: Box::new(scheduler),
step_count: 0,
last_loss: 0.0,
base_lr,
}
}
pub fn train_one_obs(&mut self, sample: &impl Observation) {
let pred = self.inner.predict(sample.features());
let err = sample.target() - pred;
self.last_loss = err * err;
let lr = self
.scheduler
.learning_rate(self.step_count, self.last_loss);
self.inner.set_learning_rate(lr);
self.step_count += 1;
self.inner.train_one(sample);
}
pub fn current_lr(&self) -> f64 {
self.inner.config().learning_rate
}
pub fn base_lr(&self) -> f64 {
self.base_lr
}
pub fn step_count(&self) -> u64 {
self.step_count
}
pub fn last_loss(&self) -> f64 {
self.last_loss
}
pub fn scheduler(&self) -> &dyn LRScheduler {
&*self.scheduler
}
pub fn scheduler_mut(&mut self) -> &mut dyn LRScheduler {
&mut *self.scheduler
}
pub fn inner(&self) -> &SGBT<L> {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut SGBT<L> {
&mut self.inner
}
pub fn into_inner(self) -> SGBT<L> {
self.inner
}
pub fn predict(&self, features: &[f64]) -> f64 {
self.inner.predict(features)
}
}
impl<L: Loss> StreamingLearner for AdaptiveSGBT<L> {
fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
let sample = SampleRef::weighted(features, target, weight);
self.train_one_obs(&sample);
}
fn predict(&self, features: &[f64]) -> f64 {
self.inner.predict_graduated_sibling_interpolated(features)
}
fn n_samples_seen(&self) -> u64 {
self.inner.n_samples_seen()
}
fn reset(&mut self) {
self.inner.reset();
self.scheduler.reset();
self.step_count = 0;
self.last_loss = 0.0;
}
}
impl<L: Loss> fmt::Debug for AdaptiveSGBT<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AdaptiveSGBT")
.field("inner", &self.inner)
.field("step_count", &self.step_count)
.field("last_loss", &self.last_loss)
.field("base_lr", &self.base_lr)
.field("current_lr", &self.current_lr())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ensemble::lr_schedule::{ConstantLR, ExponentialDecayLR, PlateauLR};
fn test_config() -> SGBTConfig {
SGBTConfig::builder()
.n_steps(5)
.learning_rate(0.1)
.build()
.unwrap()
}
#[test]
fn construction_and_initial_state() {
let model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
assert_eq!(model.step_count(), 0);
assert_eq!(model.n_samples_seen(), 0);
assert!((model.base_lr() - 0.1).abs() < 1e-12);
assert!((model.current_lr() - 0.1).abs() < 1e-12);
}
#[test]
fn train_increments_step_count() {
let mut model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
model.train(&[1.0, 2.0], 3.0);
model.train(&[4.0, 5.0], 6.0);
assert_eq!(model.step_count(), 2);
assert_eq!(model.n_samples_seen(), 2);
}
#[test]
fn exponential_decay_reduces_lr() {
let mut model = AdaptiveSGBT::new(test_config(), ExponentialDecayLR::new(0.1, 0.9));
for i in 0..10 {
model.train(&[i as f64, (i * 2) as f64], i as f64);
}
let lr = model.current_lr();
assert!(lr < 0.1, "LR should have decayed from 0.1, got {}", lr);
}
#[test]
fn predict_is_finite() {
let mut model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
model.train(&[1.0, 2.0], 3.0);
let pred = model.predict(&[1.0, 2.0]);
assert!(
pred.is_finite(),
"prediction should be finite, got {}",
pred
);
}
#[test]
fn reset_clears_all_state() {
let mut model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
model.train(&[1.0], 2.0);
model.train(&[3.0], 4.0);
assert_eq!(model.step_count(), 2);
model.reset();
assert_eq!(model.step_count(), 0);
assert_eq!(model.n_samples_seen(), 0);
assert!((model.last_loss()).abs() < 1e-12);
}
#[test]
fn as_streaming_learner_trait_object() {
let model = AdaptiveSGBT::new(test_config(), ConstantLR::new(0.1));
let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
boxed.train(&[1.0, 2.0], 3.0);
assert_eq!(boxed.n_samples_seen(), 1);
let pred = boxed.predict(&[1.0, 2.0]);
assert!(pred.is_finite());
}
#[test]
fn plateau_lr_reduces_on_stagnation() {
let scheduler = PlateauLR::new(0.1, 0.5, 5, 1e-6);
let mut model = AdaptiveSGBT::new(test_config(), scheduler);
for _ in 0..50 {
model.train(&[1.0, 1.0], 1.0);
}
let lr = model.current_lr();
assert!(lr <= 0.1, "PlateauLR should have reduced LR, got {}", lr);
}
}