1use crate::ensemble::config::SGBTConfig;
30use crate::ensemble::SGBT;
31use crate::error::{ConfigError, IrithyllError};
32use crate::loss::squared::SquaredLoss;
33use crate::loss::Loss;
34use crate::sample::Observation;
35
36#[inline]
42fn xorshift64_f64(state: &mut u64) -> f64 {
43 let mut s = *state;
44 s ^= s << 13;
45 s ^= s >> 7;
46 s ^= s << 17;
47 *state = s;
48 (s >> 11) as f64 / ((1u64 << 53) as f64)
50}
51
52fn poisson_sample(rng: &mut u64) -> usize {
58 let l = (-1.0_f64).exp(); let mut k: usize = 0;
60 let mut p: f64 = 1.0;
61 loop {
62 k += 1;
63 let u = xorshift64_f64(rng);
64 p *= u;
65 if p < l {
66 return k - 1;
67 }
68 }
69}
70
71pub struct BaggedSGBT<L: Loss = SquaredLoss> {
85 bags: Vec<SGBT<L>>,
86 n_bags: usize,
87 samples_seen: u64,
88 rng_state: u64,
89 seed: u64,
90}
91
92impl<L: Loss + Clone> Clone for BaggedSGBT<L> {
93 fn clone(&self) -> Self {
94 Self {
95 bags: self.bags.clone(),
96 n_bags: self.n_bags,
97 samples_seen: self.samples_seen,
98 rng_state: self.rng_state,
99 seed: self.seed,
100 }
101 }
102}
103
104impl<L: Loss> std::fmt::Debug for BaggedSGBT<L> {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 f.debug_struct("BaggedSGBT")
107 .field("n_bags", &self.n_bags)
108 .field("samples_seen", &self.samples_seen)
109 .finish()
110 }
111}
112
113impl BaggedSGBT<SquaredLoss> {
114 pub fn new(config: SGBTConfig, n_bags: usize) -> crate::error::Result<Self> {
120 Self::with_loss(config, SquaredLoss, n_bags)
121 }
122}
123
124impl<L: Loss + Clone> BaggedSGBT<L> {
125 pub fn with_loss(config: SGBTConfig, loss: L, n_bags: usize) -> crate::error::Result<Self> {
134 if n_bags < 1 {
135 return Err(IrithyllError::InvalidConfig(ConfigError::out_of_range(
136 "n_bags",
137 "must be >= 1",
138 n_bags,
139 )));
140 }
141
142 let seed = config.seed;
143 let bags = (0..n_bags)
144 .map(|i| {
145 let mut cfg = config.clone();
146 cfg.seed = config.seed ^ (0xBA6_0000_0000_0000 | i as u64);
148 SGBT::with_loss(cfg, loss.clone())
149 })
150 .collect();
151
152 Ok(Self {
153 bags,
154 n_bags,
155 samples_seen: 0,
156 rng_state: seed,
157 seed,
158 })
159 }
160
161 pub fn train_one(&mut self, sample: &impl Observation) {
167 self.samples_seen += 1;
168 for bag in &mut self.bags {
169 let k = poisson_sample(&mut self.rng_state);
170 for _ in 0..k {
171 bag.train_one(sample);
172 }
173 }
174 }
175
176 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
178 for sample in samples {
179 self.train_one(sample);
180 }
181 }
182
183 pub fn predict(&self, features: &[f64]) -> f64 {
185 let sum: f64 = self.bags.iter().map(|b| b.predict(features)).sum();
186 sum / self.n_bags as f64
187 }
188
189 pub fn predict_transformed(&self, features: &[f64]) -> f64 {
192 let sum: f64 = self
193 .bags
194 .iter()
195 .map(|b| b.predict_transformed(features))
196 .sum();
197 sum / self.n_bags as f64
198 }
199
200 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
202 feature_matrix.iter().map(|f| self.predict(f)).collect()
203 }
204
205 #[inline]
207 pub fn n_bags(&self) -> usize {
208 self.n_bags
209 }
210
211 #[inline]
213 pub fn n_samples_seen(&self) -> u64 {
214 self.samples_seen
215 }
216
217 pub fn bags(&self) -> &[SGBT<L>] {
219 &self.bags
220 }
221
222 pub fn bag(&self, idx: usize) -> &SGBT<L> {
228 &self.bags[idx]
229 }
230
231 pub fn is_initialized(&self) -> bool {
233 self.bags.iter().all(|b| b.is_initialized())
234 }
235
236 pub fn reset(&mut self) {
238 for bag in &mut self.bags {
239 bag.reset();
240 }
241 self.samples_seen = 0;
242 self.rng_state = self.seed;
243 }
244}
245
246use crate::learner::StreamingLearner;
251use crate::sample::SampleRef;
252
253impl<L: Loss + Clone> StreamingLearner for BaggedSGBT<L> {
254 fn train_one(&mut self, features: &[f64], target: f64, weight: f64) {
255 let sample = SampleRef::weighted(features, target, weight);
256 BaggedSGBT::train_one(self, &sample);
258 }
259
260 fn predict(&self, features: &[f64]) -> f64 {
261 BaggedSGBT::predict(self, features)
262 }
263
264 fn n_samples_seen(&self) -> u64 {
265 self.samples_seen
266 }
267
268 fn reset(&mut self) {
269 BaggedSGBT::reset(self);
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::sample::Sample;
277
278 fn test_config() -> SGBTConfig {
279 SGBTConfig::builder()
280 .n_steps(10)
281 .learning_rate(0.1)
282 .grace_period(10)
283 .initial_target_count(5)
284 .build()
285 .unwrap()
286 }
287
288 #[test]
289 fn creates_correct_number_of_bags() {
290 let model = BaggedSGBT::new(test_config(), 7).unwrap();
291 assert_eq!(model.n_bags(), 7);
292 assert_eq!(model.bags().len(), 7);
293 assert_eq!(model.n_samples_seen(), 0);
294 }
295
296 #[test]
297 fn rejects_zero_bags() {
298 let result = BaggedSGBT::new(test_config(), 0);
299 assert!(result.is_err());
300 }
301
302 #[test]
303 fn single_bag_equals_single_sgbt() {
304 let config = test_config();
307 let mut model = BaggedSGBT::new(config, 1).unwrap();
308
309 for i in 0..100 {
310 let x = i as f64 * 0.1;
311 model.train_one(&Sample::new(vec![x], x * 2.0 + 1.0));
312 }
313
314 let pred = model.predict(&[0.5]);
315 assert!(
318 pred.is_finite(),
319 "prediction should be finite, got {}",
320 pred
321 );
322 }
323
324 #[test]
325 fn poisson_mean_approximately_one() {
326 let mut rng = 0xDEAD_BEEF_u64;
327 let n = 10_000;
328 let sum: usize = (0..n).map(|_| poisson_sample(&mut rng)).sum();
329 let mean = sum as f64 / n as f64;
330 assert!(
331 (mean - 1.0).abs() < 0.1,
332 "Poisson(1) mean should be ~1.0, got {}",
333 mean
334 );
335 }
336
337 #[test]
338 fn poisson_never_negative() {
339 let mut rng = 42u64;
340 for _ in 0..10_000 {
341 let k = poisson_sample(&mut rng);
344 assert!(k < 20, "Poisson(1) should rarely exceed 10, got {}", k);
345 }
346 }
347
348 #[test]
349 fn deterministic_with_same_seed() {
350 let config = test_config();
351 let mut model1 = BaggedSGBT::new(config.clone(), 3).unwrap();
352 let mut model2 = BaggedSGBT::new(config, 3).unwrap();
353
354 let samples: Vec<Sample> = (0..50)
355 .map(|i| {
356 let x = i as f64 * 0.1;
357 Sample::new(vec![x], x * 3.0)
358 })
359 .collect();
360
361 for s in &samples {
362 model1.train_one(s);
363 model2.train_one(s);
364 }
365
366 let pred1 = model1.predict(&[0.5]);
367 let pred2 = model2.predict(&[0.5]);
368 assert!(
369 (pred1 - pred2).abs() < 1e-10,
370 "same seed should give identical predictions: {} vs {}",
371 pred1,
372 pred2
373 );
374 }
375
376 #[test]
377 fn predict_averages_bags() {
378 let config = test_config();
379 let mut model = BaggedSGBT::new(config, 5).unwrap();
380
381 for i in 0..100 {
382 let x = i as f64 * 0.1;
383 model.train_one(&Sample::new(vec![x], x));
384 }
385
386 let features = [0.5];
388 let individual_sum: f64 = model.bags().iter().map(|b| b.predict(&features)).sum();
389 let expected = individual_sum / model.n_bags() as f64;
390 let actual = model.predict(&features);
391 assert!(
392 (actual - expected).abs() < 1e-10,
393 "predict should be mean of bags: {} vs {}",
394 actual,
395 expected
396 );
397 }
398
399 #[test]
400 fn reset_clears_state() {
401 let config = test_config();
402 let mut model = BaggedSGBT::new(config, 3).unwrap();
403
404 for i in 0..100 {
405 let x = i as f64;
406 model.train_one(&Sample::new(vec![x], x));
407 }
408 assert!(model.n_samples_seen() > 0);
409
410 model.reset();
411 assert_eq!(model.n_samples_seen(), 0);
412 }
413
414 #[test]
415 fn convergence_on_linear_target() {
416 let config = SGBTConfig::builder()
417 .n_steps(20)
418 .learning_rate(0.1)
419 .grace_period(10)
420 .initial_target_count(5)
421 .build()
422 .unwrap();
423
424 let mut model = BaggedSGBT::new(config, 5).unwrap();
425
426 for i in 0..500 {
428 let x = (i % 100) as f64 * 0.1;
429 model.train_one(&Sample::new(vec![x], 2.0 * x + 1.0));
430 }
431
432 let test_points = [0.0, 0.5, 1.0];
435 for &x in &test_points {
436 let pred = model.predict(&[x]);
437 assert!(
438 pred.is_finite(),
439 "at x={}: prediction should be finite, got {}",
440 x,
441 pred
442 );
443 }
444 let p0 = model.predict(&[0.0]);
446 let p1 = model.predict(&[1.0]);
447 assert!(
450 p1 > p0 || (p1 - p0).abs() < 5.0,
451 "directional: pred(1.0)={}, pred(0.0)={}",
452 p1,
453 p0
454 );
455 }
456
457 #[test]
458 fn variance_reduction() {
459 let config = SGBTConfig::builder()
463 .n_steps(10)
464 .learning_rate(0.1)
465 .grace_period(10)
466 .initial_target_count(5)
467 .build()
468 .unwrap();
469
470 let mut model = BaggedSGBT::new(config, 10).unwrap();
471
472 for i in 0..200 {
473 let x = (i % 50) as f64 * 0.1;
474 model.train_one(&Sample::new(vec![x], x * x));
475 }
476
477 let features = [0.3];
479 let preds: Vec<f64> = model.bags().iter().map(|b| b.predict(&features)).collect();
480 let mean = preds.iter().sum::<f64>() / preds.len() as f64;
481 let variance = preds.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / preds.len() as f64;
482
483 assert!(
486 preds.len() > 1,
487 "need multiple bags to test variance reduction"
488 );
489 let ensemble_pred = model.predict(&features);
491 assert!(
492 (ensemble_pred - mean).abs() < 1e-10,
493 "ensemble prediction should be mean of bags"
494 );
495 assert!(variance >= 0.0 && variance.is_finite());
497 }
498
499 #[test]
500 fn streaming_learner_trait_object() {
501 let config = test_config();
502 let model = BaggedSGBT::new(config, 3).unwrap();
503 let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
504 for i in 0..100 {
505 let x = i as f64 * 0.1;
506 boxed.train(&[x], x * 2.0);
507 }
508 assert_eq!(boxed.n_samples_seen(), 100);
509 let pred = boxed.predict(&[5.0]);
510 assert!(pred.is_finite());
511 boxed.reset();
512 assert_eq!(boxed.n_samples_seen(), 0);
513 }
514}