Skip to main content

irithyll_core/
packed.rs

1//! 12-byte packed node format and ensemble binary layout.
2//!
3//! # Binary Format
4//!
5//! ```text
6//! [EnsembleHeader: 16 bytes]
7//! [TreeEntry × n_trees: 8 bytes each]
8//! [PackedNode × total_nodes: 12 bytes each]
9//! ```
10//!
11//! Learning rate is baked into leaf values at export time, eliminating one
12//! multiply per tree during inference.
13
14/// 12-byte packed decision tree node. AoS layout for cache-optimal inference.
15///
16/// 5 nodes per 64-byte cache line (60 bytes used, 4 bytes padding).
17/// All fields for one traversal step are adjacent in memory — no cross-vector
18/// striding like the SoA `TreeArena` used during training.
19///
20/// # Field layout
21///
22/// - `value`: split threshold (internal nodes) or leaf prediction (leaf nodes).
23///   For leaves, the learning rate is already baked in: `value = lr * leaf_f64 as f32`.
24/// - `children`: packed left (low u16) and right (high u16) child indices.
25///   For leaves, this field is unused (set to 0).
26/// - `feature_flags`: bit 15 = is_leaf flag. Bits 14:0 = feature index (max 32767).
27///   For leaves, the feature index is unused.
28/// - `_reserved`: padding for future use (categorical flag, metadata, etc.).
29#[repr(C, align(4))]
30#[derive(Clone, Copy, Debug, PartialEq)]
31pub struct PackedNode {
32    /// Split threshold (internal) or prediction value (leaf, with lr baked in).
33    pub value: f32,
34    /// Packed children: left = low u16, right = high u16.
35    pub children: u32,
36    /// Bit 15 = is_leaf. Bits 14:0 = feature index.
37    pub feature_flags: u16,
38    /// Reserved for future use.
39    pub _reserved: u16,
40}
41
42impl PackedNode {
43    /// Bit mask for the is_leaf flag in `feature_flags`.
44    pub const LEAF_FLAG: u16 = 0x8000;
45
46    /// Create a leaf node with a prediction value.
47    #[inline]
48    pub const fn leaf(value: f32) -> Self {
49        Self {
50            value,
51            children: 0,
52            feature_flags: Self::LEAF_FLAG,
53            _reserved: 0,
54        }
55    }
56
57    /// Create an internal (split) node.
58    #[inline]
59    pub const fn split(threshold: f32, feature_idx: u16, left: u16, right: u16) -> Self {
60        Self {
61            value: threshold,
62            children: (left as u32) | ((right as u32) << 16),
63            feature_flags: feature_idx & 0x7FFF,
64            _reserved: 0,
65        }
66    }
67
68    /// Returns `true` if this is a leaf node.
69    #[inline]
70    pub const fn is_leaf(&self) -> bool {
71        self.feature_flags & Self::LEAF_FLAG != 0
72    }
73
74    /// Feature index (bits 14:0). Only meaningful for internal nodes.
75    #[inline]
76    pub const fn feature_idx(&self) -> u16 {
77        self.feature_flags & 0x7FFF
78    }
79
80    /// Left child index (low 16 bits of `children`).
81    #[inline]
82    pub const fn left_child(&self) -> u16 {
83        self.children as u16
84    }
85
86    /// Right child index (high 16 bits of `children`).
87    #[inline]
88    pub const fn right_child(&self) -> u16 {
89        (self.children >> 16) as u16
90    }
91}
92
93/// Header for the packed ensemble binary format. 16 bytes, 4-byte aligned.
94///
95/// Appears at the start of every packed binary. Followed by `n_trees` [`TreeEntry`]
96/// records and then contiguous [`PackedNode`] arrays.
97#[repr(C, align(4))]
98#[derive(Clone, Copy, Debug, PartialEq)]
99pub struct EnsembleHeader {
100    /// Magic bytes: `"IRIT"` in ASCII (little-endian u32: `0x54495249`).
101    pub magic: u32,
102    /// Binary format version. Currently `1`.
103    pub version: u16,
104    /// Number of trees in the ensemble.
105    pub n_trees: u16,
106    /// Expected number of input features.
107    pub n_features: u16,
108    /// Reserved padding.
109    pub _reserved: u16,
110    /// Base prediction (f64 quantized to f32). Added to the sum of tree predictions.
111    pub base_prediction: f32,
112}
113
114impl EnsembleHeader {
115    /// Magic value: "IRIT" in little-endian ASCII.
116    pub const MAGIC: u32 = u32::from_le_bytes(*b"IRIT");
117    /// Current format version.
118    pub const VERSION: u16 = 1;
119}
120
121/// Tree table entry: metadata for one tree in the ensemble. 8 bytes.
122///
123/// The `n_nodes` field gives the number of [`PackedNode`]s in this tree.
124/// The `offset` field is the byte offset from the start of the node data
125/// region to this tree's first node.
126#[repr(C, align(4))]
127#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
128pub struct TreeEntry {
129    /// Number of nodes in this tree.
130    pub n_nodes: u32,
131    /// Byte offset from nodes_base to this tree's first PackedNode.
132    pub offset: u32,
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use core::mem::{align_of, size_of};
139
140    #[test]
141    fn packed_node_is_12_bytes() {
142        assert_eq!(size_of::<PackedNode>(), 12);
143    }
144
145    #[test]
146    fn packed_node_alignment_is_4() {
147        assert_eq!(align_of::<PackedNode>(), 4);
148    }
149
150    #[test]
151    fn ensemble_header_is_16_bytes() {
152        assert_eq!(size_of::<EnsembleHeader>(), 16);
153    }
154
155    #[test]
156    fn tree_entry_is_8_bytes() {
157        assert_eq!(size_of::<TreeEntry>(), 8);
158    }
159
160    #[test]
161    fn leaf_node_roundtrip() {
162        let node = PackedNode::leaf(0.42);
163        assert!(node.is_leaf());
164        assert_eq!(node.value, 0.42);
165        assert_eq!(node.children, 0);
166    }
167
168    #[test]
169    fn split_node_roundtrip() {
170        let node = PackedNode::split(1.5, 7, 1, 2);
171        assert!(!node.is_leaf());
172        assert_eq!(node.feature_idx(), 7);
173        assert_eq!(node.value, 1.5);
174        assert_eq!(node.left_child(), 1);
175        assert_eq!(node.right_child(), 2);
176    }
177
178    #[test]
179    fn max_feature_index() {
180        // 15 bits = max 32767
181        let node = PackedNode::split(0.0, 0x7FFF, 0, 0);
182        assert_eq!(node.feature_idx(), 0x7FFF);
183        assert!(!node.is_leaf());
184    }
185
186    #[test]
187    fn max_child_indices() {
188        let node = PackedNode::split(0.0, 0, u16::MAX, u16::MAX);
189        assert_eq!(node.left_child(), u16::MAX);
190        assert_eq!(node.right_child(), u16::MAX);
191    }
192
193    #[test]
194    fn five_nodes_per_cache_line() {
195        // 5 × 12 = 60 bytes, fits in 64-byte cache line
196        assert!(5 * size_of::<PackedNode>() <= 64);
197        // 6 × 12 = 72 bytes, does NOT fit
198        assert!(6 * size_of::<PackedNode>() > 64);
199    }
200
201    #[test]
202    fn header_magic_is_irit() {
203        // "IRIT" as little-endian bytes: I=0x49, R=0x52, I=0x49, T=0x54
204        let bytes = EnsembleHeader::MAGIC.to_le_bytes();
205        assert_eq!(&bytes, b"IRIT");
206    }
207}