irithyll_core/ensemble/
multiclass.rs1use alloc::vec::Vec;
8
9use crate::ensemble::config::SGBTConfig;
10use crate::ensemble::SGBT;
11use crate::error::{ConfigError, IrithyllError};
12use crate::loss::softmax::SoftmaxLoss;
13use crate::sample::{Observation, SampleRef};
14
15#[derive(Debug)]
23pub struct MulticlassSGBT {
24 committees: Vec<SGBT<SoftmaxLoss>>,
26 n_classes: usize,
28 samples_seen: u64,
30}
31
32impl MulticlassSGBT {
33 pub fn new(config: SGBTConfig, n_classes: usize) -> crate::error::Result<Self> {
39 if n_classes < 2 {
40 return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
41 "n_classes",
42 "must be >= 2",
43 n_classes,
44 )));
45 }
46
47 let committees = (0..n_classes)
48 .map(|_| SGBT::with_loss(config.clone(), SoftmaxLoss { n_classes }))
49 .collect();
50
51 Ok(Self {
52 committees,
53 n_classes,
54 samples_seen: 0,
55 })
56 }
57
58 pub fn train_one(&mut self, sample: &impl Observation) {
65 self.samples_seen += 1;
66 let class_idx = sample.target() as usize;
67 let features = sample.features();
68
69 for (c, committee) in self.committees.iter_mut().enumerate() {
70 let binary_target = if c == class_idx { 1.0 } else { 0.0 };
72 let binary_ref = SampleRef::new(features, binary_target);
73 committee.train_one(&binary_ref);
74 }
75 }
76
77 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
79 for sample in samples {
80 self.train_one(sample);
81 }
82 }
83
84 pub fn predict_proba(&self, features: &[f64]) -> Vec<f64> {
88 let raw: Vec<f64> = self
90 .committees
91 .iter()
92 .map(|c| c.predict(features))
93 .collect();
94
95 let max_raw = raw.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
97 let exp_sum: f64 = raw.iter().map(|&r| crate::math::exp(r - max_raw)).sum();
98 raw.iter()
99 .map(|&r| crate::math::exp(r - max_raw) / exp_sum)
100 .collect()
101 }
102
103 pub fn predict(&self, features: &[f64]) -> usize {
105 let proba = self.predict_proba(features);
106 proba
107 .iter()
108 .enumerate()
109 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
110 .map(|(idx, _)| idx)
111 .unwrap_or(0)
112 }
113
114 pub fn n_classes(&self) -> usize {
116 self.n_classes
117 }
118
119 pub fn n_samples_seen(&self) -> u64 {
121 self.samples_seen
122 }
123
124 pub fn reset(&mut self) {
126 for committee in &mut self.committees {
127 committee.reset();
128 }
129 self.samples_seen = 0;
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::sample::Sample;
137 use alloc::string::ToString;
138 use alloc::vec;
139
140 fn test_config() -> SGBTConfig {
141 SGBTConfig::builder()
142 .n_steps(5)
143 .learning_rate(0.1)
144 .grace_period(10)
145 .max_depth(3)
146 .n_bins(8)
147 .build()
148 .unwrap()
149 }
150
151 #[test]
152 fn new_multiclass_creates_committees() {
153 let model = MulticlassSGBT::new(test_config(), 3).unwrap();
154 assert_eq!(model.n_classes(), 3);
155 }
156
157 #[test]
158 fn new_multiclass_rejects_less_than_two_classes() {
159 let err = MulticlassSGBT::new(test_config(), 1).unwrap_err();
160 assert!(
161 err.to_string().contains("n_classes"),
162 "error should mention n_classes: {}",
163 err
164 );
165 }
166
167 #[test]
168 fn predict_proba_sums_to_one() {
169 let model = MulticlassSGBT::new(test_config(), 3).unwrap();
170 let proba = model.predict_proba(&[1.0, 2.0]);
171 let sum: f64 = proba.iter().sum();
172 assert!(
173 (sum - 1.0).abs() < 1e-10,
174 "probabilities should sum to 1.0, got {}",
175 sum
176 );
177 }
178
179 #[test]
180 fn predict_proba_uniform_before_training() {
181 let model = MulticlassSGBT::new(test_config(), 3).unwrap();
182 let proba = model.predict_proba(&[1.0, 2.0]);
183 for &p in &proba {
185 assert!((p - 1.0 / 3.0).abs() < 1e-10);
186 }
187 }
188
189 #[test]
190 fn train_one_does_not_panic() {
191 let mut model = MulticlassSGBT::new(test_config(), 3).unwrap();
192 model.train_one(&Sample::new(vec![1.0, 2.0], 0.0));
193 model.train_one(&Sample::new(vec![3.0, 4.0], 1.0));
194 model.train_one(&Sample::new(vec![5.0, 6.0], 2.0));
195 assert_eq!(model.n_samples_seen(), 3);
196 }
197
198 #[test]
199 fn reset_clears_state() {
200 let mut model = MulticlassSGBT::new(test_config(), 3).unwrap();
201 for i in 0..20 {
202 model.train_one(&Sample::new(vec![i as f64], (i % 3) as f64));
203 }
204 model.reset();
205 assert_eq!(model.n_samples_seen(), 0);
206 let proba = model.predict_proba(&[1.0]);
207 for &p in &proba {
208 assert!((p - 1.0 / 3.0).abs() < 1e-10);
209 }
210 }
211}