use crate::ensemble::config::SGBTConfig;
use crate::ensemble::SGBT;
use crate::error::{ConfigError, IrithyllError};
use crate::loss::softmax::SoftmaxLoss;
use crate::sample::{Observation, SampleRef};
#[derive(Debug)]
pub struct MulticlassSGBT {
committees: Vec<SGBT<SoftmaxLoss>>,
n_classes: usize,
samples_seen: u64,
}
impl MulticlassSGBT {
pub fn new(config: SGBTConfig, n_classes: usize) -> crate::error::Result<Self> {
if n_classes < 2 {
return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
"n_classes",
"must be >= 2",
n_classes,
)));
}
let committees = (0..n_classes)
.map(|_| SGBT::with_loss(config.clone(), SoftmaxLoss { n_classes }))
.collect();
Ok(Self {
committees,
n_classes,
samples_seen: 0,
})
}
pub fn train_one(&mut self, sample: &impl Observation) {
self.samples_seen += 1;
let class_idx = sample.target() as usize;
let features = sample.features();
for (c, committee) in self.committees.iter_mut().enumerate() {
let binary_target = if c == class_idx { 1.0 } else { 0.0 };
let binary_ref = SampleRef::new(features, binary_target);
committee.train_one(&binary_ref);
}
}
pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
for sample in samples {
self.train_one(sample);
}
}
pub fn predict_proba(&self, features: &[f64]) -> Vec<f64> {
let raw: Vec<f64> = self
.committees
.iter()
.map(|c| c.predict(features))
.collect();
let max_raw = raw.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exp_sum: f64 = raw.iter().map(|&r| (r - max_raw).exp()).sum();
raw.iter().map(|&r| (r - max_raw).exp() / exp_sum).collect()
}
pub fn predict(&self, features: &[f64]) -> usize {
let proba = self.predict_proba(features);
proba
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(idx, _)| idx)
.unwrap_or(0)
}
pub fn n_classes(&self) -> usize {
self.n_classes
}
pub fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
pub fn reset(&mut self) {
for committee in &mut self.committees {
committee.reset();
}
self.samples_seen = 0;
}
#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
pub fn to_multiclass_state(&self) -> crate::serde_support::MulticlassModelState {
use crate::serde_support::MulticlassModelState;
let committees = self
.committees
.iter()
.map(|c| {
c.to_model_state()
.expect("SoftmaxLoss always provides loss_type()")
})
.collect();
MulticlassModelState {
n_classes: self.n_classes,
committees,
samples_seen: self.samples_seen,
}
}
#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
pub fn from_multiclass_state(state: crate::serde_support::MulticlassModelState) -> Self {
let n_classes = state.n_classes;
let committees = state
.committees
.into_iter()
.map(|model_state| {
SGBT::from_model_state_with_loss(model_state, SoftmaxLoss { n_classes })
})
.collect();
Self {
committees,
n_classes,
samples_seen: state.samples_seen,
}
}
}
use crate::learner::StreamingLearner;
impl StreamingLearner for MulticlassSGBT {
fn train_one(&mut self, features: &[f64], target: f64, _weight: f64) {
MulticlassSGBT::train_one(self, &SampleRef::new(features, target));
}
fn predict(&self, features: &[f64]) -> f64 {
MulticlassSGBT::predict(self, features) as f64
}
fn n_samples_seen(&self) -> u64 {
self.samples_seen
}
fn reset(&mut self) {
MulticlassSGBT::reset(self);
}
#[allow(deprecated)]
fn diagnostics_array(&self) -> [f64; 5] {
<Self as crate::learner::Tunable>::diagnostics_array(self)
}
#[allow(deprecated)]
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
<Self as crate::learner::Tunable>::adjust_config(self, lr_multiplier, lambda_delta);
}
#[allow(deprecated)]
fn replacement_count(&self) -> u64 {
<Self as crate::learner::Structural>::replacement_count(self)
}
}
impl crate::learner::Tunable for MulticlassSGBT {
fn diagnostics_array(&self) -> [f64; 5] {
use crate::automl::DiagnosticSource;
if let Some(first) = self.committees.first() {
use crate::learner::SGBTLearner;
let learner = SGBTLearner::new(first.clone());
match learner.config_diagnostics() {
Some(d) => [
d.residual_alignment,
d.regularization_sensitivity,
d.depth_sufficiency,
d.effective_dof,
d.uncertainty,
],
None => [0.0; 5],
}
} else {
[0.0; 5]
}
}
fn adjust_config(&mut self, lr_multiplier: f64, lambda_delta: f64) {
for committee in &mut self.committees {
let new_lr = committee.config().learning_rate * lr_multiplier;
committee.set_learning_rate(new_lr);
let new_lambda = committee.config().lambda + lambda_delta;
committee.set_lambda(new_lambda);
}
}
}
impl crate::learner::Structural for MulticlassSGBT {
fn apply_structural_change(&mut self, _depth_delta: i32, _steps_delta: i32) {
}
fn replacement_count(&self) -> u64 {
self.committees.iter().map(|c| c.total_replacements()).sum()
}
}
impl crate::automl::DiagnosticSource for MulticlassSGBT {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sample::Sample;
fn test_config() -> SGBTConfig {
SGBTConfig::builder()
.n_steps(5)
.learning_rate(0.1)
.grace_period(10)
.max_depth(3)
.n_bins(8)
.build()
.unwrap()
}
#[test]
fn new_multiclass_creates_committees() {
let model = MulticlassSGBT::new(test_config(), 3).unwrap();
assert_eq!(model.n_classes(), 3);
}
#[test]
fn new_multiclass_rejects_less_than_two_classes() {
let err = MulticlassSGBT::new(test_config(), 1).unwrap_err();
assert!(
err.to_string().contains("n_classes"),
"error should mention n_classes: {}",
err
);
}
#[test]
fn predict_proba_sums_to_one() {
let model = MulticlassSGBT::new(test_config(), 3).unwrap();
let proba = model.predict_proba(&[1.0, 2.0]);
let sum: f64 = proba.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-10,
"probabilities should sum to 1.0, got {}",
sum
);
}
#[test]
fn predict_proba_uniform_before_training() {
let model = MulticlassSGBT::new(test_config(), 3).unwrap();
let proba = model.predict_proba(&[1.0, 2.0]);
for &p in &proba {
assert!((p - 1.0 / 3.0).abs() < 1e-10);
}
}
#[test]
fn train_one_does_not_panic() {
let mut model = MulticlassSGBT::new(test_config(), 3).unwrap();
model.train_one(&Sample::new(vec![1.0, 2.0], 0.0));
model.train_one(&Sample::new(vec![3.0, 4.0], 1.0));
model.train_one(&Sample::new(vec![5.0, 6.0], 2.0));
assert_eq!(model.n_samples_seen(), 3);
}
#[test]
fn reset_clears_state() {
let mut model = MulticlassSGBT::new(test_config(), 3).unwrap();
for i in 0..20 {
model.train_one(&Sample::new(vec![i as f64], (i % 3) as f64));
}
model.reset();
assert_eq!(model.n_samples_seen(), 0);
let proba = model.predict_proba(&[1.0]);
for &p in &proba {
assert!((p - 1.0 / 3.0).abs() < 1e-10);
}
}
}