1use alloc::boxed::Box;
22use alloc::vec;
23use alloc::vec::Vec;
24
25use crate::feature::FeatureType;
26use crate::histogram::bins::LeafHistograms;
27use crate::histogram::{BinEdges, BinnerKind};
28use crate::math;
29use crate::tree::builder::TreeConfig;
30use crate::tree::leaf_model::{LeafModel, LeafModelType};
31use crate::tree::node::{NodeId, TreeArena};
32use crate::tree::split::{leaf_weight, SplitCandidate, SplitCriterion, XGBoostGain};
33use crate::tree::StreamingTree;
34
35const TAU: f64 = 0.05;
39
40#[inline]
46fn xorshift64(state: &mut u64) -> u64 {
47 let mut s = *state;
48 s ^= s << 13;
49 s ^= s >> 7;
50 s ^= s << 17;
51 *state = s;
52 s
53}
54
55struct LeafState {
64 histograms: Option<LeafHistograms>,
67
68 binners: Vec<BinnerKind>,
72
73 bins_ready: bool,
75
76 grad_sum: f64,
78
79 hess_sum: f64,
81
82 last_reeval_count: u64,
84
85 clip_grad_mean: f64,
87
88 clip_grad_m2: f64,
90
91 clip_grad_count: u64,
93
94 output_mean: f64,
96
97 output_m2: f64,
99
100 output_count: u64,
102
103 leaf_model: Option<Box<dyn LeafModel>>,
105}
106
107impl Clone for LeafState {
108 fn clone(&self) -> Self {
109 Self {
110 histograms: self.histograms.clone(),
111 binners: self.binners.clone(),
112 bins_ready: self.bins_ready,
113 grad_sum: self.grad_sum,
114 hess_sum: self.hess_sum,
115 last_reeval_count: self.last_reeval_count,
116 clip_grad_mean: self.clip_grad_mean,
117 clip_grad_m2: self.clip_grad_m2,
118 clip_grad_count: self.clip_grad_count,
119 output_mean: self.output_mean,
120 output_m2: self.output_m2,
121 output_count: self.output_count,
122 leaf_model: self.leaf_model.as_ref().map(|m| m.clone_warm()),
123 }
124 }
125}
126
127#[inline]
133fn clip_gradient(state: &mut LeafState, gradient: f64, sigma: f64) -> f64 {
134 state.clip_grad_count += 1;
135 let n = state.clip_grad_count as f64;
136
137 let delta = gradient - state.clip_grad_mean;
139 state.clip_grad_mean += delta / n;
140 let delta2 = gradient - state.clip_grad_mean;
141 state.clip_grad_m2 += delta * delta2;
142
143 if state.clip_grad_count < 10 {
145 return gradient;
146 }
147
148 let variance = state.clip_grad_m2 / (n - 1.0);
149 let std_dev = math::sqrt(variance);
150
151 if std_dev < 1e-15 {
152 return gradient; }
154
155 let lo = state.clip_grad_mean - sigma * std_dev;
156 let hi = state.clip_grad_mean + sigma * std_dev;
157 gradient.clamp(lo, hi)
158}
159
160#[inline]
165fn update_output_stats(state: &mut LeafState, weight: f64, decay_alpha: Option<f64>) {
166 state.output_count += 1;
167
168 if let Some(alpha) = decay_alpha {
169 if state.output_count == 1 {
171 state.output_mean = weight;
172 state.output_m2 = 0.0;
173 } else {
174 let diff = weight - state.output_mean;
175 state.output_mean = alpha * state.output_mean + (1.0 - alpha) * weight;
176 let diff2 = weight - state.output_mean;
177 state.output_m2 = alpha * state.output_m2 + (1.0 - alpha) * diff * diff2;
178 }
179 } else {
180 let delta = weight - state.output_mean;
182 state.output_mean += delta / (state.output_count as f64);
183 let delta2 = weight - state.output_mean;
184 state.output_m2 += delta * delta2;
185 }
186}
187
188#[inline]
193fn adaptive_bound(state: &LeafState, k: f64, decay_alpha: Option<f64>) -> f64 {
194 if state.output_count < 10 {
195 return f64::MAX; }
197
198 let variance = if decay_alpha.is_some() {
199 state.output_m2.max(0.0)
201 } else {
202 state.output_m2 / (state.output_count as f64 - 1.0)
204 };
205 let std = math::sqrt(variance);
206
207 (math::abs(state.output_mean) + k * std).max(0.01)
209}
210
211fn make_binners(n_features: usize, feature_types: Option<&[FeatureType]>) -> Vec<BinnerKind> {
213 (0..n_features)
214 .map(|i| {
215 if let Some(ft) = feature_types {
216 if i < ft.len() && ft[i] == FeatureType::Categorical {
217 return BinnerKind::categorical();
218 }
219 }
220 BinnerKind::uniform()
221 })
222 .collect()
223}
224
225impl LeafState {
226 fn new(n_features: usize) -> Self {
228 Self::new_with_types(n_features, None)
229 }
230
231 fn new_with_types(n_features: usize, feature_types: Option<&[FeatureType]>) -> Self {
233 let binners = make_binners(n_features, feature_types);
234
235 Self {
236 histograms: None,
237 binners,
238 bins_ready: false,
239 grad_sum: 0.0,
240 hess_sum: 0.0,
241 last_reeval_count: 0,
242 clip_grad_mean: 0.0,
243 clip_grad_m2: 0.0,
244 clip_grad_count: 0,
245 output_mean: 0.0,
246 output_m2: 0.0,
247 output_count: 0,
248 leaf_model: None,
249 }
250 }
251
252 #[allow(dead_code)]
255 fn with_histograms(histograms: LeafHistograms) -> Self {
256 let n_features = histograms.n_features();
257 let binners: Vec<BinnerKind> = (0..n_features).map(|_| BinnerKind::uniform()).collect();
258
259 let grad_sum: f64 = histograms
261 .histograms
262 .first()
263 .map_or(0.0, |h| h.total_gradient());
264 let hess_sum: f64 = histograms
265 .histograms
266 .first()
267 .map_or(0.0, |h| h.total_hessian());
268
269 Self {
270 histograms: Some(histograms),
271 binners,
272 bins_ready: true,
273 grad_sum,
274 hess_sum,
275 last_reeval_count: 0,
276 clip_grad_mean: 0.0,
277 clip_grad_m2: 0.0,
278 clip_grad_count: 0,
279 output_mean: 0.0,
280 output_m2: 0.0,
281 output_count: 0,
282 leaf_model: None,
283 }
284 }
285}
286
287pub struct HoeffdingTree {
303 arena: TreeArena,
305
306 root: NodeId,
308
309 config: TreeConfig,
311
312 leaf_states: Vec<Option<LeafState>>,
315
316 n_features: Option<usize>,
318
319 samples_seen: u64,
321
322 split_criterion: XGBoostGain,
324
325 feature_mask: Vec<usize>,
327
328 feature_mask_bits: Vec<u64>,
331
332 rng_state: u64,
334
335 split_gains: Vec<f64>,
338}
339
340impl HoeffdingTree {
341 pub fn new(config: TreeConfig) -> Self {
346 let mut arena = TreeArena::new();
347 let root = arena.add_leaf(0);
348
349 let mut leaf_states = vec![None; root.0 as usize + 1];
353 let root_model = match config.leaf_model_type {
354 LeafModelType::ClosedForm => None,
355 _ => Some(config.leaf_model_type.create(config.seed, config.delta)),
356 };
357 leaf_states[root.0 as usize] = Some(LeafState {
358 histograms: None,
359 binners: Vec::new(),
360 bins_ready: false,
361 grad_sum: 0.0,
362 hess_sum: 0.0,
363 last_reeval_count: 0,
364 clip_grad_mean: 0.0,
365 clip_grad_m2: 0.0,
366 clip_grad_count: 0,
367 output_mean: 0.0,
368 output_m2: 0.0,
369 output_count: 0,
370 leaf_model: root_model,
371 });
372
373 let seed = config.seed;
374 Self {
375 arena,
376 root,
377 config,
378 leaf_states,
379 n_features: None,
380 samples_seen: 0,
381 split_criterion: XGBoostGain::default(),
382 feature_mask: Vec::new(),
383 feature_mask_bits: Vec::new(),
384 rng_state: seed,
385 split_gains: Vec::new(),
386 }
387 }
388
389 fn make_leaf_model(&self, node: NodeId) -> Option<Box<dyn LeafModel>> {
395 match self.config.leaf_model_type {
396 LeafModelType::ClosedForm => None,
397 _ => Some(
398 self.config
399 .leaf_model_type
400 .create(self.config.seed ^ (node.0 as u64), self.config.delta),
401 ),
402 }
403 }
404
405 pub fn from_arena(
414 config: TreeConfig,
415 arena: TreeArena,
416 n_features: Option<usize>,
417 samples_seen: u64,
418 rng_state: u64,
419 ) -> Self {
420 let root = if arena.n_nodes() > 0 {
421 NodeId(0)
422 } else {
423 let mut arena_mut = arena;
425 let root = arena_mut.add_leaf(0);
426 return Self {
427 arena: arena_mut,
428 root,
429 config: config.clone(),
430 leaf_states: {
431 let mut v = vec![None; root.0 as usize + 1];
432 v[root.0 as usize] = Some(LeafState::new(n_features.unwrap_or(0)));
433 v
434 },
435 n_features,
436 samples_seen,
437 split_criterion: XGBoostGain::default(),
438 feature_mask: Vec::new(),
439 feature_mask_bits: Vec::new(),
440 rng_state,
441 split_gains: vec![0.0; n_features.unwrap_or(0)],
442 };
443 };
444
445 let nf = n_features.unwrap_or(0);
447 let mut leaf_states: Vec<Option<LeafState>> = vec![None; arena.n_nodes()];
448 for (i, slot) in leaf_states.iter_mut().enumerate() {
449 if arena.is_leaf[i] {
450 *slot = Some(LeafState::new(nf));
451 }
452 }
453
454 Self {
455 arena,
456 root,
457 config,
458 leaf_states,
459 n_features,
460 samples_seen,
461 split_criterion: XGBoostGain::default(),
462 feature_mask: Vec::new(),
463 feature_mask_bits: Vec::new(),
464 rng_state,
465 split_gains: vec![0.0; nf],
466 }
467 }
468
469 #[inline]
471 pub fn root(&self) -> NodeId {
472 self.root
473 }
474
475 #[inline]
477 pub fn arena(&self) -> &TreeArena {
478 &self.arena
479 }
480
481 #[inline]
483 pub fn tree_config(&self) -> &TreeConfig {
484 &self.config
485 }
486
487 #[inline]
489 pub fn n_features(&self) -> Option<usize> {
490 self.n_features
491 }
492
493 #[inline]
495 pub fn rng_state(&self) -> u64 {
496 self.rng_state
497 }
498
499 #[inline]
509 pub fn leaf_grad_hess(&self, node: NodeId) -> Option<(f64, f64)> {
510 self.leaf_states
511 .get(node.0 as usize)
512 .and_then(|o| o.as_ref())
513 .map(|state| (state.grad_sum, state.hess_sum))
514 }
515
516 fn route_to_leaf(&self, features: &[f64]) -> NodeId {
518 let mut current = self.root;
519 while !self.arena.is_leaf(current) {
520 let feat_idx = self.arena.get_feature_idx(current) as usize;
521 current = if let Some(mask) = self.arena.get_categorical_mask(current) {
522 let cat_val = features[feat_idx] as u64;
529 if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
530 self.arena.get_left(current)
531 } else {
532 self.arena.get_right(current)
533 }
534 } else {
535 let threshold = self.arena.get_threshold(current);
537 if features[feat_idx] <= threshold {
538 self.arena.get_left(current)
539 } else {
540 self.arena.get_right(current)
541 }
542 };
543 }
544 current
545 }
546
547 #[inline]
552 fn leaf_prediction(&self, leaf_id: NodeId, features: &[f64]) -> f64 {
553 let (raw, leaf_bound) = if let Some(state) = self
554 .leaf_states
555 .get(leaf_id.0 as usize)
556 .and_then(|o| o.as_ref())
557 {
558 if let Some(min_h) = self.config.min_hessian_sum {
560 if state.hess_sum < min_h {
561 return 0.0;
562 }
563 }
564 let val = if let Some(ref model) = state.leaf_model {
565 model.predict(features)
566 } else if state.hess_sum != 0.0 {
567 leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda)
568 } else {
569 self.arena.leaf_value[leaf_id.0 as usize]
570 };
571
572 let bound = self
574 .config
575 .adaptive_leaf_bound
576 .map(|k| adaptive_bound(state, k, self.config.leaf_decay_alpha));
577
578 (val, bound)
579 } else {
580 (0.0, None)
581 };
582
583 if let Some(bound) = leaf_bound {
585 if bound < f64::MAX {
586 return raw.clamp(-bound, bound);
587 }
588 }
589 if let Some(max) = self.config.max_leaf_output {
590 raw.clamp(-max, max)
591 } else {
592 raw
593 }
594 }
595
596 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
612 self.predict_smooth_recursive(self.root, features, bandwidth)
613 }
614
615 pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
620 self.predict_smooth_auto_recursive(self.root, features, bandwidths)
621 }
622
623 pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
635 let mut current = self.root;
636 let mut parent = None;
637 while !self.arena.is_leaf(current) {
638 parent = Some(current);
639 let feat_idx = self.arena.get_feature_idx(current) as usize;
640 current = if let Some(mask) = self.arena.get_categorical_mask(current) {
641 let cat_val = features[feat_idx] as u64;
642 if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
643 self.arena.get_left(current)
644 } else {
645 self.arena.get_right(current)
646 }
647 } else {
648 let threshold = self.arena.get_threshold(current);
649 if features[feat_idx] <= threshold {
650 self.arena.get_left(current)
651 } else {
652 self.arena.get_right(current)
653 }
654 };
655 }
656
657 let leaf_pred = self.leaf_prediction(current, features);
658
659 let parent_id = match parent {
661 Some(p) => p,
662 None => return leaf_pred,
663 };
664
665 let parent_pred = self.leaf_prediction(parent_id, features);
667
668 let leaf_hess = self
670 .leaf_states
671 .get(current.0 as usize)
672 .and_then(|o| o.as_ref())
673 .map(|s| s.hess_sum)
674 .unwrap_or(0.0);
675
676 let alpha = leaf_hess / (leaf_hess + self.config.lambda);
677 alpha * leaf_pred + (1.0 - alpha) * parent_pred
678 }
679
680 pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
696 self.predict_sibling_recursive(self.root, features, bandwidths)
697 }
698
699 fn predict_sibling_recursive(&self, node: NodeId, features: &[f64], bandwidths: &[f64]) -> f64 {
700 if self.arena.is_leaf(node) {
701 return self.leaf_prediction(node, features);
702 }
703
704 let feat_idx = self.arena.get_feature_idx(node) as usize;
705 let left = self.arena.get_left(node);
706 let right = self.arena.get_right(node);
707
708 if let Some(mask) = self.arena.get_categorical_mask(node) {
710 let cat_val = features[feat_idx] as u64;
711 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
712 self.predict_sibling_recursive(left, features, bandwidths)
713 } else {
714 self.predict_sibling_recursive(right, features, bandwidths)
715 };
716 }
717
718 let threshold = self.arena.get_threshold(node);
719 let margin = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
720
721 if !margin.is_finite() || margin <= 0.0 {
723 return if features[feat_idx] <= threshold {
724 self.predict_sibling_recursive(left, features, bandwidths)
725 } else {
726 self.predict_sibling_recursive(right, features, bandwidths)
727 };
728 }
729
730 let dist = features[feat_idx] - threshold;
731
732 if dist < -margin {
733 self.predict_sibling_recursive(left, features, bandwidths)
735 } else if dist > margin {
736 self.predict_sibling_recursive(right, features, bandwidths)
738 } else {
739 let t = (dist + margin) / (2.0 * margin); let left_pred = self.predict_sibling_recursive(left, features, bandwidths);
742 let right_pred = self.predict_sibling_recursive(right, features, bandwidths);
743 (1.0 - t) * left_pred + t * right_pred
744 }
745 }
746
747 pub fn collect_split_thresholds_per_feature(&self) -> Vec<Vec<f64>> {
752 let n = self.n_features.unwrap_or(0);
753 let mut thresholds: Vec<Vec<f64>> = vec![Vec::new(); n];
754
755 for i in 0..self.arena.n_nodes() {
756 if !self.arena.is_leaf[i] && self.arena.categorical_mask[i].is_none() {
757 let feat_idx = self.arena.feature_idx[i] as usize;
758 if feat_idx < n {
759 thresholds[feat_idx].push(self.arena.threshold[i]);
760 }
761 }
762 }
763
764 thresholds
765 }
766
767 fn predict_smooth_recursive(&self, node: NodeId, features: &[f64], bandwidth: f64) -> f64 {
769 if self.arena.is_leaf(node) {
770 return self.leaf_prediction(node, features);
772 }
773
774 let feat_idx = self.arena.get_feature_idx(node) as usize;
775 let left = self.arena.get_left(node);
776 let right = self.arena.get_right(node);
777
778 if let Some(mask) = self.arena.get_categorical_mask(node) {
780 let cat_val = features[feat_idx] as u64;
781 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
782 self.predict_smooth_recursive(left, features, bandwidth)
783 } else {
784 self.predict_smooth_recursive(right, features, bandwidth)
785 };
786 }
787
788 let threshold = self.arena.get_threshold(node);
790 let z = (threshold - features[feat_idx]) / bandwidth;
791 let alpha = 1.0 / (1.0 + math::exp(-z));
792
793 let left_pred = self.predict_smooth_recursive(left, features, bandwidth);
794 let right_pred = self.predict_smooth_recursive(right, features, bandwidth);
795
796 alpha * left_pred + (1.0 - alpha) * right_pred
797 }
798
799 fn predict_smooth_auto_recursive(
801 &self,
802 node: NodeId,
803 features: &[f64],
804 bandwidths: &[f64],
805 ) -> f64 {
806 if self.arena.is_leaf(node) {
807 return self.leaf_prediction(node, features);
808 }
809
810 let feat_idx = self.arena.get_feature_idx(node) as usize;
811 let left = self.arena.get_left(node);
812 let right = self.arena.get_right(node);
813
814 if let Some(mask) = self.arena.get_categorical_mask(node) {
816 let cat_val = features[feat_idx] as u64;
817 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
818 self.predict_smooth_auto_recursive(left, features, bandwidths)
819 } else {
820 self.predict_smooth_auto_recursive(right, features, bandwidths)
821 };
822 }
823
824 let threshold = self.arena.get_threshold(node);
825 let bw = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
826
827 if !bw.is_finite() {
829 return if features[feat_idx] <= threshold {
830 self.predict_smooth_auto_recursive(left, features, bandwidths)
831 } else {
832 self.predict_smooth_auto_recursive(right, features, bandwidths)
833 };
834 }
835
836 let z = (threshold - features[feat_idx]) / bw;
838 let alpha = 1.0 / (1.0 + math::exp(-z));
839
840 let left_pred = self.predict_smooth_auto_recursive(left, features, bandwidths);
841 let right_pred = self.predict_smooth_auto_recursive(right, features, bandwidths);
842
843 alpha * left_pred + (1.0 - alpha) * right_pred
844 }
845
846 fn generate_feature_mask(&mut self, n_features: usize) {
854 self.feature_mask.clear();
855
856 if self.config.feature_subsample_rate >= 1.0 {
857 self.feature_mask.extend(0..n_features);
858 } else {
859 let target_count =
860 crate::math::ceil((n_features as f64) * self.config.feature_subsample_rate)
861 as usize;
862 let target_count = target_count.max(1).min(n_features);
863
864 let n_words = n_features.div_ceil(64);
866 self.feature_mask_bits.clear();
867 self.feature_mask_bits.resize(n_words, 0u64);
868
869 for i in 0..n_features {
871 let r = xorshift64(&mut self.rng_state);
872 let p = (r as f64) / (u64::MAX as f64);
873 if p < self.config.feature_subsample_rate {
874 self.feature_mask.push(i);
875 self.feature_mask_bits[i / 64] |= 1u64 << (i % 64);
876 }
877 }
878
879 if self.feature_mask.len() < target_count {
882 for i in 0..n_features {
883 if self.feature_mask.len() >= target_count {
884 break;
885 }
886 if self.feature_mask_bits[i / 64] & (1u64 << (i % 64)) == 0 {
887 self.feature_mask.push(i);
888 self.feature_mask_bits[i / 64] |= 1u64 << (i % 64);
889 }
890 }
891 }
892 }
893 }
894
895 fn attempt_split(&mut self, leaf_id: NodeId) -> bool {
899 let depth = self.arena.get_depth(leaf_id);
900 let at_max_depth = depth as usize >= self.config.max_depth;
901
902 if at_max_depth {
903 match self.config.split_reeval_interval {
906 None => return false,
907 Some(interval) => {
908 let state = match self
909 .leaf_states
910 .get(leaf_id.0 as usize)
911 .and_then(|o| o.as_ref())
912 {
913 Some(s) => s,
914 None => return false,
915 };
916 let sample_count = self.arena.get_sample_count(leaf_id);
917 if sample_count - state.last_reeval_count < interval as u64 {
918 return false;
919 }
920 }
922 }
923 }
924
925 let n_features = match self.n_features {
926 Some(n) => n,
927 None => return false,
928 };
929
930 let sample_count = self.arena.get_sample_count(leaf_id);
931 if sample_count < self.config.grace_period as u64 {
932 return false;
933 }
934
935 self.generate_feature_mask(n_features);
937
938 if self.config.leaf_decay_alpha.is_some() {
943 if let Some(state) = self
944 .leaf_states
945 .get_mut(leaf_id.0 as usize)
946 .and_then(|o| o.as_mut())
947 {
948 if let Some(ref mut histograms) = state.histograms {
949 histograms.materialize_decay();
950 }
951 }
952 }
953
954 let state = match self
958 .leaf_states
959 .get(leaf_id.0 as usize)
960 .and_then(|o| o.as_ref())
961 {
962 Some(s) => s,
963 None => return false,
964 };
965
966 let histograms = match &state.histograms {
967 Some(h) => h,
968 None => return false,
969 };
970
971 let feature_types = &self.config.feature_types;
975 let mut candidates: Vec<(usize, SplitCandidate, Option<Vec<usize>>)> = Vec::new();
976
977 for &feat_idx in &self.feature_mask {
978 if feat_idx >= histograms.n_features() {
979 continue;
980 }
981 let hist = &histograms.histograms[feat_idx];
982 let total_grad = hist.total_gradient();
983 let total_hess = hist.total_hessian();
984
985 let is_categorical = feature_types
986 .as_ref()
987 .is_some_and(|ft| feat_idx < ft.len() && ft[feat_idx] == FeatureType::Categorical);
988
989 if is_categorical {
990 let n_bins = hist.grad_sums.len();
995 if n_bins < 2 {
996 continue;
997 }
998
999 let mut bin_order: Vec<usize> = (0..n_bins)
1001 .filter(|&i| math::abs(hist.hess_sums[i]) > 1e-15)
1002 .collect();
1003
1004 if bin_order.len() < 2 {
1005 continue;
1006 }
1007
1008 bin_order.sort_by(|&a, &b| {
1010 let ratio_a = hist.grad_sums[a] / hist.hess_sums[a];
1011 let ratio_b = hist.grad_sums[b] / hist.hess_sums[b];
1012 ratio_a
1013 .partial_cmp(&ratio_b)
1014 .unwrap_or(core::cmp::Ordering::Equal)
1015 });
1016
1017 let sorted_grads: Vec<f64> = bin_order.iter().map(|&i| hist.grad_sums[i]).collect();
1019 let sorted_hess: Vec<f64> = bin_order.iter().map(|&i| hist.hess_sums[i]).collect();
1020
1021 if let Some(candidate) = self.split_criterion.evaluate(
1022 &sorted_grads,
1023 &sorted_hess,
1024 total_grad,
1025 total_hess,
1026 self.config.gamma,
1027 self.config.lambda,
1028 ) {
1029 candidates.push((feat_idx, candidate, Some(bin_order)));
1030 }
1031 } else {
1032 if let Some(candidate) = self.split_criterion.evaluate(
1034 &hist.grad_sums,
1035 &hist.hess_sums,
1036 total_grad,
1037 total_hess,
1038 self.config.gamma,
1039 self.config.lambda,
1040 ) {
1041 candidates.push((feat_idx, candidate, None));
1042 }
1043 }
1044 }
1045
1046 if let Some(ref mc) = self.config.monotone_constraints {
1048 candidates.retain(|(feat_idx, candidate, _)| {
1049 if *feat_idx >= mc.len() {
1050 return true; }
1052 let constraint = mc[*feat_idx];
1053 if constraint == 0 {
1054 return true; }
1056
1057 let left_val =
1058 leaf_weight(candidate.left_grad, candidate.left_hess, self.config.lambda);
1059 let right_val = leaf_weight(
1060 candidate.right_grad,
1061 candidate.right_hess,
1062 self.config.lambda,
1063 );
1064
1065 if constraint > 0 {
1066 left_val <= right_val
1068 } else {
1069 left_val >= right_val
1071 }
1072 });
1073 }
1074
1075 if candidates.is_empty() {
1076 return false;
1077 }
1078
1079 candidates.sort_by(|a, b| {
1081 b.1.gain
1082 .partial_cmp(&a.1.gain)
1083 .unwrap_or(core::cmp::Ordering::Equal)
1084 });
1085
1086 let best_gain = candidates[0].1.gain;
1087 let second_best_gain = if candidates.len() > 1 {
1088 candidates[1].1.gain
1089 } else {
1090 0.0
1091 };
1092
1093 let r_squared = 1.0;
1100 let n = sample_count as f64;
1101 let effective_n = match self.config.leaf_decay_alpha {
1102 Some(alpha) => n.min(1.0 / (1.0 - alpha)),
1103 None => n,
1104 };
1105 let ln_inv_delta = math::ln(1.0 / self.config.delta);
1106 let epsilon = math::sqrt(r_squared * ln_inv_delta / (2.0 * effective_n));
1107
1108 let gap = best_gain - second_best_gain;
1111 if gap <= epsilon && epsilon >= TAU {
1112 if at_max_depth {
1115 if let Some(state) = self
1116 .leaf_states
1117 .get_mut(leaf_id.0 as usize)
1118 .and_then(|o| o.as_mut())
1119 {
1120 state.last_reeval_count = sample_count;
1121 }
1122 }
1123 return false;
1124 }
1125
1126 let (best_feat_idx, ref best_candidate, ref fisher_order) = candidates[0];
1128
1129 if best_feat_idx < self.split_gains.len() {
1131 self.split_gains[best_feat_idx] += best_candidate.gain;
1132 }
1133
1134 let best_hist = &histograms.histograms[best_feat_idx];
1135
1136 let left_value = leaf_weight(
1137 best_candidate.left_grad,
1138 best_candidate.left_hess,
1139 self.config.lambda,
1140 );
1141 let right_value = leaf_weight(
1142 best_candidate.right_grad,
1143 best_candidate.right_hess,
1144 self.config.lambda,
1145 );
1146
1147 let (left_id, right_id) = if let Some(ref order) = fisher_order {
1149 let mut mask: u64 = 0;
1156 for &sorted_pos in order.iter().take(best_candidate.bin_idx + 1) {
1157 if sorted_pos < 64 {
1161 mask |= 1u64 << sorted_pos;
1162 }
1163 }
1164
1165 self.arena.split_leaf_categorical(
1167 leaf_id,
1168 best_feat_idx as u32,
1169 0.0,
1170 left_value,
1171 right_value,
1172 mask,
1173 )
1174 } else {
1175 let threshold = if best_candidate.bin_idx < best_hist.edges.edges.len() {
1177 best_hist.edges.edges[best_candidate.bin_idx]
1178 } else {
1179 f64::MAX
1180 };
1181
1182 self.arena.split_leaf(
1183 leaf_id,
1184 best_feat_idx as u32,
1185 threshold,
1186 left_value,
1187 right_value,
1188 )
1189 };
1190
1191 let parent_state = self
1217 .leaf_states
1218 .get_mut(leaf_id.0 as usize)
1219 .and_then(|o| o.take());
1220 let nf = n_features;
1221
1222 let max_child = left_id.0.max(right_id.0) as usize;
1224 if self.leaf_states.len() <= max_child {
1225 self.leaf_states.resize_with(max_child + 1, || None);
1226 }
1227
1228 if let Some(parent) = parent_state {
1229 if let Some(parent_hists) = parent.histograms {
1230 let edges_per_feature: Vec<BinEdges> = parent_hists
1232 .histograms
1233 .iter()
1234 .map(|h| h.edges.clone())
1235 .collect();
1236
1237 let left_hists = LeafHistograms::new(&edges_per_feature);
1246 let right_hists = LeafHistograms::new(&edges_per_feature);
1247
1248 let ft = self.config.feature_types.as_deref();
1249 let child_binners_l = make_binners(nf, ft);
1250 let child_binners_r = make_binners(nf, ft);
1251
1252 let left_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1257 let right_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1258
1259 let left_state = LeafState {
1260 histograms: Some(left_hists),
1261 binners: child_binners_l,
1262 bins_ready: true,
1263 grad_sum: 0.0,
1264 hess_sum: 0.0,
1265 last_reeval_count: 0,
1266 clip_grad_mean: 0.0,
1267 clip_grad_m2: 0.0,
1268 clip_grad_count: 0,
1269 output_mean: 0.0,
1270 output_m2: 0.0,
1271 output_count: 0,
1272 leaf_model: left_model,
1273 };
1274
1275 let right_state = LeafState {
1276 histograms: Some(right_hists),
1277 binners: child_binners_r,
1278 bins_ready: true,
1279 grad_sum: 0.0,
1280 hess_sum: 0.0,
1281 last_reeval_count: 0,
1282 clip_grad_mean: 0.0,
1283 clip_grad_m2: 0.0,
1284 clip_grad_count: 0,
1285 output_mean: 0.0,
1286 output_m2: 0.0,
1287 output_count: 0,
1288 leaf_model: right_model,
1289 };
1290
1291 self.leaf_states[left_id.0 as usize] = Some(left_state);
1292 self.leaf_states[right_id.0 as usize] = Some(right_state);
1293 } else {
1294 let ft = self.config.feature_types.as_deref();
1296 let mut ls = LeafState::new_with_types(nf, ft);
1297 ls.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1298 self.leaf_states[left_id.0 as usize] = Some(ls);
1299 let mut rs = LeafState::new_with_types(nf, ft);
1300 rs.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1301 self.leaf_states[right_id.0 as usize] = Some(rs);
1302 }
1303 } else {
1304 let ft = self.config.feature_types.as_deref();
1306 let mut ls = LeafState::new_with_types(nf, ft);
1307 ls.leaf_model = self.make_leaf_model(left_id);
1308 self.leaf_states[left_id.0 as usize] = Some(ls);
1309 let mut rs = LeafState::new_with_types(nf, ft);
1310 rs.leaf_model = self.make_leaf_model(right_id);
1311 self.leaf_states[right_id.0 as usize] = Some(rs);
1312 }
1313
1314 true
1315 }
1316}
1317
1318impl StreamingTree for HoeffdingTree {
1319 fn train_one(&mut self, features: &[f64], gradient: f64, hessian: f64) {
1324 self.samples_seen += 1;
1325
1326 let n_features = if let Some(n) = self.n_features {
1328 n
1329 } else {
1330 let n = features.len();
1331 self.n_features = Some(n);
1332 self.split_gains.resize(n, 0.0);
1333
1334 if let Some(state) = self
1336 .leaf_states
1337 .get_mut(self.root.0 as usize)
1338 .and_then(|o| o.as_mut())
1339 {
1340 state.binners = make_binners(n, self.config.feature_types.as_deref());
1341 }
1342 n
1343 };
1344
1345 debug_assert_eq!(
1346 features.len(),
1347 n_features,
1348 "feature count mismatch: got {} but expected {}",
1349 features.len(),
1350 n_features,
1351 );
1352
1353 let leaf_id = self.route_to_leaf(features);
1355
1356 self.arena.increment_sample_count(leaf_id);
1358 let sample_count = self.arena.get_sample_count(leaf_id);
1359
1360 let idx = leaf_id.0 as usize;
1362 if self.leaf_states.len() <= idx {
1363 self.leaf_states.resize_with(idx + 1, || None);
1364 }
1365 if self.leaf_states[idx].is_none() {
1366 self.leaf_states[idx] = Some(LeafState::new_with_types(
1367 n_features,
1368 self.config.feature_types.as_deref(),
1369 ));
1370 }
1371 let state = self.leaf_states[idx].as_mut().unwrap();
1372
1373 let gradient = if let Some(sigma) = self.config.gradient_clip_sigma {
1375 clip_gradient(state, gradient, sigma)
1376 } else {
1377 gradient
1378 };
1379
1380 if !state.bins_ready {
1382 for (binner, &val) in state.binners.iter_mut().zip(features.iter()) {
1384 binner.observe(val);
1385 }
1386
1387 if let Some(alpha) = self.config.leaf_decay_alpha {
1389 state.grad_sum = alpha * state.grad_sum + gradient;
1390 state.hess_sum = alpha * state.hess_sum + hessian;
1391 } else {
1392 state.grad_sum += gradient;
1393 state.hess_sum += hessian;
1394 }
1395
1396 let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1398 self.arena.set_leaf_value(leaf_id, lw);
1399
1400 if self.config.adaptive_leaf_bound.is_some() {
1402 update_output_stats(state, lw, self.config.leaf_decay_alpha);
1403 }
1404
1405 if let Some(ref mut model) = state.leaf_model {
1407 model.update(features, gradient, hessian, self.config.lambda);
1408 }
1409
1410 if sample_count >= self.config.grace_period as u64 {
1412 let edges_per_feature: Vec<BinEdges> = state
1413 .binners
1414 .iter()
1415 .map(|b| b.compute_edges(self.config.n_bins))
1416 .collect();
1417
1418 let mut histograms = LeafHistograms::new(&edges_per_feature);
1419
1420 if let Some(alpha) = self.config.leaf_decay_alpha {
1427 histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1428 } else {
1429 histograms.accumulate(features, gradient, hessian);
1430 }
1431
1432 state.histograms = Some(histograms);
1433 state.bins_ready = true;
1434 }
1435
1436 return;
1437 }
1438
1439 if let Some(ref mut histograms) = state.histograms {
1441 if let Some(alpha) = self.config.leaf_decay_alpha {
1442 histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1443 } else {
1444 histograms.accumulate(features, gradient, hessian);
1445 }
1446 }
1447
1448 if let Some(alpha) = self.config.leaf_decay_alpha {
1450 state.grad_sum = alpha * state.grad_sum + gradient;
1451 state.hess_sum = alpha * state.hess_sum + hessian;
1452 } else {
1453 state.grad_sum += gradient;
1454 state.hess_sum += hessian;
1455 }
1456 let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1457 self.arena.set_leaf_value(leaf_id, lw);
1458
1459 if self.config.adaptive_leaf_bound.is_some() {
1461 update_output_stats(state, lw, self.config.leaf_decay_alpha);
1462 }
1463
1464 if let Some(ref mut model) = state.leaf_model {
1466 model.update(features, gradient, hessian, self.config.lambda);
1467 }
1468
1469 if sample_count % (self.config.grace_period as u64) == 0 {
1472 self.attempt_split(leaf_id);
1473 }
1474 }
1475
1476 fn predict(&self, features: &[f64]) -> f64 {
1481 let leaf_id = self.route_to_leaf(features);
1482 self.leaf_prediction(leaf_id, features)
1483 }
1484
1485 #[inline]
1487 fn n_leaves(&self) -> usize {
1488 self.arena.n_leaves()
1489 }
1490
1491 #[inline]
1493 fn n_samples_seen(&self) -> u64 {
1494 self.samples_seen
1495 }
1496
1497 fn reset(&mut self) {
1499 self.arena.reset();
1500 let root = self.arena.add_leaf(0);
1501 self.root = root;
1502 self.leaf_states.clear();
1503
1504 let n_features = self.n_features.unwrap_or(0);
1506 self.leaf_states.resize_with(root.0 as usize + 1, || None);
1507 let mut root_state =
1508 LeafState::new_with_types(n_features, self.config.feature_types.as_deref());
1509 root_state.leaf_model = self.make_leaf_model(root);
1510 self.leaf_states[root.0 as usize] = Some(root_state);
1511
1512 self.samples_seen = 0;
1513 self.feature_mask.clear();
1514 self.feature_mask_bits.clear();
1515 self.rng_state = self.config.seed;
1516 self.split_gains.iter_mut().for_each(|g| *g = 0.0);
1517 }
1518
1519 fn split_gains(&self) -> &[f64] {
1520 &self.split_gains
1521 }
1522
1523 fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
1524 let leaf_id = self.route_to_leaf(features);
1525 let value = self.leaf_prediction(leaf_id, features);
1526 if let Some(state) = self
1527 .leaf_states
1528 .get(leaf_id.0 as usize)
1529 .and_then(|o| o.as_ref())
1530 {
1531 let variance = 1.0 / (state.hess_sum + self.config.lambda);
1533 (value, variance)
1534 } else {
1535 (value, f64::INFINITY)
1536 }
1537 }
1538}
1539
1540impl Clone for HoeffdingTree {
1541 fn clone(&self) -> Self {
1542 Self {
1543 arena: self.arena.clone(),
1544 root: self.root,
1545 config: self.config.clone(),
1546 leaf_states: self.leaf_states.clone(),
1547 n_features: self.n_features,
1548 samples_seen: self.samples_seen,
1549 split_criterion: self.split_criterion,
1550 feature_mask: self.feature_mask.clone(),
1551 feature_mask_bits: self.feature_mask_bits.clone(),
1552 rng_state: self.rng_state,
1553 split_gains: self.split_gains.clone(),
1554 }
1555 }
1556}
1557
1558unsafe impl Send for HoeffdingTree {}
1562unsafe impl Sync for HoeffdingTree {}
1563
1564#[cfg(test)]
1565mod tests {
1566 use super::*;
1567 use crate::tree::builder::TreeConfig;
1568 use crate::tree::StreamingTree;
1569
1570 fn test_xorshift(state: &mut u64) -> u64 {
1572 xorshift64(state)
1573 }
1574
1575 fn test_rand_f64(state: &mut u64) -> f64 {
1577 let r = test_xorshift(state);
1578 (r as f64) / (u64::MAX as f64)
1579 }
1580
1581 #[test]
1585 fn single_sample_predict_not_nan() {
1586 let config = TreeConfig::new().grace_period(10);
1587 let mut tree = HoeffdingTree::new(config);
1588
1589 let features = vec![1.0, 2.0, 3.0];
1590 tree.train_one(&features, -0.5, 1.0);
1591
1592 let pred = tree.predict(&features);
1593 assert!(!pred.is_nan(), "prediction should not be NaN, got {}", pred);
1594 assert!(
1595 pred.is_finite(),
1596 "prediction should be finite, got {}",
1597 pred
1598 );
1599
1600 assert!((pred - 0.25).abs() < 1e-10, "expected ~0.25, got {}", pred);
1603 }
1604
1605 #[test]
1609 fn linear_signal_rmse_improves() {
1610 let config = TreeConfig::new()
1611 .max_depth(4)
1612 .n_bins(32)
1613 .grace_period(50)
1614 .lambda(0.1)
1615 .gamma(0.0)
1616 .delta(1e-3);
1617
1618 let mut tree = HoeffdingTree::new(config);
1619 let mut rng_state: u64 = 12345;
1620
1621 let n_train = 1000;
1631 let mut features_all: Vec<f64> = Vec::with_capacity(n_train);
1632 let mut targets: Vec<f64> = Vec::with_capacity(n_train);
1633
1634 for _ in 0..n_train {
1635 let x = test_rand_f64(&mut rng_state) * 10.0;
1636 let noise = (test_rand_f64(&mut rng_state) - 0.5) * 0.5;
1637 let y = 2.0 * x + noise;
1638 features_all.push(x);
1639 targets.push(y);
1640 }
1641
1642 let initial_mse: f64 = targets.iter().map(|y| y * y).sum::<f64>() / n_train as f64;
1644 let initial_rmse = initial_mse.sqrt();
1645
1646 for i in 0..n_train {
1648 let feat = [features_all[i]];
1649 let pred = tree.predict(&feat);
1650 let gradient = pred - targets[i];
1652 let hessian = 1.0;
1653 tree.train_one(&feat, gradient, hessian);
1654 }
1655
1656 let mut post_mse = 0.0;
1658 for i in 0..n_train {
1659 let feat = [features_all[i]];
1660 let pred = tree.predict(&feat);
1661 let err = pred - targets[i];
1662 post_mse += err * err;
1663 }
1664 post_mse /= n_train as f64;
1665 let post_rmse = post_mse.sqrt();
1666
1667 assert!(
1668 post_rmse < initial_rmse,
1669 "RMSE should decrease after training: initial={:.4}, post={:.4}",
1670 initial_rmse,
1671 post_rmse,
1672 );
1673 }
1674
1675 #[test]
1679 fn no_splits_before_grace_period() {
1680 let grace = 100;
1681 let config = TreeConfig::new()
1682 .grace_period(grace)
1683 .max_depth(4)
1684 .n_bins(16)
1685 .delta(1e-1); let mut tree = HoeffdingTree::new(config);
1688 let mut rng_state: u64 = 99999;
1689
1690 for _ in 0..(grace - 1) {
1692 let x = test_rand_f64(&mut rng_state) * 10.0;
1693 let y = 2.0 * x;
1694 let feat = [x];
1695 let pred = tree.predict(&feat);
1696 tree.train_one(&feat, pred - y, 1.0);
1697 }
1698
1699 assert_eq!(
1700 tree.n_leaves(),
1701 1,
1702 "should be exactly 1 leaf before grace_period, got {}",
1703 tree.n_leaves()
1704 );
1705 }
1706
1707 #[test]
1711 fn respects_max_depth() {
1712 let max_depth = 3;
1713 let config = TreeConfig::new()
1714 .max_depth(max_depth)
1715 .grace_period(20)
1716 .n_bins(16)
1717 .lambda(0.01)
1718 .gamma(0.0)
1719 .delta(1e-1); let mut tree = HoeffdingTree::new(config);
1722 let mut rng_state: u64 = 7777;
1723
1724 for _ in 0..5000 {
1726 let x = test_rand_f64(&mut rng_state) * 10.0;
1727 let y = if x < 2.5 {
1728 -5.0
1729 } else if x < 5.0 {
1730 -1.0
1731 } else if x < 7.5 {
1732 1.0
1733 } else {
1734 5.0
1735 };
1736 let feat = [x];
1737 let pred = tree.predict(&feat);
1738 tree.train_one(&feat, pred - y, 1.0);
1739 }
1740
1741 let max_leaves = 1usize << max_depth;
1743 assert!(
1744 tree.n_leaves() <= max_leaves,
1745 "tree has {} leaves, but max_depth={} allows at most {}",
1746 tree.n_leaves(),
1747 max_depth,
1748 max_leaves,
1749 );
1750 }
1751
1752 #[test]
1756 fn reset_returns_to_single_leaf() {
1757 let config = TreeConfig::new()
1758 .grace_period(20)
1759 .max_depth(4)
1760 .n_bins(16)
1761 .delta(1e-1);
1762
1763 let mut tree = HoeffdingTree::new(config);
1764 let mut rng_state: u64 = 54321;
1765
1766 for _ in 0..2000 {
1768 let x = test_rand_f64(&mut rng_state) * 10.0;
1769 let y = 3.0 * x - 5.0;
1770 let feat = [x];
1771 let pred = tree.predict(&feat);
1772 tree.train_one(&feat, pred - y, 1.0);
1773 }
1774
1775 let pre_reset_samples = tree.n_samples_seen();
1776 assert!(pre_reset_samples > 0);
1777
1778 tree.reset();
1779
1780 assert_eq!(
1781 tree.n_leaves(),
1782 1,
1783 "after reset, should have exactly 1 leaf"
1784 );
1785 assert_eq!(
1786 tree.n_samples_seen(),
1787 0,
1788 "after reset, samples_seen should be 0"
1789 );
1790
1791 let pred = tree.predict(&[5.0]);
1793 assert!(
1794 pred.abs() < 1e-10,
1795 "prediction after reset should be ~0.0, got {}",
1796 pred
1797 );
1798 }
1799
1800 #[test]
1804 fn multi_feature_training() {
1805 let config = TreeConfig::new()
1806 .grace_period(30)
1807 .max_depth(4)
1808 .n_bins(16)
1809 .lambda(0.1)
1810 .delta(1e-2);
1811
1812 let mut tree = HoeffdingTree::new(config);
1813 let mut rng_state: u64 = 11111;
1814
1815 for _ in 0..1000 {
1817 let x0 = test_rand_f64(&mut rng_state) * 5.0;
1818 let x1 = test_rand_f64(&mut rng_state) * 5.0;
1819 let y = x0 + 2.0 * x1;
1820 let feat = [x0, x1];
1821 let pred = tree.predict(&feat);
1822 tree.train_one(&feat, pred - y, 1.0);
1823 }
1824
1825 let pred = tree.predict(&[2.5, 2.5]);
1827 assert!(
1828 pred.is_finite(),
1829 "multi-feature prediction should be finite"
1830 );
1831 assert_eq!(tree.n_samples_seen(), 1000);
1832 }
1833
1834 #[test]
1838 fn feature_subsampling_works() {
1839 let config = TreeConfig::new()
1840 .grace_period(30)
1841 .max_depth(3)
1842 .n_bins(16)
1843 .lambda(0.1)
1844 .delta(1e-2)
1845 .feature_subsample_rate(0.5);
1846
1847 let mut tree = HoeffdingTree::new(config);
1848 let mut rng_state: u64 = 33333;
1849
1850 for _ in 0..1000 {
1852 let feats: Vec<f64> = (0..5)
1853 .map(|_| test_rand_f64(&mut rng_state) * 10.0)
1854 .collect();
1855 let y: f64 = feats.iter().sum();
1856 let pred = tree.predict(&feats);
1857 tree.train_one(&feats, pred - y, 1.0);
1858 }
1859
1860 let pred = tree.predict(&[1.0, 2.0, 3.0, 4.0, 5.0]);
1861 assert!(pred.is_finite(), "subsampled prediction should be finite");
1862 }
1863
1864 #[test]
1868 fn xorshift64_deterministic() {
1869 let mut s1: u64 = 42;
1870 let mut s2: u64 = 42;
1871
1872 let seq1: Vec<u64> = (0..100).map(|_| xorshift64(&mut s1)).collect();
1873 let seq2: Vec<u64> = (0..100).map(|_| xorshift64(&mut s2)).collect();
1874
1875 assert_eq!(seq1, seq2, "xorshift64 should be deterministic");
1876
1877 for &v in &seq1 {
1879 assert_ne!(v, 0, "xorshift64 should never produce 0 with non-zero seed");
1880 }
1881 }
1882
1883 #[test]
1887 fn ewma_leaf_decay_recent_data_dominates() {
1888 let alpha = (-(2.0_f64.ln()) / 50.0).exp();
1890 let config = TreeConfig::new()
1891 .grace_period(20)
1892 .max_depth(4)
1893 .n_bins(16)
1894 .lambda(1.0)
1895 .leaf_decay_alpha(alpha);
1896 let mut tree = HoeffdingTree::new(config);
1897
1898 for _ in 0..1000 {
1900 let pred = tree.predict(&[1.0, 2.0]);
1901 let grad = pred - 1.0; tree.train_one(&[1.0, 2.0], grad, 1.0);
1903 }
1904
1905 for _ in 0..100 {
1907 let pred = tree.predict(&[1.0, 2.0]);
1908 let grad = pred - 5.0;
1909 tree.train_one(&[1.0, 2.0], grad, 1.0);
1910 }
1911
1912 let pred = tree.predict(&[1.0, 2.0]);
1913 assert!(
1916 pred > 2.0,
1917 "EWMA should let recent data (target=5.0) pull prediction above 2.0, got {}",
1918 pred,
1919 );
1920 }
1921
1922 #[test]
1926 fn ewma_disabled_matches_traditional() {
1927 let config_no_ewma = TreeConfig::new()
1928 .grace_period(20)
1929 .max_depth(4)
1930 .n_bins(16)
1931 .lambda(1.0);
1932 let mut tree = HoeffdingTree::new(config_no_ewma);
1933
1934 let mut rng_state: u64 = 99999;
1935 for _ in 0..200 {
1936 let x = test_rand_f64(&mut rng_state) * 10.0;
1937 let y = 3.0 * x + 1.0;
1938 let pred = tree.predict(&[x]);
1939 tree.train_one(&[x], pred - y, 1.0);
1940 }
1941
1942 let pred = tree.predict(&[5.0]);
1943 assert!(
1944 pred.is_finite(),
1945 "prediction without EWMA should be finite, got {}",
1946 pred
1947 );
1948 }
1949
1950 #[test]
1954 fn split_reeval_at_max_depth() {
1955 let config = TreeConfig::new()
1956 .grace_period(20)
1957 .max_depth(2) .n_bins(16)
1959 .lambda(1.0)
1960 .split_reeval_interval(50);
1961 let mut tree = HoeffdingTree::new(config);
1962
1963 let mut rng_state: u64 = 54321;
1964 for _ in 0..2000 {
1966 let x1 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
1967 let x2 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
1968 let y = 2.0 * x1 + 3.0 * x2;
1969 let pred = tree.predict(&[x1, x2]);
1970 tree.train_one(&[x1, x2], pred - y, 1.0);
1971 }
1972
1973 let leaves = tree.n_leaves();
1977 assert!(
1978 leaves >= 4,
1979 "split re-eval should allow growth beyond max_depth=2 cap (4 leaves), got {}",
1980 leaves,
1981 );
1982 }
1983
1984 #[test]
1988 fn split_reeval_disabled_matches_traditional() {
1989 let config = TreeConfig::new()
1990 .grace_period(20)
1991 .max_depth(2)
1992 .n_bins(16)
1993 .lambda(1.0);
1994 let mut tree = HoeffdingTree::new(config);
1996
1997 let mut rng_state: u64 = 77777;
1998 for _ in 0..2000 {
1999 let x1 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2000 let x2 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2001 let y = 2.0 * x1 + 3.0 * x2;
2002 let pred = tree.predict(&[x1, x2]);
2003 tree.train_one(&[x1, x2], pred - y, 1.0);
2004 }
2005
2006 let leaves = tree.n_leaves();
2008 assert!(
2009 leaves <= 4,
2010 "without re-eval, max_depth=2 should cap at 4 leaves, got {}",
2011 leaves,
2012 );
2013 }
2014
2015 #[test]
2019 fn gradient_clipping_clamps_outliers() {
2020 let config = TreeConfig::new()
2021 .grace_period(20)
2022 .max_depth(2)
2023 .n_bins(16)
2024 .gradient_clip_sigma(2.0);
2025
2026 let mut tree = HoeffdingTree::new(config);
2027
2028 let mut rng_state = 42u64;
2030 for _ in 0..50 {
2031 let x = test_rand_f64(&mut rng_state) * 2.0;
2032 let grad = x * 0.1; tree.train_one(&[x], grad, 1.0);
2034 }
2035
2036 let pred_before = tree.predict(&[1.0]);
2037
2038 tree.train_one(&[1.0], 1000.0, 1.0);
2040
2041 let pred_after = tree.predict(&[1.0]);
2042
2043 let delta = (pred_after - pred_before).abs();
2047 assert!(
2048 delta < 100.0,
2049 "gradient clipping should limit impact of outlier, but prediction changed by {}",
2050 delta,
2051 );
2052 }
2053
2054 #[test]
2058 fn clip_gradient_welford_tracks_stats() {
2059 let mut state = LeafState::new(1);
2060
2061 for i in 0..20 {
2063 let grad = 1.0 + (i as f64) * 0.1; let clipped = clip_gradient(&mut state, grad, 3.0);
2065 assert!(
2067 (clipped - grad).abs() < 1e-10,
2068 "normal gradients should not be clipped at 3-sigma"
2069 );
2070 }
2071
2072 let clipped = clip_gradient(&mut state, 100.0, 3.0);
2074 assert!(
2075 clipped < 100.0,
2076 "extreme outlier should be clipped, got {}",
2077 clipped,
2078 );
2079 assert!(
2080 clipped > 0.0,
2081 "clipped value should be positive, got {}",
2082 clipped,
2083 );
2084 }
2085
2086 #[test]
2090 fn clip_gradient_warmup_no_clipping() {
2091 let mut state = LeafState::new(1);
2092
2093 for i in 0..9 {
2095 let val = if i == 8 { 1000.0 } else { 1.0 };
2096 let clipped = clip_gradient(&mut state, val, 2.0);
2097 assert_eq!(clipped, val, "warmup should not clip");
2098 }
2099 }
2100
2101 #[test]
2105 fn adaptive_bound_warmup_returns_max() {
2106 let mut state = LeafState::new(1);
2107 for i in 0..9 {
2109 update_output_stats(&mut state, 0.5 + i as f64 * 0.01, None);
2110 }
2111 let bound = adaptive_bound(&state, 3.0, None);
2112 assert_eq!(bound, f64::MAX, "warmup should return f64::MAX");
2113 }
2114
2115 #[test]
2119 fn adaptive_bound_tightens_after_warmup() {
2120 let mut state = LeafState::new(1);
2121 for i in 0..20 {
2123 let w = 0.3 + (i as f64 - 10.0) * 0.01; update_output_stats(&mut state, w, None);
2125 }
2126 let bound = adaptive_bound(&state, 3.0, None);
2127 assert!(
2129 bound < 1.0,
2130 "3-sigma bound on outputs ~0.3 should be < 1.0, got {}",
2131 bound,
2132 );
2133 assert!(bound > 0.2, "bound should be > |mean|, got {}", bound,);
2134 }
2135
2136 #[test]
2140 fn adaptive_bound_clamps_outlier_leaf() {
2141 let mut state = LeafState::new(1);
2142 for _ in 0..20 {
2144 update_output_stats(&mut state, 0.3, None);
2145 }
2146 let bound = adaptive_bound(&state, 3.0, None);
2147 let clamped = (2.9_f64).clamp(-bound, bound);
2149 assert!(
2150 clamped < 2.9,
2151 "2.9 should be clamped by adaptive bound {}, got {}",
2152 bound,
2153 clamped,
2154 );
2155 }
2156
2157 #[test]
2161 fn adaptive_bound_with_decay_adapts() {
2162 let alpha = 0.95; let mut state = LeafState::new(1);
2164
2165 for _ in 0..30 {
2167 update_output_stats(&mut state, 0.3, Some(alpha));
2168 }
2169 let bound_phase1 = adaptive_bound(&state, 3.0, Some(alpha));
2170
2171 for _ in 0..100 {
2173 update_output_stats(&mut state, 2.0, Some(alpha));
2174 }
2175 let bound_phase2 = adaptive_bound(&state, 3.0, Some(alpha));
2176
2177 assert!(
2179 bound_phase2 > bound_phase1,
2180 "EWMA bound should adapt: phase1={}, phase2={}",
2181 bound_phase1,
2182 bound_phase2,
2183 );
2184 }
2185
2186 #[test]
2190 fn adaptive_bound_disabled_by_default() {
2191 let config = TreeConfig::default();
2192 assert!(
2193 config.adaptive_leaf_bound.is_none(),
2194 "adaptive_leaf_bound should default to None",
2195 );
2196 }
2197
2198 #[test]
2202 fn adaptive_bound_warmup_falls_back_to_global() {
2203 let mut state = LeafState::new(1);
2204 for _ in 0..5 {
2206 update_output_stats(&mut state, 0.3, None);
2207 }
2208 let bound = adaptive_bound(&state, 3.0, None);
2209 assert_eq!(bound, f64::MAX, "warmup should yield f64::MAX");
2210 }
2212
2213 #[test]
2217 fn monotonic_constraint_splits_respected() {
2218 let config = TreeConfig::new()
2221 .grace_period(30)
2222 .max_depth(4)
2223 .n_bins(16)
2224 .monotone_constraints(vec![1]); let mut tree = HoeffdingTree::new(config);
2227
2228 let mut rng_state = 42u64;
2229 for _ in 0..500 {
2230 let x = test_rand_f64(&mut rng_state) * 10.0;
2231 let grad = x * 0.5 - 2.5;
2233 tree.train_one(&[x], grad, 1.0);
2234 }
2235
2236 let pred_low = tree.predict(&[0.0]);
2239 let pred_mid = tree.predict(&[5.0]);
2240 let pred_high = tree.predict(&[10.0]);
2241
2242 assert!(
2244 pred_low <= pred_mid + 1e-10 && pred_mid <= pred_high + 1e-10,
2245 "monotonic +1 violated: pred(0)={}, pred(5)={}, pred(10)={}",
2246 pred_low,
2247 pred_mid,
2248 pred_high,
2249 );
2250 }
2251
2252 #[test]
2256 fn predict_with_variance_finite() {
2257 let config = TreeConfig::new().grace_period(10);
2258 let mut tree = HoeffdingTree::new(config);
2259
2260 for i in 0..30 {
2262 let x = i as f64 * 0.1;
2263 tree.train_one(&[x], x - 1.0, 1.0);
2264 }
2265
2266 let (value, variance) = tree.predict_with_variance(&[1.0]);
2267 assert!(value.is_finite(), "value should be finite");
2268 assert!(variance.is_finite(), "variance should be finite");
2269 assert!(variance > 0.0, "variance should be positive");
2270 }
2271
2272 #[test]
2276 fn predict_with_variance_decreases_with_data() {
2277 let config = TreeConfig::new().grace_period(10);
2278 let mut tree = HoeffdingTree::new(config);
2279
2280 for i in 0..20 {
2282 tree.train_one(&[1.0], 0.5, 1.0);
2283 if i == 0 {
2284 continue;
2285 }
2286 }
2287 let (_, var_20) = tree.predict_with_variance(&[1.0]);
2288
2289 for _ in 0..200 {
2291 tree.train_one(&[1.0], 0.5, 1.0);
2292 }
2293 let (_, var_220) = tree.predict_with_variance(&[1.0]);
2294
2295 assert!(
2296 var_220 < var_20,
2297 "variance should decrease with more data: var@20={} vs var@220={}",
2298 var_20,
2299 var_220,
2300 );
2301 }
2302
2303 #[test]
2307 fn predict_smooth_matches_hard_at_small_bandwidth() {
2308 let config = TreeConfig::new()
2309 .max_depth(3)
2310 .n_bins(16)
2311 .grace_period(20)
2312 .lambda(1.0);
2313 let mut tree = HoeffdingTree::new(config);
2314
2315 let mut rng = 42u64;
2317 for _ in 0..500 {
2318 let x = test_rand_f64(&mut rng) * 10.0;
2319 let y = 2.0 * x + 1.0;
2320 let features = vec![x, x * 0.5];
2321 let pred = tree.predict(&features);
2322 let grad = pred - y;
2323 let hess = 1.0;
2324 tree.train_one(&features, grad, hess);
2325 }
2326
2327 let features = vec![5.0, 2.5];
2329 let hard = tree.predict(&features);
2330 let smooth = tree.predict_smooth(&features, 0.001);
2331 assert!(
2332 (hard - smooth).abs() < 0.1,
2333 "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2334 hard,
2335 smooth,
2336 );
2337 }
2338
2339 #[test]
2343 fn predict_smooth_is_continuous() {
2344 let config = TreeConfig::new()
2345 .max_depth(3)
2346 .n_bins(16)
2347 .grace_period(20)
2348 .lambda(1.0);
2349 let mut tree = HoeffdingTree::new(config);
2350
2351 let mut rng = 42u64;
2353 for _ in 0..500 {
2354 let x = test_rand_f64(&mut rng) * 10.0;
2355 let y = 2.0 * x + 1.0;
2356 let features = vec![x, x * 0.5];
2357 let pred = tree.predict(&features);
2358 let grad = pred - y;
2359 tree.train_one(&features, grad, 1.0);
2360 }
2361
2362 let bandwidth = 1.0;
2364 let base = tree.predict_smooth(&[5.0, 2.5], bandwidth);
2365 let nudged = tree.predict_smooth(&[5.001, 2.5], bandwidth);
2366 let diff = (base - nudged).abs();
2367 assert!(
2368 diff < 0.1,
2369 "smooth prediction should be continuous: base={}, nudged={}, diff={}",
2370 base,
2371 nudged,
2372 diff,
2373 );
2374 }
2375
2376 #[test]
2380 fn leaf_grad_hess_returns_sums() {
2381 let config = TreeConfig::new().grace_period(100).lambda(1.0);
2382 let mut tree = HoeffdingTree::new(config);
2383
2384 let features = vec![1.0, 2.0, 3.0];
2385
2386 for _ in 0..10 {
2388 tree.train_one(&features, -0.5, 1.0);
2389 }
2390
2391 let root = tree.root();
2393 let (grad, hess) = tree
2394 .leaf_grad_hess(root)
2395 .expect("root should have leaf state");
2396
2397 assert!(
2399 (grad - (-5.0)).abs() < 1e-10,
2400 "grad_sum should be -5.0, got {}",
2401 grad
2402 );
2403 assert!(
2405 (hess - 10.0).abs() < 1e-10,
2406 "hess_sum should be 10.0, got {}",
2407 hess
2408 );
2409 }
2410
2411 #[test]
2412 fn leaf_grad_hess_returns_none_for_invalid_node() {
2413 let config = TreeConfig::new();
2414 let tree = HoeffdingTree::new(config);
2415
2416 assert!(tree.leaf_grad_hess(NodeId::NONE).is_none());
2418 assert!(tree.leaf_grad_hess(NodeId(999)).is_none());
2420 }
2421}