Skip to main content

irithyll_core/tree/
node.rs

1//! SoA arena-allocated node storage for streaming trees.
2//!
3//! Instead of per-node heap allocations, all node fields are stored in parallel
4//! vectors indexed by [`NodeId`]. This layout is cache-friendly for batch
5//! traversal and avoids pointer-chasing overhead inherent to linked structures.
6//!
7//! # Layout
8//!
9//! Every node occupies the same index across all vectors. Internal nodes use
10//! `feature_idx`, `threshold`, `left`, and `right`. Leaf nodes use `leaf_value`.
11//! The `is_leaf` flag discriminates between the two.
12
13use alloc::vec::Vec;
14
15/// Index into the [`TreeArena`].
16///
17/// A thin wrapper around `u32`. The sentinel value [`NodeId::NONE`] (backed by
18/// `u32::MAX`) represents an absent or unset node reference.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct NodeId(pub u32);
21
22impl NodeId {
23    /// Sentinel value indicating "no node".
24    pub const NONE: NodeId = NodeId(u32::MAX);
25
26    /// Returns `true` if this is the sentinel [`NONE`](Self::NONE) value.
27    #[inline]
28    pub fn is_none(self) -> bool {
29        self.0 == u32::MAX
30    }
31
32    /// Convert to `usize` for indexing into the arena vectors.
33    #[inline]
34    pub fn idx(self) -> usize {
35        self.0 as usize
36    }
37}
38
39/// Structure-of-Arrays arena storage for tree nodes.
40///
41/// All node properties are stored in parallel vectors for cache efficiency.
42/// Internal nodes have `feature_idx` + `threshold` + children.
43/// Leaf nodes have `leaf_value`.
44///
45/// # Invariants
46///
47/// - All vectors have the same length at all times.
48/// - A leaf node has `is_leaf[id] == true`, `left[id] == NONE`, `right[id] == NONE`.
49/// - An internal node has `is_leaf[id] == false` and valid `left`/`right` children.
50/// - `depth[id]` is set at allocation time and never changes.
51#[derive(Debug, Clone)]
52pub struct TreeArena {
53    /// Feature index used for splitting (only meaningful for internal nodes).
54    pub feature_idx: Vec<u32>,
55    /// Split threshold (samples with feature <= threshold go left).
56    pub threshold: Vec<f64>,
57    /// Left child [`NodeId`] (`NONE` for leaves).
58    pub left: Vec<NodeId>,
59    /// Right child [`NodeId`] (`NONE` for leaves).
60    pub right: Vec<NodeId>,
61    /// Leaf prediction value.
62    pub leaf_value: Vec<f64>,
63    /// Whether this node is a leaf.
64    pub is_leaf: Vec<bool>,
65    /// Depth of this node in the tree (root = 0).
66    pub depth: Vec<u16>,
67    /// Number of samples routed through this node.
68    pub sample_count: Vec<u64>,
69    /// Categorical split bitmask. For categorical splits, bit `i` set means
70    /// category `i` routes left. `None` for continuous splits (threshold-based).
71    pub categorical_mask: Vec<Option<u64>>,
72}
73
74impl TreeArena {
75    /// Create an empty arena with no pre-allocated capacity.
76    pub fn new() -> Self {
77        Self {
78            feature_idx: Vec::new(),
79            threshold: Vec::new(),
80            left: Vec::new(),
81            right: Vec::new(),
82            leaf_value: Vec::new(),
83            is_leaf: Vec::new(),
84            depth: Vec::new(),
85            sample_count: Vec::new(),
86            categorical_mask: Vec::new(),
87        }
88    }
89
90    /// Create an empty arena with pre-allocated capacity for `cap` nodes.
91    ///
92    /// This avoids reallocation when the maximum tree size is known upfront.
93    pub fn with_capacity(cap: usize) -> Self {
94        Self {
95            feature_idx: Vec::with_capacity(cap),
96            threshold: Vec::with_capacity(cap),
97            left: Vec::with_capacity(cap),
98            right: Vec::with_capacity(cap),
99            leaf_value: Vec::with_capacity(cap),
100            is_leaf: Vec::with_capacity(cap),
101            depth: Vec::with_capacity(cap),
102            sample_count: Vec::with_capacity(cap),
103            categorical_mask: Vec::with_capacity(cap),
104        }
105    }
106
107    /// Allocate a new leaf node with value 0.0 at the given depth.
108    ///
109    /// Returns the [`NodeId`] of the newly created leaf.
110    pub fn add_leaf(&mut self, depth: u16) -> NodeId {
111        let id = self.feature_idx.len() as u32;
112        self.feature_idx.push(0);
113        self.threshold.push(0.0);
114        self.left.push(NodeId::NONE);
115        self.right.push(NodeId::NONE);
116        self.leaf_value.push(0.0);
117        self.is_leaf.push(true);
118        self.depth.push(depth);
119        self.sample_count.push(0);
120        self.categorical_mask.push(None);
121        NodeId(id)
122    }
123
124    /// Convert a leaf node into an internal (split) node, creating two new leaf
125    /// children.
126    ///
127    /// The parent's `is_leaf` flag is cleared and its `feature_idx`/`threshold`
128    /// are set. Two fresh leaf nodes are allocated at `depth = parent.depth + 1`
129    /// with the specified initial values.
130    ///
131    /// # Panics
132    ///
133    /// Panics if `leaf_id` does not reference a current leaf node.
134    ///
135    /// # Returns
136    ///
137    /// `(left_id, right_id)` -- the [`NodeId`]s of the two new children.
138    pub fn split_leaf(
139        &mut self,
140        leaf_id: NodeId,
141        feature_idx: u32,
142        threshold: f64,
143        left_value: f64,
144        right_value: f64,
145    ) -> (NodeId, NodeId) {
146        let i = leaf_id.idx();
147        assert!(
148            self.is_leaf[i],
149            "split_leaf called on non-leaf node {:?}",
150            leaf_id
151        );
152
153        let child_depth = self.depth[i] + 1;
154
155        // Allocate left child.
156        let left_id = self.add_leaf(child_depth);
157        self.leaf_value[left_id.idx()] = left_value;
158
159        // Allocate right child.
160        let right_id = self.add_leaf(child_depth);
161        self.leaf_value[right_id.idx()] = right_value;
162
163        // Convert the parent from a leaf into an internal node.
164        self.is_leaf[i] = false;
165        self.feature_idx[i] = feature_idx;
166        self.threshold[i] = threshold;
167        self.left[i] = left_id;
168        self.right[i] = right_id;
169
170        (left_id, right_id)
171    }
172
173    /// Split a leaf using a categorical bitmask instead of a threshold.
174    ///
175    /// The `mask` is a `u64` where bit `i` set means category `i` goes left.
176    /// The `threshold` is still stored as the midpoint of the split partition
177    /// for backward compatibility, but routing uses the bitmask.
178    ///
179    /// # Panics
180    ///
181    /// Panics if `leaf_id` does not reference a current leaf node.
182    pub fn split_leaf_categorical(
183        &mut self,
184        leaf_id: NodeId,
185        feature_idx: u32,
186        threshold: f64,
187        left_value: f64,
188        right_value: f64,
189        mask: u64,
190    ) -> (NodeId, NodeId) {
191        let (left_id, right_id) =
192            self.split_leaf(leaf_id, feature_idx, threshold, left_value, right_value);
193        self.categorical_mask[leaf_id.idx()] = Some(mask);
194        (left_id, right_id)
195    }
196
197    /// Return the categorical bitmask for a split node, if it's a categorical split.
198    #[inline]
199    pub fn get_categorical_mask(&self, id: NodeId) -> Option<u64> {
200        self.categorical_mask[id.idx()]
201    }
202
203    /// Returns `true` if the node at `id` is a leaf.
204    #[inline]
205    pub fn is_leaf(&self, id: NodeId) -> bool {
206        self.is_leaf[id.idx()]
207    }
208
209    /// Return the leaf prediction value.
210    ///
211    /// # Panics
212    ///
213    /// Panics if `id` references an internal (non-leaf) node.
214    #[inline]
215    pub fn predict(&self, id: NodeId) -> f64 {
216        let i = id.idx();
217        assert!(self.is_leaf[i], "predict called on internal node {:?}", id);
218        self.leaf_value[i]
219    }
220
221    /// Set the leaf prediction value for a leaf node.
222    ///
223    /// # Panics
224    ///
225    /// Panics if `id` does not reference a leaf node.
226    #[inline]
227    pub fn set_leaf_value(&mut self, id: NodeId, value: f64) {
228        let i = id.idx();
229        assert!(
230            self.is_leaf[i],
231            "set_leaf_value called on internal node {:?}",
232            id
233        );
234        self.leaf_value[i] = value;
235    }
236
237    /// Return the depth of the node.
238    #[inline]
239    pub fn get_depth(&self, id: NodeId) -> u16 {
240        self.depth[id.idx()]
241    }
242
243    /// Return the feature index used for splitting at this node.
244    #[inline]
245    pub fn get_feature_idx(&self, id: NodeId) -> u32 {
246        self.feature_idx[id.idx()]
247    }
248
249    /// Return the split threshold for this node.
250    #[inline]
251    pub fn get_threshold(&self, id: NodeId) -> f64 {
252        self.threshold[id.idx()]
253    }
254
255    /// Return the left child [`NodeId`].
256    #[inline]
257    pub fn get_left(&self, id: NodeId) -> NodeId {
258        self.left[id.idx()]
259    }
260
261    /// Return the right child [`NodeId`].
262    #[inline]
263    pub fn get_right(&self, id: NodeId) -> NodeId {
264        self.right[id.idx()]
265    }
266
267    /// Return the sample count for this node.
268    #[inline]
269    pub fn get_sample_count(&self, id: NodeId) -> u64 {
270        self.sample_count[id.idx()]
271    }
272
273    /// Increment the sample count for this node by one.
274    #[inline]
275    pub fn increment_sample_count(&mut self, id: NodeId) {
276        self.sample_count[id.idx()] += 1;
277    }
278
279    /// Total number of allocated nodes (internal + leaf).
280    #[inline]
281    pub fn n_nodes(&self) -> usize {
282        self.is_leaf.len()
283    }
284
285    /// Count of leaf nodes currently in the arena.
286    pub fn n_leaves(&self) -> usize {
287        self.is_leaf.iter().filter(|&&b| b).count()
288    }
289
290    /// Clear all storage, returning the arena to an empty state.
291    ///
292    /// Allocated memory is retained (capacity is preserved) so the arena can
293    /// be reused without reallocating.
294    pub fn reset(&mut self) {
295        self.feature_idx.clear();
296        self.threshold.clear();
297        self.left.clear();
298        self.right.clear();
299        self.leaf_value.clear();
300        self.is_leaf.clear();
301        self.depth.clear();
302        self.sample_count.clear();
303        self.categorical_mask.clear();
304    }
305}
306
307impl Default for TreeArena {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    /// A freshly allocated leaf should be marked as a leaf with value 0.0,
318    /// depth as requested, zero sample count, and NONE children.
319    #[test]
320    fn single_leaf() {
321        let mut arena = TreeArena::new();
322        let root = arena.add_leaf(0);
323
324        assert_eq!(root, NodeId(0));
325        assert!(arena.is_leaf(root));
326        assert_eq!(arena.predict(root), 0.0);
327        assert_eq!(arena.get_depth(root), 0);
328        assert_eq!(arena.get_sample_count(root), 0);
329        assert_eq!(arena.get_left(root), NodeId::NONE);
330        assert_eq!(arena.get_right(root), NodeId::NONE);
331    }
332
333    /// Splitting a leaf should convert the parent to an internal node and
334    /// create two leaf children at depth + 1 with the specified values.
335    #[test]
336    fn split_leaf_basic() {
337        let mut arena = TreeArena::new();
338        let root = arena.add_leaf(0);
339
340        let (left, right) = arena.split_leaf(root, 3, 1.5, -0.25, 0.75);
341
342        // Parent is no longer a leaf.
343        assert!(!arena.is_leaf(root));
344        assert_eq!(arena.get_feature_idx(root), 3);
345        assert_eq!(arena.get_threshold(root), 1.5);
346        assert_eq!(arena.get_left(root), left);
347        assert_eq!(arena.get_right(root), right);
348
349        // Children are leaves at depth 1 with the specified values.
350        assert!(arena.is_leaf(left));
351        assert_eq!(arena.predict(left), -0.25);
352        assert_eq!(arena.get_depth(left), 1);
353
354        assert!(arena.is_leaf(right));
355        assert_eq!(arena.predict(right), 0.75);
356        assert_eq!(arena.get_depth(right), 1);
357    }
358
359    /// Splitting a child node should grow the tree to three levels (depths 0, 1, 2).
360    #[test]
361    fn split_child_three_levels() {
362        let mut arena = TreeArena::new();
363        let root = arena.add_leaf(0);
364
365        // Level 1: split root.
366        let (left, right) = arena.split_leaf(root, 0, 5.0, 0.0, 0.0);
367
368        // Level 2: split the left child.
369        let (ll, lr) = arena.split_leaf(left, 1, 2.0, -1.0, 1.0);
370
371        // Root is internal.
372        assert!(!arena.is_leaf(root));
373        assert_eq!(arena.get_depth(root), 0);
374
375        // Left child is now internal too.
376        assert!(!arena.is_leaf(left));
377        assert_eq!(arena.get_depth(left), 1);
378        assert_eq!(arena.get_feature_idx(left), 1);
379        assert_eq!(arena.get_threshold(left), 2.0);
380        assert_eq!(arena.get_left(left), ll);
381        assert_eq!(arena.get_right(left), lr);
382
383        // Right child from the first split is still a leaf.
384        assert!(arena.is_leaf(right));
385        assert_eq!(arena.get_depth(right), 1);
386
387        // Grandchildren are leaves at depth 2.
388        assert!(arena.is_leaf(ll));
389        assert_eq!(arena.get_depth(ll), 2);
390        assert_eq!(arena.predict(ll), -1.0);
391
392        assert!(arena.is_leaf(lr));
393        assert_eq!(arena.get_depth(lr), 2);
394        assert_eq!(arena.predict(lr), 1.0);
395    }
396
397    /// `n_nodes` and `n_leaves` should track allocations and splits correctly.
398    #[test]
399    fn node_and_leaf_counting() {
400        let mut arena = TreeArena::new();
401
402        // Empty arena.
403        assert_eq!(arena.n_nodes(), 0);
404        assert_eq!(arena.n_leaves(), 0);
405
406        // Single root leaf.
407        let root = arena.add_leaf(0);
408        assert_eq!(arena.n_nodes(), 1);
409        assert_eq!(arena.n_leaves(), 1);
410
411        // Split root -> 3 nodes total (1 internal + 2 leaves).
412        let (_left, right) = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
413        assert_eq!(arena.n_nodes(), 3);
414        assert_eq!(arena.n_leaves(), 2);
415
416        // Split right child -> 5 nodes total (2 internal + 3 leaves).
417        let _ = arena.split_leaf(right, 1, 2.0, 0.0, 0.0);
418        assert_eq!(arena.n_nodes(), 5);
419        assert_eq!(arena.n_leaves(), 3);
420    }
421
422    /// `NodeId::NONE` should report `is_none() == true` and should not collide
423    /// with valid node indices in any reasonably-sized arena.
424    #[test]
425    fn node_id_none_sentinel() {
426        let none = NodeId::NONE;
427        assert!(none.is_none());
428        assert_eq!(none.0, u32::MAX);
429
430        let valid = NodeId(0);
431        assert!(!valid.is_none());
432        assert_ne!(valid, NodeId::NONE);
433    }
434
435    /// `reset` should clear all vectors, bringing node/leaf counts to zero,
436    /// while preserving allocated capacity for reuse.
437    #[test]
438    fn reset_clears_everything() {
439        let mut arena = TreeArena::new();
440        let root = arena.add_leaf(0);
441        let _ = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
442
443        assert_eq!(arena.n_nodes(), 3);
444
445        arena.reset();
446
447        assert_eq!(arena.n_nodes(), 0);
448        assert_eq!(arena.n_leaves(), 0);
449
450        // Capacity should be preserved (at least 3 from prior usage).
451        assert!(arena.feature_idx.capacity() >= 3);
452        assert!(arena.is_leaf.capacity() >= 3);
453
454        // Should be able to reuse the arena normally after reset.
455        let new_root = arena.add_leaf(0);
456        assert_eq!(new_root, NodeId(0));
457        assert_eq!(arena.n_nodes(), 1);
458        assert_eq!(arena.n_leaves(), 1);
459    }
460
461    /// Sample count starts at zero and increments correctly per node.
462    #[test]
463    fn sample_count_tracking() {
464        let mut arena = TreeArena::new();
465        let root = arena.add_leaf(0);
466
467        assert_eq!(arena.get_sample_count(root), 0);
468
469        arena.increment_sample_count(root);
470        assert_eq!(arena.get_sample_count(root), 1);
471
472        arena.increment_sample_count(root);
473        arena.increment_sample_count(root);
474        assert_eq!(arena.get_sample_count(root), 3);
475
476        // Split and verify children start at zero independently.
477        let (left, right) = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
478        assert_eq!(arena.get_sample_count(left), 0);
479        assert_eq!(arena.get_sample_count(right), 0);
480
481        // Parent retains its count after split.
482        assert_eq!(arena.get_sample_count(root), 3);
483
484        arena.increment_sample_count(left);
485        assert_eq!(arena.get_sample_count(left), 1);
486        assert_eq!(arena.get_sample_count(right), 0);
487    }
488
489    /// `with_capacity` should pre-allocate without adding any nodes.
490    #[test]
491    fn with_capacity_preallocates() {
492        let arena = TreeArena::with_capacity(64);
493
494        // No nodes allocated yet.
495        assert_eq!(arena.n_nodes(), 0);
496        assert_eq!(arena.n_leaves(), 0);
497
498        // But capacity is reserved.
499        assert!(arena.feature_idx.capacity() >= 64);
500        assert!(arena.threshold.capacity() >= 64);
501        assert!(arena.left.capacity() >= 64);
502        assert!(arena.right.capacity() >= 64);
503        assert!(arena.leaf_value.capacity() >= 64);
504        assert!(arena.is_leaf.capacity() >= 64);
505        assert!(arena.depth.capacity() >= 64);
506        assert!(arena.sample_count.capacity() >= 64);
507    }
508
509    /// `set_leaf_value` should update the prediction value of a leaf.
510    #[test]
511    fn set_leaf_value_updates() {
512        let mut arena = TreeArena::new();
513        let leaf = arena.add_leaf(0);
514
515        assert_eq!(arena.predict(leaf), 0.0);
516
517        arena.set_leaf_value(leaf, 42.5);
518        assert_eq!(arena.predict(leaf), 42.5);
519
520        arena.set_leaf_value(leaf, -3.25);
521        assert_eq!(arena.predict(leaf), -3.25);
522    }
523
524    /// Calling `predict` on an internal node should panic.
525    #[test]
526    #[should_panic(expected = "predict called on internal node")]
527    fn predict_panics_on_internal_node() {
528        let mut arena = TreeArena::new();
529        let root = arena.add_leaf(0);
530        let _ = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
531
532        // root is now internal -- this should panic.
533        let _ = arena.predict(root);
534    }
535
536    /// Calling `set_leaf_value` on an internal node should panic.
537    #[test]
538    #[should_panic(expected = "set_leaf_value called on internal node")]
539    fn set_leaf_value_panics_on_internal_node() {
540        let mut arena = TreeArena::new();
541        let root = arena.add_leaf(0);
542        let _ = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
543
544        // root is now internal -- this should panic.
545        arena.set_leaf_value(root, 1.0);
546    }
547
548    /// Calling `split_leaf` on an internal node should panic.
549    #[test]
550    #[should_panic(expected = "split_leaf called on non-leaf node")]
551    fn split_leaf_panics_on_internal_node() {
552        let mut arena = TreeArena::new();
553        let root = arena.add_leaf(0);
554        let _ = arena.split_leaf(root, 0, 1.0, 0.0, 0.0);
555
556        // root is already internal -- splitting again should panic.
557        let _ = arena.split_leaf(root, 1, 2.0, 0.0, 0.0);
558    }
559
560    /// Default trait implementation should produce the same result as `new()`.
561    #[test]
562    fn default_matches_new() {
563        let a = TreeArena::new();
564        let b = TreeArena::default();
565
566        assert_eq!(a.n_nodes(), b.n_nodes());
567        assert_eq!(a.n_leaves(), b.n_leaves());
568    }
569}