Skip to main content

irithyll_core/tree/
hoeffding_classifier.rs

1//! Hoeffding Tree Classifier for streaming multi-class classification.
2//!
3//! [`HoeffdingTreeClassifier`] is a streaming decision tree that grows
4//! incrementally using the VFDT algorithm (Domingos & Hulten, 2000). Unlike
5//! the existing [`super::hoeffding::HoeffdingTree`] (which is a gradient-based
6//! regressor), this classifier maintains per-leaf class distributions and
7//! splits using information gain.
8//!
9//! # Algorithm
10//!
11//! For each incoming `(features, class_label)` pair:
12//!
13//! 1. Route the sample from root to a leaf via threshold comparisons.
14//! 2. At the leaf, update class counts and per-feature histogram bins.
15//! 3. Once enough samples arrive (grace period), evaluate candidate splits
16//!    using information gain (entropy reduction).
17//! 4. Apply the Hoeffding bound: if the gap between the best and second-best
18//!    gain exceeds `epsilon = sqrt(R^2 * ln(1/delta) / (2n))`, commit the split.
19//! 5. Split the leaf into two children, partitioning class counts by the
20//!    chosen threshold.
21//!
22//! # Key differences from the regressor
23//!
24//! - Leaves store class counts, not gradient/hessian sums.
25//! - Split criterion: Information Gain (reduction in entropy).
26//! - Prediction: majority class or class probability distribution.
27//! - No loss function -- direct classification.
28//!
29//! # Examples
30//!
31//! ```ignore
32//! use irithyll::tree::hoeffding_classifier::HoeffdingClassifierConfig;
33//! use irithyll::tree::hoeffding_classifier::HoeffdingTreeClassifier;
34//!
35//! let config = HoeffdingClassifierConfig::builder()
36//!     .max_depth(6)
37//!     .delta(1e-5)
38//!     .grace_period(50)
39//!     .n_bins(16)
40//!     .build()
41//!     .unwrap();
42//!
43//! let mut tree = HoeffdingTreeClassifier::new(config);
44//!
45//! // Train: class 0 when x[0] < 5, class 1 otherwise
46//! for i in 0..200 {
47//!     let x = (i as f64) / 20.0;
48//!     let class = if x < 5.0 { 0 } else { 1 };
49//!     tree.train_one(&[x], class);
50//! }
51//!
52//! // After training, the tree can predict class labels.
53//! let pred = tree.predict_class(&[2.0]);
54//! assert!(pred == 0 || pred == 1);
55//! ```
56
57use alloc::string::{String, ToString};
58use alloc::vec;
59use alloc::vec::Vec;
60
61use crate::learner::StreamingLearner;
62use crate::math;
63
64/// Tie-breaking threshold (tau). When `epsilon < tau`, we accept the best split
65/// even if the gap between best and second-best gain is small, because the
66/// Hoeffding bound is already tight enough that further samples won't help.
67const TAU: f64 = 0.05;
68
69// ---------------------------------------------------------------------------
70// Configuration
71// ---------------------------------------------------------------------------
72
73/// Configuration for [`HoeffdingTreeClassifier`].
74///
75/// Use [`HoeffdingClassifierConfig::builder()`] to construct with defaults
76/// and override only the parameters you need.
77#[derive(Debug, Clone)]
78pub struct HoeffdingClassifierConfig {
79    /// Maximum tree depth.
80    pub max_depth: usize,
81    /// Hoeffding bound confidence parameter (1 - delta). Lower = more splits.
82    pub delta: f64,
83    /// Minimum samples at a leaf before considering a split.
84    pub grace_period: usize,
85    /// Number of histogram bins per feature for split evaluation.
86    pub n_bins: usize,
87    /// Number of features (0 = lazy init from first sample).
88    pub n_features: usize,
89    /// Maximum number of classes (0 = auto-discover).
90    pub max_classes: usize,
91}
92
93/// Builder for [`HoeffdingClassifierConfig`].
94///
95/// All fields have sensible defaults. Call `.build()` to finalize.
96#[derive(Debug, Clone)]
97pub struct HoeffdingClassifierConfigBuilder {
98    max_depth: usize,
99    delta: f64,
100    grace_period: usize,
101    n_bins: usize,
102    n_features: usize,
103    max_classes: usize,
104}
105
106impl HoeffdingClassifierConfig {
107    /// Create a builder with default parameters.
108    ///
109    /// Defaults:
110    /// - `max_depth`: 10
111    /// - `delta`: 1e-7
112    /// - `grace_period`: 200
113    /// - `n_bins`: 32
114    /// - `n_features`: 0 (lazy init)
115    /// - `max_classes`: 0 (auto-discover)
116    pub fn builder() -> HoeffdingClassifierConfigBuilder {
117        HoeffdingClassifierConfigBuilder {
118            max_depth: 10,
119            delta: 1e-7,
120            grace_period: 200,
121            n_bins: 32,
122            n_features: 0,
123            max_classes: 0,
124        }
125    }
126}
127
128impl HoeffdingClassifierConfigBuilder {
129    /// Set the maximum tree depth.
130    pub fn max_depth(mut self, d: usize) -> Self {
131        self.max_depth = d;
132        self
133    }
134
135    /// Set the Hoeffding bound confidence parameter.
136    pub fn delta(mut self, d: f64) -> Self {
137        self.delta = d;
138        self
139    }
140
141    /// Set the minimum samples before split evaluation.
142    pub fn grace_period(mut self, g: usize) -> Self {
143        self.grace_period = g;
144        self
145    }
146
147    /// Set the number of histogram bins per feature.
148    pub fn n_bins(mut self, b: usize) -> Self {
149        self.n_bins = b;
150        self
151    }
152
153    /// Set the number of features (0 for lazy init).
154    pub fn n_features(mut self, f: usize) -> Self {
155        self.n_features = f;
156        self
157    }
158
159    /// Set the maximum number of classes (0 for auto-discover).
160    pub fn max_classes(mut self, c: usize) -> Self {
161        self.max_classes = c;
162        self
163    }
164
165    /// Build the configuration, validating all parameters.
166    ///
167    /// # Errors
168    ///
169    /// Returns `Err(String)` if:
170    /// - `max_depth` is 0
171    /// - `delta` is not in (0, 1)
172    /// - `grace_period` is 0
173    /// - `n_bins` is less than 2
174    pub fn build(self) -> Result<HoeffdingClassifierConfig, String> {
175        if self.max_depth == 0 {
176            return Err("max_depth must be >= 1".to_string());
177        }
178        if self.delta <= 0.0 || self.delta >= 1.0 {
179            return Err("delta must be in (0, 1)".to_string());
180        }
181        if self.grace_period == 0 {
182            return Err("grace_period must be >= 1".to_string());
183        }
184        if self.n_bins < 2 {
185            return Err("n_bins must be >= 2".to_string());
186        }
187        Ok(HoeffdingClassifierConfig {
188            max_depth: self.max_depth,
189            delta: self.delta,
190            grace_period: self.grace_period,
191            n_bins: self.n_bins,
192            n_features: self.n_features,
193            max_classes: self.max_classes,
194        })
195    }
196}
197
198// ---------------------------------------------------------------------------
199// Internal structures
200// ---------------------------------------------------------------------------
201
202/// Per-leaf statistics for split evaluation.
203///
204/// Tracks class distributions both globally (for prediction) and per-feature
205/// per-bin (for information gain computation).
206#[derive(Debug, Clone)]
207struct LeafStats {
208    /// Per-class counts at this leaf.
209    class_counts: Vec<u64>,
210
211    /// Per-feature, per-bin, per-class counts for split evaluation.
212    /// `feature_histograms[feature][bin][class] = count`
213    feature_histograms: Vec<Vec<Vec<u64>>>,
214
215    /// Per-feature bin boundaries (uniform between observed min/max).
216    bin_boundaries: Vec<Vec<f64>>,
217
218    /// Per-feature observed min/max for boundary computation.
219    feature_ranges: Vec<(f64, f64)>,
220
221    /// Total samples at this leaf.
222    n_samples: u64,
223}
224
225impl LeafStats {
226    /// Create fresh leaf stats for the given number of features, bins, and classes.
227    fn new(n_features: usize, n_bins: usize, n_classes: usize) -> Self {
228        let feature_histograms = vec![vec![vec![0u64; n_classes]; n_bins]; n_features];
229        let bin_boundaries = vec![Vec::new(); n_features];
230        let feature_ranges = vec![(f64::MAX, f64::MIN); n_features];
231
232        Self {
233            class_counts: vec![0u64; n_classes],
234            feature_histograms,
235            bin_boundaries,
236            feature_ranges,
237            n_samples: 0,
238        }
239    }
240
241    /// Ensure class vectors are large enough to accommodate `class_id`.
242    fn ensure_class_capacity(&mut self, n_classes: usize) {
243        if self.class_counts.len() < n_classes {
244            self.class_counts.resize(n_classes, 0);
245            for feat_bins in &mut self.feature_histograms {
246                for bin_counts in feat_bins.iter_mut() {
247                    bin_counts.resize(n_classes, 0);
248                }
249            }
250        }
251    }
252}
253
254/// A single node in the classifier tree arena.
255///
256/// Internal (split) nodes have `split_feature` and `split_threshold` set.
257/// Leaf nodes have `leaf_stats` populated for ongoing training.
258/// All nodes maintain `class_counts` for graceful prediction even at split
259/// nodes (useful when traversal ends early due to missing data).
260#[derive(Debug, Clone)]
261struct ClassifierNode {
262    /// Feature index for the split. `None` for leaf nodes.
263    split_feature: Option<usize>,
264
265    /// Threshold value for the split. `None` for leaf nodes.
266    split_threshold: Option<f64>,
267
268    /// Index of the left child in the arena (samples where feature < threshold).
269    left: Option<usize>,
270
271    /// Index of the right child in the arena (samples where feature >= threshold).
272    right: Option<usize>,
273
274    /// Depth of this node in the tree (root = 0).
275    depth: usize,
276
277    /// Per-class sample counts at this node (accumulated during training).
278    class_counts: Vec<u64>,
279
280    /// Total samples routed through this node.
281    n_samples: u64,
282
283    /// Leaf-specific statistics for split evaluation. `None` for split nodes.
284    leaf_stats: Option<LeafStats>,
285}
286
287impl ClassifierNode {
288    /// Create a new leaf node at the given depth.
289    fn new_leaf(depth: usize, n_features: usize, n_bins: usize, n_classes: usize) -> Self {
290        Self {
291            split_feature: None,
292            split_threshold: None,
293            left: None,
294            right: None,
295            depth,
296            class_counts: vec![0u64; n_classes],
297            n_samples: 0,
298            leaf_stats: Some(LeafStats::new(n_features, n_bins, n_classes)),
299        }
300    }
301
302    /// Returns `true` if this node is a leaf (has no split).
303    #[inline]
304    fn is_leaf(&self) -> bool {
305        self.split_feature.is_none()
306    }
307
308    /// Ensure class vectors are large enough to accommodate `n_classes`.
309    fn ensure_class_capacity(&mut self, n_classes: usize) {
310        if self.class_counts.len() < n_classes {
311            self.class_counts.resize(n_classes, 0);
312        }
313        if let Some(ref mut stats) = self.leaf_stats {
314            stats.ensure_class_capacity(n_classes);
315        }
316    }
317}
318
319// ---------------------------------------------------------------------------
320// HoeffdingTreeClassifier
321// ---------------------------------------------------------------------------
322
323/// A streaming decision tree classifier based on the VFDT algorithm.
324///
325/// Grows incrementally by maintaining per-leaf class distributions and
326/// histogram bins. Splits are committed when the Hoeffding bound guarantees
327/// the best information-gain split is statistically superior to the
328/// runner-up.
329///
330/// # Thread Safety
331///
332/// `HoeffdingTreeClassifier` is `Send + Sync`, making it usable in async
333/// and multi-threaded pipelines.
334#[derive(Debug, Clone)]
335pub struct HoeffdingTreeClassifier {
336    config: HoeffdingClassifierConfig,
337    /// Arena-based tree storage. Index 0 is always the root.
338    nodes: Vec<ClassifierNode>,
339    /// Number of features (lazy-initialized from first sample if config says 0).
340    n_features: usize,
341    /// Number of discovered classes.
342    n_classes: usize,
343    /// Total samples trained on.
344    n_samples: u64,
345}
346
347impl HoeffdingTreeClassifier {
348    /// Create a new classifier from the given configuration.
349    ///
350    /// If `config.n_features > 0`, the tree is immediately initialized with a
351    /// root leaf. Otherwise, initialization is deferred until the first training
352    /// sample arrives.
353    pub fn new(config: HoeffdingClassifierConfig) -> Self {
354        let n_features = config.n_features;
355        let n_classes = config.max_classes;
356
357        let mut tree = Self {
358            config,
359            nodes: Vec::new(),
360            n_features,
361            n_classes,
362            n_samples: 0,
363        };
364
365        // If features are known up front, create the root immediately.
366        if n_features > 0 {
367            let root =
368                ClassifierNode::new_leaf(0, n_features, tree.config.n_bins, n_classes.max(2));
369            tree.nodes.push(root);
370            if tree.n_classes == 0 {
371                tree.n_classes = 2; // sensible minimum
372            }
373        }
374
375        tree
376    }
377
378    /// Train on a single observation: route to leaf, update stats, maybe split.
379    ///
380    /// # Arguments
381    ///
382    /// * `features` -- feature vector for this observation.
383    /// * `class` -- the class label (0-indexed).
384    pub fn train_one(&mut self, features: &[f64], class: usize) {
385        // Lazy initialization on first sample.
386        if self.nodes.is_empty() {
387            self.n_features = features.len();
388            if self.n_classes == 0 {
389                self.n_classes = (class + 1).max(2);
390            }
391            let root =
392                ClassifierNode::new_leaf(0, self.n_features, self.config.n_bins, self.n_classes);
393            self.nodes.push(root);
394        }
395
396        // Auto-discover classes.
397        if class >= self.n_classes {
398            self.n_classes = class + 1;
399            for node in &mut self.nodes {
400                node.ensure_class_capacity(self.n_classes);
401            }
402        }
403
404        self.n_samples += 1;
405
406        // Route to leaf.
407        let leaf_idx = self.route_to_leaf(features);
408
409        // Update leaf stats.
410        self.update_leaf(leaf_idx, features, class);
411
412        // Evaluate split if grace period met.
413        let node = &self.nodes[leaf_idx];
414        let n = node.n_samples;
415        let gp = self.config.grace_period as u64;
416        if n >= gp && n % gp == 0 && node.depth < self.config.max_depth {
417            self.try_split(leaf_idx);
418        }
419    }
420
421    /// Predict the majority class for the given feature vector.
422    ///
423    /// Returns the class index with the highest count at the reached leaf.
424    /// If no samples have been seen, returns 0.
425    pub fn predict_class(&self, features: &[f64]) -> usize {
426        if self.nodes.is_empty() {
427            return 0;
428        }
429        let leaf_idx = self.route_to_leaf(features);
430        let node = &self.nodes[leaf_idx];
431        majority_class(&node.class_counts)
432    }
433
434    /// Predict the class probability distribution for the given feature vector.
435    ///
436    /// Returns a `Vec<f64>` of length `n_classes` where each entry is the
437    /// estimated probability (fraction of samples) for that class. The
438    /// probabilities sum to 1.0 (or all zeros if no samples seen).
439    pub fn predict_proba(&self, features: &[f64]) -> Vec<f64> {
440        if self.nodes.is_empty() {
441            return vec![0.0; self.n_classes.max(1)];
442        }
443        let leaf_idx = self.route_to_leaf(features);
444        let node = &self.nodes[leaf_idx];
445        class_probabilities(&node.class_counts)
446    }
447
448    /// Number of leaf nodes in the tree.
449    pub fn n_leaves(&self) -> usize {
450        self.nodes.iter().filter(|n| n.is_leaf()).count()
451    }
452
453    /// Total number of nodes (leaves + splits) in the tree.
454    pub fn n_nodes(&self) -> usize {
455        self.nodes.len()
456    }
457
458    /// Maximum depth reached by any node in the tree.
459    pub fn max_depth_seen(&self) -> usize {
460        self.nodes.iter().map(|n| n.depth).max().unwrap_or(0)
461    }
462
463    /// Number of discovered classes.
464    pub fn n_classes(&self) -> usize {
465        self.n_classes
466    }
467
468    /// Total number of training samples seen since creation or last reset.
469    pub fn n_samples_seen(&self) -> u64 {
470        self.n_samples
471    }
472
473    /// Reset the tree to its initial (untrained) state.
474    ///
475    /// Clears all nodes and counters. If `n_features` was configured up front,
476    /// the root leaf is re-created; otherwise initialization is deferred again.
477    pub fn reset(&mut self) {
478        self.nodes.clear();
479        self.n_samples = 0;
480        let n_features = self.config.n_features;
481        let n_classes = self.config.max_classes;
482        self.n_features = n_features;
483        self.n_classes = n_classes;
484
485        if n_features > 0 {
486            let root =
487                ClassifierNode::new_leaf(0, n_features, self.config.n_bins, n_classes.max(2));
488            self.nodes.push(root);
489            if self.n_classes == 0 {
490                self.n_classes = 2;
491            }
492        }
493    }
494
495    // -----------------------------------------------------------------------
496    // Private helpers
497    // -----------------------------------------------------------------------
498
499    /// Route a feature vector from root to a leaf, returning the leaf's arena index.
500    fn route_to_leaf(&self, features: &[f64]) -> usize {
501        let mut idx = 0;
502        loop {
503            let node = &self.nodes[idx];
504            if node.is_leaf() {
505                return idx;
506            }
507            let feat = node.split_feature.unwrap();
508            let thresh = node.split_threshold.unwrap();
509            if feat < features.len() && features[feat] < thresh {
510                idx = node.left.unwrap();
511            } else {
512                idx = node.right.unwrap();
513            }
514        }
515    }
516
517    /// Update the leaf node at `leaf_idx` with a new observation.
518    fn update_leaf(&mut self, leaf_idx: usize, features: &[f64], class: usize) {
519        let n_bins = self.config.n_bins;
520        let node = &mut self.nodes[leaf_idx];
521        node.n_samples += 1;
522        node.class_counts[class] += 1;
523
524        let stats = node.leaf_stats.as_mut().expect("leaf must have stats");
525        stats.n_samples += 1;
526        stats.class_counts[class] += 1;
527
528        let half_grace = (self.config.grace_period / 2).max(1) as u64;
529
530        for (f_idx, &val) in features.iter().enumerate().take(self.n_features) {
531            // Track observed min/max for this feature.
532            let (ref mut lo, ref mut hi) = stats.feature_ranges[f_idx];
533            if val < *lo {
534                *lo = val;
535            }
536            if val > *hi {
537                *hi = val;
538            }
539
540            // Initialize uniform bin boundaries after enough samples.
541            if stats.bin_boundaries[f_idx].is_empty() && stats.n_samples >= half_grace {
542                let lo_val = stats.feature_ranges[f_idx].0;
543                let hi_val = stats.feature_ranges[f_idx].1;
544                if math::abs(hi_val - lo_val) > 1e-15 {
545                    let boundaries: Vec<f64> = (1..n_bins)
546                        .map(|i| lo_val + (hi_val - lo_val) * (i as f64) / (n_bins as f64))
547                        .collect();
548                    stats.bin_boundaries[f_idx] = boundaries;
549                }
550            }
551
552            // If boundaries are available, update the histogram.
553            if !stats.bin_boundaries[f_idx].is_empty() {
554                let bin = find_bin(&stats.bin_boundaries[f_idx], val);
555                if bin < stats.feature_histograms[f_idx].len() {
556                    stats.feature_histograms[f_idx][bin][class] += 1;
557                }
558            }
559        }
560    }
561
562    /// Attempt to split the leaf at `leaf_idx` using information gain + Hoeffding bound.
563    fn try_split(&mut self, leaf_idx: usize) {
564        let n_classes = self.n_classes;
565
566        // Compute parent entropy.
567        let node = &self.nodes[leaf_idx];
568        let stats = match node.leaf_stats.as_ref() {
569            Some(s) => s,
570            None => return,
571        };
572        let parent_entropy = entropy(&stats.class_counts);
573        let n_total = stats.n_samples as f64;
574        if n_total < 1.0 {
575            return;
576        }
577
578        // Find best and second-best information gain across all features.
579        let mut best_gain = f64::NEG_INFINITY;
580        let mut second_best_gain = f64::NEG_INFINITY;
581        let mut best_feature = 0usize;
582        let mut best_bin = 0usize;
583
584        for f_idx in 0..self.n_features {
585            if stats.bin_boundaries[f_idx].is_empty() {
586                continue;
587            }
588            let n_bins_actual = stats.feature_histograms[f_idx].len();
589            // Try each bin boundary as a split point.
590            for b in 0..n_bins_actual.saturating_sub(1) {
591                // Accumulate left counts (bins 0..=b) and right counts (bins b+1..end).
592                let mut left_counts = vec![0u64; n_classes];
593                let mut right_counts = vec![0u64; n_classes];
594                let mut n_left = 0u64;
595                let mut n_right = 0u64;
596
597                for bin_idx in 0..n_bins_actual {
598                    let bin_counts = &stats.feature_histograms[f_idx][bin_idx];
599                    for c in 0..n_classes.min(bin_counts.len()) {
600                        if bin_idx <= b {
601                            left_counts[c] += bin_counts[c];
602                            n_left += bin_counts[c];
603                        } else {
604                            right_counts[c] += bin_counts[c];
605                            n_right += bin_counts[c];
606                        }
607                    }
608                }
609
610                if n_left == 0 || n_right == 0 {
611                    continue;
612                }
613
614                let n_split = (n_left + n_right) as f64;
615                let left_entropy = entropy(&left_counts);
616                let right_entropy = entropy(&right_counts);
617                let weighted_child_entropy = (n_left as f64 / n_split) * left_entropy
618                    + (n_right as f64 / n_split) * right_entropy;
619                let gain = parent_entropy - weighted_child_entropy;
620
621                if gain > best_gain {
622                    second_best_gain = best_gain;
623                    best_gain = gain;
624                    best_feature = f_idx;
625                    best_bin = b;
626                } else if gain > second_best_gain {
627                    second_best_gain = gain;
628                }
629            }
630        }
631
632        // Nothing to split on.
633        if best_gain <= 0.0 {
634            return;
635        }
636
637        // Compute Hoeffding bound.
638        let r = if n_classes > 1 {
639            math::log2(n_classes as f64)
640        } else {
641            1.0
642        };
643        let epsilon = math::sqrt(r * r * math::ln(1.0 / self.config.delta) / (2.0 * n_total));
644
645        // Check if the best split is statistically significantly better.
646        let delta_g = best_gain - second_best_gain.max(0.0);
647        if delta_g <= epsilon && epsilon >= TAU {
648            return; // Not enough evidence yet.
649        }
650
651        // Commit the split.
652        let stats = self.nodes[leaf_idx].leaf_stats.as_ref().unwrap();
653        let threshold = if best_bin < stats.bin_boundaries[best_feature].len() {
654            stats.bin_boundaries[best_feature][best_bin]
655        } else {
656            // Fallback: midpoint of feature range.
657            let (lo, hi) = stats.feature_ranges[best_feature];
658            (lo + hi) / 2.0
659        };
660
661        let depth = self.nodes[leaf_idx].depth;
662        let n_bins = self.config.n_bins;
663
664        // Build left and right child class counts from the histogram.
665        let stats = self.nodes[leaf_idx].leaf_stats.as_ref().unwrap();
666        let n_bins_actual = stats.feature_histograms[best_feature].len();
667        let mut left_class_counts = vec![0u64; n_classes];
668        let mut right_class_counts = vec![0u64; n_classes];
669        let mut n_left = 0u64;
670        let mut n_right = 0u64;
671
672        for bin_idx in 0..n_bins_actual {
673            let bin_counts = &stats.feature_histograms[best_feature][bin_idx];
674            for c in 0..n_classes.min(bin_counts.len()) {
675                if bin_idx <= best_bin {
676                    left_class_counts[c] += bin_counts[c];
677                    n_left += bin_counts[c];
678                } else {
679                    right_class_counts[c] += bin_counts[c];
680                    n_right += bin_counts[c];
681                }
682            }
683        }
684
685        // Create child nodes.
686        let mut left_node = ClassifierNode::new_leaf(depth + 1, self.n_features, n_bins, n_classes);
687        left_node.class_counts = left_class_counts;
688        left_node.n_samples = n_left;
689
690        let mut right_node =
691            ClassifierNode::new_leaf(depth + 1, self.n_features, n_bins, n_classes);
692        right_node.class_counts = right_class_counts;
693        right_node.n_samples = n_right;
694
695        let left_idx = self.nodes.len();
696        let right_idx = left_idx + 1;
697        self.nodes.push(left_node);
698        self.nodes.push(right_node);
699
700        // Convert the current leaf into a split node.
701        let node = &mut self.nodes[leaf_idx];
702        node.split_feature = Some(best_feature);
703        node.split_threshold = Some(threshold);
704        node.left = Some(left_idx);
705        node.right = Some(right_idx);
706        node.leaf_stats = None; // Free leaf memory.
707    }
708}
709
710// ---------------------------------------------------------------------------
711// StreamingLearner impl
712// ---------------------------------------------------------------------------
713
714impl StreamingLearner for HoeffdingTreeClassifier {
715    /// Train on a single observation.
716    ///
717    /// The `target` is cast to `usize` for the class label. Weight is currently
718    /// unused (all samples contribute equally).
719    #[inline]
720    fn train_one(&mut self, features: &[f64], target: f64, _weight: f64) {
721        HoeffdingTreeClassifier::train_one(self, features, target as usize);
722    }
723
724    /// Predict the majority class as a floating-point value.
725    #[inline]
726    fn predict(&self, features: &[f64]) -> f64 {
727        self.predict_class(features) as f64
728    }
729
730    #[inline]
731    fn n_samples_seen(&self) -> u64 {
732        self.n_samples
733    }
734
735    #[inline]
736    fn reset(&mut self) {
737        HoeffdingTreeClassifier::reset(self);
738    }
739}
740
741// ---------------------------------------------------------------------------
742// Free functions
743// ---------------------------------------------------------------------------
744
745/// Find the bin index for value `x` using binary search on sorted boundaries.
746///
747/// Returns the index of the first boundary that is >= x, clamped to `[0, n_bins-1]`.
748/// This places values below the first boundary into bin 0, values between
749/// boundary[i-1] and boundary[i] into bin i, and values above the last
750/// boundary into the last bin.
751#[inline]
752fn find_bin(boundaries: &[f64], x: f64) -> usize {
753    match boundaries.binary_search_by(|b| b.partial_cmp(&x).unwrap_or(core::cmp::Ordering::Equal)) {
754        Ok(i) => i,
755        Err(i) => i,
756    }
757}
758
759/// Compute the Shannon entropy (base 2) of a class count distribution.
760///
761/// Returns 0.0 for empty or single-class distributions.
762fn entropy(counts: &[u64]) -> f64 {
763    let total: u64 = counts.iter().sum();
764    if total == 0 {
765        return 0.0;
766    }
767    let total_f = total as f64;
768    let mut h = 0.0;
769    for &c in counts {
770        if c > 0 {
771            let p = c as f64 / total_f;
772            h -= p * math::log2(p);
773        }
774    }
775    h
776}
777
778/// Return the index of the class with the highest count.
779///
780/// Ties are broken by lowest index. Returns 0 if all counts are zero.
781fn majority_class(counts: &[u64]) -> usize {
782    counts
783        .iter()
784        .enumerate()
785        .max_by_key(|&(_, &c)| c)
786        .map(|(i, _)| i)
787        .unwrap_or(0)
788}
789
790/// Convert class counts to a probability distribution.
791///
792/// Returns a vector of probabilities summing to 1.0. If total is zero,
793/// returns uniform zeros.
794fn class_probabilities(counts: &[u64]) -> Vec<f64> {
795    let total: u64 = counts.iter().sum();
796    if total == 0 {
797        return vec![0.0; counts.len().max(1)];
798    }
799    let total_f = total as f64;
800    counts.iter().map(|&c| c as f64 / total_f).collect()
801}
802
803// ---------------------------------------------------------------------------
804// Tests
805// ---------------------------------------------------------------------------
806
807#[cfg(test)]
808mod tests {
809    use super::*;
810
811    /// Minimal xorshift64 PRNG for deterministic test data generation.
812    fn xorshift64(state: &mut u64) -> f64 {
813        let mut x = *state;
814        x ^= x << 13;
815        x ^= x >> 7;
816        x ^= x << 17;
817        *state = x;
818        (x as f64) / (u64::MAX as f64)
819    }
820
821    /// Shared config for most tests: small grace period for fast splitting.
822    fn test_config() -> HoeffdingClassifierConfig {
823        HoeffdingClassifierConfig::builder()
824            .max_depth(6)
825            .delta(1e-5)
826            .grace_period(50)
827            .n_bins(16)
828            .build()
829            .unwrap()
830    }
831
832    #[test]
833    fn single_sample_creates_root_leaf() {
834        let config = test_config();
835        let mut tree = HoeffdingTreeClassifier::new(config);
836
837        // Before any training, no nodes exist (lazy init with n_features=0).
838        assert_eq!(tree.n_nodes(), 0);
839
840        tree.train_one(&[1.0, 2.0, 3.0], 0);
841
842        // After one sample, a single root leaf should exist.
843        assert_eq!(tree.n_nodes(), 1);
844        assert_eq!(tree.n_leaves(), 1);
845        assert_eq!(tree.n_samples_seen(), 1);
846        assert_eq!(tree.max_depth_seen(), 0);
847    }
848
849    #[test]
850    fn predict_class_returns_majority() {
851        let config = test_config();
852        let mut tree = HoeffdingTreeClassifier::new(config);
853
854        // Train mostly class 0.
855        for _ in 0..30 {
856            tree.train_one(&[1.0, 2.0], 0);
857        }
858        for _ in 0..5 {
859            tree.train_one(&[1.0, 2.0], 1);
860        }
861
862        let predicted = tree.predict_class(&[1.0, 2.0]);
863        assert_eq!(predicted, 0, "expected majority class 0, got {}", predicted);
864    }
865
866    #[test]
867    fn predict_proba_sums_to_one() {
868        let config = test_config();
869        let mut tree = HoeffdingTreeClassifier::new(config);
870
871        // Train with multiple classes.
872        for _ in 0..20 {
873            tree.train_one(&[1.0, 2.0], 0);
874        }
875        for _ in 0..10 {
876            tree.train_one(&[1.0, 2.0], 1);
877        }
878        for _ in 0..5 {
879            tree.train_one(&[1.0, 2.0], 2);
880        }
881
882        let proba = tree.predict_proba(&[1.0, 2.0]);
883        let sum: f64 = proba.iter().sum();
884        assert!(
885            (sum - 1.0).abs() < 1e-10,
886            "probabilities should sum to 1.0, got {}",
887            sum
888        );
889
890        // Check that all probabilities are non-negative.
891        for (i, &p) in proba.iter().enumerate() {
892            assert!(p >= 0.0, "probability for class {} is negative: {}", i, p);
893        }
894    }
895
896    #[test]
897    fn tree_splits_on_separable_data() {
898        let config = HoeffdingClassifierConfig::builder()
899            .max_depth(6)
900            .delta(1e-3)
901            .grace_period(50)
902            .n_bins(16)
903            .build()
904            .unwrap();
905        let mut tree = HoeffdingTreeClassifier::new(config);
906
907        // Generate clearly separable data: class 0 if x[0] < 5, class 1 otherwise.
908        let mut rng_state: u64 = 42;
909        for _ in 0..1000 {
910            let x0 = xorshift64(&mut rng_state) * 10.0;
911            let x1 = xorshift64(&mut rng_state) * 10.0; // noise feature
912            let class = if x0 < 5.0 { 0 } else { 1 };
913            tree.train_one(&[x0, x1], class);
914        }
915
916        // The tree should have split at least once.
917        assert!(
918            tree.n_nodes() > 1,
919            "expected tree to split, but has only {} node(s)",
920            tree.n_nodes()
921        );
922        assert!(
923            tree.n_leaves() >= 2,
924            "expected at least 2 leaves, got {}",
925            tree.n_leaves()
926        );
927
928        // Verify predictions on clearly separated points.
929        assert_eq!(
930            tree.predict_class(&[1.0, 5.0]),
931            0,
932            "expected class 0 for x[0]=1.0"
933        );
934        assert_eq!(
935            tree.predict_class(&[9.0, 5.0]),
936            1,
937            "expected class 1 for x[0]=9.0"
938        );
939    }
940
941    #[test]
942    fn max_depth_limits_growth() {
943        let config = HoeffdingClassifierConfig::builder()
944            .max_depth(2)
945            .delta(1e-3)
946            .grace_period(30)
947            .n_bins(16)
948            .build()
949            .unwrap();
950        let mut tree = HoeffdingTreeClassifier::new(config);
951
952        // Train on separable data to encourage splitting.
953        let mut rng_state: u64 = 123;
954        for _ in 0..5000 {
955            let x0 = xorshift64(&mut rng_state) * 10.0;
956            let x1 = xorshift64(&mut rng_state) * 10.0;
957            let class = if x0 < 3.0 {
958                0
959            } else if x0 < 6.0 {
960                1
961            } else {
962                2
963            };
964            tree.train_one(&[x0, x1], class);
965        }
966
967        // Tree should respect max_depth = 2.
968        assert!(
969            tree.max_depth_seen() <= 2,
970            "max depth should be <= 2, got {}",
971            tree.max_depth_seen()
972        );
973    }
974
975    #[test]
976    fn streaming_learner_trait_works() {
977        let config = test_config();
978        let mut tree = HoeffdingTreeClassifier::new(config);
979
980        // Train through the StreamingLearner interface.
981        let learner: &mut dyn StreamingLearner = &mut tree;
982        learner.train(&[1.0, 2.0], 0.0);
983        learner.train(&[3.0, 4.0], 1.0);
984
985        assert_eq!(learner.n_samples_seen(), 2);
986
987        let pred = learner.predict(&[1.0, 2.0]);
988        assert!(pred.is_finite(), "prediction should be finite");
989        assert!(
990            pred == 0.0 || pred == 1.0,
991            "prediction should be a class label, got {}",
992            pred
993        );
994
995        learner.reset();
996        assert_eq!(learner.n_samples_seen(), 0);
997    }
998
999    #[test]
1000    fn reset_clears_state() {
1001        let config = test_config();
1002        let mut tree = HoeffdingTreeClassifier::new(config);
1003
1004        // Train enough to potentially split.
1005        for i in 0..200 {
1006            let class = if i % 2 == 0 { 0 } else { 1 };
1007            tree.train_one(&[i as f64, (i as f64) * 0.5], class);
1008        }
1009        assert_eq!(tree.n_samples_seen(), 200);
1010        assert!(tree.n_nodes() >= 1);
1011
1012        tree.reset();
1013
1014        assert_eq!(tree.n_samples_seen(), 0);
1015        // After reset with n_features=0 (lazy init), nodes should be empty.
1016        assert_eq!(tree.n_nodes(), 0);
1017    }
1018
1019    #[test]
1020    fn auto_discovers_classes() {
1021        let config = HoeffdingClassifierConfig::builder()
1022            .max_depth(4)
1023            .delta(1e-5)
1024            .grace_period(50)
1025            .n_bins(16)
1026            .max_classes(0) // auto-discover
1027            .build()
1028            .unwrap();
1029        let mut tree = HoeffdingTreeClassifier::new(config);
1030
1031        // Start with class 0.
1032        tree.train_one(&[1.0], 0);
1033        assert!(
1034            tree.n_classes() >= 2,
1035            "should have at least 2 classes after first sample"
1036        );
1037
1038        // Introduce class 3 (skipping 1 and 2).
1039        tree.train_one(&[2.0], 3);
1040        assert!(
1041            tree.n_classes() >= 4,
1042            "should have at least 4 classes after seeing class 3, got {}",
1043            tree.n_classes()
1044        );
1045
1046        // Verify probabilities reflect all discovered classes.
1047        let proba = tree.predict_proba(&[1.5]);
1048        assert_eq!(
1049            proba.len(),
1050            tree.n_classes(),
1051            "proba length should match n_classes"
1052        );
1053    }
1054
1055    #[test]
1056    fn config_builder_validates() {
1057        // max_depth = 0 should fail.
1058        let result = HoeffdingClassifierConfig::builder().max_depth(0).build();
1059        assert!(result.is_err(), "max_depth=0 should be rejected");
1060
1061        // delta out of range should fail.
1062        let result = HoeffdingClassifierConfig::builder().delta(0.0).build();
1063        assert!(result.is_err(), "delta=0.0 should be rejected");
1064
1065        let result = HoeffdingClassifierConfig::builder().delta(1.0).build();
1066        assert!(result.is_err(), "delta=1.0 should be rejected");
1067
1068        let result = HoeffdingClassifierConfig::builder().delta(-0.5).build();
1069        assert!(result.is_err(), "delta=-0.5 should be rejected");
1070
1071        // grace_period = 0 should fail.
1072        let result = HoeffdingClassifierConfig::builder().grace_period(0).build();
1073        assert!(result.is_err(), "grace_period=0 should be rejected");
1074
1075        // n_bins < 2 should fail.
1076        let result = HoeffdingClassifierConfig::builder().n_bins(1).build();
1077        assert!(result.is_err(), "n_bins=1 should be rejected");
1078
1079        // Valid config should succeed.
1080        let result = HoeffdingClassifierConfig::builder()
1081            .max_depth(5)
1082            .delta(0.01)
1083            .grace_period(100)
1084            .n_bins(8)
1085            .build();
1086        assert!(result.is_ok(), "valid config should build successfully");
1087    }
1088}