irithyll_core/packed.rs
1//! 12-byte packed node format and ensemble binary layout.
2//!
3//! # Binary Format
4//!
5//! ```text
6//! [EnsembleHeader: 16 bytes]
7//! [TreeEntry × n_trees: 8 bytes each]
8//! [PackedNode × total_nodes: 12 bytes each]
9//! ```
10//!
11//! Learning rate is baked into leaf values at export time, eliminating one
12//! multiply per tree during inference.
13
14/// 12-byte packed decision tree node. AoS layout for cache-optimal inference.
15///
16/// 5 nodes per 64-byte cache line (60 bytes used, 4 bytes padding).
17/// All fields for one traversal step are adjacent in memory — no cross-vector
18/// striding like the SoA `TreeArena` used during training.
19///
20/// # Field layout
21///
22/// - `value`: split threshold (internal nodes) or leaf prediction (leaf nodes).
23/// For leaves, the learning rate is already baked in: `value = lr * leaf_f64 as f32`.
24/// - `children`: packed left (low u16) and right (high u16) child indices.
25/// For leaves, this field is unused (set to 0).
26/// - `feature_flags`: bit 15 = is_leaf flag. Bits 14:0 = feature index (max 32767).
27/// For leaves, the feature index is unused.
28/// - `_reserved`: padding for future use (categorical flag, metadata, etc.).
29#[repr(C, align(4))]
30#[derive(Clone, Copy, Debug, PartialEq)]
31pub struct PackedNode {
32 /// Split threshold (internal) or prediction value (leaf, with lr baked in).
33 pub value: f32,
34 /// Packed children: left = low u16, right = high u16.
35 pub children: u32,
36 /// Bit 15 = is_leaf. Bits 14:0 = feature index.
37 pub feature_flags: u16,
38 /// Reserved for future use.
39 pub _reserved: u16,
40}
41
42impl PackedNode {
43 /// Bit mask for the is_leaf flag in `feature_flags`.
44 pub const LEAF_FLAG: u16 = 0x8000;
45
46 /// Create a leaf node with a prediction value.
47 #[inline]
48 pub const fn leaf(value: f32) -> Self {
49 Self {
50 value,
51 children: 0,
52 feature_flags: Self::LEAF_FLAG,
53 _reserved: 0,
54 }
55 }
56
57 /// Create an internal (split) node.
58 #[inline]
59 pub const fn split(threshold: f32, feature_idx: u16, left: u16, right: u16) -> Self {
60 Self {
61 value: threshold,
62 children: (left as u32) | ((right as u32) << 16),
63 feature_flags: feature_idx & 0x7FFF,
64 _reserved: 0,
65 }
66 }
67
68 /// Returns `true` if this is a leaf node.
69 #[inline]
70 pub const fn is_leaf(&self) -> bool {
71 self.feature_flags & Self::LEAF_FLAG != 0
72 }
73
74 /// Feature index (bits 14:0). Only meaningful for internal nodes.
75 #[inline]
76 pub const fn feature_idx(&self) -> u16 {
77 self.feature_flags & 0x7FFF
78 }
79
80 /// Left child index (low 16 bits of `children`).
81 #[inline]
82 pub const fn left_child(&self) -> u16 {
83 self.children as u16
84 }
85
86 /// Right child index (high 16 bits of `children`).
87 #[inline]
88 pub const fn right_child(&self) -> u16 {
89 (self.children >> 16) as u16
90 }
91}
92
93/// Header for the packed ensemble binary format. 16 bytes, 4-byte aligned.
94///
95/// Appears at the start of every packed binary. Followed by `n_trees` [`TreeEntry`]
96/// records and then contiguous [`PackedNode`] arrays.
97#[repr(C, align(4))]
98#[derive(Clone, Copy, Debug, PartialEq)]
99pub struct EnsembleHeader {
100 /// Magic bytes: `"IRIT"` in ASCII (little-endian u32: `0x54495249`).
101 pub magic: u32,
102 /// Binary format version. Currently `1`.
103 pub version: u16,
104 /// Number of trees in the ensemble.
105 pub n_trees: u16,
106 /// Expected number of input features.
107 pub n_features: u16,
108 /// Reserved padding.
109 pub _reserved: u16,
110 /// Base prediction (f64 quantized to f32). Added to the sum of tree predictions.
111 pub base_prediction: f32,
112}
113
114impl EnsembleHeader {
115 /// Magic value: "IRIT" in little-endian ASCII.
116 pub const MAGIC: u32 = u32::from_le_bytes(*b"IRIT");
117 /// Current format version.
118 pub const VERSION: u16 = 1;
119}
120
121/// Tree table entry: metadata for one tree in the ensemble. 8 bytes.
122///
123/// The `n_nodes` field gives the number of [`PackedNode`]s in this tree.
124/// The `offset` field is the byte offset from the start of the node data
125/// region to this tree's first node.
126#[repr(C, align(4))]
127#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
128pub struct TreeEntry {
129 /// Number of nodes in this tree.
130 pub n_nodes: u32,
131 /// Byte offset from nodes_base to this tree's first PackedNode.
132 pub offset: u32,
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use core::mem::{align_of, size_of};
139
140 #[test]
141 fn packed_node_is_12_bytes() {
142 assert_eq!(size_of::<PackedNode>(), 12);
143 }
144
145 #[test]
146 fn packed_node_alignment_is_4() {
147 assert_eq!(align_of::<PackedNode>(), 4);
148 }
149
150 #[test]
151 fn ensemble_header_is_16_bytes() {
152 assert_eq!(size_of::<EnsembleHeader>(), 16);
153 }
154
155 #[test]
156 fn tree_entry_is_8_bytes() {
157 assert_eq!(size_of::<TreeEntry>(), 8);
158 }
159
160 #[test]
161 fn leaf_node_roundtrip() {
162 let node = PackedNode::leaf(0.42);
163 assert!(node.is_leaf());
164 assert_eq!(node.value, 0.42);
165 assert_eq!(node.children, 0);
166 }
167
168 #[test]
169 fn split_node_roundtrip() {
170 let node = PackedNode::split(1.5, 7, 1, 2);
171 assert!(!node.is_leaf());
172 assert_eq!(node.feature_idx(), 7);
173 assert_eq!(node.value, 1.5);
174 assert_eq!(node.left_child(), 1);
175 assert_eq!(node.right_child(), 2);
176 }
177
178 #[test]
179 fn max_feature_index() {
180 // 15 bits = max 32767
181 let node = PackedNode::split(0.0, 0x7FFF, 0, 0);
182 assert_eq!(node.feature_idx(), 0x7FFF);
183 assert!(!node.is_leaf());
184 }
185
186 #[test]
187 fn max_child_indices() {
188 let node = PackedNode::split(0.0, 0, u16::MAX, u16::MAX);
189 assert_eq!(node.left_child(), u16::MAX);
190 assert_eq!(node.right_child(), u16::MAX);
191 }
192
193 #[test]
194 fn five_nodes_per_cache_line() {
195 // 5 × 12 = 60 bytes, fits in 64-byte cache line
196 assert!(5 * size_of::<PackedNode>() <= 64);
197 // 6 × 12 = 72 bytes, does NOT fit
198 assert!(6 * size_of::<PackedNode>() > 64);
199 }
200
201 #[test]
202 fn header_magic_is_irit() {
203 // "IRIT" as little-endian bytes: I=0x49, R=0x52, I=0x49, T=0x54
204 let bytes = EnsembleHeader::MAGIC.to_le_bytes();
205 assert_eq!(&bytes, b"IRIT");
206 }
207}