#[repr(C, align(4))]
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct PackedNode {
pub value: f32,
pub children: u32,
pub feature_flags: u16,
pub _reserved: u16,
}
impl PackedNode {
pub const LEAF_FLAG: u16 = 0x8000;
#[inline]
pub const fn leaf(value: f32) -> Self {
Self {
value,
children: 0,
feature_flags: Self::LEAF_FLAG,
_reserved: 0,
}
}
#[inline]
pub const fn split(threshold: f32, feature_idx: u16, left: u16, right: u16) -> Self {
Self {
value: threshold,
children: (left as u32) | ((right as u32) << 16),
feature_flags: feature_idx & 0x7FFF,
_reserved: 0,
}
}
#[inline]
pub const fn is_leaf(&self) -> bool {
self.feature_flags & Self::LEAF_FLAG != 0
}
#[inline]
pub const fn feature_idx(&self) -> u16 {
self.feature_flags & 0x7FFF
}
#[inline]
pub const fn left_child(&self) -> u16 {
self.children as u16
}
#[inline]
pub const fn right_child(&self) -> u16 {
(self.children >> 16) as u16
}
}
#[repr(C, align(4))]
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct EnsembleHeader {
pub magic: u32,
pub version: u16,
pub n_trees: u16,
pub n_features: u16,
pub _reserved: u16,
pub base_prediction: f32,
}
impl EnsembleHeader {
pub const MAGIC: u32 = u32::from_le_bytes(*b"IRIT");
pub const VERSION: u16 = 1;
}
#[repr(C, align(4))]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct TreeEntry {
pub n_nodes: u32,
pub offset: u32,
}
#[cfg(test)]
mod tests {
use super::*;
use core::mem::{align_of, size_of};
#[test]
fn packed_node_is_12_bytes() {
assert_eq!(size_of::<PackedNode>(), 12);
}
#[test]
fn packed_node_alignment_is_4() {
assert_eq!(align_of::<PackedNode>(), 4);
}
#[test]
fn ensemble_header_is_16_bytes() {
assert_eq!(size_of::<EnsembleHeader>(), 16);
}
#[test]
fn tree_entry_is_8_bytes() {
assert_eq!(size_of::<TreeEntry>(), 8);
}
#[test]
fn leaf_node_roundtrip() {
let node = PackedNode::leaf(0.42);
assert!(node.is_leaf());
assert_eq!(node.value, 0.42);
assert_eq!(node.children, 0);
}
#[test]
fn split_node_roundtrip() {
let node = PackedNode::split(1.5, 7, 1, 2);
assert!(!node.is_leaf());
assert_eq!(node.feature_idx(), 7);
assert_eq!(node.value, 1.5);
assert_eq!(node.left_child(), 1);
assert_eq!(node.right_child(), 2);
}
#[test]
fn max_feature_index() {
let node = PackedNode::split(0.0, 0x7FFF, 0, 0);
assert_eq!(node.feature_idx(), 0x7FFF);
assert!(!node.is_leaf());
}
#[test]
fn max_child_indices() {
let node = PackedNode::split(0.0, 0, u16::MAX, u16::MAX);
assert_eq!(node.left_child(), u16::MAX);
assert_eq!(node.right_child(), u16::MAX);
}
#[test]
fn five_nodes_per_cache_line() {
assert!(5 * size_of::<PackedNode>() <= 64);
assert!(6 * size_of::<PackedNode>() > 64);
}
#[test]
fn header_magic_is_irit() {
let bytes = EnsembleHeader::MAGIC.to_le_bytes();
assert_eq!(&bytes, b"IRIT");
}
}