use crate::packed_i16::PackedNodeI16;
#[inline(always)]
pub fn predict_tree_i16(nodes: &[PackedNodeI16], features_quantized: &[i16]) -> i16 {
let mut idx = 0u32;
loop {
let node = unsafe { nodes.get_unchecked(idx as usize) };
if node.is_leaf() {
return node.value;
}
let feat_idx = node.feature_idx() as usize;
let feat_val = unsafe { *features_quantized.get_unchecked(feat_idx) };
let go_right = (feat_val > node.value) as u32;
let left = node.left_child() as u32;
let right = node.right_child() as u32;
idx = left + go_right * right.wrapping_sub(left);
}
}
#[inline]
pub fn predict_tree_i16_x4(nodes: &[PackedNodeI16], features: [&[i16]; 4]) -> [i16; 4] {
let mut idx = [0u32; 4];
let mut done = [false; 4];
let mut result = [0i16; 4];
loop {
let mut all_done = true;
for s in 0..4 {
if done[s] {
continue;
}
let node = unsafe { nodes.get_unchecked(idx[s] as usize) };
if node.is_leaf() {
result[s] = node.value;
done[s] = true;
continue;
}
all_done = false;
let feat_idx = node.feature_idx() as usize;
let feat_val = unsafe { *features[s].get_unchecked(feat_idx) };
let go_right = (feat_val > node.value) as u32;
let left = node.left_child() as u32;
let right = node.right_child() as u32;
idx[s] = left + go_right * right.wrapping_sub(left);
}
if all_done {
return result;
}
}
}
#[inline(always)]
pub fn predict_tree_i16_inline(
nodes: &[PackedNodeI16],
features: &[f32],
feature_scales: &[f32],
) -> i16 {
let mut idx = 0u32;
loop {
let node = unsafe { nodes.get_unchecked(idx as usize) };
if node.is_leaf() {
return node.value;
}
let feat_idx = node.feature_idx() as usize;
let feat_val = unsafe { *features.get_unchecked(feat_idx) };
let scale = unsafe { *feature_scales.get_unchecked(feat_idx) };
let quantized = (feat_val * scale) as i16;
let go_right = (quantized > node.value) as u32;
let left = node.left_child() as u32;
let right = node.right_child() as u32;
idx = left + go_right * right.wrapping_sub(left);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packed_i16::PackedNodeI16;
fn simple_tree() -> [PackedNodeI16; 3] {
[
PackedNodeI16::split(500, 0, 1, 2),
PackedNodeI16::leaf(-100),
PackedNodeI16::leaf(100),
]
}
fn two_level_tree() -> [PackedNodeI16; 5] {
[
PackedNodeI16::split(500, 0, 1, 2),
PackedNodeI16::split(200, 1, 3, 4),
PackedNodeI16::leaf(1000),
PackedNodeI16::leaf(-500),
PackedNodeI16::leaf(300),
]
}
#[test]
fn single_leaf_tree_i16() {
let nodes = [PackedNodeI16::leaf(4200)];
assert_eq!(predict_tree_i16(&nodes, &[100, 200]), 4200);
}
#[test]
fn simple_tree_goes_left_i16() {
let nodes = simple_tree();
assert_eq!(predict_tree_i16(&nodes, &[300]), -100);
}
#[test]
fn simple_tree_goes_right_i16() {
let nodes = simple_tree();
assert_eq!(predict_tree_i16(&nodes, &[700]), 100);
}
#[test]
fn simple_tree_equal_goes_left_i16() {
let nodes = simple_tree();
assert_eq!(predict_tree_i16(&nodes, &[500]), -100);
}
#[test]
fn two_level_left_left_i16() {
let nodes = two_level_tree();
assert_eq!(predict_tree_i16(&nodes, &[100, 50]), -500);
}
#[test]
fn two_level_left_right_i16() {
let nodes = two_level_tree();
assert_eq!(predict_tree_i16(&nodes, &[400, 300]), 300);
}
#[test]
fn two_level_right_i16() {
let nodes = two_level_tree();
assert_eq!(predict_tree_i16(&nodes, &[800, 9999]), 1000);
}
#[test]
fn predict_x4_matches_single_i16() {
let nodes = two_level_tree();
let f0: &[i16] = &[100, 50];
let f1: &[i16] = &[400, 300];
let f2: &[i16] = &[800, 0];
let f3: &[i16] = &[500, 200];
let batch = predict_tree_i16_x4(&nodes, [f0, f1, f2, f3]);
assert_eq!(batch[0], predict_tree_i16(&nodes, f0));
assert_eq!(batch[1], predict_tree_i16(&nodes, f1));
assert_eq!(batch[2], predict_tree_i16(&nodes, f2));
assert_eq!(batch[3], predict_tree_i16(&nodes, f3));
}
#[test]
fn inline_matches_prequantized() {
let nodes = simple_tree();
let features_f32: &[f32] = &[300.0];
let scales: &[f32] = &[1.0];
let features_i16: &[i16] = &[300];
let inline_result = predict_tree_i16_inline(&nodes, features_f32, scales);
let preq_result = predict_tree_i16(&nodes, features_i16);
assert_eq!(inline_result, preq_result);
}
#[test]
fn inline_with_scaling() {
let nodes = simple_tree();
let features: &[f32] = &[5.0];
let scales: &[f32] = &[200.0];
assert_eq!(predict_tree_i16_inline(&nodes, features, scales), 100);
let features2: &[f32] = &[1.0];
assert_eq!(predict_tree_i16_inline(&nodes, features2, scales), -100);
}
}