Skip to main content

irithyll_core/tree/
predict.rs

1//! Cache-optimized tree traversal for prediction.
2//!
3//! Standalone functions that operate on a [`TreeArena`] for traversal, prediction,
4//! and structural queries. These are decoupled from [`crate::tree::StreamingTree`]
5//! so they can be reused by the ensemble layer and tested independently.
6
7use alloc::vec;
8use alloc::vec::Vec;
9
10use crate::tree::node::{NodeId, TreeArena};
11
12/// Traverse the tree from the given root to a leaf, returning the leaf's [`NodeId`].
13///
14/// At each internal node the traversal goes **left** when
15/// `features[feature_idx] <= threshold` and **right** otherwise.
16///
17/// # Panics
18///
19/// Panics (via bounds check) if `root` or any child index is out of range,
20/// or if a feature index referenced by an internal node exceeds `features.len()`.
21#[inline]
22pub fn traverse_to_leaf(arena: &TreeArena, root: NodeId, features: &[f64]) -> NodeId {
23    let mut current = root;
24    loop {
25        let idx = current.idx();
26        if arena.is_leaf[idx] {
27            return current;
28        }
29        let feat_idx = arena.feature_idx[idx] as usize;
30        if features[feat_idx] <= arena.threshold[idx] {
31            current = arena.left[idx];
32        } else {
33            current = arena.right[idx];
34        }
35    }
36}
37
38/// Predict the leaf value for a single feature vector, starting from `root`.
39///
40/// Equivalent to `arena.leaf_value[traverse_to_leaf(arena, root, features).idx()]`.
41#[inline]
42pub fn predict_from_root(arena: &TreeArena, root: NodeId, features: &[f64]) -> f64 {
43    let leaf = traverse_to_leaf(arena, root, features);
44    arena.leaf_value[leaf.idx()]
45}
46
47/// Batch prediction: compute the leaf value for each row in `feature_matrix`.
48///
49/// Returns a `Vec<f64>` with one prediction per row, in the same order as the
50/// input matrix.
51pub fn predict_batch(arena: &TreeArena, root: NodeId, feature_matrix: &[Vec<f64>]) -> Vec<f64> {
52    feature_matrix
53        .iter()
54        .map(|features| predict_from_root(arena, root, features))
55        .collect()
56}
57
58/// Collect every leaf [`NodeId`] reachable from `root` via depth-first traversal.
59///
60/// The returned order is deterministic (left subtree leaves appear before right
61/// subtree leaves) but should not be relied upon for semantic meaning.
62pub fn collect_leaves(arena: &TreeArena, root: NodeId) -> Vec<NodeId> {
63    let mut leaves = Vec::new();
64    let mut stack = vec![root];
65    while let Some(node) = stack.pop() {
66        let idx = node.idx();
67        if arena.is_leaf[idx] {
68            leaves.push(node);
69        } else {
70            // Push right first so left is popped first (DFS left-to-right).
71            stack.push(arena.right[idx]);
72            stack.push(arena.left[idx]);
73        }
74    }
75    leaves
76}
77
78/// Compute the maximum depth of any leaf reachable from `root`.
79///
80/// Returns `0` for a single-leaf tree (root is a leaf at depth 0).
81/// Returns `0` if the tree is empty (no leaves found).
82pub fn tree_depth(arena: &TreeArena, root: NodeId) -> u16 {
83    let leaves = collect_leaves(arena, root);
84    leaves
85        .iter()
86        .map(|id| arena.depth[id.idx()])
87        .max()
88        .unwrap_or(0)
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use crate::tree::node::{NodeId, TreeArena};
95
96    /// Helper: create an empty [`TreeArena`].
97    fn empty_arena() -> TreeArena {
98        TreeArena {
99            feature_idx: Vec::new(),
100            threshold: Vec::new(),
101            left: Vec::new(),
102            right: Vec::new(),
103            leaf_value: Vec::new(),
104            is_leaf: Vec::new(),
105            depth: Vec::new(),
106            sample_count: Vec::new(),
107            categorical_mask: Vec::new(),
108        }
109    }
110
111    /// Push a leaf node into the arena, returning its [`NodeId`].
112    fn push_leaf(arena: &mut TreeArena, value: f64, depth: u16) -> NodeId {
113        let id = NodeId(arena.is_leaf.len() as u32);
114        arena.feature_idx.push(0);
115        arena.threshold.push(0.0);
116        arena.left.push(NodeId::NONE);
117        arena.right.push(NodeId::NONE);
118        arena.leaf_value.push(value);
119        arena.is_leaf.push(true);
120        arena.depth.push(depth);
121        arena.sample_count.push(0);
122        arena.categorical_mask.push(None);
123        id
124    }
125
126    /// Convert an existing leaf into an internal node by updating its fields
127    /// in place. Useful for building trees top-down where the root is allocated
128    /// first as a leaf and then split.
129    fn convert_to_split(
130        arena: &mut TreeArena,
131        node: NodeId,
132        feature: u32,
133        threshold: f64,
134        left: NodeId,
135        right: NodeId,
136    ) {
137        let idx = node.idx();
138        arena.feature_idx[idx] = feature;
139        arena.threshold[idx] = threshold;
140        arena.left[idx] = left;
141        arena.right[idx] = right;
142        arena.is_leaf[idx] = false;
143    }
144
145    // ------------------------------------------------------------------
146    // 1. Single leaf tree: predict returns leaf value.
147    // ------------------------------------------------------------------
148
149    #[test]
150    fn single_leaf_returns_value() {
151        let mut arena = empty_arena();
152        let root = push_leaf(&mut arena, 0.0, 0);
153
154        // Any feature vector should land on the only leaf.
155        assert_eq!(predict_from_root(&arena, root, &[1.0, 2.0, 3.0]), 0.0);
156    }
157
158    #[test]
159    fn single_leaf_with_nonzero_value() {
160        let mut arena = empty_arena();
161        let root = push_leaf(&mut arena, 42.5, 0);
162
163        assert_eq!(predict_from_root(&arena, root, &[]), 42.5);
164    }
165
166    // ------------------------------------------------------------------
167    // 2. One split: verify left/right routing.
168    // ------------------------------------------------------------------
169
170    /// Build:
171    /// ```text
172    ///        [0] feat=0, thr=5.0
173    ///        /                \
174    ///   [1] leaf=-1.0    [2] leaf=+1.0
175    /// ```
176    fn build_one_split() -> (TreeArena, NodeId) {
177        let mut arena = empty_arena();
178
179        // Allocate root first as leaf, then children, then convert root.
180        let root = push_leaf(&mut arena, 0.0, 0);
181        let left_child = push_leaf(&mut arena, -1.0, 1);
182        let right_child = push_leaf(&mut arena, 1.0, 1);
183        convert_to_split(&mut arena, root, 0, 5.0, left_child, right_child);
184
185        (arena, root)
186    }
187
188    #[test]
189    fn one_split_goes_left() {
190        let (arena, root) = build_one_split();
191        // feature[0] = 3.0 <= 5.0 -> left leaf (-1.0)
192        assert_eq!(predict_from_root(&arena, root, &[3.0]), -1.0);
193    }
194
195    #[test]
196    fn one_split_goes_right() {
197        let (arena, root) = build_one_split();
198        // feature[0] = 7.0 > 5.0 -> right leaf (+1.0)
199        assert_eq!(predict_from_root(&arena, root, &[7.0]), 1.0);
200    }
201
202    #[test]
203    fn one_split_equal_goes_left() {
204        let (arena, root) = build_one_split();
205        // feature[0] = 5.0 == threshold -> left (<=)
206        assert_eq!(predict_from_root(&arena, root, &[5.0]), -1.0);
207    }
208
209    // ------------------------------------------------------------------
210    // 3. Two-level tree: three different feature vectors hit three leaves.
211    // ------------------------------------------------------------------
212
213    /// Build:
214    /// ```text
215    ///             [0] feat=0, thr=5.0
216    ///            /                   \
217    ///    [1] feat=1, thr=2.0     [2] leaf=10.0
218    ///       /            \
219    /// [3] leaf=-5.0  [4] leaf=3.0
220    /// ```
221    fn build_two_level() -> (TreeArena, NodeId) {
222        let mut arena = empty_arena();
223
224        let root = push_leaf(&mut arena, 0.0, 0); // id 0
225        let inner = push_leaf(&mut arena, 0.0, 1); // id 1
226        let right_leaf = push_leaf(&mut arena, 10.0, 1); // id 2
227        let left_left = push_leaf(&mut arena, -5.0, 2); // id 3
228        let left_right = push_leaf(&mut arena, 3.0, 2); // id 4
229
230        convert_to_split(&mut arena, root, 0, 5.0, inner, right_leaf);
231        convert_to_split(&mut arena, inner, 1, 2.0, left_left, left_right);
232
233        (arena, root)
234    }
235
236    #[test]
237    fn two_level_reaches_left_left() {
238        let (arena, root) = build_two_level();
239        // feat[0]=1.0 <= 5.0 -> go left; feat[1]=0.5 <= 2.0 -> go left => -5.0
240        assert_eq!(predict_from_root(&arena, root, &[1.0, 0.5]), -5.0);
241    }
242
243    #[test]
244    fn two_level_reaches_left_right() {
245        let (arena, root) = build_two_level();
246        // feat[0]=4.0 <= 5.0 -> go left; feat[1]=3.0 > 2.0 -> go right => 3.0
247        assert_eq!(predict_from_root(&arena, root, &[4.0, 3.0]), 3.0);
248    }
249
250    #[test]
251    fn two_level_reaches_right_leaf() {
252        let (arena, root) = build_two_level();
253        // feat[0]=8.0 > 5.0 -> go right => 10.0
254        assert_eq!(predict_from_root(&arena, root, &[8.0, 999.0]), 10.0);
255    }
256
257    // ------------------------------------------------------------------
258    // 4. Batch prediction: verify consistency with individual predictions.
259    // ------------------------------------------------------------------
260
261    #[test]
262    fn batch_matches_individual() {
263        let (arena, root) = build_two_level();
264
265        let rows = vec![
266            vec![1.0, 0.5],
267            vec![4.0, 3.0],
268            vec![8.0, 0.0],
269            vec![5.0, 2.0], // both exactly on threshold -> left, left => -5.0
270        ];
271
272        let batch = predict_batch(&arena, root, &rows);
273
274        for (i, row) in rows.iter().enumerate() {
275            let individual = predict_from_root(&arena, root, row);
276            assert_eq!(
277                batch[i], individual,
278                "batch[{}] = {} but individual = {} for features {:?}",
279                i, batch[i], individual, row
280            );
281        }
282    }
283
284    #[test]
285    fn batch_empty_input() {
286        let (arena, root) = build_one_split();
287        let result = predict_batch(&arena, root, &[]);
288        assert!(result.is_empty());
289    }
290
291    // ------------------------------------------------------------------
292    // 5. collect_leaves: correct leaf count for different tree shapes.
293    // ------------------------------------------------------------------
294
295    #[test]
296    fn collect_leaves_single_leaf() {
297        let mut arena = empty_arena();
298        let root = push_leaf(&mut arena, 0.0, 0);
299        let leaves = collect_leaves(&arena, root);
300        assert_eq!(leaves.len(), 1);
301        assert_eq!(leaves[0].idx(), root.idx());
302    }
303
304    #[test]
305    fn collect_leaves_one_split() {
306        let (arena, root) = build_one_split();
307        let leaves = collect_leaves(&arena, root);
308        assert_eq!(leaves.len(), 2);
309    }
310
311    #[test]
312    fn collect_leaves_two_level() {
313        let (arena, root) = build_two_level();
314        let leaves = collect_leaves(&arena, root);
315        // Three leaves: left-left(-5.0), left-right(3.0), right(10.0)
316        assert_eq!(leaves.len(), 3);
317
318        // Verify DFS left-to-right order by checking leaf values.
319        let values: Vec<f64> = leaves.iter().map(|id| arena.leaf_value[id.idx()]).collect();
320        assert_eq!(values, vec![-5.0, 3.0, 10.0]);
321    }
322
323    #[test]
324    fn collect_leaves_balanced_depth2() {
325        // Build a perfectly balanced tree of depth 2 with 4 leaves.
326        //
327        //             [0] feat=0, thr=5.0
328        //            /                   \
329        //    [1] feat=1, thr=2.0    [2] feat=1, thr=8.0
330        //       /        \              /        \
331        //  [3] 1.0    [4] 2.0     [5] 3.0    [6] 4.0
332        let mut arena = empty_arena();
333
334        let root = push_leaf(&mut arena, 0.0, 0);
335        let left = push_leaf(&mut arena, 0.0, 1);
336        let right = push_leaf(&mut arena, 0.0, 1);
337        let ll = push_leaf(&mut arena, 1.0, 2);
338        let lr = push_leaf(&mut arena, 2.0, 2);
339        let rl = push_leaf(&mut arena, 3.0, 2);
340        let rr = push_leaf(&mut arena, 4.0, 2);
341
342        convert_to_split(&mut arena, root, 0, 5.0, left, right);
343        convert_to_split(&mut arena, left, 1, 2.0, ll, lr);
344        convert_to_split(&mut arena, right, 1, 8.0, rl, rr);
345
346        let leaves = collect_leaves(&arena, root);
347        assert_eq!(leaves.len(), 4);
348
349        let values: Vec<f64> = leaves.iter().map(|id| arena.leaf_value[id.idx()]).collect();
350        assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0]);
351    }
352
353    // ------------------------------------------------------------------
354    // 6. tree_depth: verify for balanced and unbalanced trees.
355    // ------------------------------------------------------------------
356
357    #[test]
358    fn depth_single_leaf() {
359        let mut arena = empty_arena();
360        let root = push_leaf(&mut arena, 0.0, 0);
361        assert_eq!(tree_depth(&arena, root), 0);
362    }
363
364    #[test]
365    fn depth_one_split() {
366        let (arena, root) = build_one_split();
367        assert_eq!(tree_depth(&arena, root), 1);
368    }
369
370    #[test]
371    fn depth_two_level_unbalanced() {
372        let (arena, root) = build_two_level();
373        // Deepest leaf is at depth 2 (left-left and left-right).
374        // Right leaf is at depth 1.
375        assert_eq!(tree_depth(&arena, root), 2);
376    }
377
378    #[test]
379    fn depth_left_skewed() {
380        // Build a left-skewed chain of depth 4.
381        //
382        //  [0] split
383        //  /       \
384        // [1] split [2] leaf (depth=1)
385        //  /    \
386        // [3] split [4] leaf (depth=2)
387        //  /    \
388        // [5] split [6] leaf (depth=3)
389        //  /    \
390        // [7] leaf [8] leaf (depth=4)
391        let mut arena = empty_arena();
392
393        let n0 = push_leaf(&mut arena, 0.0, 0);
394        let n1 = push_leaf(&mut arena, 0.0, 1);
395        let n2 = push_leaf(&mut arena, 0.0, 1);
396        let n3 = push_leaf(&mut arena, 0.0, 2);
397        let n4 = push_leaf(&mut arena, 0.0, 2);
398        let n5 = push_leaf(&mut arena, 0.0, 3);
399        let n6 = push_leaf(&mut arena, 0.0, 3);
400        let n7 = push_leaf(&mut arena, 0.0, 4);
401        let n8 = push_leaf(&mut arena, 0.0, 4);
402
403        convert_to_split(&mut arena, n0, 0, 1.0, n1, n2);
404        convert_to_split(&mut arena, n1, 0, 2.0, n3, n4);
405        convert_to_split(&mut arena, n3, 0, 3.0, n5, n6);
406        convert_to_split(&mut arena, n5, 0, 4.0, n7, n8);
407
408        assert_eq!(tree_depth(&arena, n0), 4);
409        // 5 leaves total: n2(d=1), n4(d=2), n6(d=3), n7(d=4), n8(d=4)
410        assert_eq!(collect_leaves(&arena, n0).len(), 5);
411    }
412
413    // ------------------------------------------------------------------
414    // 7. Edge case: feature exactly equal to threshold goes left.
415    // ------------------------------------------------------------------
416
417    #[test]
418    fn threshold_equality_goes_left() {
419        let (arena, root) = build_one_split();
420        let leaf = traverse_to_leaf(&arena, root, &[5.0]);
421        // Left child is id 1, right child is id 2.
422        assert_eq!(leaf.idx(), 1, "value == threshold must route left");
423        assert_eq!(arena.leaf_value[leaf.idx()], -1.0);
424    }
425
426    #[test]
427    fn threshold_equality_two_level() {
428        let (arena, root) = build_two_level();
429        // Both thresholds hit exactly: feat[0]=5.0 <= 5.0 -> left,
430        // then feat[1]=2.0 <= 2.0 -> left => leaf at id 3, value -5.0.
431        assert_eq!(predict_from_root(&arena, root, &[5.0, 2.0]), -5.0);
432    }
433
434    // ------------------------------------------------------------------
435    // Extra: traverse_to_leaf returns the correct NodeId directly.
436    // ------------------------------------------------------------------
437
438    #[test]
439    fn traverse_returns_correct_node_id() {
440        let (arena, root) = build_two_level();
441
442        // Route to left-left leaf (id 3).
443        let leaf = traverse_to_leaf(&arena, root, &[0.0, 0.0]);
444        assert_eq!(leaf.idx(), 3);
445
446        // Route to left-right leaf (id 4).
447        let leaf = traverse_to_leaf(&arena, root, &[0.0, 5.0]);
448        assert_eq!(leaf.idx(), 4);
449
450        // Route to right leaf (id 2).
451        let leaf = traverse_to_leaf(&arena, root, &[10.0, 0.0]);
452        assert_eq!(leaf.idx(), 2);
453    }
454}