#[repr(C, align(4))]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct PackedNodeI16 {
pub value: i16,
pub feature_flags: u16,
pub children: u32,
}
impl PackedNodeI16 {
pub const LEAF_FLAG: u16 = 0x8000;
#[inline]
pub const fn leaf(value: i16) -> Self {
Self {
value,
children: 0,
feature_flags: Self::LEAF_FLAG,
}
}
#[inline]
pub const fn split(threshold: i16, feature_idx: u16, left: u16, right: u16) -> Self {
Self {
value: threshold,
children: (left as u32) | ((right as u32) << 16),
feature_flags: feature_idx & 0x7FFF,
}
}
#[inline]
pub const fn is_leaf(&self) -> bool {
self.feature_flags & Self::LEAF_FLAG != 0
}
#[inline]
pub const fn feature_idx(&self) -> u16 {
self.feature_flags & 0x7FFF
}
#[inline]
pub const fn left_child(&self) -> u16 {
self.children as u16
}
#[inline]
pub const fn right_child(&self) -> u16 {
(self.children >> 16) as u16
}
}
#[repr(C, align(4))]
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct QuantizedEnsembleHeader {
pub magic: u32,
pub version: u16,
pub n_trees: u16,
pub n_features: u16,
pub _reserved: u16,
pub base_prediction: f32,
}
impl QuantizedEnsembleHeader {
pub const MAGIC: u32 = u32::from_le_bytes(*b"IR16");
pub const VERSION: u16 = 1;
}
pub use crate::packed::TreeEntry as QuantizedTreeEntry;
#[cfg(test)]
mod tests {
use super::*;
use core::mem::{align_of, size_of};
#[test]
fn packed_node_i16_is_8_bytes() {
assert_eq!(size_of::<PackedNodeI16>(), 8);
}
#[test]
fn packed_node_i16_alignment_is_4() {
assert_eq!(align_of::<PackedNodeI16>(), 4);
}
#[test]
fn quantized_header_is_16_bytes() {
assert_eq!(size_of::<QuantizedEnsembleHeader>(), 16);
}
#[test]
fn leaf_node_i16_roundtrip() {
let node = PackedNodeI16::leaf(1234);
assert!(node.is_leaf());
assert_eq!(node.value, 1234);
assert_eq!(node.children, 0);
}
#[test]
fn split_node_i16_roundtrip() {
let node = PackedNodeI16::split(5000, 7, 1, 2);
assert!(!node.is_leaf());
assert_eq!(node.feature_idx(), 7);
assert_eq!(node.value, 5000);
assert_eq!(node.left_child(), 1);
assert_eq!(node.right_child(), 2);
}
#[test]
fn max_feature_index_i16() {
let node = PackedNodeI16::split(0, 0x7FFF, 0, 0);
assert_eq!(node.feature_idx(), 0x7FFF);
assert!(!node.is_leaf());
}
#[test]
fn max_child_indices_i16() {
let node = PackedNodeI16::split(0, 0, u16::MAX, u16::MAX);
assert_eq!(node.left_child(), u16::MAX);
assert_eq!(node.right_child(), u16::MAX);
}
#[test]
fn eight_nodes_per_cache_line() {
assert!(8 * size_of::<PackedNodeI16>() <= 64);
}
#[test]
fn header_magic_is_ir16() {
let bytes = QuantizedEnsembleHeader::MAGIC.to_le_bytes();
assert_eq!(&bytes, b"IR16");
}
#[test]
fn leaf_negative_value_roundtrip() {
let node = PackedNodeI16::leaf(-32000);
assert!(node.is_leaf());
assert_eq!(node.value, -32000);
}
#[test]
fn split_negative_threshold_roundtrip() {
let node = PackedNodeI16::split(-16000, 3, 10, 20);
assert!(!node.is_leaf());
assert_eq!(node.value, -16000);
assert_eq!(node.feature_idx(), 3);
assert_eq!(node.left_child(), 10);
assert_eq!(node.right_child(), 20);
}
#[test]
fn tree_entry_reused_from_packed() {
let entry: QuantizedTreeEntry = QuantizedTreeEntry {
n_nodes: 5,
offset: 40,
};
assert_eq!(entry.n_nodes, 5);
assert_eq!(entry.offset, 40);
assert_eq!(size_of::<QuantizedTreeEntry>(), 8);
}
}