1use alloc::boxed::Box;
8use alloc::string::String;
9use alloc::vec::Vec;
10
11use crate::drift::adwin::Adwin;
12use crate::drift::ddm::Ddm;
13use crate::drift::pht::PageHinkleyTest;
14use crate::drift::DriftDetector;
15use crate::ensemble::variants::SGBTVariant;
16use crate::error::Result;
17use crate::tree::leaf_model::LeafModelType;
18
19mod display;
20mod tree_config_helper;
21mod validation;
22
23pub(crate) use tree_config_helper::build_tree_config;
24
25pub use crate::feature::FeatureType;
26
27#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
38#[cfg_attr(
39 feature = "_serde_support",
40 derive(serde::Serialize, serde::Deserialize)
41)]
42#[non_exhaustive]
43pub enum ScaleMode {
44 #[default]
45 Empirical,
47 TreeChain,
49}
50
51#[derive(Debug, Clone, PartialEq)]
56#[cfg_attr(
57 feature = "_serde_support",
58 derive(serde::Serialize, serde::Deserialize)
59)]
60#[non_exhaustive]
61pub enum DriftDetectorType {
62 PageHinkley {
64 delta: f64,
66 lambda: f64,
68 },
69 Adwin {
71 delta: f64,
73 },
74 Ddm {
76 warning_level: f64,
78 drift_level: f64,
80 min_instances: u64,
82 },
83}
84
85impl Default for DriftDetectorType {
86 fn default() -> Self {
87 DriftDetectorType::PageHinkley {
88 delta: 0.005,
89 lambda: 50.0,
90 }
91 }
92}
93
94impl DriftDetectorType {
95 pub fn create(&self) -> Box<dyn DriftDetector> {
97 match self {
98 Self::PageHinkley { delta, lambda } => {
99 Box::new(PageHinkleyTest::with_params(*delta, *lambda))
100 }
101 Self::Adwin { delta } => Box::new(Adwin::with_delta(*delta)),
102 Self::Ddm {
103 warning_level,
104 drift_level,
105 min_instances,
106 } => Box::new(Ddm::with_params(
107 *warning_level,
108 *drift_level,
109 *min_instances,
110 )),
111 }
112 }
113}
114
115#[derive(Debug, Clone, PartialEq)]
119#[cfg_attr(
120 feature = "_serde_support",
121 derive(serde::Serialize, serde::Deserialize)
122)]
123pub struct SGBTConfig {
124 pub n_steps: usize,
126 pub learning_rate: f64,
128 pub feature_subsample_rate: f64,
130 pub max_depth: usize,
132 pub n_bins: usize,
134 pub lambda: f64,
136 pub gamma: f64,
138 pub grace_period: usize,
140 pub delta: f64,
142 pub drift_detector: DriftDetectorType,
144 pub variant: SGBTVariant,
146 pub seed: u64,
148 pub initial_target_count: usize,
150
151 #[cfg_attr(feature = "_serde_support", serde(default))]
153 pub leaf_half_life: Option<usize>,
154
155 #[cfg_attr(feature = "_serde_support", serde(default))]
157 pub max_tree_samples: Option<u64>,
158
159 #[cfg_attr(feature = "_serde_support", serde(default))]
162 pub adaptive_mts: Option<(u64, f64)>,
163
164 #[cfg_attr(feature = "_serde_support", serde(default))]
166 pub adaptive_mts_floor: f64,
167
168 #[cfg_attr(feature = "_serde_support", serde(default))]
170 pub proactive_prune_interval: Option<u64>,
171
172 #[cfg_attr(feature = "_serde_support", serde(default))]
174 pub split_reeval_interval: Option<usize>,
175
176 #[cfg_attr(feature = "_serde_support", serde(default))]
178 pub feature_names: Option<Vec<String>>,
179
180 #[cfg_attr(feature = "_serde_support", serde(default))]
182 pub feature_types: Option<Vec<FeatureType>>,
183
184 #[cfg_attr(feature = "_serde_support", serde(default))]
186 pub gradient_clip_sigma: Option<f64>,
187
188 #[cfg_attr(feature = "_serde_support", serde(default))]
190 pub monotone_constraints: Option<Vec<i8>>,
191
192 #[cfg_attr(feature = "_serde_support", serde(default))]
194 pub quality_prune_alpha: Option<f64>,
195
196 #[cfg_attr(
198 feature = "_serde_support",
199 serde(default = "default_quality_prune_threshold")
200 )]
201 pub quality_prune_threshold: f64,
202
203 #[cfg_attr(
205 feature = "_serde_support",
206 serde(default = "default_quality_prune_patience")
207 )]
208 pub quality_prune_patience: u64,
209
210 #[cfg_attr(feature = "_serde_support", serde(default))]
212 pub error_weight_alpha: Option<f64>,
213
214 #[cfg_attr(feature = "_serde_support", serde(default))]
216 pub uncertainty_modulated_lr: bool,
217
218 #[cfg_attr(feature = "_serde_support", serde(default))]
220 pub scale_mode: ScaleMode,
221
222 #[cfg_attr(
224 feature = "_serde_support",
225 serde(default = "default_empirical_sigma_alpha")
226 )]
227 pub empirical_sigma_alpha: f64,
228
229 #[cfg_attr(feature = "_serde_support", serde(default))]
231 pub max_leaf_output: Option<f64>,
232
233 #[cfg_attr(feature = "_serde_support", serde(default))]
235 pub adaptive_leaf_bound: Option<f64>,
236
237 #[cfg_attr(feature = "_serde_support", serde(default))]
239 pub adaptive_depth: Option<f64>,
240
241 #[cfg_attr(feature = "_serde_support", serde(default))]
243 pub min_hessian_sum: Option<f64>,
244
245 #[cfg_attr(feature = "_serde_support", serde(default))]
247 pub huber_k: Option<f64>,
248
249 #[cfg_attr(feature = "_serde_support", serde(default))]
251 pub shadow_warmup: Option<usize>,
252
253 #[cfg_attr(feature = "_serde_support", serde(default))]
255 pub leaf_model_type: LeafModelType,
256
257 #[cfg_attr(feature = "_serde_support", serde(default))]
259 pub packed_refresh_interval: u64,
260
261 #[cfg_attr(feature = "_serde_support", serde(default))]
263 pub hoeffding_r: Option<f64>,
264}
265
266#[cfg(feature = "_serde_support")]
267fn default_empirical_sigma_alpha() -> f64 {
268 0.01
269}
270
271#[cfg(feature = "_serde_support")]
272fn default_quality_prune_threshold() -> f64 {
273 1e-6
274}
275
276#[cfg(feature = "_serde_support")]
277fn default_quality_prune_patience() -> u64 {
278 500
279}
280
281impl Default for SGBTConfig {
282 fn default() -> Self {
283 Self {
284 n_steps: 100,
285 learning_rate: 0.0125,
286 feature_subsample_rate: 0.75,
287 max_depth: 6,
288 n_bins: 64,
289 lambda: 1.0,
290 gamma: 0.0,
291 grace_period: 200,
292 delta: 1e-7,
293 drift_detector: DriftDetectorType::default(),
294 variant: SGBTVariant::default(),
295 seed: 0xDEAD_BEEF_CAFE_4242,
296 initial_target_count: 50,
297 leaf_half_life: None,
298 max_tree_samples: None,
299 adaptive_mts: None,
300 adaptive_mts_floor: 0.0,
301 proactive_prune_interval: None,
302 split_reeval_interval: None,
303 feature_names: None,
304 feature_types: None,
305 gradient_clip_sigma: None,
306 monotone_constraints: None,
307 quality_prune_alpha: None,
308 quality_prune_threshold: 1e-6,
309 quality_prune_patience: 500,
310 error_weight_alpha: None,
311 uncertainty_modulated_lr: false,
312 scale_mode: ScaleMode::default(),
313 empirical_sigma_alpha: 0.01,
314 max_leaf_output: None,
315 adaptive_leaf_bound: None,
316 adaptive_depth: None,
317 min_hessian_sum: None,
318 huber_k: None,
319 shadow_warmup: None,
320 leaf_model_type: LeafModelType::default(),
321 packed_refresh_interval: 0,
322 hoeffding_r: None,
323 }
324 }
325}
326
327impl SGBTConfig {
328 pub fn builder() -> SGBTConfigBuilder {
330 SGBTConfigBuilder::default()
331 }
332}
333
334#[derive(Debug, Clone, Default)]
336pub struct SGBTConfigBuilder {
337 config: SGBTConfig,
338}
339
340impl SGBTConfigBuilder {
341 pub fn n_steps(mut self, n: usize) -> Self {
343 self.config.n_steps = n;
344 self
345 }
346
347 pub fn learning_rate(mut self, lr: f64) -> Self {
349 self.config.learning_rate = lr;
350 self
351 }
352
353 pub fn feature_subsample_rate(mut self, rate: f64) -> Self {
355 self.config.feature_subsample_rate = rate;
356 self
357 }
358
359 pub fn max_depth(mut self, depth: usize) -> Self {
361 self.config.max_depth = depth;
362 self
363 }
364
365 pub fn n_bins(mut self, bins: usize) -> Self {
367 self.config.n_bins = bins;
368 self
369 }
370
371 pub fn lambda(mut self, l: f64) -> Self {
373 self.config.lambda = l;
374 self
375 }
376
377 pub fn gamma(mut self, g: f64) -> Self {
379 self.config.gamma = g;
380 self
381 }
382
383 pub fn grace_period(mut self, gp: usize) -> Self {
385 self.config.grace_period = gp;
386 self
387 }
388
389 pub fn delta(mut self, d: f64) -> Self {
391 self.config.delta = d;
392 self
393 }
394
395 pub fn drift_detector(mut self, dt: DriftDetectorType) -> Self {
397 self.config.drift_detector = dt;
398 self
399 }
400
401 pub fn variant(mut self, v: SGBTVariant) -> Self {
403 self.config.variant = v;
404 self
405 }
406
407 pub fn seed(mut self, seed: u64) -> Self {
409 self.config.seed = seed;
410 self
411 }
412
413 pub fn initial_target_count(mut self, count: usize) -> Self {
415 self.config.initial_target_count = count;
416 self
417 }
418
419 pub fn leaf_half_life(mut self, n: usize) -> Self {
421 self.config.leaf_half_life = Some(n);
422 self
423 }
424
425 pub fn max_tree_samples(mut self, n: u64) -> Self {
427 self.config.max_tree_samples = Some(n);
428 self
429 }
430
431 pub fn adaptive_mts(mut self, base_mts: u64, k: f64) -> Self {
433 self.config.adaptive_mts = Some((base_mts, k));
434 self
435 }
436
437 pub fn adaptive_mts_floor(mut self, fraction: f64) -> Self {
439 self.config.adaptive_mts_floor = fraction;
440 self
441 }
442
443 pub fn proactive_prune_interval(mut self, interval: u64) -> Self {
445 self.config.proactive_prune_interval = Some(interval);
446 self
447 }
448
449 pub fn split_reeval_interval(mut self, n: usize) -> Self {
451 self.config.split_reeval_interval = Some(n);
452 self
453 }
454
455 pub fn feature_names(mut self, names: Vec<String>) -> Self {
457 self.config.feature_names = Some(names);
458 self
459 }
460
461 pub fn feature_types(mut self, types: Vec<FeatureType>) -> Self {
463 self.config.feature_types = Some(types);
464 self
465 }
466
467 pub fn gradient_clip_sigma(mut self, sigma: f64) -> Self {
469 self.config.gradient_clip_sigma = Some(sigma);
470 self
471 }
472
473 pub fn monotone_constraints(mut self, constraints: Vec<i8>) -> Self {
475 self.config.monotone_constraints = Some(constraints);
476 self
477 }
478
479 pub fn quality_prune_alpha(mut self, alpha: f64) -> Self {
481 self.config.quality_prune_alpha = Some(alpha);
482 self
483 }
484
485 pub fn quality_prune_threshold(mut self, threshold: f64) -> Self {
487 self.config.quality_prune_threshold = threshold;
488 self
489 }
490
491 pub fn quality_prune_patience(mut self, patience: u64) -> Self {
493 self.config.quality_prune_patience = patience;
494 self
495 }
496
497 pub fn error_weight_alpha(mut self, alpha: f64) -> Self {
499 self.config.error_weight_alpha = Some(alpha);
500 self
501 }
502
503 pub fn uncertainty_modulated_lr(mut self, enabled: bool) -> Self {
505 self.config.uncertainty_modulated_lr = enabled;
506 self
507 }
508
509 pub fn scale_mode(mut self, mode: ScaleMode) -> Self {
511 self.config.scale_mode = mode;
512 self
513 }
514
515 pub fn empirical_sigma_alpha(mut self, alpha: f64) -> Self {
517 self.config.empirical_sigma_alpha = alpha;
518 self
519 }
520
521 pub fn max_leaf_output(mut self, max: f64) -> Self {
523 self.config.max_leaf_output = Some(max);
524 self
525 }
526
527 pub fn adaptive_leaf_bound(mut self, k: f64) -> Self {
529 self.config.adaptive_leaf_bound = Some(k);
530 self
531 }
532
533 pub fn adaptive_depth(mut self, factor: f64) -> Self {
535 self.config.adaptive_depth = Some(factor);
536 self
537 }
538
539 pub fn min_hessian_sum(mut self, min_h: f64) -> Self {
541 self.config.min_hessian_sum = Some(min_h);
542 self
543 }
544
545 pub fn huber_k(mut self, k: f64) -> Self {
547 self.config.huber_k = Some(k);
548 self
549 }
550
551 pub fn shadow_warmup(mut self, warmup: usize) -> Self {
553 self.config.shadow_warmup = Some(warmup);
554 self
555 }
556
557 pub fn leaf_model_type(mut self, lmt: LeafModelType) -> Self {
559 self.config.leaf_model_type = lmt;
560 self
561 }
562
563 pub fn packed_refresh_interval(mut self, interval: u64) -> Self {
565 self.config.packed_refresh_interval = interval;
566 self
567 }
568
569 pub fn hoeffding_r(mut self, r: f64) -> Self {
571 self.config.hoeffding_r = Some(r);
572 self
573 }
574
575 pub fn build(self) -> Result<SGBTConfig> {
577 validation::validate_and_build(self.config)
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use alloc::vec;
585
586 #[test]
587 fn default_config_valid() {
588 let cfg = SGBTConfig::default();
589 assert_eq!(cfg.n_steps, 100);
590 assert_eq!(cfg.learning_rate, 0.0125);
591 }
592
593 #[test]
594 fn builder_basic() {
595 let cfg = SGBTConfig::builder()
596 .n_steps(50)
597 .learning_rate(0.05)
598 .build()
599 .unwrap();
600 assert_eq!(cfg.n_steps, 50);
601 assert_eq!(cfg.learning_rate, 0.05);
602 }
603
604 #[test]
605 fn validation_rejects_zero_n_steps() {
606 let result = SGBTConfig::builder().n_steps(0).build();
607 assert!(result.is_err());
608 }
609
610 #[test]
611 fn validation_accepts_valid_learning_rate() {
612 let result = SGBTConfig::builder().learning_rate(0.1).build();
613 assert!(result.is_ok());
614 }
615
616 #[test]
617 fn validation_rejects_zero_learning_rate() {
618 let result = SGBTConfig::builder().learning_rate(0.0).build();
619 assert!(result.is_err());
620 }
621
622 #[test]
623 fn validation_rejects_learning_rate_above_one() {
624 let result = SGBTConfig::builder().learning_rate(1.5).build();
625 assert!(result.is_err());
626 }
627
628 #[test]
629 fn validation_accepts_learning_rate_one() {
630 let result = SGBTConfig::builder().learning_rate(1.0).build();
631 assert!(result.is_ok());
632 }
633
634 #[test]
635 fn drift_detector_type_create() {
636 let dt = DriftDetectorType::PageHinkley {
637 delta: 0.005,
638 lambda: 50.0,
639 };
640 let mut detector = dt.create();
641 for _ in 0..500 {
642 detector.update(1.0);
643 }
644 let mut drifted = false;
645 for _ in 0..500 {
646 if detector.update(10.0) == crate::drift::DriftSignal::Drift {
647 drifted = true;
648 break;
649 }
650 }
651 assert!(drifted);
652 }
653
654 #[test]
655 fn boundary_n_bins_two_accepted() {
656 let result = SGBTConfig::builder().n_bins(2).build();
657 assert!(result.is_ok());
658 }
659
660 #[test]
661 fn boundary_grace_period_one_accepted() {
662 let result = SGBTConfig::builder().grace_period(1).build();
663 assert!(result.is_ok());
664 }
665
666 #[test]
667 fn feature_names_accepted() {
668 let cfg = SGBTConfig::builder()
669 .feature_names(vec!["price".into(), "volume".into(), "spread".into()])
670 .build()
671 .unwrap();
672 assert_eq!(
673 cfg.feature_names.as_ref().unwrap(),
674 &["price", "volume", "spread"]
675 );
676 }
677
678 #[test]
679 fn feature_names_rejects_duplicates() {
680 let result = SGBTConfig::builder()
681 .feature_names(vec!["price".into(), "volume".into(), "price".into()])
682 .build();
683 assert!(result.is_err());
684 }
685
686 #[test]
687 fn feature_names_empty_vec_accepted() {
688 let cfg = SGBTConfig::builder().feature_names(vec![]).build().unwrap();
689 assert!(cfg.feature_names.unwrap().is_empty());
690 }
691
692 #[test]
693 fn builder_adaptive_leaf_bound() {
694 let cfg = SGBTConfig::builder()
695 .adaptive_leaf_bound(3.0)
696 .build()
697 .unwrap();
698 assert_eq!(cfg.adaptive_leaf_bound, Some(3.0));
699 }
700
701 #[test]
702 fn validation_rejects_zero_adaptive_leaf_bound() {
703 let result = SGBTConfig::builder().adaptive_leaf_bound(0.0).build();
704 assert!(result.is_err());
705 }
706}