1use alloc::boxed::Box;
22use alloc::vec;
23use alloc::vec::Vec;
24
25use crate::feature::FeatureType;
26use crate::histogram::bins::LeafHistograms;
27use crate::histogram::{BinEdges, BinnerKind};
28use crate::math;
29use crate::tree::builder::TreeConfig;
30use crate::tree::leaf_model::{LeafModel, LeafModelType};
31use crate::tree::node::{NodeId, TreeArena};
32use crate::tree::split::{leaf_weight, SplitCandidate, SplitCriterion, XGBoostGain};
33use crate::tree::StreamingTree;
34
35const TAU: f64 = 0.05;
39
40#[inline]
46fn xorshift64(state: &mut u64) -> u64 {
47 let mut s = *state;
48 s ^= s << 13;
49 s ^= s >> 7;
50 s ^= s << 17;
51 *state = s;
52 s
53}
54
55struct LeafState {
64 histograms: Option<LeafHistograms>,
67
68 binners: Vec<BinnerKind>,
72
73 bins_ready: bool,
75
76 grad_sum: f64,
78
79 hess_sum: f64,
81
82 last_reeval_count: u64,
84
85 clip_grad_mean: f64,
87
88 clip_grad_m2: f64,
90
91 clip_grad_count: u64,
93
94 output_mean: f64,
96
97 output_m2: f64,
99
100 output_count: u64,
102
103 leaf_model: Option<Box<dyn LeafModel>>,
105}
106
107impl Clone for LeafState {
108 fn clone(&self) -> Self {
109 Self {
110 histograms: self.histograms.clone(),
111 binners: self.binners.clone(),
112 bins_ready: self.bins_ready,
113 grad_sum: self.grad_sum,
114 hess_sum: self.hess_sum,
115 last_reeval_count: self.last_reeval_count,
116 clip_grad_mean: self.clip_grad_mean,
117 clip_grad_m2: self.clip_grad_m2,
118 clip_grad_count: self.clip_grad_count,
119 output_mean: self.output_mean,
120 output_m2: self.output_m2,
121 output_count: self.output_count,
122 leaf_model: self.leaf_model.as_ref().map(|m| m.clone_warm()),
123 }
124 }
125}
126
127#[inline]
133fn clip_gradient(state: &mut LeafState, gradient: f64, sigma: f64) -> f64 {
134 state.clip_grad_count += 1;
135 let n = state.clip_grad_count as f64;
136
137 let delta = gradient - state.clip_grad_mean;
139 state.clip_grad_mean += delta / n;
140 let delta2 = gradient - state.clip_grad_mean;
141 state.clip_grad_m2 += delta * delta2;
142
143 if state.clip_grad_count < 10 {
145 return gradient;
146 }
147
148 let variance = state.clip_grad_m2 / (n - 1.0);
149 let std_dev = math::sqrt(variance);
150
151 if std_dev < 1e-15 {
152 return gradient; }
154
155 let lo = state.clip_grad_mean - sigma * std_dev;
156 let hi = state.clip_grad_mean + sigma * std_dev;
157 gradient.clamp(lo, hi)
158}
159
160#[inline]
165fn update_output_stats(state: &mut LeafState, weight: f64, decay_alpha: Option<f64>) {
166 state.output_count += 1;
167
168 if let Some(alpha) = decay_alpha {
169 if state.output_count == 1 {
171 state.output_mean = weight;
172 state.output_m2 = 0.0;
173 } else {
174 let diff = weight - state.output_mean;
175 state.output_mean = alpha * state.output_mean + (1.0 - alpha) * weight;
176 let diff2 = weight - state.output_mean;
177 state.output_m2 = alpha * state.output_m2 + (1.0 - alpha) * diff * diff2;
178 }
179 } else {
180 let delta = weight - state.output_mean;
182 state.output_mean += delta / (state.output_count as f64);
183 let delta2 = weight - state.output_mean;
184 state.output_m2 += delta * delta2;
185 }
186}
187
188#[inline]
193fn adaptive_bound(state: &LeafState, k: f64, decay_alpha: Option<f64>) -> f64 {
194 if state.output_count < 10 {
195 return f64::MAX; }
197
198 let variance = if decay_alpha.is_some() {
199 state.output_m2.max(0.0)
201 } else {
202 state.output_m2 / (state.output_count as f64 - 1.0)
204 };
205 let std = math::sqrt(variance);
206
207 (math::abs(state.output_mean) + k * std).max(0.01)
209}
210
211fn make_binners(n_features: usize, feature_types: Option<&[FeatureType]>) -> Vec<BinnerKind> {
213 (0..n_features)
214 .map(|i| {
215 if let Some(ft) = feature_types {
216 if i < ft.len() && ft[i] == FeatureType::Categorical {
217 return BinnerKind::categorical();
218 }
219 }
220 BinnerKind::uniform()
221 })
222 .collect()
223}
224
225impl LeafState {
226 fn new(n_features: usize) -> Self {
228 Self::new_with_types(n_features, None)
229 }
230
231 fn new_with_types(n_features: usize, feature_types: Option<&[FeatureType]>) -> Self {
233 let binners = make_binners(n_features, feature_types);
234
235 Self {
236 histograms: None,
237 binners,
238 bins_ready: false,
239 grad_sum: 0.0,
240 hess_sum: 0.0,
241 last_reeval_count: 0,
242 clip_grad_mean: 0.0,
243 clip_grad_m2: 0.0,
244 clip_grad_count: 0,
245 output_mean: 0.0,
246 output_m2: 0.0,
247 output_count: 0,
248 leaf_model: None,
249 }
250 }
251
252 #[allow(dead_code)]
255 fn with_histograms(histograms: LeafHistograms) -> Self {
256 let n_features = histograms.n_features();
257 let binners: Vec<BinnerKind> = (0..n_features).map(|_| BinnerKind::uniform()).collect();
258
259 let grad_sum: f64 = histograms
261 .histograms
262 .first()
263 .map_or(0.0, |h| h.total_gradient());
264 let hess_sum: f64 = histograms
265 .histograms
266 .first()
267 .map_or(0.0, |h| h.total_hessian());
268
269 Self {
270 histograms: Some(histograms),
271 binners,
272 bins_ready: true,
273 grad_sum,
274 hess_sum,
275 last_reeval_count: 0,
276 clip_grad_mean: 0.0,
277 clip_grad_m2: 0.0,
278 clip_grad_count: 0,
279 output_mean: 0.0,
280 output_m2: 0.0,
281 output_count: 0,
282 leaf_model: None,
283 }
284 }
285}
286
287pub struct HoeffdingTree {
303 arena: TreeArena,
305
306 root: NodeId,
308
309 config: TreeConfig,
311
312 leaf_states: Vec<Option<LeafState>>,
315
316 n_features: Option<usize>,
318
319 samples_seen: u64,
321
322 split_criterion: XGBoostGain,
324
325 feature_mask: Vec<usize>,
327
328 feature_mask_bits: Vec<u64>,
331
332 rng_state: u64,
334
335 split_gains: Vec<f64>,
338}
339
340impl HoeffdingTree {
341 pub fn new(config: TreeConfig) -> Self {
346 let mut arena = TreeArena::new();
347 let root = arena.add_leaf(0);
348
349 let mut leaf_states = vec![None; root.0 as usize + 1];
353 let root_model = match config.leaf_model_type {
354 LeafModelType::ClosedForm => None,
355 _ => Some(config.leaf_model_type.create(config.seed, config.delta)),
356 };
357 leaf_states[root.0 as usize] = Some(LeafState {
358 histograms: None,
359 binners: Vec::new(),
360 bins_ready: false,
361 grad_sum: 0.0,
362 hess_sum: 0.0,
363 last_reeval_count: 0,
364 clip_grad_mean: 0.0,
365 clip_grad_m2: 0.0,
366 clip_grad_count: 0,
367 output_mean: 0.0,
368 output_m2: 0.0,
369 output_count: 0,
370 leaf_model: root_model,
371 });
372
373 let seed = config.seed;
374 Self {
375 arena,
376 root,
377 config,
378 leaf_states,
379 n_features: None,
380 samples_seen: 0,
381 split_criterion: XGBoostGain::default(),
382 feature_mask: Vec::new(),
383 feature_mask_bits: Vec::new(),
384 rng_state: seed,
385 split_gains: Vec::new(),
386 }
387 }
388
389 fn make_leaf_model(&self, node: NodeId) -> Option<Box<dyn LeafModel>> {
395 match self.config.leaf_model_type {
396 LeafModelType::ClosedForm => None,
397 _ => Some(
398 self.config
399 .leaf_model_type
400 .create(self.config.seed ^ (node.0 as u64), self.config.delta),
401 ),
402 }
403 }
404
405 pub fn from_arena(
414 config: TreeConfig,
415 arena: TreeArena,
416 n_features: Option<usize>,
417 samples_seen: u64,
418 rng_state: u64,
419 ) -> Self {
420 let root = if arena.n_nodes() > 0 {
421 NodeId(0)
422 } else {
423 let mut arena_mut = arena;
425 let root = arena_mut.add_leaf(0);
426 return Self {
427 arena: arena_mut,
428 root,
429 config: config.clone(),
430 leaf_states: {
431 let mut v = vec![None; root.0 as usize + 1];
432 v[root.0 as usize] = Some(LeafState::new(n_features.unwrap_or(0)));
433 v
434 },
435 n_features,
436 samples_seen,
437 split_criterion: XGBoostGain::default(),
438 feature_mask: Vec::new(),
439 feature_mask_bits: Vec::new(),
440 rng_state,
441 split_gains: vec![0.0; n_features.unwrap_or(0)],
442 };
443 };
444
445 let nf = n_features.unwrap_or(0);
447 let mut leaf_states: Vec<Option<LeafState>> = vec![None; arena.n_nodes()];
448 for (i, slot) in leaf_states.iter_mut().enumerate() {
449 if arena.is_leaf[i] {
450 *slot = Some(LeafState::new(nf));
451 }
452 }
453
454 Self {
455 arena,
456 root,
457 config,
458 leaf_states,
459 n_features,
460 samples_seen,
461 split_criterion: XGBoostGain::default(),
462 feature_mask: Vec::new(),
463 feature_mask_bits: Vec::new(),
464 rng_state,
465 split_gains: vec![0.0; nf],
466 }
467 }
468
469 #[inline]
471 pub fn root(&self) -> NodeId {
472 self.root
473 }
474
475 #[inline]
477 pub fn arena(&self) -> &TreeArena {
478 &self.arena
479 }
480
481 #[inline]
483 pub fn tree_config(&self) -> &TreeConfig {
484 &self.config
485 }
486
487 #[inline]
489 pub fn n_features(&self) -> Option<usize> {
490 self.n_features
491 }
492
493 #[inline]
495 pub fn rng_state(&self) -> u64 {
496 self.rng_state
497 }
498
499 #[inline]
509 pub fn leaf_grad_hess(&self, node: NodeId) -> Option<(f64, f64)> {
510 self.leaf_states
511 .get(node.0 as usize)
512 .and_then(|o| o.as_ref())
513 .map(|state| (state.grad_sum, state.hess_sum))
514 }
515
516 fn route_to_leaf(&self, features: &[f64]) -> NodeId {
518 let mut current = self.root;
519 while !self.arena.is_leaf(current) {
520 let feat_idx = self.arena.get_feature_idx(current) as usize;
521 current = if let Some(mask) = self.arena.get_categorical_mask(current) {
522 let cat_val = features[feat_idx] as u64;
529 if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
530 self.arena.get_left(current)
531 } else {
532 self.arena.get_right(current)
533 }
534 } else {
535 let threshold = self.arena.get_threshold(current);
537 if features[feat_idx] <= threshold {
538 self.arena.get_left(current)
539 } else {
540 self.arena.get_right(current)
541 }
542 };
543 }
544 current
545 }
546
547 #[inline]
552 fn leaf_prediction(&self, leaf_id: NodeId, features: &[f64]) -> f64 {
553 let (raw, leaf_bound) = if let Some(state) = self
554 .leaf_states
555 .get(leaf_id.0 as usize)
556 .and_then(|o| o.as_ref())
557 {
558 if let Some(min_h) = self.config.min_hessian_sum {
560 if state.hess_sum < min_h {
561 return 0.0;
562 }
563 }
564 let val = if let Some(ref model) = state.leaf_model {
565 model.predict(features)
566 } else if state.hess_sum != 0.0 {
567 leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda)
568 } else {
569 self.arena.leaf_value[leaf_id.0 as usize]
570 };
571
572 let bound = self
574 .config
575 .adaptive_leaf_bound
576 .map(|k| adaptive_bound(state, k, self.config.leaf_decay_alpha));
577
578 (val, bound)
579 } else {
580 (0.0, None)
581 };
582
583 if let Some(bound) = leaf_bound {
585 if bound < f64::MAX {
586 return raw.clamp(-bound, bound);
587 }
588 }
589 if let Some(max) = self.config.max_leaf_output {
590 raw.clamp(-max, max)
591 } else {
592 raw
593 }
594 }
595
596 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
612 self.predict_smooth_recursive(self.root, features, bandwidth)
613 }
614
615 pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
620 self.predict_smooth_auto_recursive(self.root, features, bandwidths)
621 }
622
623 pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
635 let mut current = self.root;
636 let mut parent = None;
637 while !self.arena.is_leaf(current) {
638 parent = Some(current);
639 let feat_idx = self.arena.get_feature_idx(current) as usize;
640 current = if let Some(mask) = self.arena.get_categorical_mask(current) {
641 let cat_val = features[feat_idx] as u64;
642 if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
643 self.arena.get_left(current)
644 } else {
645 self.arena.get_right(current)
646 }
647 } else {
648 let threshold = self.arena.get_threshold(current);
649 if features[feat_idx] <= threshold {
650 self.arena.get_left(current)
651 } else {
652 self.arena.get_right(current)
653 }
654 };
655 }
656
657 let leaf_pred = self.leaf_prediction(current, features);
658
659 let parent_id = match parent {
661 Some(p) => p,
662 None => return leaf_pred,
663 };
664
665 let parent_pred = self.leaf_prediction(parent_id, features);
667
668 let leaf_hess = self
670 .leaf_states
671 .get(current.0 as usize)
672 .and_then(|o| o.as_ref())
673 .map(|s| s.hess_sum)
674 .unwrap_or(0.0);
675
676 let alpha = leaf_hess / (leaf_hess + self.config.lambda);
677 alpha * leaf_pred + (1.0 - alpha) * parent_pred
678 }
679
680 pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
696 self.predict_sibling_recursive(self.root, features, bandwidths)
697 }
698
699 fn predict_sibling_recursive(&self, node: NodeId, features: &[f64], bandwidths: &[f64]) -> f64 {
700 if self.arena.is_leaf(node) {
701 return self.leaf_prediction(node, features);
702 }
703
704 let feat_idx = self.arena.get_feature_idx(node) as usize;
705 let left = self.arena.get_left(node);
706 let right = self.arena.get_right(node);
707
708 if let Some(mask) = self.arena.get_categorical_mask(node) {
710 let cat_val = features[feat_idx] as u64;
711 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
712 self.predict_sibling_recursive(left, features, bandwidths)
713 } else {
714 self.predict_sibling_recursive(right, features, bandwidths)
715 };
716 }
717
718 let threshold = self.arena.get_threshold(node);
719 let margin = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
720
721 if !margin.is_finite() || margin <= 0.0 {
723 return if features[feat_idx] <= threshold {
724 self.predict_sibling_recursive(left, features, bandwidths)
725 } else {
726 self.predict_sibling_recursive(right, features, bandwidths)
727 };
728 }
729
730 let dist = features[feat_idx] - threshold;
731
732 if dist < -margin {
733 self.predict_sibling_recursive(left, features, bandwidths)
735 } else if dist > margin {
736 self.predict_sibling_recursive(right, features, bandwidths)
738 } else {
739 let t = (dist + margin) / (2.0 * margin); let left_pred = self.predict_sibling_recursive(left, features, bandwidths);
742 let right_pred = self.predict_sibling_recursive(right, features, bandwidths);
743 (1.0 - t) * left_pred + t * right_pred
744 }
745 }
746
747 pub fn collect_split_thresholds_per_feature(&self) -> Vec<Vec<f64>> {
752 let n = self.n_features.unwrap_or(0);
753 let mut thresholds: Vec<Vec<f64>> = vec![Vec::new(); n];
754
755 for i in 0..self.arena.n_nodes() {
756 if !self.arena.is_leaf[i] && self.arena.categorical_mask[i].is_none() {
757 let feat_idx = self.arena.feature_idx[i] as usize;
758 if feat_idx < n {
759 thresholds[feat_idx].push(self.arena.threshold[i]);
760 }
761 }
762 }
763
764 thresholds
765 }
766
767 fn predict_smooth_recursive(&self, node: NodeId, features: &[f64], bandwidth: f64) -> f64 {
769 if self.arena.is_leaf(node) {
770 return self.leaf_prediction(node, features);
772 }
773
774 let feat_idx = self.arena.get_feature_idx(node) as usize;
775 let left = self.arena.get_left(node);
776 let right = self.arena.get_right(node);
777
778 if let Some(mask) = self.arena.get_categorical_mask(node) {
780 let cat_val = features[feat_idx] as u64;
781 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
782 self.predict_smooth_recursive(left, features, bandwidth)
783 } else {
784 self.predict_smooth_recursive(right, features, bandwidth)
785 };
786 }
787
788 let threshold = self.arena.get_threshold(node);
790 let z = (threshold - features[feat_idx]) / bandwidth;
791 let alpha = 1.0 / (1.0 + math::exp(-z));
792
793 let left_pred = self.predict_smooth_recursive(left, features, bandwidth);
794 let right_pred = self.predict_smooth_recursive(right, features, bandwidth);
795
796 alpha * left_pred + (1.0 - alpha) * right_pred
797 }
798
799 fn predict_smooth_auto_recursive(
801 &self,
802 node: NodeId,
803 features: &[f64],
804 bandwidths: &[f64],
805 ) -> f64 {
806 if self.arena.is_leaf(node) {
807 return self.leaf_prediction(node, features);
808 }
809
810 let feat_idx = self.arena.get_feature_idx(node) as usize;
811 let left = self.arena.get_left(node);
812 let right = self.arena.get_right(node);
813
814 if let Some(mask) = self.arena.get_categorical_mask(node) {
816 let cat_val = features[feat_idx] as u64;
817 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
818 self.predict_smooth_auto_recursive(left, features, bandwidths)
819 } else {
820 self.predict_smooth_auto_recursive(right, features, bandwidths)
821 };
822 }
823
824 let threshold = self.arena.get_threshold(node);
825 let bw = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
826
827 if !bw.is_finite() {
829 return if features[feat_idx] <= threshold {
830 self.predict_smooth_auto_recursive(left, features, bandwidths)
831 } else {
832 self.predict_smooth_auto_recursive(right, features, bandwidths)
833 };
834 }
835
836 let z = (threshold - features[feat_idx]) / bw;
838 let alpha = 1.0 / (1.0 + math::exp(-z));
839
840 let left_pred = self.predict_smooth_auto_recursive(left, features, bandwidths);
841 let right_pred = self.predict_smooth_auto_recursive(right, features, bandwidths);
842
843 alpha * left_pred + (1.0 - alpha) * right_pred
844 }
845
846 fn generate_feature_mask(&mut self, n_features: usize) {
854 self.feature_mask.clear();
855
856 if self.config.feature_subsample_rate >= 1.0 {
857 self.feature_mask.extend(0..n_features);
858 } else {
859 let target_count =
860 crate::math::ceil((n_features as f64) * self.config.feature_subsample_rate)
861 as usize;
862 let target_count = target_count.max(1).min(n_features);
863
864 let n_words = n_features.div_ceil(64);
866 self.feature_mask_bits.clear();
867 self.feature_mask_bits.resize(n_words, 0u64);
868
869 for i in 0..n_features {
871 let r = xorshift64(&mut self.rng_state);
872 let p = (r as f64) / (u64::MAX as f64);
873 if p < self.config.feature_subsample_rate {
874 self.feature_mask.push(i);
875 self.feature_mask_bits[i / 64] |= 1u64 << (i % 64);
876 }
877 }
878
879 if self.feature_mask.len() < target_count {
882 for i in 0..n_features {
883 if self.feature_mask.len() >= target_count {
884 break;
885 }
886 if self.feature_mask_bits[i / 64] & (1u64 << (i % 64)) == 0 {
887 self.feature_mask.push(i);
888 self.feature_mask_bits[i / 64] |= 1u64 << (i % 64);
889 }
890 }
891 }
892 }
893 }
894
895 fn attempt_split(&mut self, leaf_id: NodeId) -> bool {
899 let depth = self.arena.get_depth(leaf_id);
900
901 let hard_ceiling = if self.config.adaptive_depth.is_some() {
904 self.config.max_depth.saturating_mul(2)
905 } else {
906 self.config.max_depth
907 };
908 let at_max_depth = depth as usize >= hard_ceiling;
909
910 if at_max_depth {
911 match self.config.split_reeval_interval {
914 None => return false,
915 Some(interval) => {
916 let state = match self
917 .leaf_states
918 .get(leaf_id.0 as usize)
919 .and_then(|o| o.as_ref())
920 {
921 Some(s) => s,
922 None => return false,
923 };
924 let sample_count = self.arena.get_sample_count(leaf_id);
925 if sample_count - state.last_reeval_count < interval as u64 {
926 return false;
927 }
928 }
930 }
931 }
932
933 let n_features = match self.n_features {
934 Some(n) => n,
935 None => return false,
936 };
937
938 let sample_count = self.arena.get_sample_count(leaf_id);
939 if sample_count < self.config.grace_period as u64 {
940 return false;
941 }
942
943 self.generate_feature_mask(n_features);
945
946 if self.config.leaf_decay_alpha.is_some() {
951 if let Some(state) = self
952 .leaf_states
953 .get_mut(leaf_id.0 as usize)
954 .and_then(|o| o.as_mut())
955 {
956 if let Some(ref mut histograms) = state.histograms {
957 histograms.materialize_decay();
958 }
959 }
960 }
961
962 let state = match self
966 .leaf_states
967 .get(leaf_id.0 as usize)
968 .and_then(|o| o.as_ref())
969 {
970 Some(s) => s,
971 None => return false,
972 };
973
974 let histograms = match &state.histograms {
975 Some(h) => h,
976 None => return false,
977 };
978
979 let feature_types = &self.config.feature_types;
983 let mut candidates: Vec<(usize, SplitCandidate, Option<Vec<usize>>)> = Vec::new();
984
985 for &feat_idx in &self.feature_mask {
986 if feat_idx >= histograms.n_features() {
987 continue;
988 }
989 let hist = &histograms.histograms[feat_idx];
990 let total_grad = hist.total_gradient();
991 let total_hess = hist.total_hessian();
992
993 let is_categorical = feature_types
994 .as_ref()
995 .is_some_and(|ft| feat_idx < ft.len() && ft[feat_idx] == FeatureType::Categorical);
996
997 if is_categorical {
998 let n_bins = hist.grad_sums.len();
1003 if n_bins < 2 {
1004 continue;
1005 }
1006
1007 let mut bin_order: Vec<usize> = (0..n_bins)
1009 .filter(|&i| math::abs(hist.hess_sums[i]) > 1e-15)
1010 .collect();
1011
1012 if bin_order.len() < 2 {
1013 continue;
1014 }
1015
1016 bin_order.sort_by(|&a, &b| {
1018 let ratio_a = hist.grad_sums[a] / hist.hess_sums[a];
1019 let ratio_b = hist.grad_sums[b] / hist.hess_sums[b];
1020 ratio_a
1021 .partial_cmp(&ratio_b)
1022 .unwrap_or(core::cmp::Ordering::Equal)
1023 });
1024
1025 let sorted_grads: Vec<f64> = bin_order.iter().map(|&i| hist.grad_sums[i]).collect();
1027 let sorted_hess: Vec<f64> = bin_order.iter().map(|&i| hist.hess_sums[i]).collect();
1028
1029 if let Some(candidate) = self.split_criterion.evaluate(
1030 &sorted_grads,
1031 &sorted_hess,
1032 total_grad,
1033 total_hess,
1034 self.config.gamma,
1035 self.config.lambda,
1036 ) {
1037 candidates.push((feat_idx, candidate, Some(bin_order)));
1038 }
1039 } else {
1040 if let Some(candidate) = self.split_criterion.evaluate(
1042 &hist.grad_sums,
1043 &hist.hess_sums,
1044 total_grad,
1045 total_hess,
1046 self.config.gamma,
1047 self.config.lambda,
1048 ) {
1049 candidates.push((feat_idx, candidate, None));
1050 }
1051 }
1052 }
1053
1054 if let Some(ref mc) = self.config.monotone_constraints {
1056 candidates.retain(|(feat_idx, candidate, _)| {
1057 if *feat_idx >= mc.len() {
1058 return true; }
1060 let constraint = mc[*feat_idx];
1061 if constraint == 0 {
1062 return true; }
1064
1065 let left_val =
1066 leaf_weight(candidate.left_grad, candidate.left_hess, self.config.lambda);
1067 let right_val = leaf_weight(
1068 candidate.right_grad,
1069 candidate.right_hess,
1070 self.config.lambda,
1071 );
1072
1073 if constraint > 0 {
1074 left_val <= right_val
1076 } else {
1077 left_val >= right_val
1079 }
1080 });
1081 }
1082
1083 if candidates.is_empty() {
1084 return false;
1085 }
1086
1087 candidates.sort_by(|a, b| {
1089 b.1.gain
1090 .partial_cmp(&a.1.gain)
1091 .unwrap_or(core::cmp::Ordering::Equal)
1092 });
1093
1094 let best_gain = candidates[0].1.gain;
1095 let second_best_gain = if candidates.len() > 1 {
1096 candidates[1].1.gain
1097 } else {
1098 0.0
1099 };
1100
1101 if let Some(cir_factor) = self.config.adaptive_depth {
1104 let n = sample_count as f64;
1105 if n > 1.0 {
1106 let effective_n = match self.config.leaf_decay_alpha {
1108 Some(alpha) => n.min(1.0 / (1.0 - alpha)),
1109 None => n,
1110 };
1111
1112 let grad_var = self
1114 .leaf_states
1115 .get(leaf_id.0 as usize)
1116 .and_then(|o| o.as_ref())
1117 .map(|leaf_state| {
1118 if leaf_state.clip_grad_count > 1 {
1119 leaf_state.clip_grad_m2 / (leaf_state.clip_grad_count as f64 - 1.0)
1120 } else {
1121 let mean_grad = leaf_state.grad_sum / leaf_state.hess_sum.max(1.0);
1123 mean_grad * mean_grad + 1.0
1124 }
1125 })
1126 .unwrap_or(1.0);
1127
1128 let n_feat = self.n_features.unwrap_or(1) as f64;
1129 let penalty = cir_factor * grad_var / effective_n * n_feat;
1130
1131 if best_gain <= penalty {
1132 return false; }
1134 }
1135 }
1136
1137 let r_squared = 1.0;
1144 let n = sample_count as f64;
1145 let effective_n = match self.config.leaf_decay_alpha {
1146 Some(alpha) => n.min(1.0 / (1.0 - alpha)),
1147 None => n,
1148 };
1149 let ln_inv_delta = math::ln(1.0 / self.config.delta);
1150 let epsilon = math::sqrt(r_squared * ln_inv_delta / (2.0 * effective_n));
1151
1152 let gap = best_gain - second_best_gain;
1155 if gap <= epsilon && epsilon >= TAU {
1156 if at_max_depth {
1159 if let Some(state) = self
1160 .leaf_states
1161 .get_mut(leaf_id.0 as usize)
1162 .and_then(|o| o.as_mut())
1163 {
1164 state.last_reeval_count = sample_count;
1165 }
1166 }
1167 return false;
1168 }
1169
1170 let (best_feat_idx, ref best_candidate, ref fisher_order) = candidates[0];
1172
1173 if best_feat_idx < self.split_gains.len() {
1175 self.split_gains[best_feat_idx] += best_candidate.gain;
1176 }
1177
1178 let best_hist = &histograms.histograms[best_feat_idx];
1179
1180 let left_value = leaf_weight(
1181 best_candidate.left_grad,
1182 best_candidate.left_hess,
1183 self.config.lambda,
1184 );
1185 let right_value = leaf_weight(
1186 best_candidate.right_grad,
1187 best_candidate.right_hess,
1188 self.config.lambda,
1189 );
1190
1191 let (left_id, right_id) = if let Some(ref order) = fisher_order {
1193 let mut mask: u64 = 0;
1200 for &sorted_pos in order.iter().take(best_candidate.bin_idx + 1) {
1201 if sorted_pos < 64 {
1205 mask |= 1u64 << sorted_pos;
1206 }
1207 }
1208
1209 self.arena.split_leaf_categorical(
1211 leaf_id,
1212 best_feat_idx as u32,
1213 0.0,
1214 left_value,
1215 right_value,
1216 mask,
1217 )
1218 } else {
1219 let threshold = if best_candidate.bin_idx < best_hist.edges.edges.len() {
1221 best_hist.edges.edges[best_candidate.bin_idx]
1222 } else {
1223 f64::MAX
1224 };
1225
1226 self.arena.split_leaf(
1227 leaf_id,
1228 best_feat_idx as u32,
1229 threshold,
1230 left_value,
1231 right_value,
1232 )
1233 };
1234
1235 let parent_state = self
1261 .leaf_states
1262 .get_mut(leaf_id.0 as usize)
1263 .and_then(|o| o.take());
1264 let nf = n_features;
1265
1266 let max_child = left_id.0.max(right_id.0) as usize;
1268 if self.leaf_states.len() <= max_child {
1269 self.leaf_states.resize_with(max_child + 1, || None);
1270 }
1271
1272 if let Some(parent) = parent_state {
1273 if let Some(parent_hists) = parent.histograms {
1274 let edges_per_feature: Vec<BinEdges> = parent_hists
1276 .histograms
1277 .iter()
1278 .map(|h| h.edges.clone())
1279 .collect();
1280
1281 let left_hists = LeafHistograms::new(&edges_per_feature);
1290 let right_hists = LeafHistograms::new(&edges_per_feature);
1291
1292 let ft = self.config.feature_types.as_deref();
1293 let child_binners_l = make_binners(nf, ft);
1294 let child_binners_r = make_binners(nf, ft);
1295
1296 let left_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1301 let right_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1302
1303 let left_state = LeafState {
1304 histograms: Some(left_hists),
1305 binners: child_binners_l,
1306 bins_ready: true,
1307 grad_sum: 0.0,
1308 hess_sum: 0.0,
1309 last_reeval_count: 0,
1310 clip_grad_mean: 0.0,
1311 clip_grad_m2: 0.0,
1312 clip_grad_count: 0,
1313 output_mean: 0.0,
1314 output_m2: 0.0,
1315 output_count: 0,
1316 leaf_model: left_model,
1317 };
1318
1319 let right_state = LeafState {
1320 histograms: Some(right_hists),
1321 binners: child_binners_r,
1322 bins_ready: true,
1323 grad_sum: 0.0,
1324 hess_sum: 0.0,
1325 last_reeval_count: 0,
1326 clip_grad_mean: 0.0,
1327 clip_grad_m2: 0.0,
1328 clip_grad_count: 0,
1329 output_mean: 0.0,
1330 output_m2: 0.0,
1331 output_count: 0,
1332 leaf_model: right_model,
1333 };
1334
1335 self.leaf_states[left_id.0 as usize] = Some(left_state);
1336 self.leaf_states[right_id.0 as usize] = Some(right_state);
1337 } else {
1338 let ft = self.config.feature_types.as_deref();
1340 let mut ls = LeafState::new_with_types(nf, ft);
1341 ls.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1342 self.leaf_states[left_id.0 as usize] = Some(ls);
1343 let mut rs = LeafState::new_with_types(nf, ft);
1344 rs.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
1345 self.leaf_states[right_id.0 as usize] = Some(rs);
1346 }
1347 } else {
1348 let ft = self.config.feature_types.as_deref();
1350 let mut ls = LeafState::new_with_types(nf, ft);
1351 ls.leaf_model = self.make_leaf_model(left_id);
1352 self.leaf_states[left_id.0 as usize] = Some(ls);
1353 let mut rs = LeafState::new_with_types(nf, ft);
1354 rs.leaf_model = self.make_leaf_model(right_id);
1355 self.leaf_states[right_id.0 as usize] = Some(rs);
1356 }
1357
1358 true
1359 }
1360}
1361
1362impl StreamingTree for HoeffdingTree {
1363 fn train_one(&mut self, features: &[f64], gradient: f64, hessian: f64) {
1368 self.samples_seen += 1;
1369
1370 let n_features = if let Some(n) = self.n_features {
1372 n
1373 } else {
1374 let n = features.len();
1375 self.n_features = Some(n);
1376 self.split_gains.resize(n, 0.0);
1377
1378 if let Some(state) = self
1380 .leaf_states
1381 .get_mut(self.root.0 as usize)
1382 .and_then(|o| o.as_mut())
1383 {
1384 state.binners = make_binners(n, self.config.feature_types.as_deref());
1385 }
1386 n
1387 };
1388
1389 debug_assert_eq!(
1390 features.len(),
1391 n_features,
1392 "feature count mismatch: got {} but expected {}",
1393 features.len(),
1394 n_features,
1395 );
1396
1397 let leaf_id = self.route_to_leaf(features);
1399
1400 self.arena.increment_sample_count(leaf_id);
1402 let sample_count = self.arena.get_sample_count(leaf_id);
1403
1404 let idx = leaf_id.0 as usize;
1406 if self.leaf_states.len() <= idx {
1407 self.leaf_states.resize_with(idx + 1, || None);
1408 }
1409 if self.leaf_states[idx].is_none() {
1410 self.leaf_states[idx] = Some(LeafState::new_with_types(
1411 n_features,
1412 self.config.feature_types.as_deref(),
1413 ));
1414 }
1415 let state = self.leaf_states[idx].as_mut().unwrap();
1416
1417 let gradient = if let Some(sigma) = self.config.gradient_clip_sigma {
1419 clip_gradient(state, gradient, sigma)
1420 } else {
1421 gradient
1422 };
1423
1424 if !state.bins_ready {
1426 for (binner, &val) in state.binners.iter_mut().zip(features.iter()) {
1428 binner.observe(val);
1429 }
1430
1431 if let Some(alpha) = self.config.leaf_decay_alpha {
1433 state.grad_sum = alpha * state.grad_sum + gradient;
1434 state.hess_sum = alpha * state.hess_sum + hessian;
1435 } else {
1436 state.grad_sum += gradient;
1437 state.hess_sum += hessian;
1438 }
1439
1440 let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1442 self.arena.set_leaf_value(leaf_id, lw);
1443
1444 if self.config.adaptive_leaf_bound.is_some() {
1446 update_output_stats(state, lw, self.config.leaf_decay_alpha);
1447 }
1448
1449 if let Some(ref mut model) = state.leaf_model {
1451 model.update(features, gradient, hessian, self.config.lambda);
1452 }
1453
1454 if sample_count >= self.config.grace_period as u64 {
1456 let edges_per_feature: Vec<BinEdges> = state
1457 .binners
1458 .iter()
1459 .map(|b| b.compute_edges(self.config.n_bins))
1460 .collect();
1461
1462 let mut histograms = LeafHistograms::new(&edges_per_feature);
1463
1464 if let Some(alpha) = self.config.leaf_decay_alpha {
1471 histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1472 } else {
1473 histograms.accumulate(features, gradient, hessian);
1474 }
1475
1476 state.histograms = Some(histograms);
1477 state.bins_ready = true;
1478 }
1479
1480 return;
1481 }
1482
1483 if let Some(ref mut histograms) = state.histograms {
1485 if let Some(alpha) = self.config.leaf_decay_alpha {
1486 histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1487 } else {
1488 histograms.accumulate(features, gradient, hessian);
1489 }
1490 }
1491
1492 if let Some(alpha) = self.config.leaf_decay_alpha {
1494 state.grad_sum = alpha * state.grad_sum + gradient;
1495 state.hess_sum = alpha * state.hess_sum + hessian;
1496 } else {
1497 state.grad_sum += gradient;
1498 state.hess_sum += hessian;
1499 }
1500 let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1501 self.arena.set_leaf_value(leaf_id, lw);
1502
1503 if self.config.adaptive_leaf_bound.is_some() {
1505 update_output_stats(state, lw, self.config.leaf_decay_alpha);
1506 }
1507
1508 if let Some(ref mut model) = state.leaf_model {
1510 model.update(features, gradient, hessian, self.config.lambda);
1511 }
1512
1513 if sample_count % (self.config.grace_period as u64) == 0 {
1516 self.attempt_split(leaf_id);
1517 }
1518 }
1519
1520 fn predict(&self, features: &[f64]) -> f64 {
1525 let leaf_id = self.route_to_leaf(features);
1526 self.leaf_prediction(leaf_id, features)
1527 }
1528
1529 #[inline]
1531 fn n_leaves(&self) -> usize {
1532 self.arena.n_leaves()
1533 }
1534
1535 #[inline]
1537 fn n_samples_seen(&self) -> u64 {
1538 self.samples_seen
1539 }
1540
1541 fn reset(&mut self) {
1543 self.arena.reset();
1544 let root = self.arena.add_leaf(0);
1545 self.root = root;
1546 self.leaf_states.clear();
1547
1548 let n_features = self.n_features.unwrap_or(0);
1550 self.leaf_states.resize_with(root.0 as usize + 1, || None);
1551 let mut root_state =
1552 LeafState::new_with_types(n_features, self.config.feature_types.as_deref());
1553 root_state.leaf_model = self.make_leaf_model(root);
1554 self.leaf_states[root.0 as usize] = Some(root_state);
1555
1556 self.samples_seen = 0;
1557 self.feature_mask.clear();
1558 self.feature_mask_bits.clear();
1559 self.rng_state = self.config.seed;
1560 self.split_gains.iter_mut().for_each(|g| *g = 0.0);
1561 }
1562
1563 fn split_gains(&self) -> &[f64] {
1564 &self.split_gains
1565 }
1566
1567 fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
1568 let leaf_id = self.route_to_leaf(features);
1569 let value = self.leaf_prediction(leaf_id, features);
1570 if let Some(state) = self
1571 .leaf_states
1572 .get(leaf_id.0 as usize)
1573 .and_then(|o| o.as_ref())
1574 {
1575 let variance = 1.0 / (state.hess_sum + self.config.lambda);
1577 (value, variance)
1578 } else {
1579 (value, f64::INFINITY)
1580 }
1581 }
1582}
1583
1584impl Clone for HoeffdingTree {
1585 fn clone(&self) -> Self {
1586 Self {
1587 arena: self.arena.clone(),
1588 root: self.root,
1589 config: self.config.clone(),
1590 leaf_states: self.leaf_states.clone(),
1591 n_features: self.n_features,
1592 samples_seen: self.samples_seen,
1593 split_criterion: self.split_criterion,
1594 feature_mask: self.feature_mask.clone(),
1595 feature_mask_bits: self.feature_mask_bits.clone(),
1596 rng_state: self.rng_state,
1597 split_gains: self.split_gains.clone(),
1598 }
1599 }
1600}
1601
1602unsafe impl Send for HoeffdingTree {}
1606unsafe impl Sync for HoeffdingTree {}
1607
1608#[cfg(test)]
1609mod tests {
1610 use super::*;
1611 use crate::tree::builder::TreeConfig;
1612 use crate::tree::StreamingTree;
1613
1614 fn test_xorshift(state: &mut u64) -> u64 {
1616 xorshift64(state)
1617 }
1618
1619 fn test_rand_f64(state: &mut u64) -> f64 {
1621 let r = test_xorshift(state);
1622 (r as f64) / (u64::MAX as f64)
1623 }
1624
1625 #[test]
1629 fn single_sample_predict_not_nan() {
1630 let config = TreeConfig::new().grace_period(10);
1631 let mut tree = HoeffdingTree::new(config);
1632
1633 let features = vec![1.0, 2.0, 3.0];
1634 tree.train_one(&features, -0.5, 1.0);
1635
1636 let pred = tree.predict(&features);
1637 assert!(!pred.is_nan(), "prediction should not be NaN, got {}", pred);
1638 assert!(
1639 pred.is_finite(),
1640 "prediction should be finite, got {}",
1641 pred
1642 );
1643
1644 assert!((pred - 0.25).abs() < 1e-10, "expected ~0.25, got {}", pred);
1647 }
1648
1649 #[test]
1653 fn linear_signal_rmse_improves() {
1654 let config = TreeConfig::new()
1655 .max_depth(4)
1656 .n_bins(32)
1657 .grace_period(50)
1658 .lambda(0.1)
1659 .gamma(0.0)
1660 .delta(1e-3);
1661
1662 let mut tree = HoeffdingTree::new(config);
1663 let mut rng_state: u64 = 12345;
1664
1665 let n_train = 1000;
1675 let mut features_all: Vec<f64> = Vec::with_capacity(n_train);
1676 let mut targets: Vec<f64> = Vec::with_capacity(n_train);
1677
1678 for _ in 0..n_train {
1679 let x = test_rand_f64(&mut rng_state) * 10.0;
1680 let noise = (test_rand_f64(&mut rng_state) - 0.5) * 0.5;
1681 let y = 2.0 * x + noise;
1682 features_all.push(x);
1683 targets.push(y);
1684 }
1685
1686 let initial_mse: f64 = targets.iter().map(|y| y * y).sum::<f64>() / n_train as f64;
1688 let initial_rmse = initial_mse.sqrt();
1689
1690 for i in 0..n_train {
1692 let feat = [features_all[i]];
1693 let pred = tree.predict(&feat);
1694 let gradient = pred - targets[i];
1696 let hessian = 1.0;
1697 tree.train_one(&feat, gradient, hessian);
1698 }
1699
1700 let mut post_mse = 0.0;
1702 for i in 0..n_train {
1703 let feat = [features_all[i]];
1704 let pred = tree.predict(&feat);
1705 let err = pred - targets[i];
1706 post_mse += err * err;
1707 }
1708 post_mse /= n_train as f64;
1709 let post_rmse = post_mse.sqrt();
1710
1711 assert!(
1712 post_rmse < initial_rmse,
1713 "RMSE should decrease after training: initial={:.4}, post={:.4}",
1714 initial_rmse,
1715 post_rmse,
1716 );
1717 }
1718
1719 #[test]
1723 fn no_splits_before_grace_period() {
1724 let grace = 100;
1725 let config = TreeConfig::new()
1726 .grace_period(grace)
1727 .max_depth(4)
1728 .n_bins(16)
1729 .delta(1e-1); let mut tree = HoeffdingTree::new(config);
1732 let mut rng_state: u64 = 99999;
1733
1734 for _ in 0..(grace - 1) {
1736 let x = test_rand_f64(&mut rng_state) * 10.0;
1737 let y = 2.0 * x;
1738 let feat = [x];
1739 let pred = tree.predict(&feat);
1740 tree.train_one(&feat, pred - y, 1.0);
1741 }
1742
1743 assert_eq!(
1744 tree.n_leaves(),
1745 1,
1746 "should be exactly 1 leaf before grace_period, got {}",
1747 tree.n_leaves()
1748 );
1749 }
1750
1751 #[test]
1755 fn respects_max_depth() {
1756 let max_depth = 3;
1757 let config = TreeConfig::new()
1758 .max_depth(max_depth)
1759 .grace_period(20)
1760 .n_bins(16)
1761 .lambda(0.01)
1762 .gamma(0.0)
1763 .delta(1e-1); let mut tree = HoeffdingTree::new(config);
1766 let mut rng_state: u64 = 7777;
1767
1768 for _ in 0..5000 {
1770 let x = test_rand_f64(&mut rng_state) * 10.0;
1771 let y = if x < 2.5 {
1772 -5.0
1773 } else if x < 5.0 {
1774 -1.0
1775 } else if x < 7.5 {
1776 1.0
1777 } else {
1778 5.0
1779 };
1780 let feat = [x];
1781 let pred = tree.predict(&feat);
1782 tree.train_one(&feat, pred - y, 1.0);
1783 }
1784
1785 let max_leaves = 1usize << max_depth;
1787 assert!(
1788 tree.n_leaves() <= max_leaves,
1789 "tree has {} leaves, but max_depth={} allows at most {}",
1790 tree.n_leaves(),
1791 max_depth,
1792 max_leaves,
1793 );
1794 }
1795
1796 #[test]
1800 fn reset_returns_to_single_leaf() {
1801 let config = TreeConfig::new()
1802 .grace_period(20)
1803 .max_depth(4)
1804 .n_bins(16)
1805 .delta(1e-1);
1806
1807 let mut tree = HoeffdingTree::new(config);
1808 let mut rng_state: u64 = 54321;
1809
1810 for _ in 0..2000 {
1812 let x = test_rand_f64(&mut rng_state) * 10.0;
1813 let y = 3.0 * x - 5.0;
1814 let feat = [x];
1815 let pred = tree.predict(&feat);
1816 tree.train_one(&feat, pred - y, 1.0);
1817 }
1818
1819 let pre_reset_samples = tree.n_samples_seen();
1820 assert!(pre_reset_samples > 0);
1821
1822 tree.reset();
1823
1824 assert_eq!(
1825 tree.n_leaves(),
1826 1,
1827 "after reset, should have exactly 1 leaf"
1828 );
1829 assert_eq!(
1830 tree.n_samples_seen(),
1831 0,
1832 "after reset, samples_seen should be 0"
1833 );
1834
1835 let pred = tree.predict(&[5.0]);
1837 assert!(
1838 pred.abs() < 1e-10,
1839 "prediction after reset should be ~0.0, got {}",
1840 pred
1841 );
1842 }
1843
1844 #[test]
1848 fn multi_feature_training() {
1849 let config = TreeConfig::new()
1850 .grace_period(30)
1851 .max_depth(4)
1852 .n_bins(16)
1853 .lambda(0.1)
1854 .delta(1e-2);
1855
1856 let mut tree = HoeffdingTree::new(config);
1857 let mut rng_state: u64 = 11111;
1858
1859 for _ in 0..1000 {
1861 let x0 = test_rand_f64(&mut rng_state) * 5.0;
1862 let x1 = test_rand_f64(&mut rng_state) * 5.0;
1863 let y = x0 + 2.0 * x1;
1864 let feat = [x0, x1];
1865 let pred = tree.predict(&feat);
1866 tree.train_one(&feat, pred - y, 1.0);
1867 }
1868
1869 let pred = tree.predict(&[2.5, 2.5]);
1871 assert!(
1872 pred.is_finite(),
1873 "multi-feature prediction should be finite"
1874 );
1875 assert_eq!(tree.n_samples_seen(), 1000);
1876 }
1877
1878 #[test]
1882 fn feature_subsampling_works() {
1883 let config = TreeConfig::new()
1884 .grace_period(30)
1885 .max_depth(3)
1886 .n_bins(16)
1887 .lambda(0.1)
1888 .delta(1e-2)
1889 .feature_subsample_rate(0.5);
1890
1891 let mut tree = HoeffdingTree::new(config);
1892 let mut rng_state: u64 = 33333;
1893
1894 for _ in 0..1000 {
1896 let feats: Vec<f64> = (0..5)
1897 .map(|_| test_rand_f64(&mut rng_state) * 10.0)
1898 .collect();
1899 let y: f64 = feats.iter().sum();
1900 let pred = tree.predict(&feats);
1901 tree.train_one(&feats, pred - y, 1.0);
1902 }
1903
1904 let pred = tree.predict(&[1.0, 2.0, 3.0, 4.0, 5.0]);
1905 assert!(pred.is_finite(), "subsampled prediction should be finite");
1906 }
1907
1908 #[test]
1912 fn xorshift64_deterministic() {
1913 let mut s1: u64 = 42;
1914 let mut s2: u64 = 42;
1915
1916 let seq1: Vec<u64> = (0..100).map(|_| xorshift64(&mut s1)).collect();
1917 let seq2: Vec<u64> = (0..100).map(|_| xorshift64(&mut s2)).collect();
1918
1919 assert_eq!(seq1, seq2, "xorshift64 should be deterministic");
1920
1921 for &v in &seq1 {
1923 assert_ne!(v, 0, "xorshift64 should never produce 0 with non-zero seed");
1924 }
1925 }
1926
1927 #[test]
1931 fn ewma_leaf_decay_recent_data_dominates() {
1932 let alpha = (-(2.0_f64.ln()) / 50.0).exp();
1934 let config = TreeConfig::new()
1935 .grace_period(20)
1936 .max_depth(4)
1937 .n_bins(16)
1938 .lambda(1.0)
1939 .leaf_decay_alpha(alpha);
1940 let mut tree = HoeffdingTree::new(config);
1941
1942 for _ in 0..1000 {
1944 let pred = tree.predict(&[1.0, 2.0]);
1945 let grad = pred - 1.0; tree.train_one(&[1.0, 2.0], grad, 1.0);
1947 }
1948
1949 for _ in 0..100 {
1951 let pred = tree.predict(&[1.0, 2.0]);
1952 let grad = pred - 5.0;
1953 tree.train_one(&[1.0, 2.0], grad, 1.0);
1954 }
1955
1956 let pred = tree.predict(&[1.0, 2.0]);
1957 assert!(
1960 pred > 2.0,
1961 "EWMA should let recent data (target=5.0) pull prediction above 2.0, got {}",
1962 pred,
1963 );
1964 }
1965
1966 #[test]
1970 fn ewma_disabled_matches_traditional() {
1971 let config_no_ewma = TreeConfig::new()
1972 .grace_period(20)
1973 .max_depth(4)
1974 .n_bins(16)
1975 .lambda(1.0);
1976 let mut tree = HoeffdingTree::new(config_no_ewma);
1977
1978 let mut rng_state: u64 = 99999;
1979 for _ in 0..200 {
1980 let x = test_rand_f64(&mut rng_state) * 10.0;
1981 let y = 3.0 * x + 1.0;
1982 let pred = tree.predict(&[x]);
1983 tree.train_one(&[x], pred - y, 1.0);
1984 }
1985
1986 let pred = tree.predict(&[5.0]);
1987 assert!(
1988 pred.is_finite(),
1989 "prediction without EWMA should be finite, got {}",
1990 pred
1991 );
1992 }
1993
1994 #[test]
1998 fn split_reeval_at_max_depth() {
1999 let config = TreeConfig::new()
2000 .grace_period(20)
2001 .max_depth(2) .n_bins(16)
2003 .lambda(1.0)
2004 .split_reeval_interval(50);
2005 let mut tree = HoeffdingTree::new(config);
2006
2007 let mut rng_state: u64 = 54321;
2008 for _ in 0..2000 {
2010 let x1 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2011 let x2 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2012 let y = 2.0 * x1 + 3.0 * x2;
2013 let pred = tree.predict(&[x1, x2]);
2014 tree.train_one(&[x1, x2], pred - y, 1.0);
2015 }
2016
2017 let leaves = tree.n_leaves();
2021 assert!(
2022 leaves >= 4,
2023 "split re-eval should allow growth beyond max_depth=2 cap (4 leaves), got {}",
2024 leaves,
2025 );
2026 }
2027
2028 #[test]
2032 fn split_reeval_disabled_matches_traditional() {
2033 let config = TreeConfig::new()
2034 .grace_period(20)
2035 .max_depth(2)
2036 .n_bins(16)
2037 .lambda(1.0);
2038 let mut tree = HoeffdingTree::new(config);
2040
2041 let mut rng_state: u64 = 77777;
2042 for _ in 0..2000 {
2043 let x1 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2044 let x2 = test_rand_f64(&mut rng_state) * 10.0 - 5.0;
2045 let y = 2.0 * x1 + 3.0 * x2;
2046 let pred = tree.predict(&[x1, x2]);
2047 tree.train_one(&[x1, x2], pred - y, 1.0);
2048 }
2049
2050 let leaves = tree.n_leaves();
2052 assert!(
2053 leaves <= 4,
2054 "without re-eval, max_depth=2 should cap at 4 leaves, got {}",
2055 leaves,
2056 );
2057 }
2058
2059 #[test]
2063 fn gradient_clipping_clamps_outliers() {
2064 let config = TreeConfig::new()
2065 .grace_period(20)
2066 .max_depth(2)
2067 .n_bins(16)
2068 .gradient_clip_sigma(2.0);
2069
2070 let mut tree = HoeffdingTree::new(config);
2071
2072 let mut rng_state = 42u64;
2074 for _ in 0..50 {
2075 let x = test_rand_f64(&mut rng_state) * 2.0;
2076 let grad = x * 0.1; tree.train_one(&[x], grad, 1.0);
2078 }
2079
2080 let pred_before = tree.predict(&[1.0]);
2081
2082 tree.train_one(&[1.0], 1000.0, 1.0);
2084
2085 let pred_after = tree.predict(&[1.0]);
2086
2087 let delta = (pred_after - pred_before).abs();
2091 assert!(
2092 delta < 100.0,
2093 "gradient clipping should limit impact of outlier, but prediction changed by {}",
2094 delta,
2095 );
2096 }
2097
2098 #[test]
2102 fn clip_gradient_welford_tracks_stats() {
2103 let mut state = LeafState::new(1);
2104
2105 for i in 0..20 {
2107 let grad = 1.0 + (i as f64) * 0.1; let clipped = clip_gradient(&mut state, grad, 3.0);
2109 assert!(
2111 (clipped - grad).abs() < 1e-10,
2112 "normal gradients should not be clipped at 3-sigma"
2113 );
2114 }
2115
2116 let clipped = clip_gradient(&mut state, 100.0, 3.0);
2118 assert!(
2119 clipped < 100.0,
2120 "extreme outlier should be clipped, got {}",
2121 clipped,
2122 );
2123 assert!(
2124 clipped > 0.0,
2125 "clipped value should be positive, got {}",
2126 clipped,
2127 );
2128 }
2129
2130 #[test]
2134 fn clip_gradient_warmup_no_clipping() {
2135 let mut state = LeafState::new(1);
2136
2137 for i in 0..9 {
2139 let val = if i == 8 { 1000.0 } else { 1.0 };
2140 let clipped = clip_gradient(&mut state, val, 2.0);
2141 assert_eq!(clipped, val, "warmup should not clip");
2142 }
2143 }
2144
2145 #[test]
2149 fn adaptive_bound_warmup_returns_max() {
2150 let mut state = LeafState::new(1);
2151 for i in 0..9 {
2153 update_output_stats(&mut state, 0.5 + i as f64 * 0.01, None);
2154 }
2155 let bound = adaptive_bound(&state, 3.0, None);
2156 assert_eq!(bound, f64::MAX, "warmup should return f64::MAX");
2157 }
2158
2159 #[test]
2163 fn adaptive_bound_tightens_after_warmup() {
2164 let mut state = LeafState::new(1);
2165 for i in 0..20 {
2167 let w = 0.3 + (i as f64 - 10.0) * 0.01; update_output_stats(&mut state, w, None);
2169 }
2170 let bound = adaptive_bound(&state, 3.0, None);
2171 assert!(
2173 bound < 1.0,
2174 "3-sigma bound on outputs ~0.3 should be < 1.0, got {}",
2175 bound,
2176 );
2177 assert!(bound > 0.2, "bound should be > |mean|, got {}", bound,);
2178 }
2179
2180 #[test]
2184 fn adaptive_bound_clamps_outlier_leaf() {
2185 let mut state = LeafState::new(1);
2186 for _ in 0..20 {
2188 update_output_stats(&mut state, 0.3, None);
2189 }
2190 let bound = adaptive_bound(&state, 3.0, None);
2191 let clamped = (2.9_f64).clamp(-bound, bound);
2193 assert!(
2194 clamped < 2.9,
2195 "2.9 should be clamped by adaptive bound {}, got {}",
2196 bound,
2197 clamped,
2198 );
2199 }
2200
2201 #[test]
2205 fn adaptive_bound_with_decay_adapts() {
2206 let alpha = 0.95; let mut state = LeafState::new(1);
2208
2209 for _ in 0..30 {
2211 update_output_stats(&mut state, 0.3, Some(alpha));
2212 }
2213 let bound_phase1 = adaptive_bound(&state, 3.0, Some(alpha));
2214
2215 for _ in 0..100 {
2217 update_output_stats(&mut state, 2.0, Some(alpha));
2218 }
2219 let bound_phase2 = adaptive_bound(&state, 3.0, Some(alpha));
2220
2221 assert!(
2223 bound_phase2 > bound_phase1,
2224 "EWMA bound should adapt: phase1={}, phase2={}",
2225 bound_phase1,
2226 bound_phase2,
2227 );
2228 }
2229
2230 #[test]
2234 fn adaptive_bound_disabled_by_default() {
2235 let config = TreeConfig::default();
2236 assert!(
2237 config.adaptive_leaf_bound.is_none(),
2238 "adaptive_leaf_bound should default to None",
2239 );
2240 }
2241
2242 #[test]
2246 fn adaptive_bound_warmup_falls_back_to_global() {
2247 let mut state = LeafState::new(1);
2248 for _ in 0..5 {
2250 update_output_stats(&mut state, 0.3, None);
2251 }
2252 let bound = adaptive_bound(&state, 3.0, None);
2253 assert_eq!(bound, f64::MAX, "warmup should yield f64::MAX");
2254 }
2256
2257 #[test]
2261 fn monotonic_constraint_splits_respected() {
2262 let config = TreeConfig::new()
2265 .grace_period(30)
2266 .max_depth(4)
2267 .n_bins(16)
2268 .monotone_constraints(vec![1]); let mut tree = HoeffdingTree::new(config);
2271
2272 let mut rng_state = 42u64;
2273 for _ in 0..500 {
2274 let x = test_rand_f64(&mut rng_state) * 10.0;
2275 let grad = x * 0.5 - 2.5;
2277 tree.train_one(&[x], grad, 1.0);
2278 }
2279
2280 let pred_low = tree.predict(&[0.0]);
2283 let pred_mid = tree.predict(&[5.0]);
2284 let pred_high = tree.predict(&[10.0]);
2285
2286 assert!(
2288 pred_low <= pred_mid + 1e-10 && pred_mid <= pred_high + 1e-10,
2289 "monotonic +1 violated: pred(0)={}, pred(5)={}, pred(10)={}",
2290 pred_low,
2291 pred_mid,
2292 pred_high,
2293 );
2294 }
2295
2296 #[test]
2300 fn predict_with_variance_finite() {
2301 let config = TreeConfig::new().grace_period(10);
2302 let mut tree = HoeffdingTree::new(config);
2303
2304 for i in 0..30 {
2306 let x = i as f64 * 0.1;
2307 tree.train_one(&[x], x - 1.0, 1.0);
2308 }
2309
2310 let (value, variance) = tree.predict_with_variance(&[1.0]);
2311 assert!(value.is_finite(), "value should be finite");
2312 assert!(variance.is_finite(), "variance should be finite");
2313 assert!(variance > 0.0, "variance should be positive");
2314 }
2315
2316 #[test]
2320 fn predict_with_variance_decreases_with_data() {
2321 let config = TreeConfig::new().grace_period(10);
2322 let mut tree = HoeffdingTree::new(config);
2323
2324 for i in 0..20 {
2326 tree.train_one(&[1.0], 0.5, 1.0);
2327 if i == 0 {
2328 continue;
2329 }
2330 }
2331 let (_, var_20) = tree.predict_with_variance(&[1.0]);
2332
2333 for _ in 0..200 {
2335 tree.train_one(&[1.0], 0.5, 1.0);
2336 }
2337 let (_, var_220) = tree.predict_with_variance(&[1.0]);
2338
2339 assert!(
2340 var_220 < var_20,
2341 "variance should decrease with more data: var@20={} vs var@220={}",
2342 var_20,
2343 var_220,
2344 );
2345 }
2346
2347 #[test]
2351 fn predict_smooth_matches_hard_at_small_bandwidth() {
2352 let config = TreeConfig::new()
2353 .max_depth(3)
2354 .n_bins(16)
2355 .grace_period(20)
2356 .lambda(1.0);
2357 let mut tree = HoeffdingTree::new(config);
2358
2359 let mut rng = 42u64;
2361 for _ in 0..500 {
2362 let x = test_rand_f64(&mut rng) * 10.0;
2363 let y = 2.0 * x + 1.0;
2364 let features = vec![x, x * 0.5];
2365 let pred = tree.predict(&features);
2366 let grad = pred - y;
2367 let hess = 1.0;
2368 tree.train_one(&features, grad, hess);
2369 }
2370
2371 let features = vec![5.0, 2.5];
2373 let hard = tree.predict(&features);
2374 let smooth = tree.predict_smooth(&features, 0.001);
2375 assert!(
2376 (hard - smooth).abs() < 0.1,
2377 "smooth with tiny bandwidth should approximate hard: hard={}, smooth={}",
2378 hard,
2379 smooth,
2380 );
2381 }
2382
2383 #[test]
2387 fn predict_smooth_is_continuous() {
2388 let config = TreeConfig::new()
2389 .max_depth(3)
2390 .n_bins(16)
2391 .grace_period(20)
2392 .lambda(1.0);
2393 let mut tree = HoeffdingTree::new(config);
2394
2395 let mut rng = 42u64;
2397 for _ in 0..500 {
2398 let x = test_rand_f64(&mut rng) * 10.0;
2399 let y = 2.0 * x + 1.0;
2400 let features = vec![x, x * 0.5];
2401 let pred = tree.predict(&features);
2402 let grad = pred - y;
2403 tree.train_one(&features, grad, 1.0);
2404 }
2405
2406 let bandwidth = 1.0;
2408 let base = tree.predict_smooth(&[5.0, 2.5], bandwidth);
2409 let nudged = tree.predict_smooth(&[5.001, 2.5], bandwidth);
2410 let diff = (base - nudged).abs();
2411 assert!(
2412 diff < 0.1,
2413 "smooth prediction should be continuous: base={}, nudged={}, diff={}",
2414 base,
2415 nudged,
2416 diff,
2417 );
2418 }
2419
2420 #[test]
2424 fn leaf_grad_hess_returns_sums() {
2425 let config = TreeConfig::new().grace_period(100).lambda(1.0);
2426 let mut tree = HoeffdingTree::new(config);
2427
2428 let features = vec![1.0, 2.0, 3.0];
2429
2430 for _ in 0..10 {
2432 tree.train_one(&features, -0.5, 1.0);
2433 }
2434
2435 let root = tree.root();
2437 let (grad, hess) = tree
2438 .leaf_grad_hess(root)
2439 .expect("root should have leaf state");
2440
2441 assert!(
2443 (grad - (-5.0)).abs() < 1e-10,
2444 "grad_sum should be -5.0, got {}",
2445 grad
2446 );
2447 assert!(
2449 (hess - 10.0).abs() < 1e-10,
2450 "hess_sum should be 10.0, got {}",
2451 hess
2452 );
2453 }
2454
2455 #[test]
2456 fn leaf_grad_hess_returns_none_for_invalid_node() {
2457 let config = TreeConfig::new();
2458 let tree = HoeffdingTree::new(config);
2459
2460 assert!(tree.leaf_grad_hess(NodeId::NONE).is_none());
2462 assert!(tree.leaf_grad_hess(NodeId(999)).is_none());
2464 }
2465
2466 #[test]
2471 fn adaptive_depth_none_identical_to_static_max_depth() {
2472 let config_static = TreeConfig::new()
2475 .max_depth(3)
2476 .n_bins(32)
2477 .grace_period(20)
2478 .lambda(0.1)
2479 .delta(1e-3);
2480
2481 let config_none = TreeConfig::new()
2482 .max_depth(3)
2483 .n_bins(32)
2484 .grace_period(20)
2485 .lambda(0.1)
2486 .delta(1e-3);
2487
2488 assert!(config_none.adaptive_depth.is_none());
2490
2491 let mut tree_static = HoeffdingTree::new(config_static);
2492 let mut tree_none = HoeffdingTree::new(config_none);
2493
2494 let mut rng_state: u64 = 42;
2495 for _ in 0..2000 {
2496 let x = test_rand_f64(&mut rng_state) * 10.0;
2497 let y = 2.0 * x;
2498 let feat = [x, x * 0.5, x * x];
2499 let pred_s = tree_static.predict(&feat);
2500 let pred_n = tree_none.predict(&feat);
2501 tree_static.train_one(&feat, pred_s - y, 1.0);
2502 tree_none.train_one(&feat, pred_n - y, 1.0);
2503 }
2504
2505 assert_eq!(
2507 tree_static.arena().n_nodes(),
2508 tree_none.arena().n_nodes(),
2509 "adaptive_depth=None should produce identical tree structure to static max_depth"
2510 );
2511 }
2512
2513 #[test]
2514 fn adaptive_depth_few_samples_stays_shallow() {
2515 let config = TreeConfig::new()
2518 .max_depth(6)
2519 .n_bins(32)
2520 .grace_period(20)
2521 .lambda(0.1)
2522 .delta(1e-3)
2523 .adaptive_depth(7.5);
2524
2525 let mut tree = HoeffdingTree::new(config);
2526 let mut rng_state: u64 = 99;
2527
2528 for _ in 0..100 {
2530 let x = test_rand_f64(&mut rng_state) * 10.0;
2531 let noise = (test_rand_f64(&mut rng_state) - 0.5) * 20.0; let y = 0.1 * x + noise;
2533 let feat = [x, test_rand_f64(&mut rng_state) * 5.0];
2534 let pred = tree.predict(&feat);
2535 tree.train_one(&feat, pred - y, 1.0);
2536 }
2537
2538 let n_nodes = tree.arena().n_nodes();
2542 assert!(
2543 n_nodes <= 15,
2544 "adaptive_depth with few noisy samples should keep tree shallow, got {} nodes",
2545 n_nodes
2546 );
2547 }
2548
2549 #[test]
2550 fn adaptive_depth_many_samples_grows_deeper() {
2551 let config_few = TreeConfig::new()
2554 .max_depth(6)
2555 .n_bins(32)
2556 .grace_period(20)
2557 .lambda(0.1)
2558 .delta(1e-3)
2559 .adaptive_depth(7.5);
2560
2561 let config_many = TreeConfig::new()
2562 .max_depth(6)
2563 .n_bins(32)
2564 .grace_period(20)
2565 .lambda(0.1)
2566 .delta(1e-3)
2567 .adaptive_depth(7.5);
2568
2569 let mut tree_few = HoeffdingTree::new(config_few);
2570 let mut tree_many = HoeffdingTree::new(config_many);
2571
2572 let mut rng_state: u64 = 42;
2573
2574 for _ in 0..200 {
2577 let x1 = test_rand_f64(&mut rng_state) * 10.0;
2578 let x2 = test_rand_f64(&mut rng_state) * 5.0;
2579 let y = 3.0 * x1 + 2.0 * x2;
2580 let feat = [x1, x2];
2581 let pred = tree_few.predict(&feat);
2582 tree_few.train_one(&feat, pred - y, 1.0);
2583 }
2584
2585 let mut rng_state2: u64 = 42;
2587 for _ in 0..5000 {
2588 let x1 = test_rand_f64(&mut rng_state2) * 10.0;
2589 let x2 = test_rand_f64(&mut rng_state2) * 5.0;
2590 let y = 3.0 * x1 + 2.0 * x2;
2591 let feat = [x1, x2];
2592 let pred = tree_many.predict(&feat);
2593 tree_many.train_one(&feat, pred - y, 1.0);
2594 }
2595
2596 assert!(
2599 tree_many.arena().n_nodes() >= tree_few.arena().n_nodes(),
2600 "more samples should allow deeper growth: many={} vs few={}",
2601 tree_many.arena().n_nodes(),
2602 tree_few.arena().n_nodes()
2603 );
2604 }
2605
2606 #[test]
2607 fn adaptive_depth_penalty_scales_inversely_with_n() {
2608 let cir_factor: f64 = 7.5;
2614 let grad_var: f64 = 1.0;
2615 let n_feat: f64 = 2.0;
2616
2617 let penalty_100 = cir_factor * grad_var / 100.0 * n_feat;
2618 let penalty_1000 = cir_factor * grad_var / 1000.0 * n_feat;
2619
2620 assert!(
2621 (penalty_100 - 0.15).abs() < 1e-10,
2622 "penalty at n=100 should be 0.15, got {}",
2623 penalty_100
2624 );
2625 assert!(
2626 (penalty_1000 - 0.015).abs() < 1e-10,
2627 "penalty at n=1000 should be 0.015, got {}",
2628 penalty_1000
2629 );
2630 assert!(
2631 penalty_100 > penalty_1000,
2632 "penalty should decrease with more samples"
2633 );
2634
2635 let gain = 0.05;
2637 assert!(gain <= penalty_100, "gain should fail CIR at n=100");
2638 assert!(gain > penalty_1000, "gain should pass CIR at n=1000");
2639 }
2640
2641 #[test]
2642 fn adaptive_depth_hard_ceiling_respected() {
2643 let config = TreeConfig::new()
2645 .max_depth(3)
2646 .n_bins(32)
2647 .grace_period(10)
2648 .lambda(0.01)
2649 .gamma(0.0)
2650 .delta(1e-2) .adaptive_depth(0.001); let mut tree = HoeffdingTree::new(config);
2654 let mut rng_state: u64 = 777;
2655
2656 for _ in 0..10000 {
2658 let x = test_rand_f64(&mut rng_state) * 100.0;
2659 let y = x * x; let feat = [x];
2661 let pred = tree.predict(&feat);
2662 tree.train_one(&feat, pred - y, 1.0);
2663 }
2664
2665 let max_leaves = 1usize << 6;
2667 let n_leaves = tree.arena().n_leaves();
2668 assert!(
2669 n_leaves <= max_leaves,
2670 "tree should respect hard ceiling of max_depth*2=6 ({} max leaves), got {} leaves",
2671 max_leaves,
2672 n_leaves
2673 );
2674 }
2675}