1use alloc::vec::Vec;
30
31use crate::ensemble::config::SGBTConfig;
32use crate::ensemble::SGBT;
33use crate::error::{ConfigError, IrithyllError};
34use crate::loss::squared::SquaredLoss;
35use crate::loss::Loss;
36use crate::sample::Observation;
37
38use crate::rng::xorshift64_f64;
39
40fn poisson_sample(rng: &mut u64) -> usize {
50 let l = crate::math::exp(-1.0_f64); let mut k: usize = 0;
52 let mut p: f64 = 1.0;
53 loop {
54 k += 1;
55 let u = xorshift64_f64(rng);
56 p *= u;
57 if p < l {
58 return k - 1;
59 }
60 }
61}
62
63pub struct BaggedSGBT<L: Loss = SquaredLoss> {
77 bags: Vec<SGBT<L>>,
78 n_bags: usize,
79 samples_seen: u64,
80 rng_state: u64,
81 seed: u64,
82}
83
84impl<L: Loss + Clone> Clone for BaggedSGBT<L> {
85 fn clone(&self) -> Self {
86 Self {
87 bags: self.bags.clone(),
88 n_bags: self.n_bags,
89 samples_seen: self.samples_seen,
90 rng_state: self.rng_state,
91 seed: self.seed,
92 }
93 }
94}
95
96impl<L: Loss> core::fmt::Debug for BaggedSGBT<L> {
97 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
98 f.debug_struct("BaggedSGBT")
99 .field("n_bags", &self.n_bags)
100 .field("samples_seen", &self.samples_seen)
101 .finish()
102 }
103}
104
105impl BaggedSGBT<SquaredLoss> {
106 pub fn new(config: SGBTConfig, n_bags: usize) -> crate::error::Result<Self> {
112 Self::with_loss(config, SquaredLoss, n_bags)
113 }
114}
115
116impl<L: Loss + Clone> BaggedSGBT<L> {
117 pub fn with_loss(config: SGBTConfig, loss: L, n_bags: usize) -> crate::error::Result<Self> {
126 if n_bags < 1 {
127 return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
128 "n_bags",
129 "must be >= 1",
130 n_bags,
131 )));
132 }
133
134 let seed = config.seed;
135 let bags = (0..n_bags)
136 .map(|i| {
137 let mut cfg = config.clone();
138 cfg.seed = config.seed ^ (0xBA6_0000_0000_0000 | i as u64);
140 SGBT::with_loss(cfg, loss.clone())
141 })
142 .collect();
143
144 Ok(Self {
145 bags,
146 n_bags,
147 samples_seen: 0,
148 rng_state: seed,
149 seed,
150 })
151 }
152
153 pub fn train_one(&mut self, sample: &impl Observation) {
159 self.samples_seen += 1;
160 let target = sample.target();
161 let features = sample.features();
162
163 if !target.is_finite() || !features.iter().all(|f| f.is_finite()) {
165 return;
166 }
167
168 for bag in &mut self.bags {
169 let k = poisson_sample(&mut self.rng_state);
170 for _ in 0..k {
171 bag.train_one(sample);
172 }
173 }
174 }
175
176 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
178 for sample in samples {
179 self.train_one(sample);
180 }
181 }
182
183 pub fn predict(&self, features: &[f64]) -> f64 {
185 let sum: f64 = self.bags.iter().map(|b| b.predict(features)).sum();
186 sum / self.n_bags as f64
187 }
188
189 pub fn predict_transformed(&self, features: &[f64]) -> f64 {
192 let sum: f64 = self
193 .bags
194 .iter()
195 .map(|b| b.predict_transformed(features))
196 .sum();
197 sum / self.n_bags as f64
198 }
199
200 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
202 feature_matrix.iter().map(|f| self.predict(f)).collect()
203 }
204
205 #[inline]
207 pub fn n_bags(&self) -> usize {
208 self.n_bags
209 }
210
211 #[inline]
213 pub fn n_samples_seen(&self) -> u64 {
214 self.samples_seen
215 }
216
217 pub fn bags(&self) -> &[SGBT<L>] {
219 &self.bags
220 }
221
222 pub fn bag(&self, idx: usize) -> &SGBT<L> {
228 &self.bags[idx]
229 }
230
231 pub fn is_initialized(&self) -> bool {
233 self.bags.iter().all(|b| b.is_initialized())
234 }
235
236 pub fn reset(&mut self) {
238 for bag in &mut self.bags {
239 bag.reset();
240 }
241 self.samples_seen = 0;
242 self.rng_state = self.seed;
243 }
244}
245
246use crate::learner::StreamingLearner;
251use crate::sample::SampleRef;
252
253impl<L: Loss + Clone> StreamingLearner for BaggedSGBT<L> {
254 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
255 let sample = SampleRef::weighted(features, target, weight);
256 BaggedSGBT::train_one(self, &sample);
258 }
259
260 fn predict(&self, features: &[f64]) -> f64 {
261 BaggedSGBT::predict(self, features)
262 }
263
264 fn n_samples_seen(&self) -> u64 {
265 self.samples_seen
266 }
267
268 fn reset(&mut self) {
269 BaggedSGBT::reset(self);
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::sample::Sample;
277 use alloc::boxed::Box;
278 use alloc::vec;
279 use alloc::vec::Vec;
280
281 fn test_config() -> SGBTConfig {
282 SGBTConfig::builder()
283 .n_steps(10)
284 .learning_rate(0.1)
285 .grace_period(10)
286 .initial_target_count(5)
287 .build()
288 .unwrap()
289 }
290
291 #[test]
292 fn creates_correct_number_of_bags() {
293 let model = BaggedSGBT::new(test_config(), 7).unwrap();
294 assert_eq!(model.n_bags(), 7);
295 assert_eq!(model.bags().len(), 7);
296 assert_eq!(model.n_samples_seen(), 0);
297 }
298
299 #[test]
300 fn rejects_zero_bags() {
301 let result = BaggedSGBT::new(test_config(), 0);
302 assert!(result.is_err());
303 }
304
305 #[test]
306 fn single_bag_equals_single_sgbt() {
307 let config = test_config();
310 let mut model = BaggedSGBT::new(config, 1).unwrap();
311
312 for i in 0..100 {
313 let x = i as f64 * 0.1;
314 model.train_one(&Sample::new(vec![x], x * 2.0 + 1.0));
315 }
316
317 let pred = model.predict(&[0.5]);
318 assert!(
321 pred.is_finite(),
322 "prediction should be finite, got {}",
323 pred
324 );
325 }
326
327 #[test]
328 fn poisson_mean_approximately_one() {
329 let mut rng = 0xDEAD_BEEF_u64;
330 let n = 10_000;
331 let sum: usize = (0..n).map(|_| poisson_sample(&mut rng)).sum();
332 let mean = sum as f64 / n as f64;
333 assert!(
334 (mean - 1.0).abs() < 0.1,
335 "Poisson(1) mean should be ~1.0, got {}",
336 mean
337 );
338 }
339
340 #[test]
341 fn poisson_never_negative() {
342 let mut rng = 42u64;
343 for _ in 0..10_000 {
344 let k = poisson_sample(&mut rng);
347 assert!(k < 20, "Poisson(1) should rarely exceed 10, got {}", k);
348 }
349 }
350
351 #[test]
352 fn deterministic_with_same_seed() {
353 let config = test_config();
354 let mut model1 = BaggedSGBT::new(config.clone(), 3).unwrap();
355 let mut model2 = BaggedSGBT::new(config, 3).unwrap();
356
357 let samples: Vec<Sample> = (0..50)
358 .map(|i| {
359 let x = i as f64 * 0.1;
360 Sample::new(vec![x], x * 3.0)
361 })
362 .collect();
363
364 for s in &samples {
365 model1.train_one(s);
366 model2.train_one(s);
367 }
368
369 let pred1 = model1.predict(&[0.5]);
370 let pred2 = model2.predict(&[0.5]);
371 assert!(
372 (pred1 - pred2).abs() < 1e-10,
373 "same seed should give identical predictions: {} vs {}",
374 pred1,
375 pred2
376 );
377 }
378
379 #[test]
380 fn predict_averages_bags() {
381 let config = test_config();
382 let mut model = BaggedSGBT::new(config, 5).unwrap();
383
384 for i in 0..100 {
385 let x = i as f64 * 0.1;
386 model.train_one(&Sample::new(vec![x], x));
387 }
388
389 let features = [0.5];
391 let individual_sum: f64 = model.bags().iter().map(|b| b.predict(&features)).sum();
392 let expected = individual_sum / model.n_bags() as f64;
393 let actual = model.predict(&features);
394 assert!(
395 (actual - expected).abs() < 1e-10,
396 "predict should be mean of bags: {} vs {}",
397 actual,
398 expected
399 );
400 }
401
402 #[test]
403 fn reset_clears_state() {
404 let config = test_config();
405 let mut model = BaggedSGBT::new(config, 3).unwrap();
406
407 for i in 0..100 {
408 let x = i as f64;
409 model.train_one(&Sample::new(vec![x], x));
410 }
411 assert!(model.n_samples_seen() > 0);
412
413 model.reset();
414 assert_eq!(model.n_samples_seen(), 0);
415 }
416
417 #[test]
418 fn convergence_on_linear_target() {
419 let config = SGBTConfig::builder()
420 .n_steps(20)
421 .learning_rate(0.1)
422 .grace_period(10)
423 .initial_target_count(5)
424 .build()
425 .unwrap();
426
427 let mut model = BaggedSGBT::new(config, 5).unwrap();
428
429 for i in 0..500 {
431 let x = (i % 100) as f64 * 0.1;
432 model.train_one(&Sample::new(vec![x], 2.0 * x + 1.0));
433 }
434
435 let test_points = [0.0, 0.5, 1.0];
438 for &x in &test_points {
439 let pred = model.predict(&[x]);
440 assert!(
441 pred.is_finite(),
442 "at x={}: prediction should be finite, got {}",
443 x,
444 pred
445 );
446 }
447 let p0 = model.predict(&[0.0]);
449 let p1 = model.predict(&[1.0]);
450 assert!(
453 p1 > p0 || (p1 - p0).abs() < 5.0,
454 "directional: pred(1.0)={}, pred(0.0)={}",
455 p1,
456 p0
457 );
458 }
459
460 #[test]
461 fn variance_reduction() {
462 let config = SGBTConfig::builder()
466 .n_steps(10)
467 .learning_rate(0.1)
468 .grace_period(10)
469 .initial_target_count(5)
470 .build()
471 .unwrap();
472
473 let mut model = BaggedSGBT::new(config, 10).unwrap();
474
475 for i in 0..200 {
476 let x = (i % 50) as f64 * 0.1;
477 model.train_one(&Sample::new(vec![x], x * x));
478 }
479
480 let features = [0.3];
482 let preds: Vec<f64> = model.bags().iter().map(|b| b.predict(&features)).collect();
483 let mean = preds.iter().sum::<f64>() / preds.len() as f64;
484 let variance = preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / preds.len() as f64;
485
486 assert!(
489 preds.len() > 1,
490 "need multiple bags to test variance reduction"
491 );
492 let ensemble_pred = model.predict(&features);
494 assert!(
495 (ensemble_pred - mean).abs() < 1e-10,
496 "ensemble prediction should be mean of bags"
497 );
498 assert!(variance >= 0.0 && variance.is_finite());
500 }
501
502 #[test]
503 fn streaming_learner_trait_object() {
504 let config = test_config();
505 let model = BaggedSGBT::new(config, 3).unwrap();
506 let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
507 for i in 0..100 {
508 let x = i as f64 * 0.1;
509 boxed.train(&[x], x * 2.0);
510 }
511 assert_eq!(boxed.n_samples_seen(), 100);
512 let pred = boxed.predict(&[5.0]);
513 assert!(pred.is_finite());
514 boxed.reset();
515 assert_eq!(boxed.n_samples_seen(), 0);
516 }
517}