Skip to main content

irithyll_core/
view_i16.rs

1//! Zero-copy, zero-alloc inference view over a quantized (int16) ensemble binary.
2//!
3//! [`QuantizedEnsembleView`] is the int16 equivalent of [`EnsembleView`](crate::EnsembleView).
4//! After validation in [`from_bytes`](QuantizedEnsembleView::from_bytes), all traversal
5//! uses integer-only comparisons. Features are quantized once per predict call (one f32
6//! multiply per comparison), then each tree traversal is pure i16. Leaf values accumulate
7//! as i32, dequantized once at the end.
8//!
9//! On Cortex-M0+ (no FPU), the `predict_prequantized` path eliminates **all** float ops
10//! from the hot loop — the only float is the final `leaf_sum / leaf_scale`.
11
12use crate::error::FormatError;
13use crate::packed::TreeEntry;
14use crate::packed_i16::{PackedNodeI16, QuantizedEnsembleHeader};
15use crate::traverse_i16;
16
17/// Zero-copy view over a quantized (int16) ensemble binary.
18///
19/// After validation in [`from_bytes`](Self::from_bytes), all traversal uses
20/// integer-only comparisons. Features are quantized once per predict call,
21/// then each tree traversal is pure i16. Leaf values accumulate as i32,
22/// dequantized once at the end.
23///
24/// # Lifetime
25///
26/// The view borrows the input buffer — the buffer must outlive the view.
27#[derive(Clone, Copy)]
28pub struct QuantizedEnsembleView<'a> {
29    header: &'a QuantizedEnsembleHeader,
30    leaf_scale: f32,
31    feature_scales: &'a [f32],
32    tree_table: &'a [TreeEntry],
33    nodes: &'a [PackedNodeI16],
34}
35
36impl<'a> QuantizedEnsembleView<'a> {
37    /// Parse and validate a quantized ensemble binary.
38    ///
39    /// # Binary layout
40    ///
41    /// ```text
42    /// [QuantizedEnsembleHeader: 16 bytes]   magic="IR16", version, n_trees, n_features, base_prediction
43    /// [leaf_scale: f32]                      4 bytes — global scale for leaf dequantization
44    /// [feature_scales: f32 x n_features]     n_features x 4 bytes — per-feature quantization scales
45    /// [TreeEntry x n_trees: 8 bytes each]    same struct as f32 format
46    /// [PackedNodeI16 x total_nodes: 8 bytes each]
47    /// ```
48    ///
49    /// # Validates
50    ///
51    /// - Magic bytes match `"IR16"`
52    /// - Format version is supported
53    /// - Buffer is large enough for header + leaf_scale + feature_scales + tree table + all nodes
54    /// - Every internal node's child indices are within bounds
55    /// - Every internal node's feature index is < `n_features`
56    /// - Tree offsets are aligned to `size_of::<PackedNodeI16>()`
57    ///
58    /// # Errors
59    ///
60    /// Returns [`FormatError`] if any validation check fails.
61    pub fn from_bytes(data: &'a [u8]) -> Result<Self, FormatError> {
62        use core::mem::{align_of, size_of};
63
64        let header_size = size_of::<QuantizedEnsembleHeader>();
65        if data.len() < header_size {
66            return Err(FormatError::Truncated);
67        }
68
69        // Validate alignment — QuantizedEnsembleHeader requires 4-byte alignment.
70        if (data.as_ptr() as usize) % align_of::<QuantizedEnsembleHeader>() != 0 {
71            return Err(FormatError::Unaligned);
72        }
73
74        // SAFETY: We've checked length and alignment.
75        let header = unsafe { &*(data.as_ptr() as *const QuantizedEnsembleHeader) };
76
77        if header.magic != QuantizedEnsembleHeader::MAGIC {
78            return Err(FormatError::BadMagic);
79        }
80        if header.version != QuantizedEnsembleHeader::VERSION {
81            return Err(FormatError::UnsupportedVersion);
82        }
83
84        let n_trees = header.n_trees as usize;
85        let n_features = header.n_features as usize;
86
87        // After header: leaf_scale (4 bytes) + feature_scales (n_features * 4 bytes)
88        let leaf_scale_offset = header_size;
89        let leaf_scale_size = size_of::<f32>();
90        let feature_scales_offset = leaf_scale_offset + leaf_scale_size;
91        let feature_scales_size = n_features
92            .checked_mul(size_of::<f32>())
93            .ok_or(FormatError::Truncated)?;
94
95        let tree_table_offset = feature_scales_offset
96            .checked_add(feature_scales_size)
97            .ok_or(FormatError::Truncated)?;
98        let tree_table_size = n_trees
99            .checked_mul(size_of::<TreeEntry>())
100            .ok_or(FormatError::Truncated)?;
101
102        let nodes_base_offset = tree_table_offset
103            .checked_add(tree_table_size)
104            .ok_or(FormatError::Truncated)?;
105
106        // Check we have at least up to the tree table
107        if data.len() < nodes_base_offset {
108            return Err(FormatError::Truncated);
109        }
110
111        // Read leaf_scale
112        // SAFETY: header is 4-byte aligned, header_size is 16 (multiple of 4),
113        // so leaf_scale_offset is 4-byte aligned. f32 requires 4-byte alignment.
114        let leaf_scale_ptr = unsafe { data.as_ptr().add(leaf_scale_offset) } as *const f32;
115        let leaf_scale = unsafe { *leaf_scale_ptr };
116
117        // Read feature_scales
118        // SAFETY: leaf_scale_offset + 4 = feature_scales_offset, still 4-byte aligned.
119        let feature_scales_ptr = unsafe { data.as_ptr().add(feature_scales_offset) } as *const f32;
120        let feature_scales = unsafe { core::slice::from_raw_parts(feature_scales_ptr, n_features) };
121
122        // Read tree table
123        // SAFETY: feature_scales_offset + n_features*4 = tree_table_offset.
124        // n_features*4 is a multiple of 4, so tree_table_offset is 4-byte aligned.
125        // TreeEntry has align(4).
126        let tree_table_ptr = unsafe { data.as_ptr().add(tree_table_offset) } as *const TreeEntry;
127        let tree_table = unsafe { core::slice::from_raw_parts(tree_table_ptr, n_trees) };
128
129        // Compute total nodes and validate
130        let mut total_nodes: usize = 0;
131        for entry in tree_table {
132            total_nodes = total_nodes
133                .checked_add(entry.n_nodes as usize)
134                .ok_or(FormatError::Truncated)?;
135        }
136
137        let nodes_size = total_nodes
138            .checked_mul(size_of::<PackedNodeI16>())
139            .ok_or(FormatError::Truncated)?;
140        let total_required = nodes_base_offset
141            .checked_add(nodes_size)
142            .ok_or(FormatError::Truncated)?;
143        if data.len() < total_required {
144            return Err(FormatError::Truncated);
145        }
146
147        // Validate tree offsets and alignment (must align to 8-byte PackedNodeI16 size)
148        for entry in tree_table {
149            let node_byte_offset = entry.offset as usize;
150            if node_byte_offset % size_of::<PackedNodeI16>() != 0 {
151                return Err(FormatError::MisalignedTreeOffset);
152            }
153            let tree_bytes = (entry.n_nodes as usize)
154                .checked_mul(size_of::<PackedNodeI16>())
155                .ok_or(FormatError::Truncated)?;
156            let tree_end = node_byte_offset
157                .checked_add(tree_bytes)
158                .ok_or(FormatError::Truncated)?;
159            if tree_end > nodes_size {
160                return Err(FormatError::Truncated);
161            }
162        }
163
164        // SAFETY: nodes_base_offset is 4-byte aligned (sum of 4-byte-aligned components).
165        // PackedNodeI16 has align(4), so this is safe.
166        let nodes_ptr = unsafe { data.as_ptr().add(nodes_base_offset) } as *const PackedNodeI16;
167        let nodes = unsafe { core::slice::from_raw_parts(nodes_ptr, total_nodes) };
168
169        // Validate every node's child indices and feature indices
170        for entry in tree_table {
171            let tree_node_offset = entry.offset as usize / size_of::<PackedNodeI16>();
172            let tree_n_nodes = entry.n_nodes as usize;
173
174            for local_idx in 0..tree_n_nodes {
175                let global_idx = tree_node_offset + local_idx;
176                let node = &nodes[global_idx];
177
178                if !node.is_leaf() {
179                    let left = node.left_child() as usize;
180                    let right = node.right_child() as usize;
181
182                    // Children must be within this tree's node range
183                    if left >= tree_n_nodes || right >= tree_n_nodes {
184                        return Err(FormatError::InvalidNodeIndex);
185                    }
186
187                    if n_features > 0 && node.feature_idx() as usize >= n_features {
188                        return Err(FormatError::InvalidFeatureIndex);
189                    }
190                }
191            }
192        }
193
194        Ok(Self {
195            header,
196            leaf_scale,
197            feature_scales,
198            tree_table,
199            nodes,
200        })
201    }
202
203    /// Predict a single sample with inline feature quantization. Zero allocation.
204    ///
205    /// Each tree comparison performs one f32 multiply to quantize the feature on-the-fly.
206    /// Leaf values accumulate as i32 and are dequantized once at the end:
207    /// `base_prediction + leaf_sum / leaf_scale`.
208    ///
209    /// For the pure-integer path (no f32 ops in the hot loop), use
210    /// [`predict_prequantized`](Self::predict_prequantized).
211    ///
212    /// # Precondition
213    ///
214    /// `features.len()` **must** be `>= self.n_features()`. Passing fewer features
215    /// causes **undefined behavior** via `get_unchecked`. A `debug_assert` catches
216    /// this in debug builds.
217    pub fn predict(&self, features: &[f32]) -> f32 {
218        debug_assert!(
219            features.len() >= self.header.n_features as usize,
220            "predict: features.len() ({}) < n_features ({})",
221            features.len(),
222            self.header.n_features
223        );
224
225        let mut leaf_sum: i32 = 0;
226        for entry in self.tree_table {
227            let start = entry.offset as usize / core::mem::size_of::<PackedNodeI16>();
228            let end = start + entry.n_nodes as usize;
229            let tree_nodes = &self.nodes[start..end];
230            leaf_sum +=
231                traverse_i16::predict_tree_i16_inline(tree_nodes, features, self.feature_scales)
232                    as i32;
233        }
234
235        self.header.base_prediction + (leaf_sum as f32) / self.leaf_scale
236    }
237
238    /// Predict a single sample from pre-quantized features. Pure integer hot loop.
239    ///
240    /// The caller is responsible for quantizing features beforehand:
241    /// `features_i16[i] = (features_f32[i] * feature_scales[i]) as i16`
242    ///
243    /// This eliminates **all** float ops from the tree traversal — the only float
244    /// operation is the final `base_prediction + leaf_sum / leaf_scale`.
245    ///
246    /// # Precondition
247    ///
248    /// `features_i16.len()` **must** be `>= self.n_features()`. Passing fewer features
249    /// causes **undefined behavior** via `get_unchecked`. A `debug_assert` catches
250    /// this in debug builds.
251    pub fn predict_prequantized(&self, features_i16: &[i16]) -> f32 {
252        debug_assert!(
253            features_i16.len() >= self.header.n_features as usize,
254            "predict_prequantized: features_i16.len() ({}) < n_features ({})",
255            features_i16.len(),
256            self.header.n_features
257        );
258
259        let mut leaf_sum: i32 = 0;
260        for entry in self.tree_table {
261            let start = entry.offset as usize / core::mem::size_of::<PackedNodeI16>();
262            let end = start + entry.n_nodes as usize;
263            let tree_nodes = &self.nodes[start..end];
264            leaf_sum += traverse_i16::predict_tree_i16(tree_nodes, features_i16) as i32;
265        }
266
267        self.header.base_prediction + (leaf_sum as f32) / self.leaf_scale
268    }
269
270    /// Batch predict with inline quantization. Uses x4 interleaving when `samples.len() >= 4`.
271    ///
272    /// # Panics
273    ///
274    /// Panics if `out.len() < samples.len()`.
275    ///
276    /// # Precondition
277    ///
278    /// Every sample must have `len() >= self.n_features()`. See [`predict`](Self::predict).
279    pub fn predict_batch(&self, samples: &[&[f32]], out: &mut [f32]) {
280        assert!(out.len() >= samples.len());
281
282        // No x4 interleaving for inline quantization — the inline path already
283        // quantizes per comparison, so interleaving would require 4x the f32 ops.
284        // Use simple sequential prediction.
285        for (i, &s) in samples.iter().enumerate() {
286            out[i] = self.predict(s);
287        }
288    }
289
290    /// Batch predict from pre-quantized features. Uses x4 interleaving when `samples.len() >= 4`.
291    ///
292    /// # Panics
293    ///
294    /// Panics if `out.len() < samples.len()`.
295    ///
296    /// # Precondition
297    ///
298    /// Every sample must have `len() >= self.n_features()`. See [`predict_prequantized`](Self::predict_prequantized).
299    pub fn predict_batch_prequantized(&self, samples: &[&[i16]], out: &mut [f32]) {
300        assert!(out.len() >= samples.len());
301
302        let n = samples.len();
303        let mut i = 0;
304
305        // Process groups of 4 with interleaved traversal
306        while i + 4 <= n {
307            let batch = [samples[i], samples[i + 1], samples[i + 2], samples[i + 3]];
308            let mut sums = [0i32; 4];
309            for entry in self.tree_table {
310                let start = entry.offset as usize / core::mem::size_of::<PackedNodeI16>();
311                let end = start + entry.n_nodes as usize;
312                let tree_nodes = &self.nodes[start..end];
313                let preds = traverse_i16::predict_tree_i16_x4(tree_nodes, batch);
314                for j in 0..4 {
315                    sums[j] += preds[j] as i32;
316                }
317            }
318            for j in 0..4 {
319                out[i + j] = self.header.base_prediction + (sums[j] as f32) / self.leaf_scale;
320            }
321            i += 4;
322        }
323
324        // Remainder
325        while i < n {
326            out[i] = self.predict_prequantized(samples[i]);
327            i += 1;
328        }
329    }
330
331    /// Number of trees in the ensemble.
332    #[inline]
333    pub fn n_trees(&self) -> u16 {
334        self.header.n_trees
335    }
336
337    /// Expected number of input features.
338    #[inline]
339    pub fn n_features(&self) -> u16 {
340        self.header.n_features
341    }
342
343    /// Base prediction value.
344    #[inline]
345    pub fn base_prediction(&self) -> f32 {
346        self.header.base_prediction
347    }
348
349    /// Global leaf scale factor for dequantization.
350    #[inline]
351    pub fn leaf_scale(&self) -> f32 {
352        self.leaf_scale
353    }
354
355    /// Per-feature quantization scales.
356    #[inline]
357    pub fn feature_scales(&self) -> &[f32] {
358        self.feature_scales
359    }
360
361    /// Total number of packed i16 nodes across all trees.
362    #[inline]
363    pub fn total_nodes(&self) -> usize {
364        self.nodes.len()
365    }
366}
367
368impl<'a> core::fmt::Debug for QuantizedEnsembleView<'a> {
369    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
370        f.debug_struct("QuantizedEnsembleView")
371            .field("n_trees", &self.n_trees())
372            .field("n_features", &self.n_features())
373            .field("base_prediction", &self.base_prediction())
374            .field("leaf_scale", &self.leaf_scale())
375            .field("total_nodes", &self.total_nodes())
376            .finish()
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383    use crate::packed::TreeEntry;
384    use crate::packed_i16::{PackedNodeI16, QuantizedEnsembleHeader};
385    use alloc::{format, vec, vec::Vec};
386    use core::mem::size_of;
387
388    /// Cast a repr(C) struct to bytes.
389    fn as_bytes<T: Sized>(val: &T) -> &[u8] {
390        unsafe { core::slice::from_raw_parts(val as *const T as *const u8, size_of::<T>()) }
391    }
392
393    /// Build a minimal valid quantized binary with one tree, one leaf.
394    ///
395    /// Layout: header(16) + leaf_scale(4) + feature_scales(n_features*4) + tree_entry(8) + node(8)
396    fn build_single_leaf_binary(leaf_value: i16, base: f32, leaf_scale: f32) -> Vec<u8> {
397        let header = QuantizedEnsembleHeader {
398            magic: QuantizedEnsembleHeader::MAGIC,
399            version: QuantizedEnsembleHeader::VERSION,
400            n_trees: 1,
401            n_features: 1,
402            _reserved: 0,
403            base_prediction: base,
404        };
405        let feature_scale: f32 = 1.0;
406        let entry = TreeEntry {
407            n_nodes: 1,
408            offset: 0,
409        };
410        let node = PackedNodeI16::leaf(leaf_value);
411
412        let mut buf = Vec::new();
413        buf.extend_from_slice(as_bytes(&header));
414        buf.extend_from_slice(as_bytes(&leaf_scale));
415        buf.extend_from_slice(as_bytes(&feature_scale));
416        buf.extend_from_slice(as_bytes(&entry));
417        buf.extend_from_slice(as_bytes(&node));
418        buf
419    }
420
421    /// Build a binary with one tree: root splits on feat 0 at quantized threshold 500.
422    /// Feature scale = 100.0 (so raw feat 5.0 -> quantized 500).
423    fn build_one_split_binary() -> Vec<u8> {
424        let header = QuantizedEnsembleHeader {
425            magic: QuantizedEnsembleHeader::MAGIC,
426            version: QuantizedEnsembleHeader::VERSION,
427            n_trees: 1,
428            n_features: 2,
429            _reserved: 0,
430            base_prediction: 0.0,
431        };
432        let leaf_scale: f32 = 100.0;
433        let feature_scales: [f32; 2] = [100.0, 100.0];
434        let entry = TreeEntry {
435            n_nodes: 3,
436            offset: 0,
437        };
438        let nodes = [
439            PackedNodeI16::split(500, 0, 1, 2), // threshold 500 = raw 5.0 * scale 100.0
440            PackedNodeI16::leaf(-100),          // leaf = -100, dequantized = -100/100 = -1.0
441            PackedNodeI16::leaf(100),           // leaf = 100, dequantized = 100/100 = 1.0
442        ];
443
444        let mut buf = Vec::new();
445        buf.extend_from_slice(as_bytes(&header));
446        buf.extend_from_slice(as_bytes(&leaf_scale));
447        for s in &feature_scales {
448            buf.extend_from_slice(as_bytes(s));
449        }
450        buf.extend_from_slice(as_bytes(&entry));
451        for n in &nodes {
452            buf.extend_from_slice(as_bytes(n));
453        }
454        buf
455    }
456
457    /// Build a binary with two trees for testing multi-tree prediction.
458    fn build_two_tree_binary() -> Vec<u8> {
459        let header = QuantizedEnsembleHeader {
460            magic: QuantizedEnsembleHeader::MAGIC,
461            version: QuantizedEnsembleHeader::VERSION,
462            n_trees: 2,
463            n_features: 2,
464            _reserved: 0,
465            base_prediction: 1.0,
466        };
467        let leaf_scale: f32 = 100.0;
468        let feature_scales: [f32; 2] = [100.0, 100.0];
469        // Tree 0: 3 nodes, offset 0
470        // Tree 1: 1 node (leaf), offset 3 * 8 = 24 bytes
471        let entries = [
472            TreeEntry {
473                n_nodes: 3,
474                offset: 0,
475            },
476            TreeEntry {
477                n_nodes: 1,
478                offset: 3 * size_of::<PackedNodeI16>() as u32,
479            },
480        ];
481        let nodes = [
482            // Tree 0
483            PackedNodeI16::split(500, 0, 1, 2),
484            PackedNodeI16::leaf(-100), // -1.0
485            PackedNodeI16::leaf(100),  // 1.0
486            // Tree 1
487            PackedNodeI16::leaf(50), // 0.5
488        ];
489
490        let mut buf = Vec::new();
491        buf.extend_from_slice(as_bytes(&header));
492        buf.extend_from_slice(as_bytes(&leaf_scale));
493        for s in &feature_scales {
494            buf.extend_from_slice(as_bytes(s));
495        }
496        for e in &entries {
497            buf.extend_from_slice(as_bytes(e));
498        }
499        for n in &nodes {
500            buf.extend_from_slice(as_bytes(n));
501        }
502        buf
503    }
504
505    #[test]
506    fn parse_single_leaf_i16() {
507        let buf = build_single_leaf_binary(42, 0.0, 100.0);
508        let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
509        assert_eq!(view.n_trees(), 1);
510        assert_eq!(view.n_features(), 1);
511        assert_eq!(view.total_nodes(), 1);
512        assert_eq!(view.leaf_scale(), 100.0);
513    }
514
515    #[test]
516    fn predict_single_leaf_i16() {
517        // leaf=42, base=10.0, leaf_scale=100.0
518        // prediction = 10.0 + 42/100.0 = 10.42
519        let buf = build_single_leaf_binary(42, 10.0, 100.0);
520        let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
521        let pred = view.predict(&[0.0]);
522        assert!((pred - 10.42).abs() < 1e-5, "expected 10.42, got {}", pred);
523    }
524
525    #[test]
526    fn predict_one_split_left_i16() {
527        let buf = build_one_split_binary();
528        let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
529        // feat[0]=3.0, scale=100.0 -> quantized=300, not > 500 -> left -> leaf=-100
530        // prediction = 0.0 + (-100)/100.0 = -1.0
531        let pred = view.predict(&[3.0, 0.0]);
532        assert!((pred - (-1.0)).abs() < 1e-5, "expected -1.0, got {}", pred);
533    }
534
535    #[test]
536    fn predict_one_split_right_i16() {
537        let buf = build_one_split_binary();
538        let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
539        // feat[0]=7.0, scale=100.0 -> quantized=700, > 500 -> right -> leaf=100
540        // prediction = 0.0 + 100/100.0 = 1.0
541        let pred = view.predict(&[7.0, 0.0]);
542        assert!((pred - 1.0).abs() < 1e-5, "expected 1.0, got {}", pred);
543    }
544
545    #[test]
546    fn predict_two_trees_i16() {
547        let buf = build_two_tree_binary();
548        let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
549        // feat[0]=3.0, scale=100 -> quantized=300, not > 500 -> tree0: left=-100
550        // tree1: leaf=50
551        // total = 1.0 + (-100 + 50)/100.0 = 1.0 + (-50)/100.0 = 1.0 - 0.5 = 0.5
552        let pred = view.predict(&[3.0, 0.0]);
553        assert!((pred - 0.5).abs() < 1e-5, "expected 0.5, got {}", pred);
554    }
555
556    #[test]
557    fn predict_prequantized_matches_predict() {
558        let buf = build_one_split_binary();
559        let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
560
561        // Test left path: feat[0]=3.0, scale=100.0 -> quantized=300
562        let pred_inline = view.predict(&[3.0, 0.0]);
563        let pred_preq = view.predict_prequantized(&[300, 0]);
564        assert!(
565            (pred_inline - pred_preq).abs() < 1e-5,
566            "left: inline={}, prequantized={}",
567            pred_inline,
568            pred_preq
569        );
570
571        // Test right path: feat[0]=7.0, scale=100.0 -> quantized=700
572        let pred_inline = view.predict(&[7.0, 0.0]);
573        let pred_preq = view.predict_prequantized(&[700, 0]);
574        assert!(
575            (pred_inline - pred_preq).abs() < 1e-5,
576            "right: inline={}, prequantized={}",
577            pred_inline,
578            pred_preq
579        );
580    }
581
582    #[test]
583    fn bad_magic_rejected_i16() {
584        let mut buf = build_single_leaf_binary(0, 0.0, 100.0);
585        buf[0] = 0xFF; // corrupt magic
586        assert_eq!(
587            QuantizedEnsembleView::from_bytes(&buf).unwrap_err(),
588            FormatError::BadMagic
589        );
590    }
591
592    #[test]
593    fn truncated_rejected_i16() {
594        let buf = build_single_leaf_binary(0, 0.0, 100.0);
595        // Only pass the first 4 bytes — not even a full header
596        assert_eq!(
597            QuantizedEnsembleView::from_bytes(&buf[..4]).unwrap_err(),
598            FormatError::Truncated
599        );
600    }
601
602    #[test]
603    fn debug_format_i16() {
604        let buf = build_single_leaf_binary(0, 0.0, 100.0);
605        let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
606        let debug = format!("{:?}", view);
607        assert!(
608            debug.contains("QuantizedEnsembleView"),
609            "missing struct name in debug: {}",
610            debug
611        );
612        assert!(
613            debug.contains("n_trees"),
614            "missing n_trees in debug: {}",
615            debug
616        );
617        assert!(
618            debug.contains("leaf_scale"),
619            "missing leaf_scale in debug: {}",
620            debug
621        );
622    }
623
624    #[test]
625    fn predict_batch_matches_single_i16() {
626        let buf = build_two_tree_binary();
627        let view = QuantizedEnsembleView::from_bytes(&buf).unwrap();
628
629        let samples: Vec<&[f32]> = vec![
630            &[3.0, 0.0],
631            &[7.0, 0.0],
632            &[5.0, 0.0],
633            &[0.0, 0.0],
634            &[10.0, 0.0],
635        ];
636        let mut out = vec![0.0f32; 5];
637        view.predict_batch(&samples, &mut out);
638
639        for (i, &s) in samples.iter().enumerate() {
640            let expected = view.predict(s);
641            assert!(
642                (out[i] - expected).abs() < 1e-6,
643                "batch[{}] = {}, expected {}",
644                i,
645                out[i],
646                expected
647            );
648        }
649    }
650
651    #[test]
652    fn bad_version_rejected_i16() {
653        let mut buf = build_single_leaf_binary(0, 0.0, 100.0);
654        // version is at offset 4 (after magic u32), 2 bytes LE
655        buf[4] = 99;
656        buf[5] = 0;
657        assert_eq!(
658            QuantizedEnsembleView::from_bytes(&buf).unwrap_err(),
659            FormatError::UnsupportedVersion
660        );
661    }
662
663    #[test]
664    fn invalid_child_index_rejected_i16() {
665        let header = QuantizedEnsembleHeader {
666            magic: QuantizedEnsembleHeader::MAGIC,
667            version: QuantizedEnsembleHeader::VERSION,
668            n_trees: 1,
669            n_features: 2,
670            _reserved: 0,
671            base_prediction: 0.0,
672        };
673        let leaf_scale: f32 = 100.0;
674        let feature_scales: [f32; 2] = [100.0, 100.0];
675        let entry = TreeEntry {
676            n_nodes: 3,
677            offset: 0,
678        };
679        // Node 0 points to child 99, which is out of bounds (only 3 nodes)
680        let nodes = [
681            PackedNodeI16::split(500, 0, 1, 99),
682            PackedNodeI16::leaf(-100),
683            PackedNodeI16::leaf(100),
684        ];
685
686        let mut buf = Vec::new();
687        buf.extend_from_slice(as_bytes(&header));
688        buf.extend_from_slice(as_bytes(&leaf_scale));
689        for s in &feature_scales {
690            buf.extend_from_slice(as_bytes(s));
691        }
692        buf.extend_from_slice(as_bytes(&entry));
693        for n in &nodes {
694            buf.extend_from_slice(as_bytes(n));
695        }
696
697        assert_eq!(
698            QuantizedEnsembleView::from_bytes(&buf).unwrap_err(),
699            FormatError::InvalidNodeIndex
700        );
701    }
702}