1use alloc::vec;
32use alloc::vec::Vec;
33
34use rayon::prelude::*;
35
36use crate::ensemble::config::SGBTConfig;
37use crate::ensemble::step::BoostingStep;
38use crate::loss::squared::SquaredLoss;
39use crate::loss::Loss;
40use crate::sample::Observation;
41
42use core::fmt;
43
44pub struct ParallelSGBT<L: Loss = SquaredLoss> {
63 config: SGBTConfig,
65 steps: Vec<BoostingStep>,
67 loss: L,
69 base_prediction: f64,
71 base_initialized: bool,
73 initial_targets: Vec<f64>,
75 initial_target_count: usize,
77 samples_seen: u64,
79 rng_state: u64,
81 train_counts_buf: Vec<usize>,
83}
84
85impl<L: Loss + Clone> Clone for ParallelSGBT<L> {
86 fn clone(&self) -> Self {
87 Self {
88 config: self.config.clone(),
89 steps: self.steps.clone(),
90 loss: self.loss.clone(),
91 base_prediction: self.base_prediction,
92 base_initialized: self.base_initialized,
93 initial_targets: self.initial_targets.clone(),
94 initial_target_count: self.initial_target_count,
95 samples_seen: self.samples_seen,
96 rng_state: self.rng_state,
97 train_counts_buf: self.train_counts_buf.clone(),
98 }
99 }
100}
101
102impl<L: Loss> fmt::Debug for ParallelSGBT<L> {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 f.debug_struct("ParallelSGBT")
105 .field("n_steps", &self.steps.len())
106 .field("samples_seen", &self.samples_seen)
107 .field("base_prediction", &self.base_prediction)
108 .field("base_initialized", &self.base_initialized)
109 .finish()
110 }
111}
112
113impl ParallelSGBT<SquaredLoss> {
118 pub fn new(config: SGBTConfig) -> Self {
120 Self::with_loss(config, SquaredLoss)
121 }
122}
123
124impl<L: Loss> ParallelSGBT<L> {
129 pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
143 let leaf_decay_alpha = config
144 .leaf_half_life
145 .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
146
147 let tree_config = crate::ensemble::config::build_tree_config(&config)
148 .leaf_decay_alpha_opt(leaf_decay_alpha);
149
150 let max_tree_samples = config.max_tree_samples;
151
152 let steps: Vec<BoostingStep> = (0..config.n_steps)
153 .map(|i| {
154 let mut tc = tree_config.clone();
155 tc.seed = config.seed ^ (i as u64);
156 let detector = config.drift_detector.create();
157 BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
158 })
159 .collect();
160
161 let seed = config.seed;
162 let initial_target_count = config.initial_target_count;
163 let n_steps = steps.len();
164 Self {
165 config,
166 steps,
167 loss,
168 base_prediction: 0.0,
169 base_initialized: false,
170 initial_targets: Vec::new(),
171 initial_target_count,
172 samples_seen: 0,
173 rng_state: seed,
174 train_counts_buf: vec![0; n_steps],
175 }
176 }
177
178 pub fn train_one(&mut self, sample: &impl Observation) {
187 self.samples_seen += 1;
188 let target = sample.target();
189 let features = sample.features();
190
191 if !self.base_initialized {
193 self.initial_targets.push(target);
194 if self.initial_targets.len() >= self.initial_target_count {
195 self.base_prediction = self.loss.initial_prediction(&self.initial_targets);
196 self.base_initialized = true;
197 self.initial_targets.clear();
198 self.initial_targets.shrink_to_fit();
199 }
200 }
201
202 let full_pred = self.predict(features);
204
205 let gradient = self.loss.gradient(target, full_pred);
208 let hessian = self.loss.hessian(target, full_pred);
209
210 for tc in self.train_counts_buf.iter_mut() {
215 *tc = self
216 .config
217 .variant
218 .train_count(hessian, &mut self.rng_state);
219 }
220
221 #[cfg(feature = "parallel")]
225 {
226 self.steps.par_iter_mut().enumerate().for_each(|(i, step)| {
227 step.train_and_predict(features, gradient, hessian, self.train_counts_buf[i]);
228 });
229 }
230
231 #[cfg(not(feature = "parallel"))]
232 {
233 for (i, step) in self.steps.iter_mut().enumerate() {
234 step.train_and_predict(features, gradient, hessian, self.train_counts_buf[i]);
235 }
236 }
237 }
238
239 pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
241 for sample in samples {
242 self.train_one(sample);
243 }
244 }
245
246 pub fn predict(&self, features: &[f64]) -> f64 {
251 let mut pred = self.base_prediction;
252 for step in &self.steps {
253 pred += self.config.learning_rate * step.predict(features);
254 }
255 pred
256 }
257
258 pub fn predict_transformed(&self, features: &[f64]) -> f64 {
260 self.loss.predict_transform(self.predict(features))
261 }
262
263 pub fn predict_proba(&self, features: &[f64]) -> f64 {
265 self.predict_transformed(features)
266 }
267
268 pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
270 feature_matrix.iter().map(|f| self.predict(f)).collect()
271 }
272
273 pub fn n_steps(&self) -> usize {
275 self.steps.len()
276 }
277
278 pub fn n_trees(&self) -> usize {
280 self.steps.len() + self.steps.iter().filter(|s| s.has_alternate()).count()
281 }
282
283 pub fn total_leaves(&self) -> usize {
285 self.steps.iter().map(|s| s.n_leaves()).sum()
286 }
287
288 pub fn n_samples_seen(&self) -> u64 {
290 self.samples_seen
291 }
292
293 pub fn base_prediction(&self) -> f64 {
295 self.base_prediction
296 }
297
298 pub fn is_initialized(&self) -> bool {
300 self.base_initialized
301 }
302
303 pub fn config(&self) -> &SGBTConfig {
305 &self.config
306 }
307
308 pub fn loss(&self) -> &L {
310 &self.loss
311 }
312
313 pub fn feature_importances(&self) -> Vec<f64> {
318 let mut totals: Vec<f64> = Vec::new();
319 for step in &self.steps {
320 let gains = step.slot().split_gains();
321 if totals.is_empty() && !gains.is_empty() {
322 totals.resize(gains.len(), 0.0);
323 }
324 for (i, &g) in gains.iter().enumerate() {
325 if i < totals.len() {
326 totals[i] += g;
327 }
328 }
329 }
330
331 let sum: f64 = totals.iter().sum();
332 if sum > 0.0 {
333 totals.iter_mut().for_each(|v| *v /= sum);
334 }
335 totals
336 }
337
338 pub fn reset(&mut self) {
340 for step in &mut self.steps {
341 step.reset();
342 }
343 self.base_prediction = 0.0;
344 self.base_initialized = false;
345 self.initial_targets.clear();
346 self.samples_seen = 0;
347 self.rng_state = self.config.seed;
348 }
349}
350
351#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::sample::Sample;
359 use alloc::format;
360 use alloc::vec;
361 use alloc::vec::Vec;
362
363 fn default_config() -> SGBTConfig {
364 SGBTConfig::builder()
365 .n_steps(10)
366 .learning_rate(0.1)
367 .grace_period(20)
368 .max_depth(4)
369 .n_bins(16)
370 .build()
371 .unwrap()
372 }
373
374 #[test]
378 fn new_model_predicts_zero() {
379 let model = ParallelSGBT::new(default_config());
380 let pred = model.predict(&[1.0, 2.0, 3.0]);
381 assert!(pred.abs() < 1e-12);
382 }
383
384 #[test]
388 fn train_one_does_not_panic() {
389 let mut model = ParallelSGBT::new(default_config());
390 model.train_one(&Sample::new(vec![1.0, 2.0, 3.0], 5.0));
391 assert_eq!(model.n_samples_seen(), 1);
392 }
393
394 #[test]
398 fn prediction_changes_after_training() {
399 let mut model = ParallelSGBT::new(default_config());
400 let features = vec![1.0, 2.0, 3.0];
401 for i in 0..100 {
402 model.train_one(&Sample::new(features.clone(), (i as f64) * 0.1));
403 }
404 let pred = model.predict(&features);
405 assert!(pred.is_finite());
406 }
407
408 #[test]
417 fn linear_signal_rmse_improves() {
418 let config = SGBTConfig::builder()
419 .n_steps(20)
420 .learning_rate(0.15)
421 .grace_period(10)
422 .max_depth(3)
423 .n_bins(16)
424 .build()
425 .unwrap();
426 let mut model = ParallelSGBT::new(config);
427
428 let mut rng: u64 = 12345;
429 let mut early_errors = Vec::new();
430 let mut late_errors = Vec::new();
431
432 for i in 0..1000 {
433 rng ^= rng << 13;
434 rng ^= rng >> 7;
435 rng ^= rng << 17;
436 let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
437 rng ^= rng << 13;
438 rng ^= rng >> 7;
439 rng ^= rng << 17;
440 let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
441 let target = 2.0 * x1 + 3.0 * x2;
442
443 let pred = model.predict(&[x1, x2]);
444 let error = (pred - target).powi(2);
445
446 if (100..300).contains(&i) {
447 early_errors.push(error);
448 }
449 if i >= 800 {
450 late_errors.push(error);
451 }
452
453 model.train_one(&Sample::new(vec![x1, x2], target));
454 }
455
456 let early_rmse = (early_errors.iter().sum::<f64>() / early_errors.len() as f64).sqrt();
457 let late_rmse = (late_errors.iter().sum::<f64>() / late_errors.len() as f64).sqrt();
458
459 assert!(
460 late_rmse < early_rmse,
461 "RMSE should decrease: early={:.4}, late={:.4}",
462 early_rmse,
463 late_rmse
464 );
465 }
466
467 #[test]
471 fn train_batch_equivalent_to_sequential() {
472 let config = default_config();
473 let mut model_seq = ParallelSGBT::new(config.clone());
474 let mut model_batch = ParallelSGBT::new(config);
475
476 let samples: Vec<Sample> = (0..20)
477 .map(|i| {
478 let x = i as f64 * 0.5;
479 Sample::new(vec![x, x * 2.0], x * 3.0)
480 })
481 .collect();
482
483 for s in &samples {
484 model_seq.train_one(s);
485 }
486 model_batch.train_batch(&samples);
487
488 let pred_seq = model_seq.predict(&[1.0, 2.0]);
489 let pred_batch = model_batch.predict(&[1.0, 2.0]);
490
491 assert!(
492 (pred_seq - pred_batch).abs() < 1e-10,
493 "seq={}, batch={}",
494 pred_seq,
495 pred_batch
496 );
497 }
498
499 #[test]
503 fn reset_returns_to_initial() {
504 let mut model = ParallelSGBT::new(default_config());
505 for i in 0..100 {
506 model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
507 }
508 model.reset();
509 assert_eq!(model.n_samples_seen(), 0);
510 assert!(!model.is_initialized());
511 assert!(model.predict(&[1.0, 2.0]).abs() < 1e-12);
512 }
513
514 #[test]
518 fn base_prediction_initializes() {
519 let mut model = ParallelSGBT::new(default_config());
520 for i in 0..50 {
521 model.train_one(&Sample::new(vec![1.0], i as f64 + 100.0));
522 }
523 assert!(model.is_initialized());
524 let expected = (100.0 + 149.0) / 2.0;
525 assert!((model.base_prediction() - expected).abs() < 1.0);
526 }
527
528 #[test]
532 fn with_loss_uses_custom_loss() {
533 use crate::loss::logistic::LogisticLoss;
534 let model = ParallelSGBT::with_loss(default_config(), LogisticLoss);
535 let pred = model.predict_transformed(&[1.0, 2.0]);
536 assert!(
537 (pred - 0.5).abs() < 1e-6,
538 "sigmoid(0) should be 0.5, got {}",
539 pred
540 );
541 }
542
543 #[test]
547 fn debug_format_works() {
548 let model = ParallelSGBT::new(default_config());
549 let debug_str = format!("{:?}", model);
550 assert!(
551 debug_str.contains("ParallelSGBT"),
552 "debug output should contain 'ParallelSGBT', got: {}",
553 debug_str,
554 );
555 }
556
557 #[test]
561 fn accessors_return_expected_values() {
562 let config = default_config();
563 let n = config.n_steps;
564 let model = ParallelSGBT::new(config);
565
566 assert_eq!(model.n_steps(), n);
567 assert_eq!(model.n_trees(), n); assert_eq!(model.total_leaves(), n); assert_eq!(model.n_samples_seen(), 0);
570 assert!(!model.is_initialized());
571 }
572
573 #[test]
577 fn batch_prediction_matches_individual() {
578 let mut model = ParallelSGBT::new(default_config());
579 let features = vec![1.0, 2.0, 3.0];
580 for i in 0..50 {
581 model.train_one(&Sample::new(features.clone(), (i as f64) * 0.5));
582 }
583
584 let matrix = vec![
585 vec![1.0, 2.0, 3.0],
586 vec![4.0, 5.0, 6.0],
587 vec![0.0, 0.0, 0.0],
588 ];
589 let batch_preds = model.predict_batch(&matrix);
590
591 for (feats, batch_pred) in matrix.iter().zip(batch_preds.iter()) {
592 let single_pred = model.predict(feats);
593 assert!(
594 (single_pred - batch_pred).abs() < 1e-12,
595 "batch and single predictions should match",
596 );
597 }
598 }
599
600 #[test]
604 fn feature_importances_normalized() {
605 let config = SGBTConfig::builder()
606 .n_steps(10)
607 .learning_rate(0.1)
608 .grace_period(10)
609 .max_depth(3)
610 .n_bins(16)
611 .build()
612 .unwrap();
613 let mut model = ParallelSGBT::new(config);
614
615 let mut rng: u64 = 42;
617 for _ in 0..200 {
618 rng ^= rng << 13;
619 rng ^= rng >> 7;
620 rng ^= rng << 17;
621 let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
622 rng ^= rng << 13;
623 rng ^= rng >> 7;
624 rng ^= rng << 17;
625 let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
626 let target = 3.0 * x1 - x2;
627 model.train_one(&Sample::new(vec![x1, x2], target));
628 }
629
630 let importances = model.feature_importances();
631 if !importances.is_empty() {
632 let sum: f64 = importances.iter().sum();
633 assert!(
634 (sum - 1.0).abs() < 1e-8,
635 "importances should sum to 1.0, got {}",
636 sum,
637 );
638 for &v in &importances {
639 assert!(v >= 0.0, "importances should be non-negative");
640 }
641 }
642 }
643
644 #[test]
648 fn skip_variant_works_with_parallel() {
649 use crate::ensemble::variants::SGBTVariant;
650
651 let config = SGBTConfig::builder()
652 .n_steps(10)
653 .learning_rate(0.1)
654 .grace_period(20)
655 .max_depth(4)
656 .n_bins(16)
657 .variant(SGBTVariant::Skip { k: 3 })
658 .build()
659 .unwrap();
660 let mut model = ParallelSGBT::new(config);
661
662 for i in 0..100 {
664 model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
665 }
666
667 assert_eq!(model.n_samples_seen(), 100);
668 let pred = model.predict(&[1.0, 2.0]);
669 assert!(pred.is_finite());
670 }
671
672 #[test]
676 fn mi_variant_works_with_parallel() {
677 use crate::ensemble::variants::SGBTVariant;
678
679 let config = SGBTConfig::builder()
680 .n_steps(10)
681 .learning_rate(0.1)
682 .grace_period(20)
683 .max_depth(4)
684 .n_bins(16)
685 .variant(SGBTVariant::MultipleIterations { multiplier: 2.0 })
686 .build()
687 .unwrap();
688 let mut model = ParallelSGBT::new(config);
689
690 for i in 0..100 {
691 model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
692 }
693
694 assert_eq!(model.n_samples_seen(), 100);
695 let pred = model.predict(&[1.0, 2.0]);
696 assert!(pred.is_finite());
697 }
698
699 #[test]
703 fn predict_proba_equals_predict_transformed() {
704 let mut model = ParallelSGBT::new(default_config());
705 for i in 0..50 {
706 model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
707 }
708
709 let feats = [1.0, 2.0];
710 let transformed = model.predict_transformed(&feats);
711 let proba = model.predict_proba(&feats);
712 assert!(
713 (transformed - proba).abs() < 1e-12,
714 "predict_proba and predict_transformed should be identical",
715 );
716 }
717}