Skip to main content

irithyll_core/
view.rs

1//! Zero-copy, zero-alloc inference view over a packed ensemble binary.
2//!
3//! [`EnsembleView`] is constructed from a `&[u8]` buffer and validates the
4//! entire structure on creation. After validation, all predictions are pure
5//! pointer arithmetic with no allocation — suitable for embedded targets.
6
7use crate::error::FormatError;
8use crate::packed::{EnsembleHeader, PackedNode, TreeEntry};
9use crate::traverse;
10
11/// Zero-copy view over a packed ensemble binary.
12///
13/// All validation happens in [`from_bytes`](EnsembleView::from_bytes). After
14/// construction, predictions use `get_unchecked` for zero-overhead indexing
15/// (all bounds have been verified).
16///
17/// # Lifetime
18///
19/// The view borrows the input buffer — the buffer must outlive the view.
20#[derive(Clone, Copy)]
21pub struct EnsembleView<'a> {
22    header: &'a EnsembleHeader,
23    tree_table: &'a [TreeEntry],
24    nodes: &'a [PackedNode],
25}
26
27impl<'a> EnsembleView<'a> {
28    /// Parse and validate a packed ensemble binary.
29    ///
30    /// Validates:
31    /// - Magic bytes match `"IRIT"`
32    /// - Format version is supported
33    /// - Buffer is large enough for header + tree table + all nodes
34    /// - Every internal node's child indices are within bounds
35    /// - Every internal node's feature index is < `n_features`
36    ///
37    /// # Errors
38    ///
39    /// Returns [`FormatError`] if any validation check fails.
40    pub fn from_bytes(data: &'a [u8]) -> Result<Self, FormatError> {
41        use core::mem::{align_of, size_of};
42
43        let header_size = size_of::<EnsembleHeader>();
44        if data.len() < header_size {
45            return Err(FormatError::Truncated);
46        }
47
48        // Validate alignment — EnsembleHeader requires 4-byte alignment.
49        // If the buffer isn't aligned, we can't safely cast. On embedded targets
50        // this would be a hard fault.
51        if (data.as_ptr() as usize) % align_of::<EnsembleHeader>() != 0 {
52            return Err(FormatError::Unaligned);
53        }
54
55        // SAFETY: We've checked length and alignment.
56        let header = unsafe { &*(data.as_ptr() as *const EnsembleHeader) };
57
58        if header.magic != EnsembleHeader::MAGIC {
59            return Err(FormatError::BadMagic);
60        }
61        if header.version != EnsembleHeader::VERSION {
62            return Err(FormatError::UnsupportedVersion);
63        }
64
65        let n_trees = header.n_trees as usize;
66        let tree_table_size = n_trees * size_of::<TreeEntry>();
67        let tree_table_offset = header_size;
68
69        if data.len() < tree_table_offset + tree_table_size {
70            return Err(FormatError::Truncated);
71        }
72
73        // SAFETY: Alignment of TreeEntry is 4, same as EnsembleHeader, and
74        // header_size (16) is a multiple of 4, so tree_table_ptr is aligned.
75        let tree_table_ptr = unsafe { data.as_ptr().add(tree_table_offset) } as *const TreeEntry;
76        let tree_table = unsafe { core::slice::from_raw_parts(tree_table_ptr, n_trees) };
77
78        // Compute total nodes and validate tree table
79        let nodes_base_offset = tree_table_offset + tree_table_size;
80        let mut total_nodes: usize = 0;
81        for entry in tree_table {
82            total_nodes = total_nodes
83                .checked_add(entry.n_nodes as usize)
84                .ok_or(FormatError::Truncated)?;
85        }
86
87        let nodes_size = total_nodes
88            .checked_mul(size_of::<PackedNode>())
89            .ok_or(FormatError::Truncated)?;
90        let total_required = nodes_base_offset
91            .checked_add(nodes_size)
92            .ok_or(FormatError::Truncated)?;
93        if data.len() < total_required {
94            return Err(FormatError::Truncated);
95        }
96
97        // Validate tree offsets and alignment
98        for entry in tree_table {
99            let node_byte_offset = entry.offset as usize;
100            // Tree offset must be aligned to PackedNode size
101            if node_byte_offset % size_of::<PackedNode>() != 0 {
102                return Err(FormatError::MisalignedTreeOffset);
103            }
104            let tree_bytes = (entry.n_nodes as usize)
105                .checked_mul(size_of::<PackedNode>())
106                .ok_or(FormatError::Truncated)?;
107            let tree_end = node_byte_offset
108                .checked_add(tree_bytes)
109                .ok_or(FormatError::Truncated)?;
110            if tree_end > nodes_size {
111                return Err(FormatError::Truncated);
112            }
113        }
114
115        let nodes_ptr = unsafe { data.as_ptr().add(nodes_base_offset) } as *const PackedNode;
116        let nodes = unsafe { core::slice::from_raw_parts(nodes_ptr, total_nodes) };
117
118        // Validate every node's child indices and feature indices
119        let n_features = header.n_features as usize;
120        for (tree_idx, entry) in tree_table.iter().enumerate() {
121            let tree_node_offset = entry.offset as usize / size_of::<PackedNode>();
122            let tree_n_nodes = entry.n_nodes as usize;
123
124            for local_idx in 0..tree_n_nodes {
125                let global_idx = tree_node_offset + local_idx;
126                let node = &nodes[global_idx];
127
128                if !node.is_leaf() {
129                    let left = node.left_child() as usize;
130                    let right = node.right_child() as usize;
131
132                    // Children must be within this tree's node range
133                    if left >= tree_n_nodes || right >= tree_n_nodes {
134                        return Err(FormatError::InvalidNodeIndex);
135                    }
136
137                    if n_features > 0 && node.feature_idx() as usize >= n_features {
138                        return Err(FormatError::InvalidFeatureIndex);
139                    }
140                }
141            }
142
143            let _ = tree_idx; // suppress unused warning
144        }
145
146        Ok(Self {
147            header,
148            tree_table,
149            nodes,
150        })
151    }
152
153    /// Predict a single sample. Zero allocation.
154    ///
155    /// Returns `base_prediction + sum(tree_predictions)`.
156    ///
157    /// # Precondition
158    ///
159    /// `features.len()` **must** be `>= self.n_features()`. Passing fewer
160    /// features than the model expects causes **undefined behavior** (out-of-bounds
161    /// read via `get_unchecked` in the traversal hot path). A `debug_assert`
162    /// catches this in debug builds.
163    pub fn predict(&self, features: &[f32]) -> f32 {
164        debug_assert!(
165            features.len() >= self.header.n_features as usize,
166            "predict: features.len() ({}) < n_features ({})",
167            features.len(),
168            self.header.n_features
169        );
170        let mut sum = self.header.base_prediction;
171        for entry in self.tree_table {
172            let start = entry.offset as usize / core::mem::size_of::<PackedNode>();
173            let end = start + entry.n_nodes as usize;
174            let tree_nodes = &self.nodes[start..end];
175            sum += traverse::predict_tree(tree_nodes, features);
176        }
177        sum
178    }
179
180    /// Batch predict. Uses x4 interleaving when `samples.len() >= 4`.
181    ///
182    /// # Panics
183    ///
184    /// Panics if `out.len() < samples.len()`.
185    ///
186    /// # Precondition
187    ///
188    /// Every sample must have `len() >= self.n_features()`. See [`predict`](Self::predict).
189    pub fn predict_batch(&self, samples: &[&[f32]], out: &mut [f32]) {
190        assert!(out.len() >= samples.len());
191
192        let n = samples.len();
193        let mut i = 0;
194
195        // Process groups of 4 with interleaved traversal
196        while i + 4 <= n {
197            let batch = [samples[i], samples[i + 1], samples[i + 2], samples[i + 3]];
198            // Initialize with base prediction
199            let mut sums = [self.header.base_prediction; 4];
200            for entry in self.tree_table {
201                let start = entry.offset as usize / core::mem::size_of::<PackedNode>();
202                let end = start + entry.n_nodes as usize;
203                let tree_nodes = &self.nodes[start..end];
204                let preds = traverse::predict_tree_x4(tree_nodes, batch);
205                for j in 0..4 {
206                    sums[j] += preds[j];
207                }
208            }
209            out[i] = sums[0];
210            out[i + 1] = sums[1];
211            out[i + 2] = sums[2];
212            out[i + 3] = sums[3];
213            i += 4;
214        }
215
216        // Remainder
217        while i < n {
218            out[i] = self.predict(samples[i]);
219            i += 1;
220        }
221    }
222
223    /// Number of trees in the ensemble.
224    #[inline]
225    pub fn n_trees(&self) -> u16 {
226        self.header.n_trees
227    }
228
229    /// Expected number of input features.
230    #[inline]
231    pub fn n_features(&self) -> u16 {
232        self.header.n_features
233    }
234
235    /// Base prediction value.
236    #[inline]
237    pub fn base_prediction(&self) -> f32 {
238        self.header.base_prediction
239    }
240
241    /// Total number of packed nodes across all trees.
242    #[inline]
243    pub fn total_nodes(&self) -> usize {
244        self.nodes.len()
245    }
246}
247
248impl<'a> core::fmt::Debug for EnsembleView<'a> {
249    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
250        f.debug_struct("EnsembleView")
251            .field("n_trees", &self.n_trees())
252            .field("n_features", &self.n_features())
253            .field("base_prediction", &self.base_prediction())
254            .field("total_nodes", &self.total_nodes())
255            .finish()
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::packed::{EnsembleHeader, PackedNode, TreeEntry};
263    use alloc::{format, vec, vec::Vec};
264    use core::mem::size_of;
265
266    /// Build a minimal valid packed binary with one tree, one leaf.
267    fn build_single_leaf_binary(leaf_value: f32, base: f32) -> Vec<u8> {
268        let header = EnsembleHeader {
269            magic: EnsembleHeader::MAGIC,
270            version: EnsembleHeader::VERSION,
271            n_trees: 1,
272            n_features: 1,
273            _reserved: 0,
274            base_prediction: base,
275        };
276        let entry = TreeEntry {
277            n_nodes: 1,
278            offset: 0,
279        };
280        let node = PackedNode::leaf(leaf_value);
281
282        let mut buf = Vec::new();
283        // Ensure 4-byte alignment by starting with a properly sized vec
284        buf.extend_from_slice(as_bytes(&header));
285        buf.extend_from_slice(as_bytes(&entry));
286        buf.extend_from_slice(as_bytes(&node));
287        buf
288    }
289
290    /// Build a binary with one tree: root splits on feat 0 at threshold 5.0.
291    fn build_one_split_binary() -> Vec<u8> {
292        let header = EnsembleHeader {
293            magic: EnsembleHeader::MAGIC,
294            version: EnsembleHeader::VERSION,
295            n_trees: 1,
296            n_features: 2,
297            _reserved: 0,
298            base_prediction: 0.0,
299        };
300        let entry = TreeEntry {
301            n_nodes: 3,
302            offset: 0,
303        };
304        let nodes = [
305            PackedNode::split(5.0, 0, 1, 2),
306            PackedNode::leaf(-1.0),
307            PackedNode::leaf(1.0),
308        ];
309
310        let mut buf = Vec::new();
311        buf.extend_from_slice(as_bytes(&header));
312        buf.extend_from_slice(as_bytes(&entry));
313        for n in &nodes {
314            buf.extend_from_slice(as_bytes(n));
315        }
316        buf
317    }
318
319    /// Build a binary with two trees for testing multi-tree prediction.
320    fn build_two_tree_binary() -> Vec<u8> {
321        let header = EnsembleHeader {
322            magic: EnsembleHeader::MAGIC,
323            version: EnsembleHeader::VERSION,
324            n_trees: 2,
325            n_features: 2,
326            _reserved: 0,
327            base_prediction: 1.0,
328        };
329        // Tree 0: 3 nodes, offset 0
330        // Tree 1: 1 node (leaf), offset 3*12=36
331        let entries = [
332            TreeEntry {
333                n_nodes: 3,
334                offset: 0,
335            },
336            TreeEntry {
337                n_nodes: 1,
338                offset: 3 * size_of::<PackedNode>() as u32,
339            },
340        ];
341        let nodes = [
342            // Tree 0
343            PackedNode::split(5.0, 0, 1, 2),
344            PackedNode::leaf(-1.0),
345            PackedNode::leaf(1.0),
346            // Tree 1
347            PackedNode::leaf(0.5),
348        ];
349
350        let mut buf = Vec::new();
351        buf.extend_from_slice(as_bytes(&header));
352        for e in &entries {
353            buf.extend_from_slice(as_bytes(e));
354        }
355        for n in &nodes {
356            buf.extend_from_slice(as_bytes(n));
357        }
358        buf
359    }
360
361    /// Cast a repr(C) struct to bytes.
362    fn as_bytes<T: Sized>(val: &T) -> &[u8] {
363        unsafe { core::slice::from_raw_parts(val as *const T as *const u8, size_of::<T>()) }
364    }
365
366    #[test]
367    fn parse_single_leaf() {
368        let buf = build_single_leaf_binary(42.0, 0.0);
369        let view = EnsembleView::from_bytes(&buf).unwrap();
370        assert_eq!(view.n_trees(), 1);
371        assert_eq!(view.n_features(), 1);
372        assert_eq!(view.total_nodes(), 1);
373    }
374
375    #[test]
376    fn predict_single_leaf() {
377        let buf = build_single_leaf_binary(42.0, 10.0);
378        let view = EnsembleView::from_bytes(&buf).unwrap();
379        // prediction = base(10.0) + leaf(42.0) = 52.0
380        let pred = view.predict(&[0.0]);
381        assert!((pred - 52.0).abs() < 1e-6);
382    }
383
384    #[test]
385    fn predict_one_split_left() {
386        let buf = build_one_split_binary();
387        let view = EnsembleView::from_bytes(&buf).unwrap();
388        // feat[0]=3.0, not > 5.0 -> left -> leaf=-1.0; base=0.0
389        let pred = view.predict(&[3.0, 0.0]);
390        assert!((pred - (-1.0)).abs() < 1e-6);
391    }
392
393    #[test]
394    fn predict_one_split_right() {
395        let buf = build_one_split_binary();
396        let view = EnsembleView::from_bytes(&buf).unwrap();
397        // feat[0]=7.0, > 5.0 -> right -> leaf=1.0; base=0.0
398        let pred = view.predict(&[7.0, 0.0]);
399        assert!((pred - 1.0).abs() < 1e-6);
400    }
401
402    #[test]
403    fn predict_two_trees() {
404        let buf = build_two_tree_binary();
405        let view = EnsembleView::from_bytes(&buf).unwrap();
406        // feat[0]=3.0 -> tree0: left=-1.0; tree1: leaf=0.5; base=1.0
407        // total = 1.0 + (-1.0) + 0.5 = 0.5
408        let pred = view.predict(&[3.0, 0.0]);
409        assert!((pred - 0.5).abs() < 1e-6);
410    }
411
412    #[test]
413    fn predict_batch_matches_single() {
414        let buf = build_two_tree_binary();
415        let view = EnsembleView::from_bytes(&buf).unwrap();
416
417        let samples: Vec<&[f32]> = vec![
418            &[3.0, 0.0],
419            &[7.0, 0.0],
420            &[5.0, 0.0],
421            &[0.0, 0.0],
422            &[10.0, 0.0],
423        ];
424        let mut out = vec![0.0f32; 5];
425        view.predict_batch(&samples, &mut out);
426
427        for (i, &s) in samples.iter().enumerate() {
428            let expected = view.predict(s);
429            assert!(
430                (out[i] - expected).abs() < 1e-6,
431                "batch[{}] = {}, expected {}",
432                i,
433                out[i],
434                expected
435            );
436        }
437    }
438
439    #[test]
440    fn bad_magic_is_rejected() {
441        let mut buf = build_single_leaf_binary(0.0, 0.0);
442        buf[0] = 0xFF; // corrupt magic
443        assert_eq!(
444            EnsembleView::from_bytes(&buf).unwrap_err(),
445            FormatError::BadMagic
446        );
447    }
448
449    #[test]
450    fn truncated_buffer_is_rejected() {
451        let buf = build_single_leaf_binary(0.0, 0.0);
452        assert_eq!(
453            EnsembleView::from_bytes(&buf[..4]).unwrap_err(),
454            FormatError::Truncated
455        );
456    }
457
458    #[test]
459    fn bad_version_is_rejected() {
460        let mut buf = build_single_leaf_binary(0.0, 0.0);
461        // version is at offset 4 (after magic u32), 2 bytes LE
462        buf[4] = 99;
463        buf[5] = 0;
464        assert_eq!(
465            EnsembleView::from_bytes(&buf).unwrap_err(),
466            FormatError::UnsupportedVersion
467        );
468    }
469
470    #[test]
471    fn invalid_child_index_is_rejected() {
472        let header = EnsembleHeader {
473            magic: EnsembleHeader::MAGIC,
474            version: EnsembleHeader::VERSION,
475            n_trees: 1,
476            n_features: 2,
477            _reserved: 0,
478            base_prediction: 0.0,
479        };
480        let entry = TreeEntry {
481            n_nodes: 3,
482            offset: 0,
483        };
484        // Node 0 points to child 99, which is out of bounds (only 3 nodes)
485        let nodes = [
486            PackedNode::split(5.0, 0, 1, 99), // right child out of bounds
487            PackedNode::leaf(-1.0),
488            PackedNode::leaf(1.0),
489        ];
490
491        let mut buf = Vec::new();
492        buf.extend_from_slice(as_bytes(&header));
493        buf.extend_from_slice(as_bytes(&entry));
494        for n in &nodes {
495            buf.extend_from_slice(as_bytes(n));
496        }
497
498        assert_eq!(
499            EnsembleView::from_bytes(&buf).unwrap_err(),
500            FormatError::InvalidNodeIndex
501        );
502    }
503
504    #[test]
505    fn debug_format() {
506        let buf = build_single_leaf_binary(0.0, 0.0);
507        let view = EnsembleView::from_bytes(&buf).unwrap();
508        let debug = format!("{:?}", view);
509        assert!(debug.contains("EnsembleView"));
510        assert!(debug.contains("n_trees"));
511    }
512}