1use alloc::vec::Vec;
30
31use crate::ensemble::config::SGBTConfig;
32use crate::ensemble::SGBT;
33use crate::error::{ConfigError, IrithyllError};
34use crate::loss::squared::SquaredLoss;
35use crate::loss::Loss;
36use crate::sample::Observation;
37
38#[inline]
44fn xorshift64_f64(state: &mut u64) -> f64 {
45 let mut s = *state;
46 s ^= s << 13;
47 s ^= s >> 7;
48 s ^= s << 17;
49 *state = s;
50 (s >> 11) as f64 / ((1u64 << 53) as f64)
52}
53
54fn poisson_sample(rng: &mut u64) -> usize {
60 let l = crate::math::exp(-1.0_f64); let mut k: usize = 0;
62 let mut p: f64 = 1.0;
63 loop {
64 k += 1;
65 let u = xorshift64_f64(rng);
66 p *= u;
67 if p < l {
68 return k - 1;
69 }
70 }
71}
72
73pub struct BaggedSGBT<L: Loss = SquaredLoss> {
87 bags: Vec<SGBT<L>>,
88 n_bags: usize,
89 samples_seen: u64,
90 rng_state: u64,
91 seed: u64,
92}
93
94impl<L: Loss + Clone> Clone for BaggedSGBT<L> {
95 fn clone(&self) -> Self {
96 Self {
97 bags: self.bags.clone(),
98 n_bags: self.n_bags,
99 samples_seen: self.samples_seen,
100 rng_state: self.rng_state,
101 seed: self.seed,
102 }
103 }
104}
105
106impl<L: Loss> core::fmt::Debug for BaggedSGBT<L> {
107 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
108 f.debug_struct("BaggedSGBT")
109 .field("n_bags", &self.n_bags)
110 .field("samples_seen", &self.samples_seen)
111 .finish()
112 }
113}
114
115impl BaggedSGBT<SquaredLoss> {
116 pub fn new(config: SGBTConfig, n_bags: usize) -> crate::error::Result<Self> {
122 Self::with_loss(config, SquaredLoss, n_bags)
123 }
124}
125
126impl<L: Loss + Clone> BaggedSGBT<L> {
127 pub fn with_loss(config: SGBTConfig, loss: L, n_bags: usize) -> crate::error::Result<Self> {
136 if n_bags < 1 {
137 return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
138 "n_bags",
139 "must be >= 1",
140 n_bags,
141 )));
142 }
143
144 let seed = config.seed;
145 let bags = (0..n_bags)
146 .map(|i| {
147 let mut cfg = config.clone();
148 cfg.seed = config.seed ^ (0xBA6_0000_0000_0000 | i as u64);
150 SGBT::with_loss(cfg, loss.clone())
151 })
152 .collect();
153
154 Ok(Self {
155 bags,
156 n_bags,
157 samples_seen: 0,
158 rng_state: seed,
159 seed,
160 })
161 }
162
163 pub fn train_one(&mut self, sample: &impl Observation) {
169 self.samples_seen += 1;
170 for bag in &mut self.bags {
171 let k = poisson_sample(&mut self.rng_state);
172 for _ in 0..k {
173 bag.train_one(sample);
174 }
175 }
176 }
177
178 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
180 for sample in samples {
181 self.train_one(sample);
182 }
183 }
184
185 pub fn predict(&self, features: &[f64]) -> f64 {
187 let sum: f64 = self.bags.iter().map(|b| b.predict(features)).sum();
188 sum / self.n_bags as f64
189 }
190
191 pub fn predict_transformed(&self, features: &[f64]) -> f64 {
194 let sum: f64 = self
195 .bags
196 .iter()
197 .map(|b| b.predict_transformed(features))
198 .sum();
199 sum / self.n_bags as f64
200 }
201
202 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
204 feature_matrix.iter().map(|f| self.predict(f)).collect()
205 }
206
207 #[inline]
209 pub fn n_bags(&self) -> usize {
210 self.n_bags
211 }
212
213 #[inline]
215 pub fn n_samples_seen(&self) -> u64 {
216 self.samples_seen
217 }
218
219 pub fn bags(&self) -> &[SGBT<L>] {
221 &self.bags
222 }
223
224 pub fn bag(&self, idx: usize) -> &SGBT<L> {
230 &self.bags[idx]
231 }
232
233 pub fn is_initialized(&self) -> bool {
235 self.bags.iter().all(|b| b.is_initialized())
236 }
237
238 pub fn reset(&mut self) {
240 for bag in &mut self.bags {
241 bag.reset();
242 }
243 self.samples_seen = 0;
244 self.rng_state = self.seed;
245 }
246}
247
248use crate::learner::StreamingLearner;
253use crate::sample::SampleRef;
254
255impl<L: Loss + Clone> StreamingLearner for BaggedSGBT<L> {
256 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
257 let sample = SampleRef::weighted(features, target, weight);
258 BaggedSGBT::train_one(self, &sample);
260 }
261
262 fn predict(&self, features: &[f64]) -> f64 {
263 BaggedSGBT::predict(self, features)
264 }
265
266 fn n_samples_seen(&self) -> u64 {
267 self.samples_seen
268 }
269
270 fn reset(&mut self) {
271 BaggedSGBT::reset(self);
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use crate::sample::Sample;
279 use alloc::boxed::Box;
280 use alloc::vec;
281 use alloc::vec::Vec;
282
283 fn test_config() -> SGBTConfig {
284 SGBTConfig::builder()
285 .n_steps(10)
286 .learning_rate(0.1)
287 .grace_period(10)
288 .initial_target_count(5)
289 .build()
290 .unwrap()
291 }
292
293 #[test]
294 fn creates_correct_number_of_bags() {
295 let model = BaggedSGBT::new(test_config(), 7).unwrap();
296 assert_eq!(model.n_bags(), 7);
297 assert_eq!(model.bags().len(), 7);
298 assert_eq!(model.n_samples_seen(), 0);
299 }
300
301 #[test]
302 fn rejects_zero_bags() {
303 let result = BaggedSGBT::new(test_config(), 0);
304 assert!(result.is_err());
305 }
306
307 #[test]
308 fn single_bag_equals_single_sgbt() {
309 let config = test_config();
312 let mut model = BaggedSGBT::new(config, 1).unwrap();
313
314 for i in 0..100 {
315 let x = i as f64 * 0.1;
316 model.train_one(&Sample::new(vec![x], x * 2.0 + 1.0));
317 }
318
319 let pred = model.predict(&[0.5]);
320 assert!(
323 pred.is_finite(),
324 "prediction should be finite, got {}",
325 pred
326 );
327 }
328
329 #[test]
330 fn poisson_mean_approximately_one() {
331 let mut rng = 0xDEAD_BEEF_u64;
332 let n = 10_000;
333 let sum: usize = (0..n).map(|_| poisson_sample(&mut rng)).sum();
334 let mean = sum as f64 / n as f64;
335 assert!(
336 (mean - 1.0).abs() < 0.1,
337 "Poisson(1) mean should be ~1.0, got {}",
338 mean
339 );
340 }
341
342 #[test]
343 fn poisson_never_negative() {
344 let mut rng = 42u64;
345 for _ in 0..10_000 {
346 let k = poisson_sample(&mut rng);
349 assert!(k < 20, "Poisson(1) should rarely exceed 10, got {}", k);
350 }
351 }
352
353 #[test]
354 fn deterministic_with_same_seed() {
355 let config = test_config();
356 let mut model1 = BaggedSGBT::new(config.clone(), 3).unwrap();
357 let mut model2 = BaggedSGBT::new(config, 3).unwrap();
358
359 let samples: Vec<Sample> = (0..50)
360 .map(|i| {
361 let x = i as f64 * 0.1;
362 Sample::new(vec![x], x * 3.0)
363 })
364 .collect();
365
366 for s in &samples {
367 model1.train_one(s);
368 model2.train_one(s);
369 }
370
371 let pred1 = model1.predict(&[0.5]);
372 let pred2 = model2.predict(&[0.5]);
373 assert!(
374 (pred1 - pred2).abs() < 1e-10,
375 "same seed should give identical predictions: {} vs {}",
376 pred1,
377 pred2
378 );
379 }
380
381 #[test]
382 fn predict_averages_bags() {
383 let config = test_config();
384 let mut model = BaggedSGBT::new(config, 5).unwrap();
385
386 for i in 0..100 {
387 let x = i as f64 * 0.1;
388 model.train_one(&Sample::new(vec![x], x));
389 }
390
391 let features = [0.5];
393 let individual_sum: f64 = model.bags().iter().map(|b| b.predict(&features)).sum();
394 let expected = individual_sum / model.n_bags() as f64;
395 let actual = model.predict(&features);
396 assert!(
397 (actual - expected).abs() < 1e-10,
398 "predict should be mean of bags: {} vs {}",
399 actual,
400 expected
401 );
402 }
403
404 #[test]
405 fn reset_clears_state() {
406 let config = test_config();
407 let mut model = BaggedSGBT::new(config, 3).unwrap();
408
409 for i in 0..100 {
410 let x = i as f64;
411 model.train_one(&Sample::new(vec![x], x));
412 }
413 assert!(model.n_samples_seen() > 0);
414
415 model.reset();
416 assert_eq!(model.n_samples_seen(), 0);
417 }
418
419 #[test]
420 fn convergence_on_linear_target() {
421 let config = SGBTConfig::builder()
422 .n_steps(20)
423 .learning_rate(0.1)
424 .grace_period(10)
425 .initial_target_count(5)
426 .build()
427 .unwrap();
428
429 let mut model = BaggedSGBT::new(config, 5).unwrap();
430
431 for i in 0..500 {
433 let x = (i % 100) as f64 * 0.1;
434 model.train_one(&Sample::new(vec![x], 2.0 * x + 1.0));
435 }
436
437 let test_points = [0.0, 0.5, 1.0];
440 for &x in &test_points {
441 let pred = model.predict(&[x]);
442 assert!(
443 pred.is_finite(),
444 "at x={}: prediction should be finite, got {}",
445 x,
446 pred
447 );
448 }
449 let p0 = model.predict(&[0.0]);
451 let p1 = model.predict(&[1.0]);
452 assert!(
455 p1 > p0 || (p1 - p0).abs() < 5.0,
456 "directional: pred(1.0)={}, pred(0.0)={}",
457 p1,
458 p0
459 );
460 }
461
462 #[test]
463 fn variance_reduction() {
464 let config = SGBTConfig::builder()
468 .n_steps(10)
469 .learning_rate(0.1)
470 .grace_period(10)
471 .initial_target_count(5)
472 .build()
473 .unwrap();
474
475 let mut model = BaggedSGBT::new(config, 10).unwrap();
476
477 for i in 0..200 {
478 let x = (i % 50) as f64 * 0.1;
479 model.train_one(&Sample::new(vec![x], x * x));
480 }
481
482 let features = [0.3];
484 let preds: Vec<f64> = model.bags().iter().map(|b| b.predict(&features)).collect();
485 let mean = preds.iter().sum::<f64>() / preds.len() as f64;
486 let variance = preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / preds.len() as f64;
487
488 assert!(
491 preds.len() > 1,
492 "need multiple bags to test variance reduction"
493 );
494 let ensemble_pred = model.predict(&features);
496 assert!(
497 (ensemble_pred - mean).abs() < 1e-10,
498 "ensemble prediction should be mean of bags"
499 );
500 assert!(variance >= 0.0 && variance.is_finite());
502 }
503
504 #[test]
505 fn streaming_learner_trait_object() {
506 let config = test_config();
507 let model = BaggedSGBT::new(config, 3).unwrap();
508 let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
509 for i in 0..100 {
510 let x = i as f64 * 0.1;
511 boxed.train(&[x], x * 2.0);
512 }
513 assert_eq!(boxed.n_samples_seen(), 100);
514 let pred = boxed.predict(&[5.0]);
515 assert!(pred.is_finite());
516 boxed.reset();
517 assert_eq!(boxed.n_samples_seen(), 0);
518 }
519}