irithyll_core/ensemble/
multiclass.rs1use alloc::vec::Vec;
8
9use crate::ensemble::config::SGBTConfig;
10use crate::ensemble::SGBT;
11use crate::error::{ConfigError, IrithyllError};
12use crate::loss::softmax::SoftmaxLoss;
13use crate::sample::{Observation, SampleRef};
14
15#[derive(Debug)]
23pub struct MulticlassSGBT {
24 committees: Vec<SGBT<SoftmaxLoss>>,
26 n_classes: usize,
28 samples_seen: u64,
30}
31
32impl MulticlassSGBT {
33 pub fn new(config: SGBTConfig, n_classes: usize) -> crate::error::Result<Self> {
39 if n_classes < 2 {
40 return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
41 "n_classes",
42 "must be >= 2",
43 n_classes,
44 )));
45 }
46
47 let committees = (0..n_classes)
48 .map(|_| SGBT::with_loss(config.clone(), SoftmaxLoss { n_classes }))
49 .collect();
50
51 Ok(Self {
52 committees,
53 n_classes,
54 samples_seen: 0,
55 })
56 }
57
58 pub fn train_one(&mut self, sample: &impl Observation) {
65 self.samples_seen += 1;
66 let target = sample.target();
67 let features = sample.features();
68
69 if !target.is_finite() || !features.iter().all(|f| f.is_finite()) {
71 return;
72 }
73
74 let class_idx = target as usize;
75
76 for (c, committee) in self.committees.iter_mut().enumerate() {
77 let binary_target = if c == class_idx { 1.0 } else { 0.0 };
79 let binary_ref = SampleRef::new(features, binary_target);
80 committee.train_one(&binary_ref);
81 }
82 }
83
84 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
86 for sample in samples {
87 self.train_one(sample);
88 }
89 }
90
91 pub fn predict_proba(&self, features: &[f64]) -> Vec<f64> {
95 let raw: Vec<f64> = self
97 .committees
98 .iter()
99 .map(|c| c.predict(features))
100 .collect();
101
102 let max_raw = raw.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
104 let exp_sum: f64 = raw.iter().map(|&r| crate::math::exp(r - max_raw)).sum();
105 raw.iter()
106 .map(|&r| crate::math::exp(r - max_raw) / exp_sum)
107 .collect()
108 }
109
110 pub fn predict(&self, features: &[f64]) -> usize {
112 let proba = self.predict_proba(features);
113 proba
114 .iter()
115 .enumerate()
116 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
117 .map(|(idx, _)| idx)
118 .unwrap_or(0)
119 }
120
121 pub fn n_classes(&self) -> usize {
123 self.n_classes
124 }
125
126 pub fn n_samples_seen(&self) -> u64 {
128 self.samples_seen
129 }
130
131 pub fn reset(&mut self) {
133 for committee in &mut self.committees {
134 committee.reset();
135 }
136 self.samples_seen = 0;
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::sample::Sample;
144 use alloc::string::ToString;
145 use alloc::vec;
146
147 fn test_config() -> SGBTConfig {
148 SGBTConfig::builder()
149 .n_steps(5)
150 .learning_rate(0.1)
151 .grace_period(10)
152 .max_depth(3)
153 .n_bins(8)
154 .build()
155 .unwrap()
156 }
157
158 #[test]
159 fn new_multiclass_creates_committees() {
160 let model = MulticlassSGBT::new(test_config(), 3).unwrap();
161 assert_eq!(model.n_classes(), 3);
162 }
163
164 #[test]
165 fn new_multiclass_rejects_less_than_two_classes() {
166 let err = MulticlassSGBT::new(test_config(), 1).unwrap_err();
167 assert!(
168 err.to_string().contains("n_classes"),
169 "error should mention n_classes: {}",
170 err
171 );
172 }
173
174 #[test]
175 fn predict_proba_sums_to_one() {
176 let model = MulticlassSGBT::new(test_config(), 3).unwrap();
177 let proba = model.predict_proba(&[1.0, 2.0]);
178 let sum: f64 = proba.iter().sum();
179 assert!(
180 (sum - 1.0).abs() < 1e-10,
181 "probabilities should sum to 1.0, got {}",
182 sum
183 );
184 }
185
186 #[test]
187 fn predict_proba_uniform_before_training() {
188 let model = MulticlassSGBT::new(test_config(), 3).unwrap();
189 let proba = model.predict_proba(&[1.0, 2.0]);
190 for &p in &proba {
192 assert!((p - 1.0 / 3.0).abs() < 1e-10);
193 }
194 }
195
196 #[test]
197 fn train_one_does_not_panic() {
198 let mut model = MulticlassSGBT::new(test_config(), 3).unwrap();
199 model.train_one(&Sample::new(vec![1.0, 2.0], 0.0));
200 model.train_one(&Sample::new(vec![3.0, 4.0], 1.0));
201 model.train_one(&Sample::new(vec![5.0, 6.0], 2.0));
202 assert_eq!(model.n_samples_seen(), 3);
203 }
204
205 #[test]
206 fn reset_clears_state() {
207 let mut model = MulticlassSGBT::new(test_config(), 3).unwrap();
208 for i in 0..20 {
209 model.train_one(&Sample::new(vec![i as f64], (i % 3) as f64));
210 }
211 model.reset();
212 assert_eq!(model.n_samples_seen(), 0);
213 let proba = model.predict_proba(&[1.0]);
214 for &p in &proba {
215 assert!((p - 1.0 / 3.0).abs() < 1e-10);
216 }
217 }
218}