Skip to main content

irithyll_core/ensemble/
multiclass.rs

1//! Multi-class classification via one-vs-rest SGBT committees.
2//!
3//! For C classes, maintains C independent SGBT ensembles, each trained with
4//! logistic loss on a binarized target (1.0 for the class, 0.0 otherwise).
5//! Final predictions are normalized via softmax across committee outputs.
6
7use alloc::vec::Vec;
8
9use crate::ensemble::config::SGBTConfig;
10use crate::ensemble::SGBT;
11use crate::error::{ConfigError, IrithyllError};
12use crate::loss::softmax::SoftmaxLoss;
13use crate::sample::{Observation, SampleRef};
14
15/// Multi-class SGBT using one-vs-rest committee of ensembles.
16///
17/// Each class gets its own `SGBT<SoftmaxLoss>` trained with softmax
18/// (logistic per-class) loss. The concrete loss type is monomorphized
19/// for each committee -- no `Box<dyn Loss>` overhead.
20///
21/// Predictions are softmax-normalized across all class committees.
22#[derive(Debug)]
23pub struct MulticlassSGBT {
24    /// One SGBT per class, each with monomorphized SoftmaxLoss.
25    committees: Vec<SGBT<SoftmaxLoss>>,
26    /// Number of classes.
27    n_classes: usize,
28    /// Total samples seen.
29    samples_seen: u64,
30}
31
32impl MulticlassSGBT {
33    /// Create a new multi-class SGBT.
34    ///
35    /// # Errors
36    ///
37    /// Returns [`IrithyllError::InvalidConfig`] if `n_classes < 2`.
38    pub fn new(config: SGBTConfig, n_classes: usize) -> crate::error::Result<Self> {
39        if n_classes < 2 {
40            return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
41                "n_classes",
42                "must be >= 2",
43                n_classes,
44            )));
45        }
46
47        let committees = (0..n_classes)
48            .map(|_| SGBT::with_loss(config.clone(), SoftmaxLoss { n_classes }))
49            .collect();
50
51        Ok(Self {
52            committees,
53            n_classes,
54            samples_seen: 0,
55        })
56    }
57
58    /// Train on a single observation.
59    ///
60    /// The observation's target should be the class index as f64 (0.0, 1.0, 2.0, ...).
61    ///
62    /// Uses [`SampleRef`] internally to avoid cloning feature vectors for each
63    /// committee (N classes = 0 clones instead of N clones).
64    pub fn train_one(&mut self, sample: &impl Observation) {
65        self.samples_seen += 1;
66        let target = sample.target();
67        let features = sample.features();
68
69        // Guard: skip non-finite inputs to prevent NaN/Inf from corrupting model state.
70        if !target.is_finite() || !features.iter().all(|f| f.is_finite()) {
71            return;
72        }
73
74        let class_idx = target as usize;
75
76        for (c, committee) in self.committees.iter_mut().enumerate() {
77            // Binary target: 1.0 for the correct class, 0.0 otherwise
78            let binary_target = if c == class_idx { 1.0 } else { 0.0 };
79            let binary_ref = SampleRef::new(features, binary_target);
80            committee.train_one(&binary_ref);
81        }
82    }
83
84    /// Train on a batch of observations.
85    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
86        for sample in samples {
87            self.train_one(sample);
88        }
89    }
90
91    /// Predict class probabilities via softmax normalization.
92    ///
93    /// Returns a vector of length `n_classes` summing to ~1.0.
94    pub fn predict_proba(&self, features: &[f64]) -> Vec<f64> {
95        // Get raw predictions from each committee
96        let raw: Vec<f64> = self
97            .committees
98            .iter()
99            .map(|c| c.predict(features))
100            .collect();
101
102        // Softmax normalization (numerically stable)
103        let max_raw = raw.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
104        let exp_sum: f64 = raw.iter().map(|&r| crate::math::exp(r - max_raw)).sum();
105        raw.iter()
106            .map(|&r| crate::math::exp(r - max_raw) / exp_sum)
107            .collect()
108    }
109
110    /// Predict the most likely class.
111    pub fn predict(&self, features: &[f64]) -> usize {
112        let proba = self.predict_proba(features);
113        proba
114            .iter()
115            .enumerate()
116            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
117            .map(|(idx, _)| idx)
118            .unwrap_or(0)
119    }
120
121    /// Number of classes.
122    pub fn n_classes(&self) -> usize {
123        self.n_classes
124    }
125
126    /// Total samples trained.
127    pub fn n_samples_seen(&self) -> u64 {
128        self.samples_seen
129    }
130
131    /// Reset all committees.
132    pub fn reset(&mut self) {
133        for committee in &mut self.committees {
134            committee.reset();
135        }
136        self.samples_seen = 0;
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::sample::Sample;
144    use alloc::string::ToString;
145    use alloc::vec;
146
147    fn test_config() -> SGBTConfig {
148        SGBTConfig::builder()
149            .n_steps(5)
150            .learning_rate(0.1)
151            .grace_period(10)
152            .max_depth(3)
153            .n_bins(8)
154            .build()
155            .unwrap()
156    }
157
158    #[test]
159    fn new_multiclass_creates_committees() {
160        let model = MulticlassSGBT::new(test_config(), 3).unwrap();
161        assert_eq!(model.n_classes(), 3);
162    }
163
164    #[test]
165    fn new_multiclass_rejects_less_than_two_classes() {
166        let err = MulticlassSGBT::new(test_config(), 1).unwrap_err();
167        assert!(
168            err.to_string().contains("n_classes"),
169            "error should mention n_classes: {}",
170            err
171        );
172    }
173
174    #[test]
175    fn predict_proba_sums_to_one() {
176        let model = MulticlassSGBT::new(test_config(), 3).unwrap();
177        let proba = model.predict_proba(&[1.0, 2.0]);
178        let sum: f64 = proba.iter().sum();
179        assert!(
180            (sum - 1.0).abs() < 1e-10,
181            "probabilities should sum to 1.0, got {}",
182            sum
183        );
184    }
185
186    #[test]
187    fn predict_proba_uniform_before_training() {
188        let model = MulticlassSGBT::new(test_config(), 3).unwrap();
189        let proba = model.predict_proba(&[1.0, 2.0]);
190        // All committees predict 0.0, softmax of equal values = uniform
191        for &p in &proba {
192            assert!((p - 1.0 / 3.0).abs() < 1e-10);
193        }
194    }
195
196    #[test]
197    fn train_one_does_not_panic() {
198        let mut model = MulticlassSGBT::new(test_config(), 3).unwrap();
199        model.train_one(&Sample::new(vec![1.0, 2.0], 0.0));
200        model.train_one(&Sample::new(vec![3.0, 4.0], 1.0));
201        model.train_one(&Sample::new(vec![5.0, 6.0], 2.0));
202        assert_eq!(model.n_samples_seen(), 3);
203    }
204
205    #[test]
206    fn reset_clears_state() {
207        let mut model = MulticlassSGBT::new(test_config(), 3).unwrap();
208        for i in 0..20 {
209            model.train_one(&Sample::new(vec![i as f64], (i % 3) as f64));
210        }
211        model.reset();
212        assert_eq!(model.n_samples_seen(), 0);
213        let proba = model.predict_proba(&[1.0]);
214        for &p in &proba {
215            assert!((p - 1.0 / 3.0).abs() < 1e-10);
216        }
217    }
218}