Skip to main content

irithyll_core/ensemble/
multiclass.rs

1//! Multi-class classification via one-vs-rest SGBT committees.
2//!
3//! For C classes, maintains C independent SGBT ensembles, each trained with
4//! logistic loss on a binarized target (1.0 for the class, 0.0 otherwise).
5//! Final predictions are normalized via softmax across committee outputs.
6
7use alloc::vec::Vec;
8
9use crate::ensemble::config::SGBTConfig;
10use crate::ensemble::SGBT;
11use crate::error::{ConfigError, IrithyllError};
12use crate::loss::softmax::SoftmaxLoss;
13use crate::sample::{Observation, SampleRef};
14
15/// Multi-class SGBT using one-vs-rest committee of ensembles.
16///
17/// Each class gets its own `SGBT<SoftmaxLoss>` trained with softmax
18/// (logistic per-class) loss. The concrete loss type is monomorphized
19/// for each committee -- no `Box<dyn Loss>` overhead.
20///
21/// Predictions are softmax-normalized across all class committees.
22#[derive(Debug)]
23pub struct MulticlassSGBT {
24    /// One SGBT per class, each with monomorphized SoftmaxLoss.
25    committees: Vec<SGBT<SoftmaxLoss>>,
26    /// Number of classes.
27    n_classes: usize,
28    /// Total samples seen.
29    samples_seen: u64,
30}
31
32impl MulticlassSGBT {
33    /// Create a new multi-class SGBT.
34    ///
35    /// # Errors
36    ///
37    /// Returns [`IrithyllError::InvalidConfig`] if `n_classes < 2`.
38    pub fn new(config: SGBTConfig, n_classes: usize) -> crate::error::Result<Self> {
39        if n_classes < 2 {
40            return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
41                "n_classes",
42                "must be >= 2",
43                n_classes,
44            )));
45        }
46
47        let committees = (0..n_classes)
48            .map(|_| SGBT::with_loss(config.clone(), SoftmaxLoss { n_classes }))
49            .collect();
50
51        Ok(Self {
52            committees,
53            n_classes,
54            samples_seen: 0,
55        })
56    }
57
58    /// Train on a single observation.
59    ///
60    /// The observation's target should be the class index as f64 (0.0, 1.0, 2.0, ...).
61    ///
62    /// Uses [`SampleRef`] internally to avoid cloning feature vectors for each
63    /// committee (N classes = 0 clones instead of N clones).
64    pub fn train_one(&mut self, sample: &impl Observation) {
65        self.samples_seen += 1;
66        let class_idx = sample.target() as usize;
67        let features = sample.features();
68
69        for (c, committee) in self.committees.iter_mut().enumerate() {
70            // Binary target: 1.0 for the correct class, 0.0 otherwise
71            let binary_target = if c == class_idx { 1.0 } else { 0.0 };
72            let binary_ref = SampleRef::new(features, binary_target);
73            committee.train_one(&binary_ref);
74        }
75    }
76
77    /// Train on a batch of observations.
78    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
79        for sample in samples {
80            self.train_one(sample);
81        }
82    }
83
84    /// Predict class probabilities via softmax normalization.
85    ///
86    /// Returns a vector of length `n_classes` summing to ~1.0.
87    pub fn predict_proba(&self, features: &[f64]) -> Vec<f64> {
88        // Get raw predictions from each committee
89        let raw: Vec<f64> = self
90            .committees
91            .iter()
92            .map(|c| c.predict(features))
93            .collect();
94
95        // Softmax normalization (numerically stable)
96        let max_raw = raw.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
97        let exp_sum: f64 = raw.iter().map(|&r| crate::math::exp(r - max_raw)).sum();
98        raw.iter()
99            .map(|&r| crate::math::exp(r - max_raw) / exp_sum)
100            .collect()
101    }
102
103    /// Predict the most likely class.
104    pub fn predict(&self, features: &[f64]) -> usize {
105        let proba = self.predict_proba(features);
106        proba
107            .iter()
108            .enumerate()
109            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
110            .map(|(idx, _)| idx)
111            .unwrap_or(0)
112    }
113
114    /// Number of classes.
115    pub fn n_classes(&self) -> usize {
116        self.n_classes
117    }
118
119    /// Total samples trained.
120    pub fn n_samples_seen(&self) -> u64 {
121        self.samples_seen
122    }
123
124    /// Reset all committees.
125    pub fn reset(&mut self) {
126        for committee in &mut self.committees {
127            committee.reset();
128        }
129        self.samples_seen = 0;
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use crate::sample::Sample;
137    use alloc::string::ToString;
138    use alloc::vec;
139
140    fn test_config() -> SGBTConfig {
141        SGBTConfig::builder()
142            .n_steps(5)
143            .learning_rate(0.1)
144            .grace_period(10)
145            .max_depth(3)
146            .n_bins(8)
147            .build()
148            .unwrap()
149    }
150
151    #[test]
152    fn new_multiclass_creates_committees() {
153        let model = MulticlassSGBT::new(test_config(), 3).unwrap();
154        assert_eq!(model.n_classes(), 3);
155    }
156
157    #[test]
158    fn new_multiclass_rejects_less_than_two_classes() {
159        let err = MulticlassSGBT::new(test_config(), 1).unwrap_err();
160        assert!(
161            err.to_string().contains("n_classes"),
162            "error should mention n_classes: {}",
163            err
164        );
165    }
166
167    #[test]
168    fn predict_proba_sums_to_one() {
169        let model = MulticlassSGBT::new(test_config(), 3).unwrap();
170        let proba = model.predict_proba(&[1.0, 2.0]);
171        let sum: f64 = proba.iter().sum();
172        assert!(
173            (sum - 1.0).abs() < 1e-10,
174            "probabilities should sum to 1.0, got {}",
175            sum
176        );
177    }
178
179    #[test]
180    fn predict_proba_uniform_before_training() {
181        let model = MulticlassSGBT::new(test_config(), 3).unwrap();
182        let proba = model.predict_proba(&[1.0, 2.0]);
183        // All committees predict 0.0, softmax of equal values = uniform
184        for &p in &proba {
185            assert!((p - 1.0 / 3.0).abs() < 1e-10);
186        }
187    }
188
189    #[test]
190    fn train_one_does_not_panic() {
191        let mut model = MulticlassSGBT::new(test_config(), 3).unwrap();
192        model.train_one(&Sample::new(vec![1.0, 2.0], 0.0));
193        model.train_one(&Sample::new(vec![3.0, 4.0], 1.0));
194        model.train_one(&Sample::new(vec![5.0, 6.0], 2.0));
195        assert_eq!(model.n_samples_seen(), 3);
196    }
197
198    #[test]
199    fn reset_clears_state() {
200        let mut model = MulticlassSGBT::new(test_config(), 3).unwrap();
201        for i in 0..20 {
202            model.train_one(&Sample::new(vec![i as f64], (i % 3) as f64));
203        }
204        model.reset();
205        assert_eq!(model.n_samples_seen(), 0);
206        let proba = model.predict_proba(&[1.0]);
207        for &p in &proba {
208            assert!((p - 1.0 / 3.0).abs() < 1e-10);
209        }
210    }
211}