use crate::packed::PackedNode;
#[inline(always)]
pub fn predict_tree(nodes: &[PackedNode], features: &[f32]) -> f32 {
let mut idx = 0u32;
loop {
let node = unsafe { nodes.get_unchecked(idx as usize) };
if node.is_leaf() {
return node.value;
}
let feat_idx = node.feature_idx() as usize;
let feat_val = unsafe { *features.get_unchecked(feat_idx) };
let go_right = (feat_val > node.value) as u32;
let left = node.left_child() as u32;
let right = node.right_child() as u32;
idx = left + go_right * right.wrapping_sub(left);
}
}
#[inline]
pub fn predict_tree_x4(nodes: &[PackedNode], features: [&[f32]; 4]) -> [f32; 4] {
let mut idx = [0u32; 4];
let mut done = [false; 4];
let mut result = [0.0f32; 4];
loop {
let mut all_done = true;
for s in 0..4 {
if done[s] {
continue;
}
let node = unsafe { nodes.get_unchecked(idx[s] as usize) };
if node.is_leaf() {
result[s] = node.value;
done[s] = true;
continue;
}
all_done = false;
let feat_idx = node.feature_idx() as usize;
let feat_val = unsafe { *features[s].get_unchecked(feat_idx) };
let go_right = (feat_val > node.value) as u32;
let left = node.left_child() as u32;
let right = node.right_child() as u32;
idx[s] = left + go_right * right.wrapping_sub(left);
}
if all_done {
return result;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packed::PackedNode;
fn simple_tree() -> [PackedNode; 3] {
[
PackedNode::split(5.0, 0, 1, 2),
PackedNode::leaf(-1.0),
PackedNode::leaf(1.0),
]
}
fn two_level_tree() -> [PackedNode; 5] {
[
PackedNode::split(5.0, 0, 1, 2),
PackedNode::split(2.0, 1, 3, 4),
PackedNode::leaf(10.0),
PackedNode::leaf(-5.0),
PackedNode::leaf(3.0),
]
}
#[test]
fn single_leaf_tree() {
let nodes = [PackedNode::leaf(42.0)];
assert_eq!(predict_tree(&nodes, &[1.0, 2.0]), 42.0);
}
#[test]
fn simple_tree_goes_left() {
let nodes = simple_tree();
assert_eq!(predict_tree(&nodes, &[3.0]), -1.0);
}
#[test]
fn simple_tree_goes_right() {
let nodes = simple_tree();
assert_eq!(predict_tree(&nodes, &[7.0]), 1.0);
}
#[test]
fn simple_tree_equal_goes_left() {
let nodes = simple_tree();
assert_eq!(predict_tree(&nodes, &[5.0]), -1.0);
}
#[test]
fn two_level_left_left() {
let nodes = two_level_tree();
assert_eq!(predict_tree(&nodes, &[1.0, 0.5]), -5.0);
}
#[test]
fn two_level_left_right() {
let nodes = two_level_tree();
assert_eq!(predict_tree(&nodes, &[4.0, 3.0]), 3.0);
}
#[test]
fn two_level_right() {
let nodes = two_level_tree();
assert_eq!(predict_tree(&nodes, &[8.0, 999.0]), 10.0);
}
#[test]
fn predict_x4_matches_single() {
let nodes = two_level_tree();
let f0: &[f32] = &[1.0, 0.5];
let f1: &[f32] = &[4.0, 3.0];
let f2: &[f32] = &[8.0, 0.0];
let f3: &[f32] = &[5.0, 2.0];
let batch = predict_tree_x4(&nodes, [f0, f1, f2, f3]);
assert_eq!(batch[0], predict_tree(&nodes, f0));
assert_eq!(batch[1], predict_tree(&nodes, f1));
assert_eq!(batch[2], predict_tree(&nodes, f2));
assert_eq!(batch[3], predict_tree(&nodes, f3));
}
}