Skip to main content

irithyll_core/ensemble/
parallel.rs

1//! Parallel SGBT training with delayed gradient updates.
2//!
3//! Instead of sequential gradient propagation through boosting steps,
4//! this module uses the full ensemble prediction as the gradient target
5//! for all steps simultaneously. Each step trains independently on the
6//! same gradient, enabling rayon-based parallelism across steps.
7//!
8//! # Algorithm
9//!
10//! For each incoming sample `(x, y)`:
11//! 1. Compute the full ensemble prediction: `F(x) = base + lr * sum tree_s(x)`
12//! 2. Compute gradient `g = loss.gradient(y, F(x))` and hessian `h = loss.hessian(y, F(x))`
13//! 3. Pre-compute `train_count` for each step (sequential, uses RNG state)
14//! 4. Train ALL steps in parallel with the same `(x, g, h)` and per-step train_count
15//!
16//! This is a "delayed gradient" approach: all steps see the same gradient
17//! computed from the full ensemble prediction, rather than the sequential
18//! rolling prediction used in standard SGBT. This trades a small amount of
19//! gradient freshness for parallelism across boosting steps.
20//!
21//! Requires the `parallel` feature flag for rayon-based parallelism. Without
22//! the feature, the module still compiles and works correctly using sequential
23//! iteration (identical results, just no multi-core speedup).
24//!
25//! # Trade-offs
26//!
27//! - **Pro:** Near-linear speedup with number of cores for large ensembles.
28//! - **Con:** Gradient staleness may slow convergence slightly; typically
29//!   compensated by a slightly higher learning rate or more training samples.
30
31use alloc::vec;
32use alloc::vec::Vec;
33
34use rayon::prelude::*;
35
36use crate::ensemble::config::SGBTConfig;
37use crate::ensemble::step::BoostingStep;
38use crate::loss::squared::SquaredLoss;
39use crate::loss::Loss;
40use crate::sample::Observation;
41
42use core::fmt;
43
44/// Parallel SGBT ensemble with delayed gradient updates.
45///
46/// All boosting steps train concurrently using the full ensemble prediction
47/// for gradient computation. Predictions remain sequential (deterministic)
48/// -- only training is parallelized.
49///
50/// Generic over `L: Loss` so the loss function's gradient/hessian calls
51/// are monomorphized (inlined) into the training loop -- no virtual dispatch.
52///
53/// # Differences from [`SGBT`](super::SGBT)
54///
55/// | Aspect | `SGBT` | `ParallelSGBT` |
56/// |--------|--------|----------------|
57/// | Gradient target | Rolling (step-by-step) | Full ensemble prediction |
58/// | Step training | Sequential | Parallel (rayon) |
59/// | Prediction | Sequential | Sequential (identical) |
60/// | Convergence | Optimal | Slightly delayed |
61/// | Throughput | 1x | ~Nx (N = cores) |
62pub struct ParallelSGBT<L: Loss = SquaredLoss> {
63    /// Configuration.
64    config: SGBTConfig,
65    /// Boosting steps (one tree + drift detector each).
66    steps: Vec<BoostingStep>,
67    /// Loss function (monomorphized -- no vtable).
68    loss: L,
69    /// Base prediction (initial constant, computed from first batch of targets).
70    base_prediction: f64,
71    /// Whether base_prediction has been initialized.
72    base_initialized: bool,
73    /// Running collection of initial targets for computing base_prediction.
74    initial_targets: Vec<f64>,
75    /// Number of initial targets to collect before setting base_prediction.
76    initial_target_count: usize,
77    /// Total samples trained.
78    samples_seen: u64,
79    /// RNG state for variant skip logic.
80    rng_state: u64,
81    /// Pre-allocated buffer for per-step train counts (avoids heap alloc per sample).
82    train_counts_buf: Vec<usize>,
83}
84
85impl<L: Loss + Clone> Clone for ParallelSGBT<L> {
86    fn clone(&self) -> Self {
87        Self {
88            config: self.config.clone(),
89            steps: self.steps.clone(),
90            loss: self.loss.clone(),
91            base_prediction: self.base_prediction,
92            base_initialized: self.base_initialized,
93            initial_targets: self.initial_targets.clone(),
94            initial_target_count: self.initial_target_count,
95            samples_seen: self.samples_seen,
96            rng_state: self.rng_state,
97            train_counts_buf: self.train_counts_buf.clone(),
98        }
99    }
100}
101
102impl<L: Loss> fmt::Debug for ParallelSGBT<L> {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        f.debug_struct("ParallelSGBT")
105            .field("n_steps", &self.steps.len())
106            .field("samples_seen", &self.samples_seen)
107            .field("base_prediction", &self.base_prediction)
108            .field("base_initialized", &self.base_initialized)
109            .finish()
110    }
111}
112
113// ---------------------------------------------------------------------------
114// Convenience constructor for the default loss (SquaredLoss)
115// ---------------------------------------------------------------------------
116
117impl ParallelSGBT<SquaredLoss> {
118    /// Create a new parallel SGBT ensemble with squared loss (regression).
119    pub fn new(config: SGBTConfig) -> Self {
120        Self::with_loss(config, SquaredLoss)
121    }
122}
123
124// ---------------------------------------------------------------------------
125// General impl for all Loss types
126// ---------------------------------------------------------------------------
127
128impl<L: Loss> ParallelSGBT<L> {
129    /// Create a new parallel SGBT ensemble with a specific loss function.
130    ///
131    /// The loss is stored by value (monomorphized), giving zero-cost
132    /// gradient/hessian dispatch.
133    ///
134    /// ```ignore
135    /// use irithyll::SGBTConfig;
136    /// use irithyll::ensemble::parallel::ParallelSGBT;
137    /// use irithyll::loss::logistic::LogisticLoss;
138    ///
139    /// let config = SGBTConfig::builder().n_steps(10).build().unwrap();
140    /// let model = ParallelSGBT::with_loss(config, LogisticLoss);
141    /// ```
142    pub fn with_loss(config: SGBTConfig, loss: L) -> Self {
143        let leaf_decay_alpha = config
144            .leaf_half_life
145            .map(|hl| crate::math::exp(-crate::math::ln(2.0) / hl as f64));
146
147        let tree_config = crate::ensemble::config::build_tree_config(&config)
148            .leaf_decay_alpha_opt(leaf_decay_alpha);
149
150        let max_tree_samples = config.max_tree_samples;
151
152        let steps: Vec<BoostingStep> = (0..config.n_steps)
153            .map(|i| {
154                let mut tc = tree_config.clone();
155                tc.seed = config.seed ^ (i as u64);
156                let detector = config.drift_detector.create();
157                BoostingStep::new_with_max_samples(tc, detector, max_tree_samples)
158            })
159            .collect();
160
161        let seed = config.seed;
162        let initial_target_count = config.initial_target_count;
163        let n_steps = steps.len();
164        Self {
165            config,
166            steps,
167            loss,
168            base_prediction: 0.0,
169            base_initialized: false,
170            initial_targets: Vec::new(),
171            initial_target_count,
172            samples_seen: 0,
173            rng_state: seed,
174            train_counts_buf: vec![0; n_steps],
175        }
176    }
177
178    /// Train on a single observation using delayed gradient updates.
179    ///
180    /// Accepts any type implementing [`Observation`], including [`Sample`](crate::Sample),
181    /// [`SampleRef`](crate::SampleRef), or tuples like `(&[f64], f64)`.
182    ///
183    /// All boosting steps receive the same gradient/hessian computed from
184    /// the full ensemble prediction, then train in parallel (when the
185    /// `parallel` feature is enabled).
186    pub fn train_one(&mut self, sample: &impl Observation) {
187        self.samples_seen += 1;
188        let target = sample.target();
189        let features = sample.features();
190
191        // Initialize base prediction from first few targets.
192        if !self.base_initialized {
193            self.initial_targets.push(target);
194            if self.initial_targets.len() >= self.initial_target_count {
195                self.base_prediction = self.loss.initial_prediction(&self.initial_targets);
196                self.base_initialized = true;
197                self.initial_targets.clear();
198                self.initial_targets.shrink_to_fit();
199            }
200        }
201
202        // Compute the FULL ensemble prediction (same as predict()).
203        let full_pred = self.predict(features);
204
205        // Compute gradient and hessian from the full ensemble prediction.
206        // All steps will use these same values (delayed gradient approach).
207        let gradient = self.loss.gradient(target, full_pred);
208        let hessian = self.loss.hessian(target, full_pred);
209
210        // Pre-compute train_count for each step sequentially into the
211        // pre-allocated buffer (zero heap alloc per sample).
212        // The RNG state is sequential (xorshift), so we must advance it
213        // in order before entering the parallel section.
214        for tc in self.train_counts_buf.iter_mut() {
215            *tc = self
216                .config
217                .variant
218                .train_count(hessian, &mut self.rng_state);
219        }
220
221        // Train all steps with the same gradient/hessian.
222        // When `parallel` feature is enabled, use rayon for concurrency.
223        // Otherwise, fall back to sequential iteration.
224        #[cfg(feature = "parallel")]
225        {
226            self.steps.par_iter_mut().enumerate().for_each(|(i, step)| {
227                step.train_and_predict(features, gradient, hessian, self.train_counts_buf[i]);
228            });
229        }
230
231        #[cfg(not(feature = "parallel"))]
232        {
233            for (i, step) in self.steps.iter_mut().enumerate() {
234                step.train_and_predict(features, gradient, hessian, self.train_counts_buf[i]);
235            }
236        }
237    }
238
239    /// Train on a batch of observations.
240    pub fn train_batch<O: Observation>(&mut self, samples: &[O]) {
241        for sample in samples {
242            self.train_one(sample);
243        }
244    }
245
246    /// Predict the raw output for a feature vector.
247    ///
248    /// Prediction is always sequential and deterministic, regardless of
249    /// whether training uses parallelism.
250    pub fn predict(&self, features: &[f64]) -> f64 {
251        let mut pred = self.base_prediction;
252        for step in &self.steps {
253            pred += self.config.learning_rate * step.predict(features);
254        }
255        pred
256    }
257
258    /// Predict with loss transform applied (e.g., sigmoid for logistic loss).
259    pub fn predict_transformed(&self, features: &[f64]) -> f64 {
260        self.loss.predict_transform(self.predict(features))
261    }
262
263    /// Predict probability (alias for `predict_transformed`).
264    pub fn predict_proba(&self, features: &[f64]) -> f64 {
265        self.predict_transformed(features)
266    }
267
268    /// Batch prediction.
269    pub fn predict_batch(&self, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
270        feature_matrix.iter().map(|f| self.predict(f)).collect()
271    }
272
273    /// Number of boosting steps.
274    pub fn n_steps(&self) -> usize {
275        self.steps.len()
276    }
277
278    /// Total trees (active + alternates).
279    pub fn n_trees(&self) -> usize {
280        self.steps.len() + self.steps.iter().filter(|s| s.has_alternate()).count()
281    }
282
283    /// Total leaves across all active trees.
284    pub fn total_leaves(&self) -> usize {
285        self.steps.iter().map(|s| s.n_leaves()).sum()
286    }
287
288    /// Total samples trained.
289    pub fn n_samples_seen(&self) -> u64 {
290        self.samples_seen
291    }
292
293    /// The current base prediction.
294    pub fn base_prediction(&self) -> f64 {
295        self.base_prediction
296    }
297
298    /// Whether the base prediction has been initialized.
299    pub fn is_initialized(&self) -> bool {
300        self.base_initialized
301    }
302
303    /// Access the configuration.
304    pub fn config(&self) -> &SGBTConfig {
305        &self.config
306    }
307
308    /// Immutable access to the loss function.
309    pub fn loss(&self) -> &L {
310        &self.loss
311    }
312
313    /// Feature importances based on accumulated split gains across all trees.
314    ///
315    /// Returns normalized importances (sum to 1.0) indexed by feature.
316    /// Returns an empty Vec if no splits have occurred yet.
317    pub fn feature_importances(&self) -> Vec<f64> {
318        let mut totals: Vec<f64> = Vec::new();
319        for step in &self.steps {
320            let gains = step.slot().split_gains();
321            if totals.is_empty() && !gains.is_empty() {
322                totals.resize(gains.len(), 0.0);
323            }
324            for (i, &g) in gains.iter().enumerate() {
325                if i < totals.len() {
326                    totals[i] += g;
327                }
328            }
329        }
330
331        let sum: f64 = totals.iter().sum();
332        if sum > 0.0 {
333            totals.iter_mut().for_each(|v| *v /= sum);
334        }
335        totals
336    }
337
338    /// Reset the ensemble to initial state.
339    pub fn reset(&mut self) {
340        for step in &mut self.steps {
341            step.reset();
342        }
343        self.base_prediction = 0.0;
344        self.base_initialized = false;
345        self.initial_targets.clear();
346        self.samples_seen = 0;
347        self.rng_state = self.config.seed;
348    }
349}
350
351// ---------------------------------------------------------------------------
352// Tests
353// ---------------------------------------------------------------------------
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use crate::sample::Sample;
359    use alloc::format;
360    use alloc::vec;
361    use alloc::vec::Vec;
362
363    fn default_config() -> SGBTConfig {
364        SGBTConfig::builder()
365            .n_steps(10)
366            .learning_rate(0.1)
367            .grace_period(20)
368            .max_depth(4)
369            .n_bins(16)
370            .build()
371            .unwrap()
372    }
373
374    // -------------------------------------------------------------------
375    // 1. Fresh model predicts zero.
376    // -------------------------------------------------------------------
377    #[test]
378    fn new_model_predicts_zero() {
379        let model = ParallelSGBT::new(default_config());
380        let pred = model.predict(&[1.0, 2.0, 3.0]);
381        assert!(pred.abs() < 1e-12);
382    }
383
384    // -------------------------------------------------------------------
385    // 2. train_one does not panic.
386    // -------------------------------------------------------------------
387    #[test]
388    fn train_one_does_not_panic() {
389        let mut model = ParallelSGBT::new(default_config());
390        model.train_one(&Sample::new(vec![1.0, 2.0, 3.0], 5.0));
391        assert_eq!(model.n_samples_seen(), 1);
392    }
393
394    // -------------------------------------------------------------------
395    // 3. Prediction changes after training.
396    // -------------------------------------------------------------------
397    #[test]
398    fn prediction_changes_after_training() {
399        let mut model = ParallelSGBT::new(default_config());
400        let features = vec![1.0, 2.0, 3.0];
401        for i in 0..100 {
402            model.train_one(&Sample::new(features.clone(), (i as f64) * 0.1));
403        }
404        let pred = model.predict(&features);
405        assert!(pred.is_finite());
406    }
407
408    // -------------------------------------------------------------------
409    // 4. Linear signal RMSE improves over time.
410    //
411    // NOTE: The delayed gradient approach converges slower than sequential
412    // SGBT because all steps see the same (slightly stale) gradient. We
413    // compensate with a higher learning rate and more training samples,
414    // and widen the measurement windows.
415    // -------------------------------------------------------------------
416    #[test]
417    fn linear_signal_rmse_improves() {
418        let config = SGBTConfig::builder()
419            .n_steps(20)
420            .learning_rate(0.15)
421            .grace_period(10)
422            .max_depth(3)
423            .n_bins(16)
424            .build()
425            .unwrap();
426        let mut model = ParallelSGBT::new(config);
427
428        let mut rng: u64 = 12345;
429        let mut early_errors = Vec::new();
430        let mut late_errors = Vec::new();
431
432        for i in 0..1000 {
433            rng ^= rng << 13;
434            rng ^= rng >> 7;
435            rng ^= rng << 17;
436            let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
437            rng ^= rng << 13;
438            rng ^= rng >> 7;
439            rng ^= rng << 17;
440            let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
441            let target = 2.0 * x1 + 3.0 * x2;
442
443            let pred = model.predict(&[x1, x2]);
444            let error = (pred - target).powi(2);
445
446            if (100..300).contains(&i) {
447                early_errors.push(error);
448            }
449            if i >= 800 {
450                late_errors.push(error);
451            }
452
453            model.train_one(&Sample::new(vec![x1, x2], target));
454        }
455
456        let early_rmse = (early_errors.iter().sum::<f64>() / early_errors.len() as f64).sqrt();
457        let late_rmse = (late_errors.iter().sum::<f64>() / late_errors.len() as f64).sqrt();
458
459        assert!(
460            late_rmse < early_rmse,
461            "RMSE should decrease: early={:.4}, late={:.4}",
462            early_rmse,
463            late_rmse
464        );
465    }
466
467    // -------------------------------------------------------------------
468    // 5. train_batch is equivalent to sequential train_one calls.
469    // -------------------------------------------------------------------
470    #[test]
471    fn train_batch_equivalent_to_sequential() {
472        let config = default_config();
473        let mut model_seq = ParallelSGBT::new(config.clone());
474        let mut model_batch = ParallelSGBT::new(config);
475
476        let samples: Vec<Sample> = (0..20)
477            .map(|i| {
478                let x = i as f64 * 0.5;
479                Sample::new(vec![x, x * 2.0], x * 3.0)
480            })
481            .collect();
482
483        for s in &samples {
484            model_seq.train_one(s);
485        }
486        model_batch.train_batch(&samples);
487
488        let pred_seq = model_seq.predict(&[1.0, 2.0]);
489        let pred_batch = model_batch.predict(&[1.0, 2.0]);
490
491        assert!(
492            (pred_seq - pred_batch).abs() < 1e-10,
493            "seq={}, batch={}",
494            pred_seq,
495            pred_batch
496        );
497    }
498
499    // -------------------------------------------------------------------
500    // 6. Reset returns to initial state.
501    // -------------------------------------------------------------------
502    #[test]
503    fn reset_returns_to_initial() {
504        let mut model = ParallelSGBT::new(default_config());
505        for i in 0..100 {
506            model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
507        }
508        model.reset();
509        assert_eq!(model.n_samples_seen(), 0);
510        assert!(!model.is_initialized());
511        assert!(model.predict(&[1.0, 2.0]).abs() < 1e-12);
512    }
513
514    // -------------------------------------------------------------------
515    // 7. Base prediction initializes correctly.
516    // -------------------------------------------------------------------
517    #[test]
518    fn base_prediction_initializes() {
519        let mut model = ParallelSGBT::new(default_config());
520        for i in 0..50 {
521            model.train_one(&Sample::new(vec![1.0], i as f64 + 100.0));
522        }
523        assert!(model.is_initialized());
524        let expected = (100.0 + 149.0) / 2.0;
525        assert!((model.base_prediction() - expected).abs() < 1.0);
526    }
527
528    // -------------------------------------------------------------------
529    // 8. with_loss uses custom loss function.
530    // -------------------------------------------------------------------
531    #[test]
532    fn with_loss_uses_custom_loss() {
533        use crate::loss::logistic::LogisticLoss;
534        let model = ParallelSGBT::with_loss(default_config(), LogisticLoss);
535        let pred = model.predict_transformed(&[1.0, 2.0]);
536        assert!(
537            (pred - 0.5).abs() < 1e-6,
538            "sigmoid(0) should be 0.5, got {}",
539            pred
540        );
541    }
542
543    // -------------------------------------------------------------------
544    // 9. Debug formatting works.
545    // -------------------------------------------------------------------
546    #[test]
547    fn debug_format_works() {
548        let model = ParallelSGBT::new(default_config());
549        let debug_str = format!("{:?}", model);
550        assert!(
551            debug_str.contains("ParallelSGBT"),
552            "debug output should contain 'ParallelSGBT', got: {}",
553            debug_str,
554        );
555    }
556
557    // -------------------------------------------------------------------
558    // 10. Accessors return expected values.
559    // -------------------------------------------------------------------
560    #[test]
561    fn accessors_return_expected_values() {
562        let config = default_config();
563        let n = config.n_steps;
564        let model = ParallelSGBT::new(config);
565
566        assert_eq!(model.n_steps(), n);
567        assert_eq!(model.n_trees(), n); // no alternates initially
568        assert_eq!(model.total_leaves(), n); // 1 leaf per tree initially
569        assert_eq!(model.n_samples_seen(), 0);
570        assert!(!model.is_initialized());
571    }
572
573    // -------------------------------------------------------------------
574    // 11. Batch prediction matches individual predictions.
575    // -------------------------------------------------------------------
576    #[test]
577    fn batch_prediction_matches_individual() {
578        let mut model = ParallelSGBT::new(default_config());
579        let features = vec![1.0, 2.0, 3.0];
580        for i in 0..50 {
581            model.train_one(&Sample::new(features.clone(), (i as f64) * 0.5));
582        }
583
584        let matrix = vec![
585            vec![1.0, 2.0, 3.0],
586            vec![4.0, 5.0, 6.0],
587            vec![0.0, 0.0, 0.0],
588        ];
589        let batch_preds = model.predict_batch(&matrix);
590
591        for (feats, batch_pred) in matrix.iter().zip(batch_preds.iter()) {
592            let single_pred = model.predict(feats);
593            assert!(
594                (single_pred - batch_pred).abs() < 1e-12,
595                "batch and single predictions should match",
596            );
597        }
598    }
599
600    // -------------------------------------------------------------------
601    // 12. Feature importances are normalized.
602    // -------------------------------------------------------------------
603    #[test]
604    fn feature_importances_normalized() {
605        let config = SGBTConfig::builder()
606            .n_steps(10)
607            .learning_rate(0.1)
608            .grace_period(10)
609            .max_depth(3)
610            .n_bins(16)
611            .build()
612            .unwrap();
613        let mut model = ParallelSGBT::new(config);
614
615        // Train enough for splits to occur.
616        let mut rng: u64 = 42;
617        for _ in 0..200 {
618            rng ^= rng << 13;
619            rng ^= rng >> 7;
620            rng ^= rng << 17;
621            let x1 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
622            rng ^= rng << 13;
623            rng ^= rng >> 7;
624            rng ^= rng << 17;
625            let x2 = (rng as f64 / u64::MAX as f64) * 10.0 - 5.0;
626            let target = 3.0 * x1 - x2;
627            model.train_one(&Sample::new(vec![x1, x2], target));
628        }
629
630        let importances = model.feature_importances();
631        if !importances.is_empty() {
632            let sum: f64 = importances.iter().sum();
633            assert!(
634                (sum - 1.0).abs() < 1e-8,
635                "importances should sum to 1.0, got {}",
636                sum,
637            );
638            for &v in &importances {
639                assert!(v >= 0.0, "importances should be non-negative");
640            }
641        }
642    }
643
644    // -------------------------------------------------------------------
645    // 13. Variant train_counts are pre-computed correctly (Skip variant).
646    // -------------------------------------------------------------------
647    #[test]
648    fn skip_variant_works_with_parallel() {
649        use crate::ensemble::variants::SGBTVariant;
650
651        let config = SGBTConfig::builder()
652            .n_steps(10)
653            .learning_rate(0.1)
654            .grace_period(20)
655            .max_depth(4)
656            .n_bins(16)
657            .variant(SGBTVariant::Skip { k: 3 })
658            .build()
659            .unwrap();
660        let mut model = ParallelSGBT::new(config);
661
662        // Should not panic, and should train with some steps skipped.
663        for i in 0..100 {
664            model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
665        }
666
667        assert_eq!(model.n_samples_seen(), 100);
668        let pred = model.predict(&[1.0, 2.0]);
669        assert!(pred.is_finite());
670    }
671
672    // -------------------------------------------------------------------
673    // 14. MI variant works with parallel.
674    // -------------------------------------------------------------------
675    #[test]
676    fn mi_variant_works_with_parallel() {
677        use crate::ensemble::variants::SGBTVariant;
678
679        let config = SGBTConfig::builder()
680            .n_steps(10)
681            .learning_rate(0.1)
682            .grace_period(20)
683            .max_depth(4)
684            .n_bins(16)
685            .variant(SGBTVariant::MultipleIterations { multiplier: 2.0 })
686            .build()
687            .unwrap();
688        let mut model = ParallelSGBT::new(config);
689
690        for i in 0..100 {
691            model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
692        }
693
694        assert_eq!(model.n_samples_seen(), 100);
695        let pred = model.predict(&[1.0, 2.0]);
696        assert!(pred.is_finite());
697    }
698
699    // -------------------------------------------------------------------
700    // 15. Predict_proba and predict_transformed are equivalent.
701    // -------------------------------------------------------------------
702    #[test]
703    fn predict_proba_equals_predict_transformed() {
704        let mut model = ParallelSGBT::new(default_config());
705        for i in 0..50 {
706            model.train_one(&Sample::new(vec![1.0, 2.0], i as f64));
707        }
708
709        let feats = [1.0, 2.0];
710        let transformed = model.predict_transformed(&feats);
711        let proba = model.predict_proba(&feats);
712        assert!(
713            (transformed - proba).abs() < 1e-12,
714            "predict_proba and predict_transformed should be identical",
715        );
716    }
717}