irithyll_core/
traverse.rs1use crate::packed::PackedNode;
7
8#[inline(always)]
19pub fn predict_tree(nodes: &[PackedNode], features: &[f32]) -> f32 {
20 let mut idx = 0u32;
21 loop {
22 let node = unsafe { nodes.get_unchecked(idx as usize) };
24 if node.is_leaf() {
25 return node.value;
26 }
27 let feat_idx = node.feature_idx() as usize;
28 let feat_val = unsafe { *features.get_unchecked(feat_idx) };
29
30 let go_right = (feat_val > node.value) as u32;
35 let left = node.left_child() as u32;
36 let right = node.right_child() as u32;
37 idx = left + go_right * right.wrapping_sub(left);
38 }
39}
40
41#[inline]
47pub fn predict_tree_x4(nodes: &[PackedNode], features: [&[f32]; 4]) -> [f32; 4] {
48 let mut idx = [0u32; 4];
49 let mut done = [false; 4];
50 let mut result = [0.0f32; 4];
51
52 loop {
53 let mut all_done = true;
54 for s in 0..4 {
55 if done[s] {
56 continue;
57 }
58 let node = unsafe { nodes.get_unchecked(idx[s] as usize) };
59 if node.is_leaf() {
60 result[s] = node.value;
61 done[s] = true;
62 continue;
63 }
64 all_done = false;
65 let feat_idx = node.feature_idx() as usize;
66 let feat_val = unsafe { *features[s].get_unchecked(feat_idx) };
67 let go_right = (feat_val > node.value) as u32;
68 let left = node.left_child() as u32;
69 let right = node.right_child() as u32;
70 idx[s] = left + go_right * right.wrapping_sub(left);
71 }
72 if all_done {
73 return result;
74 }
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81 use crate::packed::PackedNode;
82
83 fn simple_tree() -> [PackedNode; 3] {
90 [
91 PackedNode::split(5.0, 0, 1, 2),
92 PackedNode::leaf(-1.0),
93 PackedNode::leaf(1.0),
94 ]
95 }
96
97 fn two_level_tree() -> [PackedNode; 5] {
106 [
107 PackedNode::split(5.0, 0, 1, 2),
108 PackedNode::split(2.0, 1, 3, 4),
109 PackedNode::leaf(10.0),
110 PackedNode::leaf(-5.0),
111 PackedNode::leaf(3.0),
112 ]
113 }
114
115 #[test]
116 fn single_leaf_tree() {
117 let nodes = [PackedNode::leaf(42.0)];
118 assert_eq!(predict_tree(&nodes, &[1.0, 2.0]), 42.0);
119 }
120
121 #[test]
122 fn simple_tree_goes_left() {
123 let nodes = simple_tree();
124 assert_eq!(predict_tree(&nodes, &[3.0]), -1.0);
128 }
129
130 #[test]
131 fn simple_tree_goes_right() {
132 let nodes = simple_tree();
133 assert_eq!(predict_tree(&nodes, &[7.0]), 1.0);
135 }
136
137 #[test]
138 fn simple_tree_equal_goes_left() {
139 let nodes = simple_tree();
140 assert_eq!(predict_tree(&nodes, &[5.0]), -1.0);
143 }
144
145 #[test]
146 fn two_level_left_left() {
147 let nodes = two_level_tree();
148 assert_eq!(predict_tree(&nodes, &[1.0, 0.5]), -5.0);
150 }
151
152 #[test]
153 fn two_level_left_right() {
154 let nodes = two_level_tree();
155 assert_eq!(predict_tree(&nodes, &[4.0, 3.0]), 3.0);
157 }
158
159 #[test]
160 fn two_level_right() {
161 let nodes = two_level_tree();
162 assert_eq!(predict_tree(&nodes, &[8.0, 999.0]), 10.0);
164 }
165
166 #[test]
167 fn predict_x4_matches_single() {
168 let nodes = two_level_tree();
169 let f0: &[f32] = &[1.0, 0.5];
170 let f1: &[f32] = &[4.0, 3.0];
171 let f2: &[f32] = &[8.0, 0.0];
172 let f3: &[f32] = &[5.0, 2.0];
173
174 let batch = predict_tree_x4(&nodes, [f0, f1, f2, f3]);
175
176 assert_eq!(batch[0], predict_tree(&nodes, f0));
177 assert_eq!(batch[1], predict_tree(&nodes, f1));
178 assert_eq!(batch[2], predict_tree(&nodes, f2));
179 assert_eq!(batch[3], predict_tree(&nodes, f3));
180 }
181}