1pub mod leaf;
22pub mod split_logic;
23
24use alloc::vec;
25use alloc::vec::Vec;
26
27use crate::histogram::bins::LeafHistograms;
28use crate::math;
29use crate::tree::builder::TreeConfig;
30use crate::tree::leaf_model::LeafModelType;
31use crate::tree::node::{NodeId, TreeArena};
32use crate::tree::split::{leaf_weight, XGBoostGain};
33use crate::tree::StreamingTree;
34
35use leaf::{adaptive_bound, clip_gradient, make_binners, update_output_stats, LeafState};
36
37pub struct HoeffdingTree {
49 pub(crate) arena: TreeArena,
51
52 pub(crate) root: NodeId,
54
55 pub(crate) config: TreeConfig,
57
58 pub(crate) leaf_states: Vec<Option<LeafState>>,
61
62 pub(crate) n_features: Option<usize>,
64
65 pub(crate) samples_seen: u64,
67
68 pub(crate) split_criterion: XGBoostGain,
70
71 pub(crate) feature_mask: Vec<usize>,
73
74 pub(crate) feature_mask_bits: Vec<u64>,
77
78 pub(crate) rng_state: u64,
80
81 pub(crate) split_gains: Vec<f64>,
84
85 pub(crate) node_bandwidths: Vec<f64>,
88}
89
90impl HoeffdingTree {
91 pub fn new(config: TreeConfig) -> Self {
96 let mut arena = TreeArena::new();
97 let root = arena.add_leaf(0);
98
99 let mut leaf_states = vec![None; root.0 as usize + 1];
103 let root_model = match config.leaf_model_type {
104 LeafModelType::ClosedForm => None,
105 _ => Some(config.leaf_model_type.create(config.seed, config.delta)),
106 };
107 leaf_states[root.0 as usize] = Some(LeafState {
108 histograms: None,
109 binners: Vec::new(),
110 bins_ready: false,
111 grad_sum: 0.0,
112 hess_sum: 0.0,
113 last_reeval_count: 0,
114 clip_grad_mean: 0.0,
115 clip_grad_m2: 0.0,
116 clip_grad_count: 0,
117 output_mean: 0.0,
118 output_m2: 0.0,
119 output_count: 0,
120 leaf_model: root_model,
121 });
122
123 let seed = config.seed;
124 Self {
125 arena,
126 root,
127 config,
128 leaf_states,
129 n_features: None,
130 samples_seen: 0,
131 split_criterion: XGBoostGain::default(),
132 feature_mask: Vec::new(),
133 feature_mask_bits: Vec::new(),
134 rng_state: seed,
135 split_gains: Vec::new(),
136 node_bandwidths: Vec::new(),
137 }
138 }
139
140 fn make_leaf_model(
146 &self,
147 node: NodeId,
148 ) -> Option<alloc::boxed::Box<dyn crate::tree::leaf_model::LeafModel>> {
149 match self.config.leaf_model_type {
150 LeafModelType::ClosedForm => None,
151 _ => Some(
152 self.config
153 .leaf_model_type
154 .create(self.config.seed ^ (node.0 as u64), self.config.delta),
155 ),
156 }
157 }
158
159 pub fn from_arena(
168 config: TreeConfig,
169 arena: TreeArena,
170 n_features: Option<usize>,
171 samples_seen: u64,
172 rng_state: u64,
173 ) -> Self {
174 let root = if arena.n_nodes() > 0 {
175 NodeId(0)
176 } else {
177 let mut arena_mut = arena;
179 let root = arena_mut.add_leaf(0);
180 return Self {
181 arena: arena_mut,
182 root,
183 config: config.clone(),
184 leaf_states: {
185 let mut v = vec![None; root.0 as usize + 1];
186 v[root.0 as usize] = Some(LeafState::new(n_features.unwrap_or(0)));
187 v
188 },
189 n_features,
190 samples_seen,
191 split_criterion: XGBoostGain::default(),
192 feature_mask: Vec::new(),
193 feature_mask_bits: Vec::new(),
194 rng_state,
195 split_gains: vec![0.0; n_features.unwrap_or(0)],
196 node_bandwidths: Vec::new(),
197 };
198 };
199
200 let nf = n_features.unwrap_or(0);
202 let mut leaf_states: Vec<Option<LeafState>> = vec![None; arena.n_nodes()];
203 for (i, slot) in leaf_states.iter_mut().enumerate() {
204 if arena.is_leaf[i] {
205 *slot = Some(LeafState::new(nf));
206 }
207 }
208
209 Self {
210 arena,
211 root,
212 config,
213 leaf_states,
214 n_features,
215 samples_seen,
216 split_criterion: XGBoostGain::default(),
217 feature_mask: Vec::new(),
218 feature_mask_bits: Vec::new(),
219 rng_state,
220 split_gains: vec![0.0; nf],
221 node_bandwidths: Vec::new(),
222 }
223 }
224
225 #[inline]
227 pub fn root(&self) -> NodeId {
228 self.root
229 }
230
231 #[inline]
233 pub fn arena(&self) -> &TreeArena {
234 &self.arena
235 }
236
237 #[inline]
239 pub fn tree_config(&self) -> &TreeConfig {
240 &self.config
241 }
242
243 #[inline]
245 pub fn n_features(&self) -> Option<usize> {
246 self.n_features
247 }
248
249 #[inline]
251 pub fn rng_state(&self) -> u64 {
252 self.rng_state
253 }
254
255 #[inline]
265 pub fn leaf_grad_hess(&self, node: NodeId) -> Option<(f64, f64)> {
266 self.leaf_states
267 .get(node.0 as usize)
268 .and_then(|o| o.as_ref())
269 .map(|state| (state.grad_sum, state.hess_sum))
270 }
271
272 pub(crate) fn route_to_leaf(&self, features: &[f64]) -> NodeId {
274 let mut current = self.root;
275 while !self.arena.is_leaf(current) {
276 let feat_idx = self.arena.get_feature_idx(current) as usize;
277 current = if let Some(mask) = self.arena.get_categorical_mask(current) {
278 let cat_val = features[feat_idx] as u64;
285 if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
286 self.arena.get_left(current)
287 } else {
288 self.arena.get_right(current)
289 }
290 } else {
291 let threshold = self.arena.get_threshold(current);
293 if features[feat_idx] <= threshold {
294 self.arena.get_left(current)
295 } else {
296 self.arena.get_right(current)
297 }
298 };
299 }
300 current
301 }
302
303 #[inline]
308 fn leaf_prediction(&self, leaf_id: NodeId, features: &[f64]) -> f64 {
309 let (raw, leaf_bound) = if let Some(state) = self
310 .leaf_states
311 .get(leaf_id.0 as usize)
312 .and_then(|o| o.as_ref())
313 {
314 if let Some(min_h) = self.config.min_hessian_sum {
316 if state.hess_sum < min_h {
317 return 0.0;
318 }
319 }
320 let val = if let Some(ref model) = state.leaf_model {
321 model.predict(features)
322 } else if state.hess_sum != 0.0 {
323 leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda)
324 } else {
325 self.arena.leaf_value[leaf_id.0 as usize]
326 };
327
328 let bound = self
330 .config
331 .adaptive_leaf_bound
332 .map(|k| adaptive_bound(state, k, self.config.leaf_decay_alpha));
333
334 (val, bound)
335 } else {
336 (0.0, None)
337 };
338
339 if let Some(bound) = leaf_bound {
341 if bound < f64::MAX {
342 return raw.clamp(-bound, bound);
343 }
344 }
345 if let Some(max) = self.config.max_leaf_output {
346 raw.clamp(-max, max)
347 } else {
348 raw
349 }
350 }
351
352 pub fn predict_smooth(&self, features: &[f64], bandwidth: f64) -> f64 {
368 self.predict_smooth_recursive(self.root, features, bandwidth)
369 }
370
371 pub fn predict_smooth_auto(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
376 self.predict_smooth_auto_recursive(self.root, features, bandwidths)
377 }
378
379 pub fn predict_interpolated(&self, features: &[f64]) -> f64 {
391 let mut current = self.root;
392 let mut parent = None;
393 while !self.arena.is_leaf(current) {
394 parent = Some(current);
395 let feat_idx = self.arena.get_feature_idx(current) as usize;
396 current = if let Some(mask) = self.arena.get_categorical_mask(current) {
397 let cat_val = features[feat_idx] as u64;
398 if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
399 self.arena.get_left(current)
400 } else {
401 self.arena.get_right(current)
402 }
403 } else {
404 let threshold = self.arena.get_threshold(current);
405 if features[feat_idx] <= threshold {
406 self.arena.get_left(current)
407 } else {
408 self.arena.get_right(current)
409 }
410 };
411 }
412
413 let leaf_pred = self.leaf_prediction(current, features);
414
415 let parent_id = match parent {
417 Some(p) => p,
418 None => return leaf_pred,
419 };
420
421 let parent_pred = self.leaf_prediction(parent_id, features);
423
424 let leaf_hess = self
426 .leaf_states
427 .get(current.0 as usize)
428 .and_then(|o| o.as_ref())
429 .map(|s| s.hess_sum)
430 .unwrap_or(0.0);
431
432 let alpha = leaf_hess / (leaf_hess + self.config.lambda);
433 alpha * leaf_pred + (1.0 - alpha) * parent_pred
434 }
435
436 pub fn predict_sibling_interpolated(&self, features: &[f64], bandwidths: &[f64]) -> f64 {
452 self.predict_sibling_recursive(self.root, features, bandwidths)
453 }
454
455 fn predict_sibling_recursive(&self, node: NodeId, features: &[f64], bandwidths: &[f64]) -> f64 {
456 if self.arena.is_leaf(node) {
457 return self.leaf_prediction(node, features);
458 }
459
460 let feat_idx = self.arena.get_feature_idx(node) as usize;
461 let left = self.arena.get_left(node);
462 let right = self.arena.get_right(node);
463
464 if let Some(mask) = self.arena.get_categorical_mask(node) {
466 let cat_val = features[feat_idx] as u64;
467 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
468 self.predict_sibling_recursive(left, features, bandwidths)
469 } else {
470 self.predict_sibling_recursive(right, features, bandwidths)
471 };
472 }
473
474 let threshold = self.arena.get_threshold(node);
475 let margin = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
476
477 if !margin.is_finite() || margin <= 0.0 {
479 return if features[feat_idx] <= threshold {
480 self.predict_sibling_recursive(left, features, bandwidths)
481 } else {
482 self.predict_sibling_recursive(right, features, bandwidths)
483 };
484 }
485
486 let dist = features[feat_idx] - threshold;
487
488 if dist < -margin {
489 self.predict_sibling_recursive(left, features, bandwidths)
491 } else if dist > margin {
492 self.predict_sibling_recursive(right, features, bandwidths)
494 } else {
495 let t = (dist + margin) / (2.0 * margin); let left_pred = self.predict_sibling_recursive(left, features, bandwidths);
498 let right_pred = self.predict_sibling_recursive(right, features, bandwidths);
499 (1.0 - t) * left_pred + t * right_pred
500 }
501 }
502
503 pub fn collect_split_thresholds_per_feature(&self) -> Vec<Vec<f64>> {
508 let n = self.n_features.unwrap_or(0);
509 let mut thresholds: Vec<Vec<f64>> = vec![Vec::new(); n];
510
511 for i in 0..self.arena.n_nodes() {
512 if !self.arena.is_leaf[i] && self.arena.categorical_mask[i].is_none() {
513 let feat_idx = self.arena.feature_idx[i] as usize;
514 if feat_idx < n {
515 thresholds[feat_idx].push(self.arena.threshold[i]);
516 }
517 }
518 }
519
520 thresholds
521 }
522
523 fn compute_node_bandwidth(&self, node: NodeId, all_thresholds: &[Vec<f64>]) -> f64 {
525 let feat_idx = self.arena.get_feature_idx(node) as usize;
526 let threshold = self.arena.get_threshold(node);
527
528 let thresholds = if feat_idx < all_thresholds.len() {
529 &all_thresholds[feat_idx]
530 } else {
531 return f64::INFINITY;
532 };
533
534 let below = thresholds.iter().rev().find(|&&t| t < threshold - 1e-15);
536 let above = thresholds.iter().find(|&&t| t > threshold + 1e-15);
537
538 match (below, above) {
539 (Some(&b), Some(&a)) => (threshold - b).min(a - threshold),
540 (Some(&b), None) => threshold - b,
541 (None, Some(&a)) => a - threshold,
542 (None, None) => f64::INFINITY,
543 }
544 }
545
546 pub fn recompute_bandwidths(&mut self) {
548 let n = self.arena.n_nodes();
549 self.node_bandwidths.resize(n, f64::INFINITY);
550
551 let mut all_thresholds = self.collect_split_thresholds_per_feature();
553 for v in &mut all_thresholds {
554 v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
555 }
556
557 for i in 0..n {
558 let nid = NodeId(i as u32);
559 if !self.arena.is_leaf(nid) {
560 self.node_bandwidths[i] = self.compute_node_bandwidth(nid, &all_thresholds);
561 } else {
562 self.node_bandwidths[i] = f64::INFINITY;
563 }
564 }
565 }
566
567 pub fn predict_soft_routed(&self, features: &[f64]) -> f64 {
570 self.predict_soft_recursive(self.root, features)
571 }
572
573 fn predict_soft_recursive(&self, node: NodeId, features: &[f64]) -> f64 {
574 if self.arena.is_leaf(node) {
575 return self.leaf_prediction(node, features);
576 }
577
578 let feat_idx = self.arena.get_feature_idx(node) as usize;
579 let left = self.arena.get_left(node);
580 let right = self.arena.get_right(node);
581
582 if let Some(mask) = self.arena.get_categorical_mask(node) {
584 let cat_val = features[feat_idx] as u64;
585 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
586 self.predict_soft_recursive(left, features)
587 } else {
588 self.predict_soft_recursive(right, features)
589 };
590 }
591
592 let threshold = self.arena.get_threshold(node);
593 let margin = self
594 .node_bandwidths
595 .get(node.0 as usize)
596 .copied()
597 .unwrap_or(f64::INFINITY);
598
599 let left_pred = self.predict_soft_recursive(left, features);
600 let right_pred = self.predict_soft_recursive(right, features);
601
602 if !margin.is_finite() || margin <= 0.0 {
604 let dist = features[feat_idx] - threshold;
605 let scale = math::abs(threshold) * 0.01 + 1e-10;
606 let z = (-dist / scale).clamp(-500.0, 500.0);
607 let t = 1.0 / (1.0 + math::exp(z));
608 return (1.0 - t) * left_pred + t * right_pred;
609 }
610
611 let dist = features[feat_idx] - threshold;
613 let t = ((dist + margin) / (2.0 * margin)).clamp(0.0, 1.0);
614 (1.0 - t) * left_pred + t * right_pred
615 }
616
617 fn predict_smooth_recursive(&self, node: NodeId, features: &[f64], bandwidth: f64) -> f64 {
619 if self.arena.is_leaf(node) {
620 return self.leaf_prediction(node, features);
622 }
623
624 let feat_idx = self.arena.get_feature_idx(node) as usize;
625 let left = self.arena.get_left(node);
626 let right = self.arena.get_right(node);
627
628 if let Some(mask) = self.arena.get_categorical_mask(node) {
630 let cat_val = features[feat_idx] as u64;
631 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
632 self.predict_smooth_recursive(left, features, bandwidth)
633 } else {
634 self.predict_smooth_recursive(right, features, bandwidth)
635 };
636 }
637
638 let threshold = self.arena.get_threshold(node);
640 let z = (threshold - features[feat_idx]) / bandwidth;
641 let alpha = 1.0 / (1.0 + math::exp(-z));
642
643 let left_pred = self.predict_smooth_recursive(left, features, bandwidth);
644 let right_pred = self.predict_smooth_recursive(right, features, bandwidth);
645
646 alpha * left_pred + (1.0 - alpha) * right_pred
647 }
648
649 fn predict_smooth_auto_recursive(
651 &self,
652 node: NodeId,
653 features: &[f64],
654 bandwidths: &[f64],
655 ) -> f64 {
656 if self.arena.is_leaf(node) {
657 return self.leaf_prediction(node, features);
658 }
659
660 let feat_idx = self.arena.get_feature_idx(node) as usize;
661 let left = self.arena.get_left(node);
662 let right = self.arena.get_right(node);
663
664 if let Some(mask) = self.arena.get_categorical_mask(node) {
666 let cat_val = features[feat_idx] as u64;
667 return if cat_val < 64 && (mask >> cat_val) & 1 == 1 {
668 self.predict_smooth_auto_recursive(left, features, bandwidths)
669 } else {
670 self.predict_smooth_auto_recursive(right, features, bandwidths)
671 };
672 }
673
674 let threshold = self.arena.get_threshold(node);
675 let bw = bandwidths.get(feat_idx).copied().unwrap_or(f64::INFINITY);
676
677 if !bw.is_finite() {
679 return if features[feat_idx] <= threshold {
680 self.predict_smooth_auto_recursive(left, features, bandwidths)
681 } else {
682 self.predict_smooth_auto_recursive(right, features, bandwidths)
683 };
684 }
685
686 let z = (threshold - features[feat_idx]) / bw;
688 let alpha = 1.0 / (1.0 + math::exp(-z));
689
690 let left_pred = self.predict_smooth_auto_recursive(left, features, bandwidths);
691 let right_pred = self.predict_smooth_auto_recursive(right, features, bandwidths);
692
693 alpha * left_pred + (1.0 - alpha) * right_pred
694 }
695
696 pub(crate) fn attempt_split(&mut self, leaf_id: NodeId) -> bool {
700 let depth = self.arena.get_depth(leaf_id);
701
702 let hard_ceiling = if self.config.adaptive_depth.is_some() {
705 self.config.max_depth.saturating_mul(2)
706 } else {
707 self.config.max_depth
708 };
709 let at_max_depth = depth as usize >= hard_ceiling;
710
711 if at_max_depth {
712 match self.config.split_reeval_interval {
715 None => return false,
716 Some(interval) => {
717 let state = match self
718 .leaf_states
719 .get(leaf_id.0 as usize)
720 .and_then(|o| o.as_ref())
721 {
722 Some(s) => s,
723 None => return false,
724 };
725 let sample_count = self.arena.get_sample_count(leaf_id);
726 if sample_count - state.last_reeval_count < interval as u64 {
727 return false;
728 }
729 }
731 }
732 }
733
734 let n_features = match self.n_features {
735 Some(n) => n,
736 None => return false,
737 };
738
739 let sample_count = self.arena.get_sample_count(leaf_id);
740 if sample_count < self.config.grace_period as u64 {
741 return false;
742 }
743
744 let (feature_mask, feature_mask_bits) = split_logic::generate_feature_mask(
746 core::mem::take(&mut self.feature_mask),
747 core::mem::take(&mut self.feature_mask_bits),
748 &mut self.rng_state,
749 self.config.feature_subsample_rate,
750 n_features,
751 );
752 self.feature_mask = feature_mask;
753 self.feature_mask_bits = feature_mask_bits;
754
755 if self.config.leaf_decay_alpha.is_some() {
757 if let Some(state) = self
758 .leaf_states
759 .get_mut(leaf_id.0 as usize)
760 .and_then(|o| o.as_mut())
761 {
762 if let Some(ref mut histograms) = state.histograms {
763 histograms.materialize_decay();
764 }
765 }
766 }
767
768 let state = match self
770 .leaf_states
771 .get(leaf_id.0 as usize)
772 .and_then(|o| o.as_ref())
773 {
774 Some(s) => s,
775 None => return false,
776 };
777
778 let histograms = match &state.histograms {
779 Some(h) => h,
780 None => return false,
781 };
782
783 let ctx = split_logic::private::SplitContext {
784 config: &self.config,
785 n_features: self.n_features,
786 n_feature_mask: &self.feature_mask,
787 split_criterion: &self.split_criterion,
788 rng_state: &mut self.rng_state,
789 };
790
791 let candidates = split_logic::private::evaluate_split_candidates(
792 histograms,
793 self.config.feature_types.as_deref(),
794 &ctx,
795 );
796
797 if candidates.is_empty() {
798 return false;
799 }
800
801 let best_gain = candidates[0].1.gain;
802 let second_best_gain = if candidates.len() > 1 {
803 candidates[1].1.gain
804 } else {
805 0.0
806 };
807
808 let ctx = split_logic::private::SplitContext {
810 config: &self.config,
811 n_features: self.n_features,
812 n_feature_mask: &self.feature_mask,
813 split_criterion: &self.split_criterion,
814 rng_state: &mut self.rng_state,
815 };
816
817 if !split_logic::private::should_split_hoeffding(
818 best_gain,
819 second_best_gain,
820 sample_count,
821 &ctx,
822 ) {
823 if at_max_depth {
824 if let Some(state) = self
825 .leaf_states
826 .get_mut(leaf_id.0 as usize)
827 .and_then(|o| o.as_mut())
828 {
829 state.last_reeval_count = sample_count;
830 }
831 }
832 return false;
833 }
834
835 let (best_feat_idx, ref best_candidate, ref fisher_order) = candidates[0];
837
838 if best_feat_idx < self.split_gains.len() {
840 self.split_gains[best_feat_idx] += best_candidate.gain;
841 }
842
843 let best_hist = &histograms.histograms[best_feat_idx];
844
845 let left_value = leaf_weight(
846 best_candidate.left_grad,
847 best_candidate.left_hess,
848 self.config.lambda,
849 );
850 let right_value = leaf_weight(
851 best_candidate.right_grad,
852 best_candidate.right_hess,
853 self.config.lambda,
854 );
855
856 let (left_id, right_id) = if let Some(ref order) = fisher_order {
858 let mut mask: u64 = 0;
859 for &sorted_pos in order.iter().take(best_candidate.bin_idx + 1) {
860 if sorted_pos < 64 {
861 mask |= 1u64 << sorted_pos;
862 }
863 }
864
865 self.arena.split_leaf_categorical(
866 leaf_id,
867 best_feat_idx as u32,
868 0.0,
869 left_value,
870 right_value,
871 mask,
872 )
873 } else {
874 let threshold = if best_candidate.bin_idx < best_hist.edges.edges.len() {
875 best_hist.edges.edges[best_candidate.bin_idx]
876 } else {
877 f64::MAX
878 };
879
880 self.arena.split_leaf(
881 leaf_id,
882 best_feat_idx as u32,
883 threshold,
884 left_value,
885 right_value,
886 )
887 };
888
889 let parent_state = self
890 .leaf_states
891 .get_mut(leaf_id.0 as usize)
892 .and_then(|o| o.take());
893 let nf = n_features;
894
895 let max_child = left_id.0.max(right_id.0) as usize;
897 if self.leaf_states.len() <= max_child {
898 self.leaf_states.resize_with(max_child + 1, || None);
899 }
900
901 if let Some(parent) = parent_state {
902 if let Some(parent_hists) = parent.histograms {
903 let edges_per_feature: Vec<crate::histogram::BinEdges> = parent_hists
904 .histograms
905 .iter()
906 .map(|h| h.edges.clone())
907 .collect();
908
909 let left_hists = LeafHistograms::new(&edges_per_feature);
910 let right_hists = LeafHistograms::new(&edges_per_feature);
911
912 let ft = self.config.feature_types.as_deref();
913 let child_binners_l = make_binners(nf, ft);
914 let child_binners_r = make_binners(nf, ft);
915
916 let left_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
917 let right_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
918
919 let left_state = LeafState {
920 histograms: Some(left_hists),
921 binners: child_binners_l,
922 bins_ready: true,
923 grad_sum: 0.0,
924 hess_sum: 0.0,
925 last_reeval_count: 0,
926 clip_grad_mean: 0.0,
927 clip_grad_m2: 0.0,
928 clip_grad_count: 0,
929 output_mean: 0.0,
930 output_m2: 0.0,
931 output_count: 0,
932 leaf_model: left_model,
933 };
934
935 let right_state = LeafState {
936 histograms: Some(right_hists),
937 binners: child_binners_r,
938 bins_ready: true,
939 grad_sum: 0.0,
940 hess_sum: 0.0,
941 last_reeval_count: 0,
942 clip_grad_mean: 0.0,
943 clip_grad_m2: 0.0,
944 clip_grad_count: 0,
945 output_mean: 0.0,
946 output_m2: 0.0,
947 output_count: 0,
948 leaf_model: right_model,
949 };
950
951 self.leaf_states[left_id.0 as usize] = Some(left_state);
952 self.leaf_states[right_id.0 as usize] = Some(right_state);
953 } else {
954 let ft = self.config.feature_types.as_deref();
955 let mut ls = LeafState::new_with_types(nf, ft);
956 ls.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
957 self.leaf_states[left_id.0 as usize] = Some(ls);
958 let mut rs = LeafState::new_with_types(nf, ft);
959 rs.leaf_model = parent.leaf_model.as_ref().map(|m| m.clone_warm());
960 self.leaf_states[right_id.0 as usize] = Some(rs);
961 }
962 } else {
963 let ft = self.config.feature_types.as_deref();
964 let mut ls = LeafState::new_with_types(nf, ft);
965 ls.leaf_model = self.make_leaf_model(left_id);
966 self.leaf_states[left_id.0 as usize] = Some(ls);
967 let mut rs = LeafState::new_with_types(nf, ft);
968 rs.leaf_model = self.make_leaf_model(right_id);
969 self.leaf_states[right_id.0 as usize] = Some(rs);
970 }
971
972 self.recompute_bandwidths();
973 true
974 }
975}
976
977impl StreamingTree for HoeffdingTree {
978 fn train_one(&mut self, features: &[f64], gradient: f64, hessian: f64) {
979 self.samples_seen += 1;
980
981 let n_features = if let Some(n) = self.n_features {
982 n
983 } else {
984 let n = features.len();
985 self.n_features = Some(n);
986 self.split_gains.resize(n, 0.0);
987
988 if let Some(state) = self
989 .leaf_states
990 .get_mut(self.root.0 as usize)
991 .and_then(|o| o.as_mut())
992 {
993 state.binners = make_binners(n, self.config.feature_types.as_deref());
994 }
995 n
996 };
997
998 debug_assert_eq!(
999 features.len(),
1000 n_features,
1001 "feature count mismatch: got {} but expected {}",
1002 features.len(),
1003 n_features,
1004 );
1005
1006 let leaf_id = self.route_to_leaf(features);
1007 self.arena.increment_sample_count(leaf_id);
1008 let sample_count = self.arena.get_sample_count(leaf_id);
1009
1010 let idx = leaf_id.0 as usize;
1011 if self.leaf_states.len() <= idx {
1012 self.leaf_states.resize_with(idx + 1, || None);
1013 }
1014 if self.leaf_states[idx].is_none() {
1015 self.leaf_states[idx] = Some(LeafState::new_with_types(
1016 n_features,
1017 self.config.feature_types.as_deref(),
1018 ));
1019 }
1020 let state = self.leaf_states[idx].as_mut().unwrap();
1021
1022 let gradient = if let Some(sigma) = self.config.gradient_clip_sigma {
1023 clip_gradient(state, gradient, sigma)
1024 } else {
1025 gradient
1026 };
1027
1028 if !state.bins_ready {
1029 for (binner, &val) in state.binners.iter_mut().zip(features.iter()) {
1030 binner.observe(val);
1031 }
1032
1033 if let Some(alpha) = self.config.leaf_decay_alpha {
1034 state.grad_sum = alpha * state.grad_sum + gradient;
1035 state.hess_sum = alpha * state.hess_sum + hessian;
1036 } else {
1037 state.grad_sum += gradient;
1038 state.hess_sum += hessian;
1039 }
1040
1041 let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1042 self.arena.set_leaf_value(leaf_id, lw);
1043
1044 if self.config.adaptive_leaf_bound.is_some() {
1045 update_output_stats(state, lw, self.config.leaf_decay_alpha);
1046 }
1047
1048 if let Some(ref mut model) = state.leaf_model {
1049 model.update(features, gradient, hessian, self.config.lambda);
1050 }
1051
1052 if sample_count >= self.config.grace_period as u64 {
1053 let edges_per_feature: Vec<crate::histogram::BinEdges> = state
1054 .binners
1055 .iter()
1056 .map(|b| b.compute_edges(self.config.n_bins))
1057 .collect();
1058
1059 let mut histograms = LeafHistograms::new(&edges_per_feature);
1060
1061 if let Some(alpha) = self.config.leaf_decay_alpha {
1062 histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1063 } else {
1064 histograms.accumulate(features, gradient, hessian);
1065 }
1066
1067 state.histograms = Some(histograms);
1068 state.bins_ready = true;
1069 }
1070
1071 return;
1072 }
1073
1074 if let Some(ref mut histograms) = state.histograms {
1075 if let Some(alpha) = self.config.leaf_decay_alpha {
1076 histograms.accumulate_with_decay(features, gradient, hessian, alpha);
1077 } else {
1078 histograms.accumulate(features, gradient, hessian);
1079 }
1080 }
1081
1082 if let Some(alpha) = self.config.leaf_decay_alpha {
1083 state.grad_sum = alpha * state.grad_sum + gradient;
1084 state.hess_sum = alpha * state.hess_sum + hessian;
1085 } else {
1086 state.grad_sum += gradient;
1087 state.hess_sum += hessian;
1088 }
1089 let lw = leaf_weight(state.grad_sum, state.hess_sum, self.config.lambda);
1090 self.arena.set_leaf_value(leaf_id, lw);
1091
1092 if self.config.adaptive_leaf_bound.is_some() {
1093 update_output_stats(state, lw, self.config.leaf_decay_alpha);
1094 }
1095
1096 if let Some(ref mut model) = state.leaf_model {
1097 model.update(features, gradient, hessian, self.config.lambda);
1098 }
1099
1100 if sample_count % (self.config.grace_period as u64) == 0 {
1101 self.attempt_split(leaf_id);
1102 }
1103 }
1104
1105 fn predict(&self, features: &[f64]) -> f64 {
1106 let leaf_id = self.route_to_leaf(features);
1107 self.leaf_prediction(leaf_id, features)
1108 }
1109
1110 #[inline]
1111 fn n_leaves(&self) -> usize {
1112 self.arena.n_leaves()
1113 }
1114
1115 #[inline]
1116 fn n_samples_seen(&self) -> u64 {
1117 self.samples_seen
1118 }
1119
1120 fn reset(&mut self) {
1121 self.arena.reset();
1122 let root = self.arena.add_leaf(0);
1123 self.root = root;
1124 self.leaf_states.clear();
1125
1126 let n_features = self.n_features.unwrap_or(0);
1127 self.leaf_states.resize_with(root.0 as usize + 1, || None);
1128 let mut root_state =
1129 LeafState::new_with_types(n_features, self.config.feature_types.as_deref());
1130 root_state.leaf_model = self.make_leaf_model(root);
1131 self.leaf_states[root.0 as usize] = Some(root_state);
1132
1133 self.samples_seen = 0;
1134 self.feature_mask.clear();
1135 self.feature_mask_bits.clear();
1136 self.rng_state = self.config.seed;
1137 self.split_gains.iter_mut().for_each(|g| *g = 0.0);
1138 self.node_bandwidths.clear();
1139 }
1140
1141 fn split_gains(&self) -> &[f64] {
1142 &self.split_gains
1143 }
1144
1145 fn predict_with_variance(&self, features: &[f64]) -> (f64, f64) {
1146 let leaf_id = self.route_to_leaf(features);
1147 let value = self.leaf_prediction(leaf_id, features);
1148 if let Some(state) = self
1149 .leaf_states
1150 .get(leaf_id.0 as usize)
1151 .and_then(|o| o.as_ref())
1152 {
1153 let variance = 1.0 / (state.hess_sum + self.config.lambda);
1154 (value, variance)
1155 } else {
1156 (value, f64::INFINITY)
1157 }
1158 }
1159}
1160
1161impl Clone for HoeffdingTree {
1162 fn clone(&self) -> Self {
1163 Self {
1164 arena: self.arena.clone(),
1165 root: self.root,
1166 config: self.config.clone(),
1167 leaf_states: self.leaf_states.clone(),
1168 n_features: self.n_features,
1169 samples_seen: self.samples_seen,
1170 split_criterion: self.split_criterion,
1171 feature_mask: self.feature_mask.clone(),
1172 feature_mask_bits: self.feature_mask_bits.clone(),
1173 rng_state: self.rng_state,
1174 split_gains: self.split_gains.clone(),
1175 node_bandwidths: self.node_bandwidths.clone(),
1176 }
1177 }
1178}
1179
1180unsafe impl Send for HoeffdingTree {}
1184unsafe impl Sync for HoeffdingTree {}
1185
1186#[cfg(test)]
1187mod tests {
1188 use super::*;
1189 use crate::tree::builder::TreeConfig;
1190 use crate::tree::StreamingTree;
1191
1192 #[test]
1193 fn single_sample_predict_not_nan() {
1194 let config = TreeConfig::new().grace_period(10);
1195 let mut tree = HoeffdingTree::new(config);
1196
1197 let features = vec![1.0, 2.0, 3.0];
1198 tree.train_one(&features, -0.5, 1.0);
1199
1200 let pred = tree.predict(&features);
1201 assert!(!pred.is_nan(), "prediction should not be NaN, got {}", pred);
1202 assert!(
1203 pred.is_finite(),
1204 "prediction should be finite, got {}",
1205 pred
1206 );
1207
1208 assert!((pred - 0.25).abs() < 1e-10, "expected ~0.25, got {}", pred);
1209 }
1210}