1use alloc::boxed::Box;
22use alloc::vec;
23use alloc::vec::Vec;
24
25use crate::feature::FeatureType;
26use crate::histogram::bins::LeafHistograms;
27use crate::histogram::{BinEdges, BinnerKind};
28use crate::math;
29use crate::tree::builder::TreeConfig;
30use crate::tree::leaf_model::{LeafModel, LeafModelType};
31use crate::tree::node::{NodeId, TreeArena};
32use crate::tree::split::{leaf_weight, SplitCandidate, SplitCriterion, XGBoostGain};
33use crate::tree::StreamingTree;
34
35const TAU: f64 = 0.05;
39
40#[inline]
46fn xorshift64(state: &mut u64) -> u64 {
47 let mut s = *state;
48 s ^= s << 13;
49 s ^= s >> 7;
50 s ^= s << 17;
51 *state = s;
52 s
53}
54
55struct LeafState {
64 histograms: Option<LeafHistograms>,
67
68 binners: Vec<BinnerKind>,
72
73 bins_ready: bool,
75
76 grad_sum: f64,
78
79 hess_sum: f64,
81
82 last_reeval_count: u64,
84
85 clip_grad_mean: f64,
87
88 clip_grad_m2: f64,
90
91 clip_grad_count: u64,
93
94 output_mean: f64,
96
97 output_m2: f64,
99
100 output_count: u64,
102
103 leaf_model: Option<Box<dyn LeafModel>>,
105}
106
107impl Clone for LeafState {
108 fn clone(&self) -> Self {
109 Self {
110 histograms: self.histograms.clone(),
111 binners: self.binners.clone(),
112 bins_ready: self.bins_ready,
113 grad_sum: self.grad_sum,
114 hess_sum: self.hess_sum,
115 last_reeval_count: self.last_reeval_count,
116 clip_grad_mean: self.clip_grad_mean,
117 clip_grad_m2: self.clip_grad_m2,
118 clip_grad_count: self.clip_grad_count,
119 output_mean: self.output_mean,
120 output_m2: self.output_m2,
121 output_count: self.output_count,
122 leaf_model: self.leaf_model.as_ref().map(|m| m.clone_warm()),
123 }
124 }
125}
126
127#[inline]
133fn clip_gradient(state: &mut LeafState, gradient: f64, sigma: f64) -> f64 {
134 state.clip_grad_count += 1;
135 let n = state.clip_grad_count as f64;
136
137 let delta = gradient - state.clip_grad_mean;
139 state.clip_grad_mean += delta / n;
140 let delta2 = gradient - state.clip_grad_mean;
141 state.clip_grad_m2 += delta * delta2;
142
143 if state.clip_grad_count < 10 {
145 return gradient;
146 }
147
148 let variance = state.clip_grad_m2 / (n - 1.0);
149 let std_dev = math::sqrt(variance);
150
151 if std_dev < 1e-15 {
152 return gradient; }
154
155 let lo = state.clip_grad_mean - sigma * std_dev;
156 let hi = state.clip_grad_mean + sigma * std_dev;
157 gradient.clamp(lo, hi)
158}
159
160#[inline]
165fn update_output_stats(state: &mut LeafState, weight: f64, decay_alpha: Option<f64>) {
166 state.output_count += 1;
167
168 if let Some(alpha) = decay_alpha {
169 if state.output_count == 1 {
171 state.output_mean = weight;
172 state.output_m2 = 0.0;
173 } else {
174 let diff = weight - state.output_mean;
175 state.output_mean = alpha * state.output_mean + (1.0 - alpha) * weight;
176 let diff2 = weight - state.output_mean;
177 state.output_m2 = alpha * state.output_m2 + (1.0 - alpha) * diff * diff2;
178 }
179 } else {
180 let delta = weight - state.output_mean;
182 state.output_mean += delta / (state.output_count as f64);
183 let delta2 = weight - state.output_mean;
184 state.output_m2 += delta * delta2;
185 }
186}
187
188#[inline]
193fn adaptive_bound(state: &LeafState, k: f64, decay_alpha: Option<f64>) -> f64 {
194 if state.output_count < 10 {
195 return f64::MAX; }
197
198 let variance = if decay_alpha.is_some() {
199 state.output_m2.max(0.0)
201 } else {
202 state.output_m2 / (state.output_count as f64 - 1.0)
204 };
205 let std = math::sqrt(variance);
206
207 (math::abs(state.output_mean) + k * std).max(0.01)
209}
210
211fn make_binners(n_features: usize, feature_types: Option<&[FeatureType]>) -> Vec<BinnerKind> {
213 (0..n_features)
214 .map(|i| {
215 if let Some(ft) = feature_types {
216 if i < ft.len() && ft[i] == FeatureType::Categorical {
217 return BinnerKind::categorical();
218 }
219 }
220 BinnerKind::uniform()
221 })
222 .collect()
223}
224
225impl LeafState {
226 fn new(n_features: usize) -> Self {
228 Self::new_with_types(n_features, None)
229 }
230
231 fn new_with_types(n_features: usize, feature_types: Option<&[FeatureType]>) -> Self {
233 let binners = make_binners(n_features, feature_types);
234
235 Self {
236 histograms: None,
237 binners,
238 bins_ready: false,
239 grad_sum: 0.0,
240 hess_sum: 0.0,
241 last_reeval_count: 0,
242 clip_grad_mean: 0.0,
243 clip_grad_m2: 0.0,
244 clip_grad_count: 0,
245 output_mean: 0.0,
246 output_m2: 0.0,
247 output_count: 0,
248 leaf_model: None,
249 }
250 }
251
252 #[allow(dead_code)]
255 fn with_histograms(histograms: LeafHistograms) -> Self {
256 let n_features = histograms.n_features();
257 let binners: Vec<BinnerKind> = (0..n_features).map(|_| BinnerKind::uniform()).collect();
258
259 let grad_sum: f64 = histograms
261 .histograms
262 .first()
263 .map_or(0.0, |h| h.total_gradient());
264 let hess_sum: f64 = histograms
265 .histograms
266 .first()
267 .map_or(0.0, |h| h.total_hessian());
268
269 Self {
270 histograms: Some(histograms),
271 binners,
272 bins_ready: true,
273 grad_sum,
274 hess_sum,
275 last_reeval_count: 0,
276 clip_grad_mean: 0.0,
277 clip_grad_m2: 0.0,
278 clip_grad_count: 0,
279 output_mean: 0.0,
280 output_m2: 0.0,
281 output_count: 0,
282 leaf_model: None,
283 }
284 }
285}
286
287pub struct HoeffdingTree {
303 arena: TreeArena,
305
306 root: NodeId,
308
309 config: TreeConfig,
311
312 leaf_states: Vec<Option<LeafState>>,
315
316 n_features: Option<usize>,
318
319 samples_seen: u64,
321
322 split_criterion: XGBoostGain,
324
325 feature_mask: Vec<usize>,
327
328 feature_mask_bits: Vec<u64>,
331
332 rng_state: u64,
334
335 split_gains: Vec<f64>,
338
339 node_bandwidths: Vec<f64>,
342}
343
344impl HoeffdingTree {
345 pub fn new(config: TreeConfig) -> Self {
350 let mut arena = TreeArena::new();
351 let root = arena.add_leaf(0);
352
353 let mut leaf_states = vec![None; root.0 as usize + 1];
357 let root_model = match config.leaf_model_type {
358 LeafModelType::ClosedForm => None,
359 _ => Some(config.leaf_model_type.create(config.seed, config.delta)),
360 };
361 leaf_states[root.0 as usize] = Some(LeafState {
362 histograms: None,
363 binners: Vec::new(),
364 bins_ready: false,
365 grad_sum: 0.0,
366 hess_sum: 0.0,
367 last_reeval_count: 0,
368 clip_grad_mean: 0.0,
369 clip_grad_m2: 0.0,
370 clip_grad_count: 0,
371 output_mean: 0.0,
372 output_m2: 0.0,
373 output_count: 0,
374 leaf_model: root_model,
375 });
376
377 let seed = config.seed;
378 Self {
379 arena,
380 root,
381 config,
382 leaf_states,
383 n_features: None,
384 samples_seen: 0,
385 split_criterion: XGBoostGain::default(),
386 feature_mask: Vec::new(),
387 feature_mask_bits: Vec::new(),
388 rng_state: seed,
389 split_gains: Vec::new(),
390 node_bandwidths: Vec::new(),
391 }
392 }
393
394 fn make_leaf_model(&self, node: NodeId) -> Option<Box<dyn LeafModel>> {
400 match self.config.leaf_model_type {
401 LeafModelType::ClosedForm => None,
402 _ => Some(
403 self.config
404 .leaf_model_type
405 .create(self.config.seed ^ (node.0 as u64), self.config.delta),
406 ),
407 }
408 }
409
410 pub fn from_arena(
419 config: TreeConfig,
420 arena: TreeArena,
421 n_features: Option<usize>,
422 samples_seen: u64,
423 rng_state: u64,
424 ) -> Self {
425 let root = if arena.n_nodes() > 0 {
426 NodeId(0)
427 } else {
428 let mut arena_mut = arena;
430 let root = arena_mut.add_leaf(0);
431 return Self {
432 arena: arena_mut,
433 root,
434 config: config.clone(),
435 leaf_states: {
436 let mut v = vec![None; root.0 as usize + 1];
437 v[root.0 as usize] = Some(LeafState::new(n_features.unwrap_or(0)));
438 v
439 },
440 n_features,
441 samples_seen,
442 split_criterion: XGBoostGain::default(),
443 feature_mask: Vec::new(),
444 feature_mask_bits: Vec::new(),
445 rng_state,
446 split_gains: vec![0.0; n_features.unwrap_or(0)],
447 node_bandwidths: Vec::new(),
448 };
449 };
450
451 let nf = n_features.unwrap_or(0);
453 let mut leaf_states: Vec<Option<LeafState>> = vec![None; arena.n_nodes()];
454 for (i, slot) in leaf_states.iter_mut().enumerate() {
455 if arena.is_leaf[i] {
456 *slot = Some(LeafState::new(nf));
457 }
458 }
459
460 Self {
461 arena,
462 root,
463 config,
464 leaf_states,
465 n_features,
466 samples_seen,
467 split_criterion: XGBoostGain::default(),
468 feature_mask: Vec::new(),
469 feature_mask_bits: Vec::new(),
470 rng_state,
471 split_gains: vec![0.0; nf],
472 node_bandwidths: Vec::new(),
473 }
474 }
475
476 #[inline]
478 pub fn root(&self) -> NodeId {
479 self.root
480 }
481
482 #[inline]
484 pub fn arena(&self) -> &TreeArena {
485 &self.arena
486 }
487
488 #[inline]
490 pub fn tree_config(&self) -> &TreeConfig {
491 &self.config
492 }
493
494 #[inline]
496 pub fn n_features(&self) -> Option<usize> {
497 self.n_features
498 }
499
500 #[inline]
502 pub fn rng_state(&self) -> u64 {
503 self.rng_state
504 }
505
506 #[inline]
516 pub fn leaf_grad_hess(&self, node: NodeId) -> Option<(f64, f64)> {
517 self.leaf_states
518 .get(node.0 as usize)
519 .and_then(|o| o.as_ref())
520 .map(|state| (state.grad_sum, state.hess_sum))
521 }
522
523 fn route_to_leaf(&self, features: &[f64]) -> NodeId {
525 let mut current = self.root;
526 while !self.arena.is_leaf(current) {
527 let feat_idx = self.arena.get_feature_idx(current) as usize;
528 current = if let Some(mask) = self.arena.get_categorical_mask(current) {
529 let cat_val = features[feat_idx] as u64;
536 if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
537 self.arena.get_left(current)
538 } else {
539 self.arena.get_right(current)
540 }
541 } else {
542 let threshold = self.arena.get_threshold(current);
544 if features[feat_idx] <= threshold {
545 self.arena.get_left(current)
546 } else {
547 self.arena.get_right(current)
548 }
549 };
550 }
551 current
552 }
553
554 #[inline]
559 fn leaf_prediction(&self, leaf_id: NodeId, features: &[f64]) -> f64 {
560 let (raw, leaf_bound) = if let Some(state) = self
561 .leaf_states
562 .get(leaf_id.0 as usize)
563 .and_then(|o| o.as_ref())
564 {
565 if let Some(min_h) = self.config.min_hessian_sum {
567 if state.hess_sum < min_h {
568 return 0.0;
569 }
570 }
571 let val = if let Some(ref model) = state.leaf_model {
572 model.predict(features)
573 } else if state.hess_sum != 0.0 {
574 leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda)
575 } else {
576 self.arena.leaf_value[leaf_id.0 as usize]
577 };
578
579 let bound = self
581 .config
582 .adaptive_leaf_bound
583 .map(|k| adaptive_bound(state, k, self.config.leaf_decay_alpha));
584
585 (val, bound)
586 } else {
587 (0.0, None)
588 };
589
590 if let Some(bound) = leaf_bound {
592 if bound < f64::MAX {
593 return raw.clamp(-bound, bound);
594 }
595 }
596 if let Some(max) = self.config.max_leaf_output {
597 raw.clamp(-max, max)
598 } else {
599 raw
600 }
601 }
602
603 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
619 self.predict_smooth_recursive(self.root, features, bandwidth)
620 }
621
622 pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
627 self.predict_smooth_auto_recursive(self.root, features, bandwidths)
628 }
629
630 pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
642 let mut current = self.root;
643 let mut parent = None;
644 while !self.arena.is_leaf(current) {
645 parent = Some(current);
646 let feat_idx = self.arena.get_feature_idx(current) as usize;
647 current = if let Some(mask) = self.arena.get_categorical_mask(current) {
648 let cat_val = features[feat_idx] as u64;
649 if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
650 self.arena.get_left(current)
651 } else {
652 self.arena.get_right(current)
653 }
654 } else {
655 let threshold = self.arena.get_threshold(current);
656 if features[feat_idx] <= threshold {
657 self.arena.get_left(current)
658 } else {
659 self.arena.get_right(current)
660 }
661 };
662 }
663
664 let leaf_pred = self.leaf_prediction(current, features);
665
666 let parent_id = match parent {
668 Some(p) => p,
669 None => return leaf_pred,
670 };
671
672 let parent_pred = self.leaf_prediction(parent_id, features);
674
675 let leaf_hess = self
677 .leaf_states
678 .get(current.0 as usize)
679 .and_then(|o| o.as_ref())
680 .map(|s| s.hess_sum)
681 .unwrap_or(0.0);
682
683 let alpha = leaf_hess / (leaf_hess + self.config.lambda);
684 alpha * leaf_pred + (1.0 - alpha) * parent_pred
685 }
686
687 pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
703 self.predict_sibling_recursive(self.root, features, bandwidths)
704 }
705
706 fn predict_sibling_recursive(&self, node: NodeId, features: &[f64], bandwidths: &[f64]) -> f64 {
707 if self.arena.is_leaf(node) {
708 return self.leaf_prediction(node, features);
709 }
710
711 let feat_idx = self.arena.get_feature_idx(node) as usize;
712 let left = self.arena.get_left(node);
713 let right = self.arena.get_right(node);
714
715 if let Some(mask) = self.arena.get_categorical_mask(node) {
717 let cat_val = features[feat_idx] as u64;
718 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
719 self.predict_sibling_recursive(left, features, bandwidths)
720 } else {
721 self.predict_sibling_recursive(right, features, bandwidths)
722 };
723 }
724
725 let threshold = self.arena.get_threshold(node);
726 let margin = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
727
728 if !margin.is_finite() || margin <= 0.0 {
730 return if features[feat_idx] <= threshold {
731 self.predict_sibling_recursive(left, features, bandwidths)
732 } else {
733 self.predict_sibling_recursive(right, features, bandwidths)
734 };
735 }
736
737 let dist = features[feat_idx] - threshold;
738
739 if dist < -margin {
740 self.predict_sibling_recursive(left, features, bandwidths)
742 } else if dist > margin {
743 self.predict_sibling_recursive(right, features, bandwidths)
745 } else {
746 let t = (dist + margin) / (2.0 * margin); let left_pred = self.predict_sibling_recursive(left, features, bandwidths);
749 let right_pred = self.predict_sibling_recursive(right, features, bandwidths);
750 (1.0 - t) * left_pred + t * right_pred
751 }
752 }
753
754 pub fn collect_split_thresholds_per_feature(&self) -> Vec<Vec<f64>> {
759 let n = self.n_features.unwrap_or(0);
760 let mut thresholds: Vec<Vec<f64>> = vec![Vec::new(); n];
761
762 for i in 0..self.arena.n_nodes() {
763 if !self.arena.is_leaf[i] && self.arena.categorical_mask[i].is_none() {
764 let feat_idx = self.arena.feature_idx[i] as usize;
765 if feat_idx < n {
766 thresholds[feat_idx].push(self.arena.threshold[i]);
767 }
768 }
769 }
770
771 thresholds
772 }
773
774 fn compute_node_bandwidth(&self, node: NodeId, all_thresholds: &[Vec<f64>]) -> f64 {
776 let feat_idx = self.arena.get_feature_idx(node) as usize;
777 let threshold = self.arena.get_threshold(node);
778
779 let thresholds = if feat_idx < all_thresholds.len() {
780 &all_thresholds[feat_idx]
781 } else {
782 return f64::INFINITY;
783 };
784
785 let below = thresholds.iter().rev().find(|&&t| t < threshold - 1e-15);
787 let above = thresholds.iter().find(|&&t| t > threshold + 1e-15);
788
789 match (below, above) {
790 (Some(&b), Some(&a)) => (threshold - b).min(a - threshold),
791 (Some(&b), None) => threshold - b,
792 (None, Some(&a)) => a - threshold,
793 (None, None) => f64::INFINITY,
794 }
795 }
796
797 pub fn recompute_bandwidths(&mut self) {
799 let n = self.arena.n_nodes();
800 self.node_bandwidths.resize(n, f64::INFINITY);
801
802 let mut all_thresholds = self.collect_split_thresholds_per_feature();
804 for v in &mut all_thresholds {
805 v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
806 }
807
808 for i in 0..n {
809 let nid = NodeId(i as u32);
810 if !self.arena.is_leaf(nid) {
811 self.node_bandwidths[i] = self.compute_node_bandwidth(nid, &all_thresholds);
812 } else {
813 self.node_bandwidths[i] = f64::INFINITY;
814 }
815 }
816 }
817
818 pub fn predict_soft_routed(&self, features: &[f64]) -> f64 {
821 self.predict_soft_recursive(self.root, features)
822 }
823
824 fn predict_soft_recursive(&self, node: NodeId, features: &[f64]) -> f64 {
825 if self.arena.is_leaf(node) {
826 return self.leaf_prediction(node, features);
827 }
828
829 let feat_idx = self.arena.get_feature_idx(node) as usize;
830 let left = self.arena.get_left(node);
831 let right = self.arena.get_right(node);
832
833 if let Some(mask) = self.arena.get_categorical_mask(node) {
835 let cat_val = features[feat_idx] as u64;
836 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
837 self.predict_soft_recursive(left, features)
838 } else {
839 self.predict_soft_recursive(right, features)
840 };
841 }
842
843 let threshold = self.arena.get_threshold(node);
844 let margin = self
845 .node_bandwidths
846 .get(node.0 as usize)
847 .copied()
848 .unwrap_or(f64::INFINITY);
849
850 let left_pred = self.predict_soft_recursive(left, features);
851 let right_pred = self.predict_soft_recursive(right, features);
852
853 if !margin.is_finite() || margin <= 0.0 {
855 let dist = features[feat_idx] - threshold;
856 let scale = math::abs(threshold) * 0.01 + 1e-10;
857 let z = (-dist / scale).clamp(-500.0, 500.0);
858 let t = 1.0 / (1.0 + math::exp(z));
859 return (1.0 - t) * left_pred + t * right_pred;
860 }
861
862 let dist = features[feat_idx] - threshold;
864 let t = ((dist + margin) / (2.0 * margin)).clamp(0.0, 1.0);
865 (1.0 - t) * left_pred + t * right_pred
866 }
867
868 fn predict_smooth_recursive(&self, node: NodeId, features: &[f64], bandwidth: f64) -> f64 {
870 if self.arena.is_leaf(node) {
871 return self.leaf_prediction(node, features);
873 }
874
875 let feat_idx = self.arena.get_feature_idx(node) as usize;
876 let left = self.arena.get_left(node);
877 let right = self.arena.get_right(node);
878
879 if let Some(mask) = self.arena.get_categorical_mask(node) {
881 let cat_val = features[feat_idx] as u64;
882 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
883 self.predict_smooth_recursive(left, features, bandwidth)
884 } else {
885 self.predict_smooth_recursive(right, features, bandwidth)
886 };
887 }
888
889 let threshold = self.arena.get_threshold(node);
891 let z = (threshold - features[feat_idx]) / bandwidth;
892 let alpha = 1.0 / (1.0 + math::exp(-z));
893
894 let left_pred = self.predict_smooth_recursive(left, features, bandwidth);
895 let right_pred = self.predict_smooth_recursive(right, features, bandwidth);
896
897 alpha * left_pred + (1.0 - alpha) * right_pred
898 }
899
900 fn predict_smooth_auto_recursive(
902 &self,
903 node: NodeId,
904 features: &[f64],
905 bandwidths: &[f64],
906 ) -> f64 {
907 if self.arena.is_leaf(node) {
908 return self.leaf_prediction(node, features);
909 }
910
911 let feat_idx = self.arena.get_feature_idx(node) as usize;
912 let left = self.arena.get_left(node);
913 let right = self.arena.get_right(node);
914
915 if let Some(mask) = self.arena.get_categorical_mask(node) {
917 let cat_val = features[feat_idx] as u64;
918 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
919 self.predict_smooth_auto_recursive(left, features, bandwidths)
920 } else {
921 self.predict_smooth_auto_recursive(right, features, bandwidths)
922 };
923 }
924
925 let threshold = self.arena.get_threshold(node);
926 let bw = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
927
928 if !bw.is_finite() {
930 return if features[feat_idx] <= threshold {
931 self.predict_smooth_auto_recursive(left, features, bandwidths)
932 } else {
933 self.predict_smooth_auto_recursive(right, features, bandwidths)
934 };
935 }
936
937 let z = (threshold - features[feat_idx]) / bw;
939 let alpha = 1.0 / (1.0 + math::exp(-z));
940
941 let left_pred = self.predict_smooth_auto_recursive(left, features, bandwidths);
942 let right_pred = self.predict_smooth_auto_recursive(right, features, bandwidths);
943
944 alpha * left_pred + (1.0 - alpha) * right_pred
945 }
946
947 fn generate_feature_mask(&mut self, n_features: usize) {
955 self.feature_mask.clear();
956
957 if self.config.feature_subsample_rate >= 1.0 {
958 self.feature_mask.extend(0..n_features);
959 } else {
960 let target_count =
961 crate::math::ceil((n_features as f64) * self.config.feature_subsample_rate)
962 as usize;
963 let target_count = target_count.max(1).min(n_features);
964
965 let n_words = n_features.div_ceil(64);
967 self.feature_mask_bits.clear();
968 self.feature_mask_bits.resize(n_words, 0u64);
969
970 for i in 0..n_features {
972 let r = xorshift64(&mut self.rng_state);
973 let p = (r as f64) / (u64::MAX as f64);
974 if p < self.config.feature_subsample_rate {
975 self.feature_mask.push(i);
976 self.feature_mask_bits[i / 64] |= 1u64 << (i % 64);
977 }
978 }
979
980 if self.feature_mask.len() < target_count {
983 for i in 0..n_features {
984 if self.feature_mask.len() >= target_count {
985 break;
986 }
987 if self.feature_mask_bits[i / 64] & (1u64 << (i % 64)) == 0 {
988 self.feature_mask.push(i);
989 self.feature_mask_bits[i / 64] |= 1u64 << (i % 64);
990 }
991 }
992 }
993 }
994 }
995
996 fn attempt_split(&mut self, leaf_id: NodeId) -> bool {
1000 let depth = self.arena.get_depth(leaf_id);
1001
1002 let hard_ceiling = if self.config.adaptive_depth.is_some() {
1005 self.config.max_depth.saturating_mul(2)
1006 } else {
1007 self.config.max_depth
1008 };
1009 let at_max_depth = depth as usize >= hard_ceiling;
1010
1011 if at_max_depth {
1012 match self.config.split_reeval_interval {
1015 None => return false,
1016 Some(interval) => {
1017 let state = match self
1018 .leaf_states
1019 .get(leaf_id.0 as usize)
1020 .and_then(|o| o.as_ref())
1021 {
1022 Some(s) => s,
1023 None => return false,
1024 };
1025 let sample_count = self.arena.get_sample_count(leaf_id);
1026 if sample_count - state.last_reeval_count < interval as u64 {
1027 return false;
1028 }
1029 }
1031 }
1032 }
1033
1034 let n_features = match self.n_features {
1035 Some(n) => n,
1036 None => return false,
1037 };
1038
1039 let sample_count = self.arena.get_sample_count(leaf_id);
1040 if sample_count < self.config.grace_period as u64 {
1041 return false;
1042 }
1043
1044 self.generate_feature_mask(n_features);
1046
1047 if self.config.leaf_decay_alpha.is_some() {
1052 if let Some(state) = self
1053 .leaf_states
1054 .get_mut(leaf_id.0 as usize)
1055 .and_then(|o| o.as_mut())
1056 {
1057 if let Some(ref mut histograms) = state.histograms {
1058 histograms.materialize_decay();
1059 }
1060 }
1061 }
1062
1063 let state = match self
1067 .leaf_states
1068 .get(leaf_id.0 as usize)
1069 .and_then(|o| o.as_ref())
1070 {
1071 Some(s) => s,
1072 None => return false,
1073 };
1074
1075 let histograms = match &state.histograms {
1076 Some(h) => h,
1077 None => return false,
1078 };
1079
1080 let feature_types = &self.config.feature_types;
1084 let mut candidates: Vec<(usize, SplitCandidate, Option<Vec<usize>>)> = Vec::new();
1085
1086 for &feat_idx in &self.feature_mask {
1087 if feat_idx >= histograms.n_features() {
1088 continue;
1089 }
1090 let hist = &histograms.histograms[feat_idx];
1091 let total_grad = hist.total_gradient();
1092 let total_hess = hist.total_hessian();
1093
1094 let is_categorical = feature_types
1095 .as_ref()
1096 .is_some_and(|ft| feat_idx < ft.len() && ft[feat_idx] == FeatureType::Categorical);
1097
1098 if is_categorical {
1099 let n_bins = hist.grad_sums.len();
1104 if n_bins < 2 {
1105 continue;
1106 }
1107
1108 let mut bin_order: Vec<usize> = (0..n_bins)
1110 .filter(|&i| math::abs(hist.hess_sums[i]) > 1e-15)
1111 .collect();
1112
1113 if bin_order.len() < 2 {
1114 continue;
1115 }
1116
1117 bin_order.sort_by(|&a, &b| {
1119 let ratio_a = hist.grad_sums[a] / hist.hess_sums[a];
1120 let ratio_b = hist.grad_sums[b] / hist.hess_sums[b];
1121 ratio_a
1122 .partial_cmp(&ratio_b)
1123 .unwrap_or(core::cmp::Ordering::Equal)
1124 });
1125
1126 let sorted_grads: Vec<f64> = bin_order.iter().map(|&i| hist.grad_sums[i]).collect();
1128 let sorted_hess: Vec<f64> = bin_order.iter().map(|&i| hist.hess_sums[i]).collect();
1129
1130 if let Some(candidate) = self.split_criterion.evaluate(
1131 &sorted_grads,
1132 &sorted_hess,
1133 total_grad,
1134 total_hess,
1135 self.config.gamma,
1136 self.config.lambda,
1137 ) {
1138 candidates.push((feat_idx, candidate, Some(bin_order)));
1139 }
1140 } else {
1141 if let Some(candidate) = self.split_criterion.evaluate(
1143 &hist.grad_sums,
1144 &hist.hess_sums,
1145 total_grad,
1146 total_hess,
1147 self.config.gamma,
1148 self.config.lambda,
1149 ) {
1150 candidates.push((feat_idx, candidate, None));
1151 }
1152 }
1153 }
1154
1155 if let Some(ref mc) = self.config.monotone_constraints {
1157 candidates.retain(|(feat_idx, candidate, _)| {
1158 if *feat_idx >= mc.len() {
1159 return true; }
1161 let constraint = mc[*feat_idx];
1162 if constraint == 0 {
1163 return true; }
1165
1166 let left_val =
1167 leaf_weight(candidate.left_grad, candidate.left_hess, self.config.lambda);
1168 let right_val = leaf_weight(
1169 candidate.right_grad,
1170 candidate.right_hess,
1171 self.config.lambda,
1172 );
1173
1174 if constraint > 0 {
1175 left_val <= right_val
1177 } else {
1178 left_val >= right_val
1180 }
1181 });
1182 }
1183
1184 if candidates.is_empty() {
1185 return false;
1186 }
1187
1188 candidates.sort_by(|a, b| {
1190 b.1.gain
1191 .partial_cmp(&a.1.gain)
1192 .unwrap_or(core::cmp::Ordering::Equal)
1193 });
1194
1195 let best_gain = candidates[0].1.gain;
1196 let second_best_gain = if candidates.len() > 1 {
1197 candidates[1].1.gain
1198 } else {
1199 0.0
1200 };
1201
1202 if let Some(cir_factor) = self.config.adaptive_depth {
1205 let n = sample_count as f64;
1206 if n > 1.0 {
1207 let effective_n = match self.config.leaf_decay_alpha {
1209 Some(alpha) => n.min(1.0 / (1.0 - alpha)),
1210 None => n,
1211 };
1212
1213 let grad_var = self
1215 .leaf_states
1216 .get(leaf_id.0 as usize)
1217 .and_then(|o| o.as_ref())
1218 .map(|leaf_state| {
1219 if leaf_state.clip_grad_count > 1 {
1220 leaf_state.clip_grad_m2 / (leaf_state.clip_grad_count as f64 - 1.0)
1221 } else {
1222 let mean_grad = leaf_state.grad_sum / leaf_state.hess_sum.max(1.0);
1224 mean_grad * mean_grad + 1.0
1225 }
1226 })
1227 .unwrap_or(1.0);
1228
1229 let n_feat = self.n_features.unwrap_or(1) as f64;
1230 let penalty = cir_factor * grad_var / effective_n * n_feat;
1231
1232 if best_gain <= penalty {
1233 return false; }
1235 }
1236 }
1237
1238 let r_squared = 1.0;
1245 let n = sample_count as f64;
1246 let effective_n = match self.config.leaf_decay_alpha {
1247 Some(alpha) => n.min(1.0 / (1.0 - alpha)),
1248 None => n,
1249 };
1250 let ln_inv_delta = math::ln(1.0 / self.config.delta);
1251 let epsilon = math::sqrt(r_squared * ln_inv_delta / (2.0 * effective_n));
1252
1253 let gap = best_gain - second_best_gain;
1256 if gap <= epsilon && epsilon >= TAU {
1257 if at_max_depth {
1260 if let Some(state) = self
1261 .leaf_states
1262 .get_mut(leaf_id.0 as usize)
1263 .and_then(|o| o.as_mut())
1264 {
1265 state.last_reeval_count = sample_count;
1266 }
1267 }
1268 return false;
1269 }
1270
1271 let (best_feat_idx, ref best_candidate, ref fisher_order) = candidates[0];
1273
1274 if best_feat_idx < self.split_gains.len() {
1276 self.split_gains[best_feat_idx] += best_candidate.gain;
1277 }
1278
1279 let best_hist = &histograms.histograms[best_feat_idx];
1280
1281 let left_value = leaf_weight(
1282 best_candidate.left_grad,
1283 best_candidate.left_hess,
1284 self.config.lambda,
1285 );
1286 let right_value = leaf_weight(
1287 best_candidate.right_grad,
1288 best_candidate.right_hess,
1289 self.config.lambda,
1290 );
1291
1292 let (left_id, right_id) = if let Some(ref order) = fisher_order {
1294 let mut mask: u64 = 0;
1301 for &sorted_pos in order.iter().take(best_candidate.bin_idx + 1) {
1302 if sorted_pos < 64 {
1306 mask |= 1u64 << sorted_pos;
1307 }
1308 }
1309
1310 self.arena.split_leaf_categorical(
1312 leaf_id,
1313 best_feat_idx as u32,
1314 0.0,
1315 left_value,
1316 right_value,
1317 mask,
1318 )
1319 } else {
1320 let threshold = if best_candidate.bin_idx < best_hist.edges.edges.len() {
1322 best_hist.edges.edges[best_candidate.bin_idx]
1323 } else {
1324 f64::MAX
1325 };
1326
1327 self.arena.split_leaf(
1328 leaf_id,
1329 best_feat_idx as u32,
1330 threshold,
1331 left_value,
1332 right_value,
1333 )
1334 };
1335
1336 let parent_state = self
1362 .leaf_states
1363 .get_mut(leaf_id.0 as usize)
1364 .and_then(|o| o.take());
1365 let nf = n_features;
1366
1367 let max_child = left_id.0.max(right_id.0) as usize;
1369 if self.leaf_states.len() <= max_child {
1370 self.leaf_states.resize_with(max_child + 1, || None);
1371 }
1372
1373 if let Some(parent) = parent_state {
1374 if let Some(parent_hists) = parent.histograms {
1375 let edges_per_feature: Vec<BinEdges> = parent_hists
1377 .histograms
1378 .iter()
1379 .map(|h| h.edges.clone())
1380 .collect();
1381
1382 let left_hists = LeafHistograms::new(&edges_per_feature);
1391 let right_hists = LeafHistograms::new(&edges_per_feature);
1392
1393 let ft = self.config.feature_types.as_deref();
1394 let child_binners_l = make_binners(nf, ft);
1395 let child_binners_r = make_binners(nf, ft);
1396
1397 let left_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1402 let right_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1403
1404 let left_state = LeafState {
1405 histograms: Some(left_hists),
1406 binners: child_binners_l,
1407 bins_ready: true,
1408 grad_sum: 0.0,
1409 hess_sum: 0.0,
1410 last_reeval_count: 0,
1411 clip_grad_mean: 0.0,
1412 clip_grad_m2: 0.0,
1413 clip_grad_count: 0,
1414 output_mean: 0.0,
1415 output_m2: 0.0,
1416 output_count: 0,
1417 leaf_model: left_model,
1418 };
1419
1420 let right_state = LeafState {
1421 histograms: Some(right_hists),
1422 binners: child_binners_r,
1423 bins_ready: true,
1424 grad_sum: 0.0,
1425 hess_sum: 0.0,
1426 last_reeval_count: 0,
1427 clip_grad_mean: 0.0,
1428 clip_grad_m2: 0.0,
1429 clip_grad_count: 0,
1430 output_mean: 0.0,
1431 output_m2: 0.0,
1432 output_count: 0,
1433 leaf_model: right_model,
1434 };
1435
1436 self.leaf_states[left_id.0 as usize] = Some(left_state);
1437 self.leaf_states[right_id.0 as usize] = Some(right_state);
1438 } else {
1439 let ft = self.config.feature_types.as_deref();
1441 let mut ls = LeafState::new_with_types(nf, ft);
1442 ls.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1443 self.leaf_states[left_id.0 as usize] = Some(ls);
1444 let mut rs = LeafState::new_with_types(nf, ft);
1445 rs.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1446 self.leaf_states[right_id.0 as usize] = Some(rs);
1447 }
1448 } else {
1449 let ft = self.config.feature_types.as_deref();
1451 let mut ls = LeafState::new_with_types(nf, ft);
1452 ls.leaf_model = self.make_leaf_model(left_id);
1453 self.leaf_states[left_id.0 as usize] = Some(ls);
1454 let mut rs = LeafState::new_with_types(nf, ft);
1455 rs.leaf_model = self.make_leaf_model(right_id);
1456 self.leaf_states[right_id.0 as usize] = Some(rs);
1457 }
1458
1459 self.recompute_bandwidths();
1461
1462 true
1463 }
1464}
1465
1466impl StreamingTree for HoeffdingTree {
1467 fn train_one(&mut self, features: &[f64], gradient: f64, hessian: f64) {
1472 self.samples_seen += 1;
1473
1474 let n_features = if let Some(n) = self.n_features {
1476 n
1477 } else {
1478 let n = features.len();
1479 self.n_features = Some(n);
1480 self.split_gains.resize(n, 0.0);
1481
1482 if let Some(state) = self
1484 .leaf_states
1485 .get_mut(self.root.0 as usize)
1486 .and_then(|o| o.as_mut())
1487 {
1488 state.binners = make_binners(n, self.config.feature_types.as_deref());
1489 }
1490 n
1491 };
1492
1493 debug_assert_eq!(
1494 features.len(),
1495 n_features,
1496 "feature count mismatch: got {} but expected {}",
1497 features.len(),
1498 n_features,
1499 );
1500
1501 let leaf_id = self.route_to_leaf(features);
1503
1504 self.arena.increment_sample_count(leaf_id);
1506 let sample_count = self.arena.get_sample_count(leaf_id);
1507
1508 let idx = leaf_id.0 as usize;
1510 if self.leaf_states.len() <= idx {
1511 self.leaf_states.resize_with(idx + 1, || None);
1512 }
1513 if self.leaf_states[idx].is_none() {
1514 self.leaf_states[idx] = Some(LeafState::new_with_types(
1515 n_features,
1516 self.config.feature_types.as_deref(),
1517 ));
1518 }
1519 let state = self.leaf_states[idx].as_mut().unwrap();
1520
1521 let gradient = if let Some(sigma) = self.config.gradient_clip_sigma {
1523 clip_gradient(state, gradient, sigma)
1524 } else {
1525 gradient
1526 };
1527
1528 if !state.bins_ready {
1530 for (binner, &val) in state.binners.iter_mut().zip(features.iter()) {
1532 binner.observe(val);
1533 }
1534
1535 if let Some(alpha) = self.config.leaf_decay_alpha {
1537 state.grad_sum = alpha * state.grad_sum + gradient;
1538 state.hess_sum = alpha * state.hess_sum + hessian;
1539 } else {
1540 state.grad_sum += gradient;
1541 state.hess_sum += hessian;
1542 }
1543
1544 let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1546 self.arena.set_leaf_value(leaf_id, lw);
1547
1548 if self.config.adaptive_leaf_bound.is_some() {
1550 update_output_stats(state, lw, self.config.leaf_decay_alpha);
1551 }
1552
1553 if let Some(ref mut model) = state.leaf_model {
1555 model.update(features, gradient, hessian, self.config.lambda);
1556 }
1557
1558 if sample_count >= self.config.grace_period as u64 {
1560 let edges_per_feature: Vec<BinEdges> = state
1561 .binners
1562 .iter()
1563 .map(|b| b.compute_edges(self.config.n_bins))
1564 .collect();
1565
1566 let mut histograms = LeafHistograms::new(&edges_per_feature);
1567
1568 if let Some(alpha) = self.config.leaf_decay_alpha {
1575 histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1576 } else {
1577 histograms.accumulate(features, gradient, hessian);
1578 }
1579
1580 state.histograms = Some(histograms);
1581 state.bins_ready = true;
1582 }
1583
1584 return;
1585 }
1586
1587 if let Some(ref mut histograms) = state.histograms {
1589 if let Some(alpha) = self.config.leaf_decay_alpha {
1590 histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1591 } else {
1592 histograms.accumulate(features, gradient, hessian);
1593 }
1594 }
1595
1596 if let Some(alpha) = self.config.leaf_decay_alpha {
1598 state.grad_sum = alpha * state.grad_sum + gradient;
1599 state.hess_sum = alpha * state.hess_sum + hessian;
1600 } else {
1601 state.grad_sum += gradient;
1602 state.hess_sum += hessian;
1603 }
1604 let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1605 self.arena.set_leaf_value(leaf_id, lw);
1606
1607 if self.config.adaptive_leaf_bound.is_some() {
1609 update_output_stats(state, lw, self.config.leaf_decay_alpha);
1610 }
1611
1612 if let Some(ref mut model) = state.leaf_model {
1614 model.update(features, gradient, hessian, self.config.lambda);
1615 }
1616
1617 if sample_count % (self.config.grace_period as u64) == 0 {
1620 self.attempt_split(leaf_id);
1621 }
1622 }
1623
1624 fn predict(&self, features: &[f64]) -> f64 {
1629 let leaf_id = self.route_to_leaf(features);
1630 self.leaf_prediction(leaf_id, features)
1631 }
1632
1633 #[inline]
1635 fn n_leaves(&self) -> usize {
1636 self.arena.n_leaves()
1637 }
1638
1639 #[inline]
1641 fn n_samples_seen(&self) -> u64 {
1642 self.samples_seen
1643 }
1644
1645 fn reset(&mut self) {
1647 self.arena.reset();
1648 let root = self.arena.add_leaf(0);
1649 self.root = root;
1650 self.leaf_states.clear();
1651
1652 let n_features = self.n_features.unwrap_or(0);
1654 self.leaf_states.resize_with(root.0 as usize + 1, || None);
1655 let mut root_state =
1656 LeafState::new_with_types(n_features, self.config.feature_types.as_deref());
1657 root_state.leaf_model = self.make_leaf_model(root);
1658 self.leaf_states[root.0 as usize] = Some(root_state);
1659
1660 self.samples_seen = 0;
1661 self.feature_mask.clear();
1662 self.feature_mask_bits.clear();
1663 self.rng_state = self.config.seed;
1664 self.split_gains.iter_mut().for_each(|g| *g = 0.0);
1665 self.node_bandwidths.clear();
1666 }
1667
1668 fn split_gains(&self) -> &[f64] {
1669 &self.split_gains
1670 }
1671
1672 fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
1673 let leaf_id = self.route_to_leaf(features);
1674 let value = self.leaf_prediction(leaf_id, features);
1675 if let Some(state) = self
1676 .leaf_states
1677 .get(leaf_id.0 as usize)
1678 .and_then(|o| o.as_ref())
1679 {
1680 let variance = 1.0 / (state.hess_sum + self.config.lambda);
1682 (value, variance)
1683 } else {
1684 (value, f64::INFINITY)
1685 }
1686 }
1687}
1688
1689impl Clone for HoeffdingTree {
1690 fn clone(&self) -> Self {
1691 Self {
1692 arena: self.arena.clone(),
1693 root: self.root,
1694 config: self.config.clone(),
1695 leaf_states: self.leaf_states.clone(),
1696 n_features: self.n_features,
1697 samples_seen: self.samples_seen,
1698 split_criterion: self.split_criterion,
1699 feature_mask: self.feature_mask.clone(),
1700 feature_mask_bits: self.feature_mask_bits.clone(),
1701 rng_state: self.rng_state,
1702 split_gains: self.split_gains.clone(),
1703 node_bandwidths: self.node_bandwidths.clone(),
1704 }
1705 }
1706}
1707
1708unsafe impl Send for HoeffdingTree {}
1712unsafe impl Sync for HoeffdingTree {}
1713
1714#[cfg(test)]
1715mod tests {
1716 use super::*;
1717 use crate::tree::builder::TreeConfig;
1718 use crate::tree::StreamingTree;
1719
1720 fn test_xorshift(state: &mut u64) -> u64 {
1722 xorshift64(state)
1723 }
1724
1725 fn test_rand_f64(state: &mut u64) -> f64 {
1727 let r = test_xorshift(state);
1728 (r as f64) / (u64::MAX as f64)
1729 }
1730
1731 #[test]
1735 fn single_sample_predict_not_nan() {
1736 let config = TreeConfig::new().grace_period(10);
1737 let mut tree = HoeffdingTree::new(config);
1738
1739 let features = vec![1.0, 2.0, 3.0];
1740 tree.train_one(&features, -0.5, 1.0);
1741
1742 let pred = tree.predict(&features);
1743 assert!(!pred.is_nan(), "prediction should not be NaN, got {}", pred);
1744 assert!(
1745 pred.is_finite(),
1746 "prediction should be finite, got {}",
1747 pred
1748 );
1749
1750 assert!((pred - 0.25).abs() < 1e-10, "expected ~0.25, got {}", pred);
1753 }
1754
1755 #[test]
1759 fn linear_signal_rmse_improves() {
1760 let config = TreeConfig::new()
1761 .max_depth(4)
1762 .n_bins(32)
1763 .grace_period(50)
1764 .lambda(0.1)
1765 .gamma(0.0)
1766 .delta(1e-3);
1767
1768 let mut tree = HoeffdingTree::new(config);
1769 let mut rng_state: u64 = 12345;
1770
1771 let n_train = 1000;
1781 let mut features_all: Vec<f64> = Vec::with_capacity(n_train);
1782 let mut targets: Vec<f64> = Vec::with_capacity(n_train);
1783
1784 for _ in 0..n_train {
1785 let x = test_rand_f64(&mut rng_state) * 10.0;
1786 let noise = (test_rand_f64(&mut rng_state) - 0.5) * 0.5;
1787 let y = 2.0 * x + noise;
1788 features_all.push(x);
1789 targets.push(y);
1790 }
1791
1792 let initial_mse: f64 = targets.iter().map(|y| y * y).sum::<f64>() / n_train as f64;
1794 let initial_rmse = initial_mse.sqrt();
1795
1796 for i in 0..n_train {
1798 let feat = [features_all[i]];
1799 let pred = tree.predict(&feat);
1800 let gradient = pred - targets[i];
1802 let hessian = 1.0;
1803 tree.train_one(&feat, gradient, hessian);
1804 }
1805
1806 let mut post_mse = 0.0;
1808 for i in 0..n_train {
1809 let feat = [features_all[i]];
1810 let pred = tree.predict(&feat);
1811 let err = pred - targets[i];
1812 post_mse += err * err;
1813 }
1814 post_mse /= n_train as f64;
1815 let post_rmse = post_mse.sqrt();
1816
1817 assert!(
1818 post_rmse < initial_rmse,
1819 "RMSE should decrease after training: initial={:.4}, post={:.4}",
1820 initial_rmse,
1821 post_rmse,
1822 );
1823 }
1824
1825 #[test]
1829 fn no_splits_before_grace_period() {
1830 let grace = 100;
1831 let config = TreeConfig::new()
1832 .grace_period(grace)
1833 .max_depth(4)
1834 .n_bins(16)
1835 .delta(1e-1); let mut tree = HoeffdingTree::new(config);
1838 let mut rng_state: u64 = 99999;
1839
1840 for _ in 0..(grace - 1) {
1842 let x = test_rand_f64(&mut rng_state) * 10.0;
1843 let y = 2.0 * x;
1844 let feat = [x];
1845 let pred = tree.predict(&feat);
1846 tree.train_one(&feat, pred - y, 1.0);
1847 }
1848
1849 assert_eq!(
1850 tree.n_leaves(),
1851 1,
1852 "should be exactly 1 leaf before grace_period, got {}",
1853 tree.n_leaves()
1854 );
1855 }
1856
1857 #[test]
1861 fn respects_max_depth() {
1862 let max_depth = 3;
1863 let config = TreeConfig::new()
1864 .max_depth(max_depth)
1865 .grace_period(20)
1866 .n_bins(16)
1867 .lambda(0.01)
1868 .gamma(0.0)
1869 .delta(1e-1); let mut tree = HoeffdingTree::new(config);
1872 let mut rng_state: u64 = 7777;
1873
1874 for _ in 0..5000 {
1876 let x = test_rand_f64(&mut rng_state) * 10.0;
1877 let y = if x < 2.5 {
1878 -5.0
1879 } else if x < 5.0 {
1880 -1.0
1881 } else if x < 7.5 {
1882 1.0
1883 } else {
1884 5.0
1885 };
1886 let feat = [x];
1887 let pred = tree.predict(&feat);
1888 tree.train_one(&feat, pred - y, 1.0);
1889 }
1890
1891 let max_leaves = 1usize << max_depth;
1893 assert!(
1894 tree.n_leaves() <= max_leaves,
1895 "tree has {} leaves, but max_depth={} allows at most {}",
1896 tree.n_leaves(),
1897 max_depth,
1898 max_leaves,
1899 );
1900 }
1901
1902 #[test]
1906 fn reset_returns_to_single_leaf() {
1907 let config = TreeConfig::new()
1908 .grace_period(20)
1909 .max_depth(4)
1910 .n_bins(16)
1911 .delta(1e-1);
1912
1913 let mut tree = HoeffdingTree::new(config);
1914 let mut rng_state: u64 = 54321;
1915
1916 for _ in 0..2000 {
1918 let x = test_rand_f64(&mut rng_state) * 10.0;
1919 let y = 3.0 * x - 5.0;
1920 let feat = [x];
1921 let pred = tree.predict(&feat);
1922 tree.train_one(&feat, pred - y, 1.0);
1923 }
1924
1925 let pre_reset_samples = tree.n_samples_seen();
1926 assert!(pre_reset_samples > 0);
1927
1928 tree.reset();
1929
1930 assert_eq!(
1931 tree.n_leaves(),
1932 1,
1933 "after reset, should have exactly 1 leaf"
1934 );
1935 assert_eq!(
1936 tree.n_samples_seen(),
1937 0,
1938 "after reset, samples_seen should be 0"
1939 );
1940
1941 let pred = tree.predict(&[5.0]);
1943 assert!(
1944 pred.abs() < 1e-10,
1945 "prediction after reset should be ~0.0, got {}",
1946 pred
1947 );
1948 }
1949
1950 #[test]
1954 fn multi_feature_training() {
1955 let config = TreeConfig::new()
1956 .grace_period(30)
1957 .max_depth(4)
1958 .n_bins(16)
1959 .lambda(0.1)
1960 .delta(1e-2);
1961
1962 let mut tree = HoeffdingTree::new(config);
1963 let mut rng_state: u64 = 11111;
1964
1965 for _ in 0..1000 {
1967 let x0 = test_rand_f64(&mut rng_state) * 5.0;
1968 let x1 = test_rand_f64(&mut rng_state) * 5.0;
1969 let y = x0 + 2.0 * x1;
1970 let feat = [x0, x1];
1971 let pred = tree.predict(&feat);
1972 tree.train_one(&feat, pred - y, 1.0);
1973 }
1974
1975 let pred = tree.predict(&[2.5, 2.5]);
1977 assert!(
1978 pred.is_finite(),
1979 "multi-feature prediction should be finite"
1980 );
1981 assert_eq!(tree.n_samples_seen(), 1000);
1982 }
1983
1984 #[test]
1988 fn feature_subsampling_works() {
1989 let config = TreeConfig::new()
1990 .grace_period(30)
1991 .max_depth(3)
1992 .n_bins(16)
1993 .lambda(0.1)
1994 .delta(1e-2)
1995 .feature_subsample_rate(0.5);
1996
1997 let mut tree = HoeffdingTree::new(config);
1998 let mut rng_state: u64 = 33333;
1999
2000 for _ in 0..1000 {
2002 let feats: Vec<f64> = (0..5)
2003 .map(|_| test_rand_f64(&mut rng_state) * 10.0)
2004 .collect();
2005 let y: f64 = feats.iter().sum();
2006 let pred = tree.predict(&feats);
2007 tree.train_one(&feats, pred - y, 1.0);
2008 }
2009
2010 let pred = tree.predict(&[1.0, 2.0, 3.0, 4.0, 5.0]);
2011 assert!(pred.is_finite(), "subsampled prediction should be finite");
2012 }
2013
2014 #[test]
2018 fn xorshift64_deterministic() {
2019 let mut s1: u64 = 42;
2020 let mut s2: u64 = 42;
2021
2022 let seq1: Vec<u64> = (0..100).map(|_| xorshift64(&mut s1)).collect();
2023 let seq2: Vec<u64> = (0..100).map(|_| xorshift64(&mut s2)).collect();
2024
2025 assert_eq!(seq1, seq2, "xorshift64 should be deterministic");
2026
2027 for &v in &seq1 {
2029 assert_ne!(v, 0, "xorshift64 should never produce 0 with non-zero seed");
2030 }
2031 }
2032
2033 #[test]
2037 fn ewma_leaf_decay_recent_data_dominates() {
2038 let alpha = (-(2.0_f64.ln()) / 50.0).exp();
2040 let config = TreeConfig::new()
2041 .grace_period(20)
2042 .max_depth(4)
2043 .n_bins(16)
2044 .lambda(1.0)
2045 .leaf_decay_alpha(alpha);
2046 let mut tree = HoeffdingTree::new(config);
2047
2048 for _ in 0..1000 {
2050 let pred = tree.predict(&[1.0, 2.0]);
2051 let grad = pred - 1.0; tree.train_one(&[1.0, 2.0], grad, 1.0);
2053 }
2054
2055 for _ in 0..100 {
2057 let pred = tree.predict(&[1.0, 2.0]);
2058 let grad = pred - 5.0;
2059 tree.train_one(&[1.0, 2.0], grad, 1.0);
2060 }
2061
2062 let pred = tree.predict(&[1.0, 2.0]);
2063 assert!(
2066 pred > 2.0,
2067 "EWMA should let recent data (target=5.0) pull prediction above 2.0, got {}",
2068 pred,
2069 );
2070 }
2071
2072 #[test]
2076 fn ewma_disabled_matches_traditional() {
2077 let config_no_ewma = TreeConfig::new()
2078 .grace_period(20)
2079 .max_depth(4)
2080 .n_bins(16)
2081 .lambda(1.0);
2082 let mut tree = HoeffdingTree::new(config_no_ewma);
2083
2084 let mut rng_state: u64 = 99999;
2085 for _ in 0..200 {
2086 let x = test_rand_f64(&mut rng_state) * 10.0;
2087 let y = 3.0 * x + 1.0;
2088 let pred = tree.predict(&[x]);
2089 tree.train_one(&[x], pred - y, 1.0);
2090 }
2091
2092 let pred = tree.predict(&[5.0]);
2093 assert!(
2094 pred.is_finite(),
2095 "prediction without EWMA should be finite, got {}",
2096 pred
2097 );
2098 }
2099
2100 #[test]
2104 fn split_reeval_at_max_depth() {
2105 let config = TreeConfig::new()
2106 .grace_period(20)
2107 .max_depth(2) .n_bins(16)
2109 .lambda(1.0)
2110 .split_reeval_interval(50);
2111 let mut tree = HoeffdingTree::new(config);
2112
2113 let mut rng_state: u64 = 54321;
2114 for _ in 0..2000 {
2116 let x1 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2117 let x2 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2118 let y = 2.0 * x1 + 3.0 * x2;
2119 let pred = tree.predict(&[x1, x2]);
2120 tree.train_one(&[x1, x2], pred - y, 1.0);
2121 }
2122
2123 let leaves = tree.n_leaves();
2127 assert!(
2128 leaves >= 4,
2129 "split re-eval should allow growth beyond max_depth=2 cap (4 leaves), got {}",
2130 leaves,
2131 );
2132 }
2133
2134 #[test]
2138 fn split_reeval_disabled_matches_traditional() {
2139 let config = TreeConfig::new()
2140 .grace_period(20)
2141 .max_depth(2)
2142 .n_bins(16)
2143 .lambda(1.0);
2144 let mut tree = HoeffdingTree::new(config);
2146
2147 let mut rng_state: u64 = 77777;
2148 for _ in 0..2000 {
2149 let x1 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2150 let x2 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2151 let y = 2.0 * x1 + 3.0 * x2;
2152 let pred = tree.predict(&[x1, x2]);
2153 tree.train_one(&[x1, x2], pred - y, 1.0);
2154 }
2155
2156 let leaves = tree.n_leaves();
2158 assert!(
2159 leaves <= 4,
2160 "without re-eval, max_depth=2 should cap at 4 leaves, got {}",
2161 leaves,
2162 );
2163 }
2164
2165 #[test]
2169 fn gradient_clipping_clamps_outliers() {
2170 let config = TreeConfig::new()
2171 .grace_period(20)
2172 .max_depth(2)
2173 .n_bins(16)
2174 .gradient_clip_sigma(2.0);
2175
2176 let mut tree = HoeffdingTree::new(config);
2177
2178 let mut rng_state = 42u64;
2180 for _ in 0..50 {
2181 let x = test_rand_f64(&mut rng_state) * 2.0;
2182 let grad = x * 0.1; tree.train_one(&[x], grad, 1.0);
2184 }
2185
2186 let pred_before = tree.predict(&[1.0]);
2187
2188 tree.train_one(&[1.0], 1000.0, 1.0);
2190
2191 let pred_after = tree.predict(&[1.0]);
2192
2193 let delta = (pred_after - pred_before).abs();
2197 assert!(
2198 delta < 100.0,
2199 "gradient clipping should limit impact of outlier, but prediction changed by {}",
2200 delta,
2201 );
2202 }
2203
2204 #[test]
2208 fn clip_gradient_welford_tracks_stats() {
2209 let mut state = LeafState::new(1);
2210
2211 for i in 0..20 {
2213 let grad = 1.0 + (i as f64) * 0.1; let clipped = clip_gradient(&mut state, grad, 3.0);
2215 assert!(
2217 (clipped - grad).abs() < 1e-10,
2218 "normal gradients should not be clipped at 3-sigma"
2219 );
2220 }
2221
2222 let clipped = clip_gradient(&mut state, 100.0, 3.0);
2224 assert!(
2225 clipped < 100.0,
2226 "extreme outlier should be clipped, got {}",
2227 clipped,
2228 );
2229 assert!(
2230 clipped > 0.0,
2231 "clipped value should be positive, got {}",
2232 clipped,
2233 );
2234 }
2235
2236 #[test]
2240 fn clip_gradient_warmup_no_clipping() {
2241 let mut state = LeafState::new(1);
2242
2243 for i in 0..9 {
2245 let val = if i == 8 { 1000.0 } else { 1.0 };
2246 let clipped = clip_gradient(&mut state, val, 2.0);
2247 assert_eq!(clipped, val, "warmup should not clip");
2248 }
2249 }
2250
2251 #[test]
2255 fn adaptive_bound_warmup_returns_max() {
2256 let mut state = LeafState::new(1);
2257 for i in 0..9 {
2259 update_output_stats(&mut state, 0.5 + i as f64 * 0.01, None);
2260 }
2261 let bound = adaptive_bound(&state, 3.0, None);
2262 assert_eq!(bound, f64::MAX, "warmup should return f64::MAX");
2263 }
2264
2265 #[test]
2269 fn adaptive_bound_tightens_after_warmup() {
2270 let mut state = LeafState::new(1);
2271 for i in 0..20 {
2273 let w = 0.3 + (i as f64 - 10.0) * 0.01; update_output_stats(&mut state, w, None);
2275 }
2276 let bound = adaptive_bound(&state, 3.0, None);
2277 assert!(
2279 bound < 1.0,
2280 "3-sigma bound on outputs ~0.3 should be < 1.0, got {}",
2281 bound,
2282 );
2283 assert!(bound > 0.2, "bound should be > |mean|, got {}", bound,);
2284 }
2285
2286 #[test]
2290 fn adaptive_bound_clamps_outlier_leaf() {
2291 let mut state = LeafState::new(1);
2292 for _ in 0..20 {
2294 update_output_stats(&mut state, 0.3, None);
2295 }
2296 let bound = adaptive_bound(&state, 3.0, None);
2297 let clamped = (2.9_f64).clamp(-bound, bound);
2299 assert!(
2300 clamped < 2.9,
2301 "2.9 should be clamped by adaptive bound {}, got {}",
2302 bound,
2303 clamped,
2304 );
2305 }
2306
2307 #[test]
2311 fn adaptive_bound_with_decay_adapts() {
2312 let alpha = 0.95; let mut state = LeafState::new(1);
2314
2315 for _ in 0..30 {
2317 update_output_stats(&mut state, 0.3, Some(alpha));
2318 }
2319 let bound_phase1 = adaptive_bound(&state, 3.0, Some(alpha));
2320
2321 for _ in 0..100 {
2323 update_output_stats(&mut state, 2.0, Some(alpha));
2324 }
2325 let bound_phase2 = adaptive_bound(&state, 3.0, Some(alpha));
2326
2327 assert!(
2329 bound_phase2 > bound_phase1,
2330 "EWMA bound should adapt: phase1={}, phase2={}",
2331 bound_phase1,
2332 bound_phase2,
2333 );
2334 }
2335
2336 #[test]
2340 fn adaptive_bound_disabled_by_default() {
2341 let config = TreeConfig::default();
2342 assert!(
2343 config.adaptive_leaf_bound.is_none(),
2344 "adaptive_leaf_bound should default to None",
2345 );
2346 }
2347
2348 #[test]
2352 fn adaptive_bound_warmup_falls_back_to_global() {
2353 let mut state = LeafState::new(1);
2354 for _ in 0..5 {
2356 update_output_stats(&mut state, 0.3, None);
2357 }
2358 let bound = adaptive_bound(&state, 3.0, None);
2359 assert_eq!(bound, f64::MAX, "warmup should yield f64::MAX");
2360 }
2362
2363 #[test]
2367 fn monotonic_constraint_splits_respected() {
2368 let config = TreeConfig::new()
2371 .grace_period(30)
2372 .max_depth(4)
2373 .n_bins(16)
2374 .monotone_constraints(vec![1]); let mut tree = HoeffdingTree::new(config);
2377
2378 let mut rng_state = 42u64;
2379 for _ in 0..500 {
2380 let x = test_rand_f64(&mut rng_state) * 10.0;
2381 let grad = x * 0.5 - 2.5;
2383 tree.train_one(&[x], grad, 1.0);
2384 }
2385
2386 let pred_low = tree.predict(&[0.0]);
2389 let pred_mid = tree.predict(&[5.0]);
2390 let pred_high = tree.predict(&[10.0]);
2391
2392 assert!(
2394 pred_low <= pred_mid + 1e-10 && pred_mid <= pred_high + 1e-10,
2395 "monotonic +1 violated: pred(0)={}, pred(5)={}, pred(10)={}",
2396 pred_low,
2397 pred_mid,
2398 pred_high,
2399 );
2400 }
2401
2402 #[test]
2406 fn predict_with_variance_finite() {
2407 let config = TreeConfig::new().grace_period(10);
2408 let mut tree = HoeffdingTree::new(config);
2409
2410 for i in 0..30 {
2412 let x = i as f64 * 0.1;
2413 tree.train_one(&[x], x - 1.0, 1.0);
2414 }
2415
2416 let (value, variance) = tree.predict_with_variance(&[1.0]);
2417 assert!(value.is_finite(), "value should be finite");
2418 assert!(variance.is_finite(), "variance should be finite");
2419 assert!(variance > 0.0, "variance should be positive");
2420 }
2421
2422 #[test]
2426 fn predict_with_variance_decreases_with_data() {
2427 let config = TreeConfig::new().grace_period(10);
2428 let mut tree = HoeffdingTree::new(config);
2429
2430 for i in 0..20 {
2432 tree.train_one(&[1.0], 0.5, 1.0);
2433 if i == 0 {
2434 continue;
2435 }
2436 }
2437 let (_, var_20) = tree.predict_with_variance(&[1.0]);
2438
2439 for _ in 0..200 {
2441 tree.train_one(&[1.0], 0.5, 1.0);
2442 }
2443 let (_, var_220) = tree.predict_with_variance(&[1.0]);
2444
2445 assert!(
2446 var_220 < var_20,
2447 "variance should decrease with more data: var@20={} vs var@220={}",
2448 var_20,
2449 var_220,
2450 );
2451 }
2452
2453 #[test]
2457 fn predict_smooth_matches_hard_at_small_bandwidth() {
2458 let config = TreeConfig::new()
2459 .max_depth(3)
2460 .n_bins(16)
2461 .grace_period(20)
2462 .lambda(1.0);
2463 let mut tree = HoeffdingTree::new(config);
2464
2465 let mut rng = 42u64;
2467 for _ in 0..500 {
2468 let x = test_rand_f64(&mut rng) * 10.0;
2469 let y = 2.0 * x + 1.0;
2470 let features = vec![x, x * 0.5];
2471 let pred = tree.predict(&features);
2472 let grad = pred - y;
2473 let hess = 1.0;
2474 tree.train_one(&features, grad, hess);
2475 }
2476
2477 let features = vec![5.0, 2.5];
2479 let hard = tree.predict(&features);
2480 let smooth = tree.predict_smooth(&features, 0.001);
2481 assert!(
2482 (hard - smooth).abs() < 0.1,
2483 "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2484 hard,
2485 smooth,
2486 );
2487 }
2488
2489 #[test]
2493 fn predict_smooth_is_continuous() {
2494 let config = TreeConfig::new()
2495 .max_depth(3)
2496 .n_bins(16)
2497 .grace_period(20)
2498 .lambda(1.0);
2499 let mut tree = HoeffdingTree::new(config);
2500
2501 let mut rng = 42u64;
2503 for _ in 0..500 {
2504 let x = test_rand_f64(&mut rng) * 10.0;
2505 let y = 2.0 * x + 1.0;
2506 let features = vec![x, x * 0.5];
2507 let pred = tree.predict(&features);
2508 let grad = pred - y;
2509 tree.train_one(&features, grad, 1.0);
2510 }
2511
2512 let bandwidth = 1.0;
2514 let base = tree.predict_smooth(&[5.0, 2.5], bandwidth);
2515 let nudged = tree.predict_smooth(&[5.001, 2.5], bandwidth);
2516 let diff = (base - nudged).abs();
2517 assert!(
2518 diff < 0.1,
2519 "smooth prediction should be continuous: base={}, nudged={}, diff={}",
2520 base,
2521 nudged,
2522 diff,
2523 );
2524 }
2525
2526 #[test]
2530 fn leaf_grad_hess_returns_sums() {
2531 let config = TreeConfig::new().grace_period(100).lambda(1.0);
2532 let mut tree = HoeffdingTree::new(config);
2533
2534 let features = vec![1.0, 2.0, 3.0];
2535
2536 for _ in 0..10 {
2538 tree.train_one(&features, -0.5, 1.0);
2539 }
2540
2541 let root = tree.root();
2543 let (grad, hess) = tree
2544 .leaf_grad_hess(root)
2545 .expect("root should have leaf state");
2546
2547 assert!(
2549 (grad - (-5.0)).abs() < 1e-10,
2550 "grad_sum should be -5.0, got {}",
2551 grad
2552 );
2553 assert!(
2555 (hess - 10.0).abs() < 1e-10,
2556 "hess_sum should be 10.0, got {}",
2557 hess
2558 );
2559 }
2560
2561 #[test]
2562 fn leaf_grad_hess_returns_none_for_invalid_node() {
2563 let config = TreeConfig::new();
2564 let tree = HoeffdingTree::new(config);
2565
2566 assert!(tree.leaf_grad_hess(NodeId::NONE).is_none());
2568 assert!(tree.leaf_grad_hess(NodeId(999)).is_none());
2570 }
2571
2572 #[test]
2577 fn adaptive_depth_none_identical_to_static_max_depth() {
2578 let config_static = TreeConfig::new()
2581 .max_depth(3)
2582 .n_bins(32)
2583 .grace_period(20)
2584 .lambda(0.1)
2585 .delta(1e-3);
2586
2587 let config_none = TreeConfig::new()
2588 .max_depth(3)
2589 .n_bins(32)
2590 .grace_period(20)
2591 .lambda(0.1)
2592 .delta(1e-3);
2593
2594 assert!(config_none.adaptive_depth.is_none());
2596
2597 let mut tree_static = HoeffdingTree::new(config_static);
2598 let mut tree_none = HoeffdingTree::new(config_none);
2599
2600 let mut rng_state: u64 = 42;
2601 for _ in 0..2000 {
2602 let x = test_rand_f64(&mut rng_state) * 10.0;
2603 let y = 2.0 * x;
2604 let feat = [x, x * 0.5, x * x];
2605 let pred_s = tree_static.predict(&feat);
2606 let pred_n = tree_none.predict(&feat);
2607 tree_static.train_one(&feat, pred_s - y, 1.0);
2608 tree_none.train_one(&feat, pred_n - y, 1.0);
2609 }
2610
2611 assert_eq!(
2613 tree_static.arena().n_nodes(),
2614 tree_none.arena().n_nodes(),
2615 "adaptive_depth=None should produce identical tree structure to static max_depth"
2616 );
2617 }
2618
2619 #[test]
2620 fn adaptive_depth_few_samples_stays_shallow() {
2621 let config = TreeConfig::new()
2624 .max_depth(6)
2625 .n_bins(32)
2626 .grace_period(20)
2627 .lambda(0.1)
2628 .delta(1e-3)
2629 .adaptive_depth(7.5);
2630
2631 let mut tree = HoeffdingTree::new(config);
2632 let mut rng_state: u64 = 99;
2633
2634 for _ in 0..100 {
2636 let x = test_rand_f64(&mut rng_state) * 10.0;
2637 let noise = (test_rand_f64(&mut rng_state) - 0.5) * 20.0; let y = 0.1 * x + noise;
2639 let feat = [x, test_rand_f64(&mut rng_state) * 5.0];
2640 let pred = tree.predict(&feat);
2641 tree.train_one(&feat, pred - y, 1.0);
2642 }
2643
2644 let n_nodes = tree.arena().n_nodes();
2648 assert!(
2649 n_nodes <= 15,
2650 "adaptive_depth with few noisy samples should keep tree shallow, got {} nodes",
2651 n_nodes
2652 );
2653 }
2654
2655 #[test]
2656 fn adaptive_depth_many_samples_grows_deeper() {
2657 let config_few = TreeConfig::new()
2660 .max_depth(6)
2661 .n_bins(32)
2662 .grace_period(20)
2663 .lambda(0.1)
2664 .delta(1e-3)
2665 .adaptive_depth(7.5);
2666
2667 let config_many = TreeConfig::new()
2668 .max_depth(6)
2669 .n_bins(32)
2670 .grace_period(20)
2671 .lambda(0.1)
2672 .delta(1e-3)
2673 .adaptive_depth(7.5);
2674
2675 let mut tree_few = HoeffdingTree::new(config_few);
2676 let mut tree_many = HoeffdingTree::new(config_many);
2677
2678 let mut rng_state: u64 = 42;
2679
2680 for _ in 0..200 {
2683 let x1 = test_rand_f64(&mut rng_state) * 10.0;
2684 let x2 = test_rand_f64(&mut rng_state) * 5.0;
2685 let y = 3.0 * x1 + 2.0 * x2;
2686 let feat = [x1, x2];
2687 let pred = tree_few.predict(&feat);
2688 tree_few.train_one(&feat, pred - y, 1.0);
2689 }
2690
2691 let mut rng_state2: u64 = 42;
2693 for _ in 0..5000 {
2694 let x1 = test_rand_f64(&mut rng_state2) * 10.0;
2695 let x2 = test_rand_f64(&mut rng_state2) * 5.0;
2696 let y = 3.0 * x1 + 2.0 * x2;
2697 let feat = [x1, x2];
2698 let pred = tree_many.predict(&feat);
2699 tree_many.train_one(&feat, pred - y, 1.0);
2700 }
2701
2702 assert!(
2705 tree_many.arena().n_nodes() >= tree_few.arena().n_nodes(),
2706 "more samples should allow deeper growth: many={} vs few={}",
2707 tree_many.arena().n_nodes(),
2708 tree_few.arena().n_nodes()
2709 );
2710 }
2711
2712 #[test]
2713 fn adaptive_depth_penalty_scales_inversely_with_n() {
2714 let cir_factor: f64 = 7.5;
2720 let grad_var: f64 = 1.0;
2721 let n_feat: f64 = 2.0;
2722
2723 let penalty_100 = cir_factor * grad_var / 100.0 * n_feat;
2724 let penalty_1000 = cir_factor * grad_var / 1000.0 * n_feat;
2725
2726 assert!(
2727 (penalty_100 - 0.15).abs() < 1e-10,
2728 "penalty at n=100 should be 0.15, got {}",
2729 penalty_100
2730 );
2731 assert!(
2732 (penalty_1000 - 0.015).abs() < 1e-10,
2733 "penalty at n=1000 should be 0.015, got {}",
2734 penalty_1000
2735 );
2736 assert!(
2737 penalty_100 > penalty_1000,
2738 "penalty should decrease with more samples"
2739 );
2740
2741 let gain = 0.05;
2743 assert!(gain <= penalty_100, "gain should fail CIR at n=100");
2744 assert!(gain > penalty_1000, "gain should pass CIR at n=1000");
2745 }
2746
2747 #[test]
2748 fn adaptive_depth_hard_ceiling_respected() {
2749 let config = TreeConfig::new()
2751 .max_depth(3)
2752 .n_bins(32)
2753 .grace_period(10)
2754 .lambda(0.01)
2755 .gamma(0.0)
2756 .delta(1e-2) .adaptive_depth(0.001); let mut tree = HoeffdingTree::new(config);
2760 let mut rng_state: u64 = 777;
2761
2762 for _ in 0..10000 {
2764 let x = test_rand_f64(&mut rng_state) * 100.0;
2765 let y = x * x; let feat = [x];
2767 let pred = tree.predict(&feat);
2768 tree.train_one(&feat, pred - y, 1.0);
2769 }
2770
2771 let max_leaves = 1usize << 6;
2773 let n_leaves = tree.arena().n_leaves();
2774 assert!(
2775 n_leaves <= max_leaves,
2776 "tree should respect hard ceiling of max_depth*2=6 ({} max leaves), got {} leaves",
2777 max_leaves,
2778 n_leaves
2779 );
2780 }
2781
2782 #[test]
2787 fn soft_routed_prediction_finite() {
2788 let config = TreeConfig::new().grace_period(20).max_depth(4).n_bins(16);
2789 let mut tree = HoeffdingTree::new(config);
2790 let mut rng: u64 = 42;
2791
2792 for _ in 0..2000 {
2794 let x = test_rand_f64(&mut rng) * 10.0;
2795 let y = x * 2.0 + 1.0;
2796 let feat = [x, x * 0.5];
2797 let pred = tree.predict(&feat);
2798 tree.train_one(&feat, pred - y, 1.0);
2799 }
2800
2801 for i in 0..20 {
2803 let x = i as f64 * 0.5;
2804 let pred = tree.predict_soft_routed(&[x, x * 0.5]);
2805 assert!(
2806 pred.is_finite(),
2807 "soft-routed prediction should be finite at x={}, got {}",
2808 x,
2809 pred
2810 );
2811 }
2812 }
2813
2814 #[test]
2815 fn soft_routed_smoother_than_hard() {
2816 let config = TreeConfig::new()
2817 .grace_period(20)
2818 .max_depth(4)
2819 .n_bins(32)
2820 .delta(1e-2);
2821 let mut tree = HoeffdingTree::new(config);
2822 let mut rng: u64 = 123;
2823
2824 for _ in 0..5000 {
2826 let x = test_rand_f64(&mut rng) * 6.283; let y = math::sin(x);
2828 let feat = [x];
2829 let pred = tree.predict(&feat);
2830 tree.train_one(&feat, pred - y, 1.0);
2831 }
2832
2833 if tree.arena().n_nodes() < 3 {
2835 return; }
2837
2838 let n_points = 100;
2840 let hard_preds: Vec<f64> = (0..n_points)
2841 .map(|i| {
2842 let x = i as f64 * 6.283 / n_points as f64;
2843 tree.predict(&[x])
2844 })
2845 .collect();
2846 let soft_preds: Vec<f64> = (0..n_points)
2847 .map(|i| {
2848 let x = i as f64 * 6.283 / n_points as f64;
2849 tree.predict_soft_routed(&[x])
2850 })
2851 .collect();
2852
2853 let hard_tv: f64 = hard_preds.windows(2).map(|w| math::abs(w[1] - w[0])).sum();
2854 let soft_tv: f64 = soft_preds.windows(2).map(|w| math::abs(w[1] - w[0])).sum();
2855
2856 assert!(
2857 soft_tv <= hard_tv + 1e-10,
2858 "soft routing should have <= total variation than hard: soft_tv={}, hard_tv={}",
2859 soft_tv,
2860 hard_tv
2861 );
2862 }
2863
2864 #[test]
2865 fn node_bandwidths_computed_after_splits() {
2866 let config = TreeConfig::new().grace_period(20).max_depth(4).n_bins(16);
2867 let mut tree = HoeffdingTree::new(config);
2868 let mut rng: u64 = 999;
2869
2870 for _ in 0..3000 {
2872 let x = test_rand_f64(&mut rng) * 10.0;
2873 let y = x * x;
2874 let feat = [x];
2875 let pred = tree.predict(&feat);
2876 tree.train_one(&feat, pred - y, 1.0);
2877 }
2878
2879 if tree.arena().n_nodes() < 3 {
2880 return; }
2882
2883 assert!(
2885 !tree.node_bandwidths.is_empty(),
2886 "node_bandwidths should be non-empty after splits"
2887 );
2888
2889 let mut found_finite = false;
2891 for i in 0..tree.arena().n_nodes() {
2892 let nid = NodeId(i as u32);
2893 if !tree.arena().is_leaf(nid) {
2894 let bw = tree.node_bandwidths[i];
2895 assert!(
2896 bw > 0.0,
2897 "bandwidth for internal node {} should be > 0, got {}",
2898 i,
2899 bw
2900 );
2901 if bw.is_finite() {
2902 found_finite = true;
2903 }
2904 }
2905 }
2906 assert!(
2907 found_finite,
2908 "at least one internal node should have a finite bandwidth"
2909 );
2910 }
2911
2912 #[test]
2913 fn soft_routed_agrees_at_training_points() {
2914 let config = TreeConfig::new().grace_period(20).max_depth(3).n_bins(16);
2915 let mut tree = HoeffdingTree::new(config);
2916
2917 let training_data: Vec<(f64, f64)> = (0..500)
2919 .map(|i| {
2920 let x = i as f64 * 0.02;
2921 let y = x * 3.0 + 1.0;
2922 (x, y)
2923 })
2924 .collect();
2925
2926 for &(x, y) in &training_data {
2927 let feat = [x];
2928 let pred = tree.predict(&feat);
2929 tree.train_one(&feat, pred - y, 1.0);
2930 }
2931
2932 let mut total_diff = 0.0;
2934 let mut count = 0;
2935 for &(x, _) in &training_data {
2936 let feat = [x];
2937 let hard = tree.predict(&feat);
2938 let soft = tree.predict_soft_routed(&feat);
2939 if hard.is_finite() && soft.is_finite() {
2940 total_diff += math::abs(hard - soft);
2941 count += 1;
2942 }
2943 }
2944
2945 if count > 0 {
2946 let mean_diff = total_diff / count as f64;
2947 assert!(
2950 mean_diff < 5.0,
2951 "mean difference between hard and soft should be reasonable, got {}",
2952 mean_diff
2953 );
2954 }
2955 }
2956}