Skip to main content

irithyll_core/
traverse.rs

1//! Branch-free tree traversal for packed nodes.
2//!
3//! The hot loop uses `get_unchecked` (bounds validated at `EnsembleView` construction)
4//! and branchless child selection (`cmov` on x86, `csel` on ARM) for maximum throughput.
5
6use crate::packed::PackedNode;
7
8/// Traverse a single tree and return the leaf prediction value.
9///
10/// # Safety contract (upheld by `EnsembleView::from_bytes`)
11///
12/// - All node child indices are within `nodes.len()`.
13/// - All feature indices are within `features.len()`.
14/// - The tree is acyclic and every path terminates at a leaf.
15///
16/// These invariants are validated once during `EnsembleView::from_bytes()`. After
17/// validation, traversal uses `get_unchecked` for zero-overhead indexing.
18#[inline(always)]
19pub fn predict_tree(nodes: &[PackedNode], features: &[f32]) -> f32 {
20    let mut idx = 0u32;
21    loop {
22        // SAFETY: All indices validated during EnsembleView construction.
23        let node = unsafe { nodes.get_unchecked(idx as usize) };
24        if node.is_leaf() {
25            return node.value;
26        }
27        let feat_idx = node.feature_idx() as usize;
28        let feat_val = unsafe { *features.get_unchecked(feat_idx) };
29
30        // Branchless child selection:
31        // go_right = 1 if feat_val > threshold, 0 otherwise
32        // idx = left + go_right * (right - left)
33        // This compiles to cmov (x86) or csel (ARM) — no branch prediction miss.
34        let go_right = (feat_val > node.value) as u32;
35        let left = node.left_child() as u32;
36        let right = node.right_child() as u32;
37        idx = left + go_right * right.wrapping_sub(left);
38    }
39}
40
41/// Predict 4 samples through one tree simultaneously.
42///
43/// Exploits CPU out-of-order execution with 4 independent traversal states.
44/// Each sample follows its own path through the tree — the CPU can overlap
45/// memory loads across samples since the data dependencies are independent.
46#[inline]
47pub fn predict_tree_x4(nodes: &[PackedNode], features: [&[f32]; 4]) -> [f32; 4] {
48    let mut idx = [0u32; 4];
49    let mut done = [false; 4];
50    let mut result = [0.0f32; 4];
51
52    loop {
53        let mut all_done = true;
54        for s in 0..4 {
55            if done[s] {
56                continue;
57            }
58            let node = unsafe { nodes.get_unchecked(idx[s] as usize) };
59            if node.is_leaf() {
60                result[s] = node.value;
61                done[s] = true;
62                continue;
63            }
64            all_done = false;
65            let feat_idx = node.feature_idx() as usize;
66            let feat_val = unsafe { *features[s].get_unchecked(feat_idx) };
67            let go_right = (feat_val > node.value) as u32;
68            let left = node.left_child() as u32;
69            let right = node.right_child() as u32;
70            idx[s] = left + go_right * right.wrapping_sub(left);
71        }
72        if all_done {
73            return result;
74        }
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81    use crate::packed::PackedNode;
82
83    /// Build a simple tree:
84    /// ```text
85    ///     [0] feat=0, thr=5.0
86    ///     /                \
87    /// [1] leaf=-1.0    [2] leaf=1.0
88    /// ```
89    fn simple_tree() -> [PackedNode; 3] {
90        [
91            PackedNode::split(5.0, 0, 1, 2),
92            PackedNode::leaf(-1.0),
93            PackedNode::leaf(1.0),
94        ]
95    }
96
97    /// Build a two-level tree:
98    /// ```text
99    ///          [0] feat=0, thr=5.0
100    ///          /                \
101    ///   [1] feat=1, thr=2.0   [2] leaf=10.0
102    ///    /           \
103    /// [3] leaf=-5.0  [4] leaf=3.0
104    /// ```
105    fn two_level_tree() -> [PackedNode; 5] {
106        [
107            PackedNode::split(5.0, 0, 1, 2),
108            PackedNode::split(2.0, 1, 3, 4),
109            PackedNode::leaf(10.0),
110            PackedNode::leaf(-5.0),
111            PackedNode::leaf(3.0),
112        ]
113    }
114
115    #[test]
116    fn single_leaf_tree() {
117        let nodes = [PackedNode::leaf(42.0)];
118        assert_eq!(predict_tree(&nodes, &[1.0, 2.0]), 42.0);
119    }
120
121    #[test]
122    fn simple_tree_goes_left() {
123        let nodes = simple_tree();
124        // feat[0] = 3.0 <= 5.0 -> NOT > threshold -> go_right=0 -> left child
125        // But wait: our comparison is `feat_val > threshold` for go_right.
126        // feat[0]=3.0 > 5.0 is false -> go_right=0 -> idx = left = 1 -> leaf=-1.0
127        assert_eq!(predict_tree(&nodes, &[3.0]), -1.0);
128    }
129
130    #[test]
131    fn simple_tree_goes_right() {
132        let nodes = simple_tree();
133        // feat[0] = 7.0 > 5.0 -> go_right=1 -> right child -> leaf=1.0
134        assert_eq!(predict_tree(&nodes, &[7.0]), 1.0);
135    }
136
137    #[test]
138    fn simple_tree_equal_goes_left() {
139        let nodes = simple_tree();
140        // feat[0] = 5.0, NOT > 5.0 -> go_right=0 -> left -> leaf=-1.0
141        // Note: this matches the SGBT convention (<=threshold goes left)
142        assert_eq!(predict_tree(&nodes, &[5.0]), -1.0);
143    }
144
145    #[test]
146    fn two_level_left_left() {
147        let nodes = two_level_tree();
148        // feat[0]=1.0, not > 5.0 -> left(1); feat[1]=0.5, not > 2.0 -> left(3) -> -5.0
149        assert_eq!(predict_tree(&nodes, &[1.0, 0.5]), -5.0);
150    }
151
152    #[test]
153    fn two_level_left_right() {
154        let nodes = two_level_tree();
155        // feat[0]=4.0, not > 5.0 -> left(1); feat[1]=3.0, > 2.0 -> right(4) -> 3.0
156        assert_eq!(predict_tree(&nodes, &[4.0, 3.0]), 3.0);
157    }
158
159    #[test]
160    fn two_level_right() {
161        let nodes = two_level_tree();
162        // feat[0]=8.0, > 5.0 -> right(2) -> leaf=10.0
163        assert_eq!(predict_tree(&nodes, &[8.0, 999.0]), 10.0);
164    }
165
166    #[test]
167    fn predict_x4_matches_single() {
168        let nodes = two_level_tree();
169        let f0: &[f32] = &[1.0, 0.5];
170        let f1: &[f32] = &[4.0, 3.0];
171        let f2: &[f32] = &[8.0, 0.0];
172        let f3: &[f32] = &[5.0, 2.0];
173
174        let batch = predict_tree_x4(&nodes, [f0, f1, f2, f3]);
175
176        assert_eq!(batch[0], predict_tree(&nodes, f0));
177        assert_eq!(batch[1], predict_tree(&nodes, f1));
178        assert_eq!(batch[2], predict_tree(&nodes, f2));
179        assert_eq!(batch[3], predict_tree(&nodes, f3));
180    }
181}