Skip to main content

edgefirst_image/cpu/
masks.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use super::CPUProcessor;
5use crate::Result;
6use edgefirst_decoder::{DetectBox, Segmentation};
7use ndarray::Axis;
8use rayon::prelude::*;
9
10impl CPUProcessor {
11    #[allow(clippy::too_many_arguments)]
12    pub(super) fn render_modelpack_segmentation(
13        &mut self,
14        dst_w: usize,
15        dst_h: usize,
16        dst_rs: usize,
17        dst_c: usize,
18        dst_slice: &mut [u8],
19        segmentation: &Segmentation,
20        opacity: f32,
21    ) -> Result<()> {
22        use ndarray_stats::QuantileExt;
23
24        let seg = &segmentation.segmentation;
25        let [seg_height, seg_width, seg_classes] = *seg.shape() else {
26            unreachable!("Array3 did not have [usize; 3] as shape");
27        };
28        let start_y = (dst_h as f32 * segmentation.ymin).round();
29        let end_y = (dst_h as f32 * segmentation.ymax).round();
30        let start_x = (dst_w as f32 * segmentation.xmin).round();
31        let end_x = (dst_w as f32 * segmentation.xmax).round();
32
33        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
34        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
35
36        let start_x_u = (start_x as usize).min(dst_w);
37        let start_y_u = (start_y as usize).min(dst_h);
38        let end_x_u = (end_x as usize).min(dst_w);
39        let end_y_u = (end_y as usize).min(dst_h);
40
41        let argmax = seg.map_axis(Axis(2), |r| r.argmax().unwrap());
42        let get_value_at_nearest = |x: f32, y: f32| -> usize {
43            let x = x.round() as usize;
44            let y = y.round() as usize;
45            argmax
46                .get([y.min(seg_height - 1), x.min(seg_width - 1)])
47                .copied()
48                .unwrap_or(0)
49        };
50
51        for y in start_y_u..end_y_u {
52            for x in start_x_u..end_x_u {
53                let seg_x = (x as f32 - start_x) * scale_x;
54                let seg_y = (y as f32 - start_y) * scale_y;
55                let label = get_value_at_nearest(seg_x, seg_y);
56
57                if label == seg_classes - 1 {
58                    continue;
59                }
60
61                let color = self.colors[label % self.colors.len()];
62
63                let alpha = if opacity == 1.0 {
64                    color[3] as u16
65                } else {
66                    (color[3] as f32 * opacity).round() as u16
67                };
68
69                let dst_index = (y * dst_rs) + (x * dst_c);
70                for c in 0..3 {
71                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
72                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
73                        / 255) as u8;
74                }
75            }
76        }
77
78        Ok(())
79    }
80
81    #[allow(clippy::too_many_arguments)]
82    pub(super) fn render_yolo_segmentation(
83        &mut self,
84        dst_w: usize,
85        dst_h: usize,
86        dst_rs: usize,
87        dst_c: usize,
88        dst_slice: &mut [u8],
89        segmentation: &Segmentation,
90        class: usize,
91        opacity: f32,
92    ) -> Result<()> {
93        let seg = &segmentation.segmentation;
94        let [seg_height, seg_width, classes] = *seg.shape() else {
95            unreachable!("Array3 did not have [usize;3] as shape");
96        };
97        debug_assert_eq!(classes, 1);
98
99        let start_y = (dst_h as f32 * segmentation.ymin).round();
100        let end_y = (dst_h as f32 * segmentation.ymax).round();
101        let start_x = (dst_w as f32 * segmentation.xmin).round();
102        let end_x = (dst_w as f32 * segmentation.xmax).round();
103
104        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
105        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
106
107        let start_x_u = (start_x as usize).min(dst_w);
108        let start_y_u = (start_y as usize).min(dst_h);
109        let end_x_u = (end_x as usize).min(dst_w);
110        let end_y_u = (end_y as usize).min(dst_h);
111
112        for y in start_y_u..end_y_u {
113            for x in start_x_u..end_x_u {
114                let seg_x = ((x as f32 - start_x) * scale_x) as usize;
115                let seg_y = ((y as f32 - start_y) * scale_y) as usize;
116                let val = *seg.get([seg_y, seg_x, 0]).unwrap_or(&0);
117
118                if val < 127 {
119                    continue;
120                }
121
122                let color = self.colors[class % self.colors.len()];
123
124                let alpha = if opacity == 1.0 {
125                    color[3] as u16
126                } else {
127                    (color[3] as f32 * opacity).round() as u16
128                };
129
130                let dst_index = (y * dst_rs) + (x * dst_c);
131                for c in 0..3 {
132                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
133                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
134                        / 255) as u8;
135                }
136            }
137        }
138
139        Ok(())
140    }
141
142    #[allow(clippy::too_many_arguments)]
143    pub(super) fn render_box(
144        &mut self,
145        dst_w: usize,
146        dst_h: usize,
147        dst_rs: usize,
148        dst_c: usize,
149        dst_slice: &mut [u8],
150        detect: &[DetectBox],
151        color_mode: crate::ColorMode,
152    ) -> Result<()> {
153        const LINE_THICKNESS: usize = 3;
154
155        for (idx, d) in detect.iter().enumerate() {
156            use edgefirst_decoder::BoundingBox;
157
158            let color_index = color_mode.index(idx, d.label);
159            let [r, g, b, _] = self.colors[color_index % self.colors.len()];
160            let bbox = d.bbox.to_canonical();
161            let bbox = BoundingBox {
162                xmin: bbox.xmin.clamp(0.0, 1.0),
163                ymin: bbox.ymin.clamp(0.0, 1.0),
164                xmax: bbox.xmax.clamp(0.0, 1.0),
165                ymax: bbox.ymax.clamp(0.0, 1.0),
166            };
167            let inner = [
168                ((dst_w - 1) as f32 * bbox.xmin - 0.5).round() as usize,
169                ((dst_h - 1) as f32 * bbox.ymin - 0.5).round() as usize,
170                ((dst_w - 1) as f32 * bbox.xmax + 0.5).round() as usize,
171                ((dst_h - 1) as f32 * bbox.ymax + 0.5).round() as usize,
172            ];
173
174            let outer = [
175                inner[0].saturating_sub(LINE_THICKNESS),
176                inner[1].saturating_sub(LINE_THICKNESS),
177                (inner[2] + LINE_THICKNESS).min(dst_w),
178                (inner[3] + LINE_THICKNESS).min(dst_h),
179            ];
180
181            // top line
182            for y in outer[1] + 1..=inner[1] {
183                for x in outer[0] + 1..outer[2] {
184                    let index = (y * dst_rs) + (x * dst_c);
185                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
186                }
187            }
188
189            // left and right lines
190            for y in inner[1]..inner[3] {
191                for x in outer[0] + 1..=inner[0] {
192                    let index = (y * dst_rs) + (x * dst_c);
193                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
194                }
195
196                for x in inner[2]..outer[2] {
197                    let index = (y * dst_rs) + (x * dst_c);
198                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
199                }
200            }
201
202            // bottom line
203            for y in inner[3]..outer[3] {
204                for x in outer[0] + 1..outer[2] {
205                    let index = (y * dst_rs) + (x * dst_c);
206                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
207                }
208            }
209        }
210        Ok(())
211    }
212
213    /// Materialize segmentation masks from proto data into `Vec<Segmentation>`.
214    ///
215    /// This is the CPU-side decode step of the hybrid mask rendering path:
216    /// call this to get pre-decoded masks, then pass them to
217    /// [`draw_decoded_masks`](crate::ImageProcessorTrait::draw_decoded_masks) for GPU overlay.
218    /// Benchmarks show this hybrid path (CPU decode + GL overlay) is faster
219    /// than the fused GPU `draw_proto_masks` on all tested platforms.
220    ///
221    /// Optimized: fused dequantization + dot product avoids a 3.1MB f32
222    /// allocation for the full proto tensor. Uses fast sigmoid approximation.
223    pub fn materialize_segmentations(
224        &self,
225        detect: &[crate::DetectBox],
226        proto_data: &crate::ProtoData,
227        letterbox: Option<[f32; 4]>,
228    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
229        use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
230
231        let _span = tracing::trace_span!(
232            "image.materialize_masks",
233            mode = "proto",
234            n_detections = detect.len(),
235        )
236        .entered();
237
238        if detect.is_empty() {
239            return Ok(Vec::new());
240        }
241        let proto_shape = proto_data.protos.shape();
242        if proto_shape.len() != 3 {
243            return Err(crate::Error::InvalidShape(format!(
244                "protos tensor must be rank-3, got {proto_shape:?}"
245            )));
246        }
247        // Interpret shape based on physical layout.
248        let (proto_h, proto_w, num_protos) = match proto_data.layout {
249            edgefirst_decoder::ProtoLayout::Nhwc => {
250                (proto_shape[0], proto_shape[1], proto_shape[2])
251            }
252            edgefirst_decoder::ProtoLayout::Nchw => {
253                (proto_shape[1], proto_shape[2], proto_shape[0])
254            }
255        };
256        let coeff_shape = proto_data.mask_coefficients.shape();
257        if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
258            return Err(crate::Error::InvalidShape(format!(
259                "mask_coefficients shape {coeff_shape:?} incompatible with protos \
260                 {proto_shape:?} (expected [N, {num_protos}])"
261            )));
262        }
263        if coeff_shape[0] == 0 {
264            return Ok(Vec::new());
265        }
266        if coeff_shape[0] != detect.len() {
267            return Err(crate::Error::Internal(format!(
268                "mask_coefficients rows {} != detection count {}",
269                coeff_shape[0],
270                detect.len()
271            )));
272        }
273
274        // Precompute inverse letterbox scale for output-coord conversion.
275        let (lx0, inv_lw, ly0, inv_lh) = match letterbox {
276            Some([lx0, ly0, lx1, ly1]) => {
277                let lw = lx1 - lx0;
278                let lh = ly1 - ly0;
279                (
280                    lx0,
281                    if lw > 0.0 { 1.0 / lw } else { 1.0 },
282                    ly0,
283                    if lh > 0.0 { 1.0 / lh } else { 1.0 },
284                )
285            }
286            None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
287        };
288
289        // Fast integer path: when both coefficients and protos are I8 with
290        // per-tensor quantization, use the all-integer kernel (same dot
291        // product infrastructure as the scaled path, but at proto resolution
292        // without bilinear upsampling). Output is binary {0, 255}.
293        // Falls through to the general f32 dequant path for per-channel
294        // quantization or other unsupported modes.
295        if proto_data.mask_coefficients.dtype() == DType::I8
296            && proto_data.protos.dtype() == DType::I8
297        {
298            let coeff_t = proto_data
299                .mask_coefficients
300                .as_i8()
301                .expect("I8 coefficients");
302            let coeff_m = coeff_t.map()?;
303            let coeff_quant = coeff_t.quantization().ok_or_else(|| {
304                crate::Error::InvalidShape(
305                    "I8 mask_coefficients require quantization metadata".into(),
306                )
307            })?;
308            let proto_t = proto_data.protos.as_i8().expect("I8 protos");
309            let proto_m = proto_t.map()?;
310            let proto_quant = proto_t.quantization().ok_or_else(|| {
311                crate::Error::InvalidShape("I8 protos require quantization metadata".into())
312            })?;
313            match proto_segmentations_i8_i8(
314                detect,
315                coeff_m.as_slice(),
316                coeff_quant,
317                proto_m.as_slice(),
318                proto_quant,
319                proto_h,
320                proto_w,
321                num_protos,
322                lx0,
323                inv_lw,
324                ly0,
325                inv_lh,
326                proto_data.layout,
327            ) {
328                Ok(result) => return Ok(result),
329                Err(crate::Error::NotSupported(_)) => {
330                    // Fall through to the general f32 dequant path below for
331                    // per-channel quantization and other unsupported modes.
332                }
333                Err(e) => return Err(e),
334            }
335        }
336
337        // Fast i16×i8 integer path: i16 coefficients × i8 protos → i32 dot.
338        if proto_data.mask_coefficients.dtype() == DType::I16
339            && proto_data.protos.dtype() == DType::I8
340        {
341            let coeff_t = proto_data
342                .mask_coefficients
343                .as_i16()
344                .expect("I16 coefficients");
345            let coeff_m = coeff_t.map()?;
346            // Skip the integer fast path when coefficient quantization is
347            // absent — the f32 fallback below handles raw i16 by widening.
348            if let Some(coeff_quant) = coeff_t.quantization() {
349                let proto_t = proto_data.protos.as_i8().expect("I8 protos");
350                let proto_m = proto_t.map()?;
351                let proto_quant = proto_t.quantization().ok_or_else(|| {
352                    crate::Error::InvalidShape("I8 protos require quantization metadata".into())
353                })?;
354                match proto_segmentations_i16_i8(
355                    detect,
356                    coeff_m.as_slice(),
357                    coeff_quant,
358                    proto_m.as_slice(),
359                    proto_quant,
360                    proto_h,
361                    proto_w,
362                    num_protos,
363                    lx0,
364                    inv_lw,
365                    ly0,
366                    inv_lh,
367                    proto_data.layout,
368                ) {
369                    Ok(result) => return Ok(result),
370                    Err(crate::Error::NotSupported(_)) => {
371                        // Fall through to the general f32 dequant path.
372                    }
373                    Err(e) => return Err(e),
374                }
375            }
376        }
377
378        // Coefficients may be F32 (from f32 models), F16 (from fp16 models),
379        // I8 (from quantized models — kept raw with quantization), or I16.
380        // For the mask kernel we always need an f32 view (the multiply-accumulate
381        // is done in f32 for precision). Map once and widen once outside the loop.
382        if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw
383            && proto_data.protos.dtype() != DType::I8
384        {
385            return Err(crate::Error::NotSupported(
386                "NCHW proto layout with non-I8 protos is not supported in the f32 fallback path"
387                    .into(),
388            ));
389        }
390        let coeff_f32_storage: Vec<f32>;
391        let coeff_f32_slice: &[f32] = match proto_data.mask_coefficients.dtype() {
392            DType::F32 => {
393                let t = proto_data
394                    .mask_coefficients
395                    .as_f32()
396                    .expect("dtype matched F32");
397                let m = t.map()?;
398                coeff_f32_storage = m.as_slice().to_vec();
399                &coeff_f32_storage[..]
400            }
401            DType::F16 => {
402                let t = proto_data
403                    .mask_coefficients
404                    .as_f16()
405                    .expect("dtype matched F16");
406                let m = t.map()?;
407                coeff_f32_storage = m.as_slice().iter().map(|v| v.to_f32()).collect();
408                &coeff_f32_storage[..]
409            }
410            DType::I8 => {
411                let t = proto_data
412                    .mask_coefficients
413                    .as_i8()
414                    .expect("dtype matched I8");
415                let m = t.map()?;
416                coeff_f32_storage = if let Some(q) = t.quantization() {
417                    use edgefirst_tensor::QuantMode;
418                    let (scale, zp) = match q.mode() {
419                        QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
420                        QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
421                        other => {
422                            return Err(crate::Error::NotSupported(format!(
423                                "I8 mask_coefficients quantization mode {other:?} not supported"
424                            )));
425                        }
426                    };
427                    m.as_slice()
428                        .iter()
429                        .map(|&v| (v as f32 - zp) * scale)
430                        .collect()
431                } else {
432                    m.as_slice().iter().map(|&v| v as f32).collect()
433                };
434                &coeff_f32_storage[..]
435            }
436            DType::I16 => {
437                let t = proto_data
438                    .mask_coefficients
439                    .as_i16()
440                    .expect("dtype matched I16");
441                let m = t.map()?;
442                coeff_f32_storage = if let Some(q) = t.quantization() {
443                    use edgefirst_tensor::QuantMode;
444                    let (scale, zp) = match q.mode() {
445                        QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
446                        QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
447                        other => {
448                            return Err(crate::Error::NotSupported(format!(
449                                "I16 mask_coefficients quantization mode {other:?} not supported"
450                            )));
451                        }
452                    };
453                    m.as_slice()
454                        .iter()
455                        .map(|&v| (v as f32 - zp) * scale)
456                        .collect()
457                } else {
458                    m.as_slice().iter().map(|&v| v as f32).collect()
459                };
460                &coeff_f32_storage[..]
461            }
462            other => {
463                return Err(crate::Error::InvalidShape(format!(
464                    "mask_coefficients dtype {other:?} not supported; expected F32, F16, I8, or I16"
465                )));
466            }
467        };
468
469        // Hoist the proto tensor map() out of the per-detection loop so the
470        // map-guard is acquired once. Then dispatch per-dtype via a helper
471        // that runs the per-detection kernels in parallel across detections
472        // via rayon. This restores the parallelism that PR #54 added and
473        // PR #51 (EDGEAI-1244 f16 refactor) inadvertently removed.
474        match proto_data.protos.dtype() {
475            DType::I8 => {
476                let t = proto_data.protos.as_i8().expect("dtype matched I8");
477                let quant = t.quantization().ok_or_else(|| {
478                    crate::Error::InvalidShape("I8 protos require quantization metadata".into())
479                })?;
480                let m = t.map()?;
481                let src_slice = m.as_slice();
482                let transposed_storage =
483                    if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw {
484                        let hw = proto_h * proto_w;
485                        let mut nhwc = vec![0i8; hw * num_protos];
486                        for c in 0..num_protos {
487                            let plane = &src_slice[c * hw..(c + 1) * hw];
488                            for px in 0..hw {
489                                nhwc[px * num_protos + c] = plane[px];
490                            }
491                        }
492                        Some(nhwc)
493                    } else {
494                        None
495                    };
496                let protos_slice = transposed_storage.as_deref().unwrap_or(src_slice);
497                detect
498                    .par_iter()
499                    .enumerate()
500                    .map(|(i, det)| {
501                        let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
502                        let (x0, y0, x1, y1, roi_w, roi_h) =
503                            bbox_to_proto_roi(det, proto_w, proto_h);
504                        let mask = fused_dequant_dot_sign_i8_slice(
505                            protos_slice,
506                            coeff,
507                            quant,
508                            proto_h,
509                            proto_w,
510                            y0,
511                            x0,
512                            roi_h,
513                            roi_w,
514                            num_protos,
515                        )?;
516                        Ok(seg_from_roi(
517                            mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
518                        ))
519                    })
520                    .collect()
521            }
522            DType::F32 => {
523                let t = proto_data.protos.as_f32().expect("dtype matched F32");
524                let m = t.map()?;
525                let protos_slice = m.as_slice();
526                detect
527                    .par_iter()
528                    .enumerate()
529                    .map(|(i, det)| {
530                        let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
531                        let (x0, y0, x1, y1, roi_w, roi_h) =
532                            bbox_to_proto_roi(det, proto_w, proto_h);
533                        let mask = fused_dot_sign_f32_slice(
534                            protos_slice,
535                            coeff,
536                            proto_h,
537                            proto_w,
538                            y0,
539                            x0,
540                            roi_h,
541                            roi_w,
542                            num_protos,
543                        );
544                        Ok(seg_from_roi(
545                            mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
546                        ))
547                    })
548                    .collect()
549            }
550            DType::F16 => {
551                let t = proto_data.protos.as_f16().expect("dtype matched F16");
552                let m = t.map()?;
553                let protos_slice = m.as_slice();
554                detect
555                    .par_iter()
556                    .enumerate()
557                    .map(|(i, det)| {
558                        let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
559                        let (x0, y0, x1, y1, roi_w, roi_h) =
560                            bbox_to_proto_roi(det, proto_w, proto_h);
561                        let mask = fused_dot_sign_f16_slice(
562                            protos_slice,
563                            coeff,
564                            proto_h,
565                            proto_w,
566                            y0,
567                            x0,
568                            roi_h,
569                            roi_w,
570                            num_protos,
571                        );
572                        Ok(seg_from_roi(
573                            mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
574                        ))
575                    })
576                    .collect()
577            }
578            other => Err(crate::Error::InvalidShape(format!(
579                "proto tensor dtype {other:?} not supported"
580            ))),
581        }
582    }
583
584    /// Produce per-detection masks at `(width, height)` pixel resolution by
585    /// upsampling the full proto plane once then cropping per bbox. Each
586    /// `det.bbox` is assumed to be in model-input normalized coordinates
587    /// (the convention used by the decoder output); when `letterbox` is
588    /// `Some`, `(width, height)` are original-content pixel dims and the
589    /// inverse letterbox transform is applied to both the bbox (for the
590    /// crop region and returned `Segmentation` metadata) and each output
591    /// pixel (for proto-plane sampling). Mask values are binary
592    /// `uint8 {0, 255}` after thresholding sigmoid > 0.5.
593    ///
594    /// Used by [`ImageProcessor::materialize_masks`] when the caller selects
595    /// [`MaskResolution::Scaled`](crate::MaskResolution::Scaled).
596    pub fn materialize_scaled_segmentations(
597        &self,
598        detect: &[crate::DetectBox],
599        proto_data: &crate::ProtoData,
600        letterbox: Option<[f32; 4]>,
601        width: u32,
602        height: u32,
603    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
604        use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
605
606        let _span = tracing::trace_span!(
607            "image.materialize_masks",
608            mode = "scaled",
609            n_detections = detect.len(),
610            width,
611            height,
612        )
613        .entered();
614
615        if detect.is_empty() {
616            return Ok(Vec::new());
617        }
618        if width == 0 || height == 0 {
619            return Err(crate::Error::InvalidShape(
620                "Scaled mask width/height must be positive".into(),
621            ));
622        }
623        let proto_shape = proto_data.protos.shape();
624        if proto_shape.len() != 3 {
625            return Err(crate::Error::InvalidShape(format!(
626                "protos tensor must be rank-3, got {proto_shape:?}"
627            )));
628        }
629        // Interpret shape based on physical layout.
630        let (proto_h, proto_w, num_protos) = match proto_data.layout {
631            edgefirst_decoder::ProtoLayout::Nhwc => {
632                (proto_shape[0], proto_shape[1], proto_shape[2])
633            }
634            edgefirst_decoder::ProtoLayout::Nchw => {
635                (proto_shape[1], proto_shape[2], proto_shape[0])
636            }
637        };
638        let coeff_shape = proto_data.mask_coefficients.shape();
639        if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
640            return Err(crate::Error::InvalidShape(format!(
641                "mask_coefficients shape {coeff_shape:?} incompatible with protos \
642                 {proto_shape:?}"
643            )));
644        }
645        if coeff_shape[0] == 0 {
646            return Ok(Vec::new());
647        }
648        if coeff_shape[0] != detect.len() {
649            return Err(crate::Error::Internal(format!(
650                "mask_coefficients rows {} != detection count {}",
651                coeff_shape[0],
652                detect.len()
653            )));
654        }
655
656        // Fast integer path: when both coefficients and protos are I8, use
657        // the all-integer kernel (i8×i8→i32 dot product, sign-shortcut
658        // bilinear). No floating-point conversion at all.
659        // Falls through to the general f32 dequant path for per-channel
660        // quantization or other unsupported modes.
661        if proto_data.mask_coefficients.dtype() == DType::I8
662            && proto_data.protos.dtype() == DType::I8
663        {
664            let coeff_t = proto_data
665                .mask_coefficients
666                .as_i8()
667                .expect("I8 coefficients");
668            let coeff_m = coeff_t.map()?;
669            let coeff_quant = coeff_t.quantization().ok_or_else(|| {
670                crate::Error::InvalidShape(
671                    "I8 mask_coefficients require quantization metadata".into(),
672                )
673            })?;
674            let proto_t = proto_data.protos.as_i8().expect("I8 protos");
675            let proto_m = proto_t.map()?;
676            let proto_quant = proto_t.quantization().ok_or_else(|| {
677                crate::Error::InvalidShape("I8 protos require quantization metadata".into())
678            })?;
679            match scaled_segmentations_i8_i8(
680                detect,
681                coeff_m.as_slice(),
682                coeff_quant,
683                proto_m.as_slice(),
684                proto_quant,
685                proto_h,
686                proto_w,
687                num_protos,
688                letterbox,
689                width,
690                height,
691                proto_data.layout,
692            ) {
693                Ok(result) => return Ok(result),
694                Err(crate::Error::NotSupported(_)) => {
695                    // Fall through to the general f32 dequant path below for
696                    // per-channel quantization and other unsupported modes.
697                }
698                Err(e) => return Err(e),
699            }
700        }
701
702        // Fast i16×i8 integer path: i16 coefficients × i8 protos.
703        if proto_data.mask_coefficients.dtype() == DType::I16
704            && proto_data.protos.dtype() == DType::I8
705        {
706            let coeff_t = proto_data
707                .mask_coefficients
708                .as_i16()
709                .expect("I16 coefficients");
710            let coeff_m = coeff_t.map()?;
711            // Skip the integer fast path when coefficient quantization is
712            // absent — the f32 fallback below handles raw i16 by widening.
713            if let Some(coeff_quant) = coeff_t.quantization() {
714                let proto_t = proto_data.protos.as_i8().expect("I8 protos");
715                let proto_m = proto_t.map()?;
716                let proto_quant = proto_t.quantization().ok_or_else(|| {
717                    crate::Error::InvalidShape("I8 protos require quantization metadata".into())
718                })?;
719                match scaled_segmentations_i16_i8(
720                    detect,
721                    coeff_m.as_slice(),
722                    coeff_quant,
723                    proto_m.as_slice(),
724                    proto_quant,
725                    proto_h,
726                    proto_w,
727                    num_protos,
728                    letterbox,
729                    width,
730                    height,
731                    proto_data.layout,
732                ) {
733                    Ok(result) => return Ok(result),
734                    Err(crate::Error::NotSupported(_)) => {}
735                    Err(e) => return Err(e),
736                }
737            }
738        }
739
740        // Fallback: widen coefficients to f32 for the float-path kernels.
741        if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw
742            && proto_data.protos.dtype() != DType::I8
743        {
744            return Err(crate::Error::NotSupported(
745                "NCHW proto layout with non-I8 protos is not supported in the f32 fallback path"
746                    .into(),
747            ));
748        }
749        let coeff_f32: Vec<f32> = match proto_data.mask_coefficients.dtype() {
750            DType::F32 => {
751                let t = proto_data.mask_coefficients.as_f32().expect("F32");
752                let m = t.map()?;
753                m.as_slice().to_vec()
754            }
755            DType::F16 => {
756                let t = proto_data.mask_coefficients.as_f16().expect("F16");
757                let m = t.map()?;
758                m.as_slice().iter().map(|v| v.to_f32()).collect()
759            }
760            DType::I8 => {
761                // Dequantize I8 coefficients to f32 for the float proto path.
762                let t = proto_data.mask_coefficients.as_i8().expect("I8");
763                let m = t.map()?;
764                let q = t.quantization().ok_or_else(|| {
765                    crate::Error::InvalidShape(
766                        "I8 mask_coefficients require quantization metadata".into(),
767                    )
768                })?;
769                use edgefirst_tensor::QuantMode;
770                let (scale, zp) = match q.mode() {
771                    QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
772                    QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
773                    _ => {
774                        return Err(crate::Error::NotSupported(
775                            "per-channel mask_coefficients not supported".into(),
776                        ))
777                    }
778                };
779                m.as_slice()
780                    .iter()
781                    .map(|&v| (v as f32 - zp) * scale)
782                    .collect()
783            }
784            DType::I16 => {
785                let t = proto_data.mask_coefficients.as_i16().expect("I16");
786                let m = t.map()?;
787                if let Some(q) = t.quantization() {
788                    use edgefirst_tensor::QuantMode;
789                    let (scale, zp) = match q.mode() {
790                        QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
791                        QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
792                        other => {
793                            return Err(crate::Error::NotSupported(format!(
794                                "I16 mask_coefficients quantization mode {other:?} not supported"
795                            )))
796                        }
797                    };
798                    m.as_slice()
799                        .iter()
800                        .map(|&v| (v as f32 - zp) * scale)
801                        .collect()
802                } else {
803                    m.as_slice().iter().map(|&v| v as f32).collect()
804                }
805            }
806            other => {
807                return Err(crate::Error::InvalidShape(format!(
808                    "mask_coefficients dtype {other:?} not supported"
809                )));
810            }
811        };
812
813        match proto_data.protos.dtype() {
814            DType::F32 => {
815                let t = proto_data.protos.as_f32().expect("F32");
816                let m = t.map()?;
817                scaled_segmentations_f32_slice(
818                    detect,
819                    &coeff_f32,
820                    m.as_slice(),
821                    proto_h,
822                    proto_w,
823                    num_protos,
824                    letterbox,
825                    width,
826                    height,
827                )
828            }
829            DType::F16 => {
830                let t = proto_data.protos.as_f16().expect("F16");
831                let m = t.map()?;
832                scaled_segmentations_f16_slice(
833                    detect,
834                    &coeff_f32,
835                    m.as_slice(),
836                    proto_h,
837                    proto_w,
838                    num_protos,
839                    letterbox,
840                    width,
841                    height,
842                )
843            }
844            DType::I8 => {
845                let t = proto_data.protos.as_i8().expect("I8");
846                let m = t.map()?;
847                let quant = t.quantization().ok_or_else(|| {
848                    crate::Error::InvalidShape("I8 protos require quantization metadata".into())
849                })?;
850                let src_slice = m.as_slice();
851                let transposed_storage =
852                    if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw {
853                        let hw = proto_h * proto_w;
854                        let mut nhwc = vec![0i8; hw * num_protos];
855                        for c in 0..num_protos {
856                            let plane = &src_slice[c * hw..(c + 1) * hw];
857                            for px in 0..hw {
858                                nhwc[px * num_protos + c] = plane[px];
859                            }
860                        }
861                        Some(nhwc)
862                    } else {
863                        None
864                    };
865                let protos_slice = transposed_storage.as_deref().unwrap_or(src_slice);
866                scaled_segmentations_i8_slice(
867                    detect,
868                    &coeff_f32,
869                    protos_slice,
870                    proto_h,
871                    proto_w,
872                    num_protos,
873                    quant,
874                    letterbox,
875                    width,
876                    height,
877                )
878            }
879            other => Err(crate::Error::InvalidShape(format!(
880                "proto tensor dtype {other:?} not supported"
881            ))),
882        }
883    }
884}
885
886// =============================================================================
887// Slice-native fused kernels.
888//
889// All kernels take row-major `[H, W, num_protos]` proto slices + `&[f32]`
890// coefficients (widened once from the source dtype at the materialize entry
891// point). Per-dtype variants exist for i8 (with on-the-fly dequant using a
892// tensor-level `Quantization`), f32, and f16; f16 widens to f32 per-element
893// at the FMA site via `half::f16::to_f32()`.
894//
895// On ARMv8.2-FP16 this compiles to `fcvt`; on Cortex-A53 and non-F16C x86 it
896// becomes a soft-float helper.
897// =============================================================================
898
899/// Map a detection bbox in normalised letterboxed coords to its ROI in
900/// the proto plane (floor xmin/ymin, ceil xmax/ymax, clamp to plane bounds).
901/// Returns `(x0, y0, x1, y1, roi_w, roi_h)` where roi_w/h are guaranteed ≥ 1.
902fn bbox_to_proto_roi(
903    det: &DetectBox,
904    proto_w: usize,
905    proto_h: usize,
906) -> (usize, usize, usize, usize, usize, usize) {
907    let bbox = det.bbox.to_canonical();
908    let xmin = bbox.xmin.clamp(0.0, 1.0);
909    let ymin = bbox.ymin.clamp(0.0, 1.0);
910    let xmax = bbox.xmax.clamp(0.0, 1.0);
911    let ymax = bbox.ymax.clamp(0.0, 1.0);
912    let x0 = ((xmin * proto_w as f32) as usize).min(proto_w.saturating_sub(1));
913    let y0 = ((ymin * proto_h as f32) as usize).min(proto_h.saturating_sub(1));
914    let x1 = ((xmax * proto_w as f32).ceil() as usize).min(proto_w);
915    let y1 = ((ymax * proto_h as f32).ceil() as usize).min(proto_h);
916    let roi_w = x1.saturating_sub(x0).max(1);
917    let roi_h = y1.saturating_sub(y0).max(1);
918    (x0, y0, x1, y1, roi_w, roi_h)
919}
920
921/// Build a `Segmentation` from a per-detection mask + the ROI bounds in
922/// proto coords. Applies the inverse letterbox transform to express the
923/// segmentation bbox in original-image-content normalised space.
924#[allow(clippy::too_many_arguments)]
925fn seg_from_roi(
926    mask: ndarray::Array3<u8>,
927    x0: usize,
928    y0: usize,
929    x1: usize,
930    y1: usize,
931    proto_w: usize,
932    proto_h: usize,
933    lx0: f32,
934    inv_lw: f32,
935    ly0: f32,
936    inv_lh: f32,
937) -> edgefirst_decoder::Segmentation {
938    let seg_xmin = ((x0 as f32 / proto_w as f32) - lx0) * inv_lw;
939    let seg_ymin = ((y0 as f32 / proto_h as f32) - ly0) * inv_lh;
940    let seg_xmax = ((x1 as f32 / proto_w as f32) - lx0) * inv_lw;
941    let seg_ymax = ((y1 as f32 / proto_h as f32) - ly0) * inv_lh;
942    edgefirst_decoder::Segmentation {
943        xmin: seg_xmin.clamp(0.0, 1.0),
944        ymin: seg_ymin.clamp(0.0, 1.0),
945        xmax: seg_xmax.clamp(0.0, 1.0),
946        ymax: seg_ymax.clamp(0.0, 1.0),
947        segmentation: mask,
948    }
949}
950
951// =============================================================================
952// Integer-domain proto-resolution kernel: i8 coefficients × i8 protos → i32
953// → sign threshold → binary {0, 255}.
954//
955// Reuses the same dot product infrastructure as the scaled path (NEON sdot on
956// A55+, smull+sadalp on A53, scalar fallback on x86). Since proto-resolution
957// produces masks at the native proto grid (~30×30 per ROI), there is no
958// bilinear upsampling — just a direct sign threshold per pixel.
959// =============================================================================
960
961/// Proto-resolution mask materialization using integer-domain math.
962///
963/// For each detection, computes the i8×i8 dot product at every proto-ROI pixel,
964/// applies the zero-point correction, and thresholds at sign(logit) → {0, 255}.
965/// Supports both NHWC and NCHW proto layouts.
966#[allow(clippy::too_many_arguments)]
967fn proto_segmentations_i8_i8(
968    detect: &[crate::DetectBox],
969    coeff_all: &[i8],
970    coeff_quant: &edgefirst_tensor::Quantization,
971    protos: &[i8],
972    proto_quant: &edgefirst_tensor::Quantization,
973    proto_h: usize,
974    proto_w: usize,
975    num_protos: usize,
976    lx0: f32,
977    inv_lw: f32,
978    ly0: f32,
979    inv_lh: f32,
980    layout: edgefirst_decoder::ProtoLayout,
981) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
982    use edgefirst_tensor::QuantMode;
983
984    let _span = tracing::trace_span!(
985        "image.materialize_masks.kernel_i8",
986        n = detect.len(),
987        proto_h,
988        proto_w,
989        num_protos,
990        ?layout,
991    )
992    .entered();
993
994    let zp_c: i32 = match coeff_quant.mode() {
995        QuantMode::PerTensor { zero_point, .. } => zero_point,
996        QuantMode::PerTensorSymmetric { .. } => 0,
997        _ => {
998            return Err(crate::Error::NotSupported(
999                "per-channel coeff quantization not supported on proto-res i8 path".into(),
1000            ))
1001        }
1002    };
1003    let zp_p: i32 = match proto_quant.mode() {
1004        QuantMode::PerTensor { zero_point, .. } => zero_point,
1005        QuantMode::PerTensorSymmetric { .. } => 0,
1006        _ => {
1007            return Err(crate::Error::NotSupported(
1008                "per-channel proto quantization not supported on proto-res i8 path".into(),
1009            ))
1010        }
1011    };
1012
1013    let hw = proto_h * proto_w;
1014
1015    // Precompute per-pixel proto sums for zero-point correction.
1016    let proto_sums: Vec<i32> = if zp_c != 0 {
1017        match layout {
1018            edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
1019                .map(|px_idx| {
1020                    let base = px_idx * num_protos;
1021                    protos[base..base + num_protos]
1022                        .iter()
1023                        .map(|&v| v as i32)
1024                        .sum()
1025                })
1026                .collect(),
1027            edgefirst_decoder::ProtoLayout::Nchw => {
1028                let mut sums = vec![0i32; hw];
1029                for c in 0..num_protos {
1030                    let plane = &protos[c * hw..];
1031                    for (px, s) in sums.iter_mut().enumerate() {
1032                        *s += plane[px] as i32;
1033                    }
1034                }
1035                sums
1036            }
1037        }
1038    } else {
1039        Vec::new()
1040    };
1041
1042    #[cfg(target_arch = "aarch64")]
1043    let use_dotprod = std::arch::is_aarch64_feature_detected!("dotprod");
1044
1045    detect
1046        .par_iter()
1047        .enumerate()
1048        .map(|(i, det)| {
1049            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1050            let (x0, y0, x1, y1, roi_w, roi_h) = bbox_to_proto_roi(det, proto_w, proto_h);
1051
1052            // Per-detection bias: zp_p·Σc_raw - N·zp_c·zp_p
1053            let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
1054            let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
1055
1056            let mut mask_buf = vec![0u8; roi_h * roi_w];
1057
1058            match layout {
1059                edgefirst_decoder::ProtoLayout::Nhwc => {
1060                    let stride_y = proto_w * num_protos;
1061                    #[cfg(target_arch = "aarch64")]
1062                    {
1063                        if use_dotprod {
1064                            for ly in 0..roi_h {
1065                                let py = y0 + ly;
1066                                let row_base = py * stride_y + x0 * num_protos;
1067                                for lx in 0..roi_w {
1068                                    let pix_base = row_base + lx * num_protos;
1069                                    let proto_px = &protos[pix_base..pix_base + num_protos];
1070                                    let raw_dot = unsafe {
1071                                        dot_i8_neon_dotprod(
1072                                            coeff.as_ptr(),
1073                                            proto_px.as_ptr(),
1074                                            num_protos,
1075                                        )
1076                                    };
1077                                    let correction = if zp_c != 0 {
1078                                        zp_c * proto_sums[py * proto_w + x0 + lx]
1079                                    } else {
1080                                        0
1081                                    };
1082                                    let logit = raw_dot - correction - bias;
1083                                    if logit > 0 {
1084                                        mask_buf[ly * roi_w + lx] = 255;
1085                                    }
1086                                }
1087                            }
1088                        } else {
1089                            for ly in 0..roi_h {
1090                                let py = y0 + ly;
1091                                let row_base = py * stride_y + x0 * num_protos;
1092                                for lx in 0..roi_w {
1093                                    let pix_base = row_base + lx * num_protos;
1094                                    let proto_px = &protos[pix_base..pix_base + num_protos];
1095                                    let raw_dot = unsafe {
1096                                        dot_i8_neon_base(
1097                                            coeff.as_ptr(),
1098                                            proto_px.as_ptr(),
1099                                            num_protos,
1100                                        )
1101                                    };
1102                                    let correction = if zp_c != 0 {
1103                                        zp_c * proto_sums[py * proto_w + x0 + lx]
1104                                    } else {
1105                                        0
1106                                    };
1107                                    let logit = raw_dot - correction - bias;
1108                                    if logit > 0 {
1109                                        mask_buf[ly * roi_w + lx] = 255;
1110                                    }
1111                                }
1112                            }
1113                        }
1114                    }
1115                    #[cfg(not(target_arch = "aarch64"))]
1116                    {
1117                        for ly in 0..roi_h {
1118                            let py = y0 + ly;
1119                            let row_base = py * stride_y + x0 * num_protos;
1120                            for lx in 0..roi_w {
1121                                let pix_base = row_base + lx * num_protos;
1122                                let proto_px = &protos[pix_base..pix_base + num_protos];
1123                                let raw_dot = dot_i8_scalar(coeff, proto_px, num_protos);
1124                                let correction = if zp_c != 0 {
1125                                    zp_c * proto_sums[py * proto_w + x0 + lx]
1126                                } else {
1127                                    0
1128                                };
1129                                let logit = raw_dot - correction - bias;
1130                                if logit > 0 {
1131                                    mask_buf[ly * roi_w + lx] = 255;
1132                                }
1133                            }
1134                        }
1135                    }
1136                }
1137                edgefirst_decoder::ProtoLayout::Nchw => {
1138                    // Channel-major accumulation: for each channel, accumulate
1139                    // coeff[c] * proto[c, py, px] across the ROI. Each channel
1140                    // plane is contiguous, giving excellent sequential read access.
1141                    let mut accum = vec![0i32; roi_h * roi_w];
1142                    for c in 0..num_protos {
1143                        let plane = &protos[c * hw..];
1144                        let coeff_c = coeff[c] as i32;
1145                        for ly in 0..roi_h {
1146                            let py = y0 + ly;
1147                            let row_start = py * proto_w + x0;
1148                            let out_row_start = ly * roi_w;
1149                            for lx in 0..roi_w {
1150                                accum[out_row_start + lx] += coeff_c * plane[row_start + lx] as i32;
1151                            }
1152                        }
1153                    }
1154                    // Apply zero-point correction and threshold.
1155                    for ly in 0..roi_h {
1156                        let py = y0 + ly;
1157                        for lx in 0..roi_w {
1158                            let idx = ly * roi_w + lx;
1159                            let correction = if zp_c != 0 {
1160                                zp_c * proto_sums[py * proto_w + x0 + lx]
1161                            } else {
1162                                0
1163                            };
1164                            let logit = accum[idx] - correction - bias;
1165                            if logit > 0 {
1166                                mask_buf[idx] = 255;
1167                            }
1168                        }
1169                    }
1170                }
1171            }
1172
1173            let mask = ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1174                .expect("mask_buf length matches roi_h * roi_w");
1175            Ok(seg_from_roi(
1176                mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
1177            ))
1178        })
1179        .collect()
1180}
1181
1182// =============================================================================
1183#[allow(clippy::too_many_arguments)]
1184fn proto_segmentations_i16_i8(
1185    detect: &[crate::DetectBox],
1186    coeff_all: &[i16],
1187    coeff_quant: &edgefirst_tensor::Quantization,
1188    protos: &[i8],
1189    proto_quant: &edgefirst_tensor::Quantization,
1190    proto_h: usize,
1191    proto_w: usize,
1192    num_protos: usize,
1193    lx0: f32,
1194    inv_lw: f32,
1195    ly0: f32,
1196    inv_lh: f32,
1197    layout: edgefirst_decoder::ProtoLayout,
1198) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1199    use edgefirst_tensor::QuantMode;
1200
1201    let _span = tracing::trace_span!(
1202        "image.materialize_masks.kernel_i16xi8",
1203        n = detect.len(),
1204        proto_h,
1205        proto_w,
1206        num_protos,
1207        ?layout,
1208    )
1209    .entered();
1210
1211    let zp_c: i32 = match coeff_quant.mode() {
1212        QuantMode::PerTensor { zero_point, .. } => zero_point,
1213        QuantMode::PerTensorSymmetric { .. } => 0,
1214        _ => {
1215            return Err(crate::Error::NotSupported(
1216                "per-channel coeff quantization not supported on proto-res i16 path".into(),
1217            ))
1218        }
1219    };
1220    let zp_p: i32 = match proto_quant.mode() {
1221        QuantMode::PerTensor { zero_point, .. } => zero_point,
1222        QuantMode::PerTensorSymmetric { .. } => 0,
1223        _ => {
1224            return Err(crate::Error::NotSupported(
1225                "per-channel proto quantization not supported on proto-res i8 path".into(),
1226            ))
1227        }
1228    };
1229
1230    let hw = proto_h * proto_w;
1231
1232    // Precompute per-pixel proto sums for zero-point correction.
1233    let proto_sums: Vec<i32> = if zp_c != 0 {
1234        match layout {
1235            edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
1236                .map(|px_idx| {
1237                    let base = px_idx * num_protos;
1238                    protos[base..base + num_protos]
1239                        .iter()
1240                        .map(|&v| v as i32)
1241                        .sum()
1242                })
1243                .collect(),
1244            edgefirst_decoder::ProtoLayout::Nchw => {
1245                let mut sums = vec![0i32; hw];
1246                for c in 0..num_protos {
1247                    let plane = &protos[c * hw..];
1248                    for (px, s) in sums.iter_mut().enumerate() {
1249                        *s += plane[px] as i32;
1250                    }
1251                }
1252                sums
1253            }
1254        }
1255    } else {
1256        Vec::new()
1257    };
1258
1259    detect
1260        .par_iter()
1261        .enumerate()
1262        .map(|(i, det)| {
1263            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1264            let (x0, y0, x1, y1, roi_w, roi_h) = bbox_to_proto_roi(det, proto_w, proto_h);
1265
1266            // Per-detection bias: zp_p·Σc_raw - N·zp_c·zp_p
1267            let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
1268            let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
1269
1270            let mut mask_buf = vec![0u8; roi_h * roi_w];
1271
1272            match layout {
1273                edgefirst_decoder::ProtoLayout::Nhwc => {
1274                    let stride_y = proto_w * num_protos;
1275                    #[cfg(target_arch = "aarch64")]
1276                    {
1277                        for ly in 0..roi_h {
1278                            let py = y0 + ly;
1279                            let row_base = py * stride_y + x0 * num_protos;
1280                            for lx in 0..roi_w {
1281                                let pix_base = row_base + lx * num_protos;
1282                                let proto_px = &protos[pix_base..pix_base + num_protos];
1283                                let raw_dot = unsafe {
1284                                    dot_i16_i8_neon(coeff.as_ptr(), proto_px.as_ptr(), num_protos)
1285                                };
1286                                let correction = if zp_c != 0 {
1287                                    zp_c * proto_sums[py * proto_w + x0 + lx]
1288                                } else {
1289                                    0
1290                                };
1291                                let logit = raw_dot - correction - bias;
1292                                if logit > 0 {
1293                                    mask_buf[ly * roi_w + lx] = 255;
1294                                }
1295                            }
1296                        }
1297                    }
1298                    #[cfg(not(target_arch = "aarch64"))]
1299                    {
1300                        for ly in 0..roi_h {
1301                            let py = y0 + ly;
1302                            let row_base = py * stride_y + x0 * num_protos;
1303                            for lx in 0..roi_w {
1304                                let pix_base = row_base + lx * num_protos;
1305                                let proto_px = &protos[pix_base..pix_base + num_protos];
1306                                let raw_dot = dot_i16_i8_scalar(coeff, proto_px, num_protos);
1307                                let correction = if zp_c != 0 {
1308                                    zp_c * proto_sums[py * proto_w + x0 + lx]
1309                                } else {
1310                                    0
1311                                };
1312                                let logit = raw_dot - correction - bias;
1313                                if logit > 0 {
1314                                    mask_buf[ly * roi_w + lx] = 255;
1315                                }
1316                            }
1317                        }
1318                    }
1319                }
1320                edgefirst_decoder::ProtoLayout::Nchw => {
1321                    // Channel-major accumulation: for each channel, accumulate
1322                    // coeff[c] * proto[c, py, px] across the ROI. Each channel
1323                    // plane is contiguous, giving excellent sequential read access.
1324                    let mut accum = vec![0i32; roi_h * roi_w];
1325                    for c in 0..num_protos {
1326                        let plane = &protos[c * hw..];
1327                        let coeff_c = coeff[c] as i32;
1328                        for ly in 0..roi_h {
1329                            let py = y0 + ly;
1330                            let row_start = py * proto_w + x0;
1331                            let out_row_start = ly * roi_w;
1332                            for lx in 0..roi_w {
1333                                accum[out_row_start + lx] += coeff_c * plane[row_start + lx] as i32;
1334                            }
1335                        }
1336                    }
1337                    // Apply zero-point correction and threshold.
1338                    for ly in 0..roi_h {
1339                        let py = y0 + ly;
1340                        for lx in 0..roi_w {
1341                            let idx = ly * roi_w + lx;
1342                            let correction = if zp_c != 0 {
1343                                zp_c * proto_sums[py * proto_w + x0 + lx]
1344                            } else {
1345                                0
1346                            };
1347                            let logit = accum[idx] - correction - bias;
1348                            if logit > 0 {
1349                                mask_buf[idx] = 255;
1350                            }
1351                        }
1352                    }
1353                }
1354            }
1355
1356            let mask = ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1357                .expect("mask_buf length matches roi_h * roi_w");
1358            Ok(seg_from_roi(
1359                mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
1360            ))
1361        })
1362        .collect()
1363}
1364
1365// =============================================================================
1366
1367// Sign-threshold proto-resolution kernels (f32/f16/i8 protos with f32 coeffs).
1368//
1369// These replace the sigmoid-computing kernels for the non-i8×i8 fallback paths.
1370// Since downstream always thresholds at > 127, computing sigmoid is wasteful;
1371// sign(dot) > 0 ⟺ sigmoid(dot) > 0.5 gives the same binary result.
1372// =============================================================================
1373
1374/// f32 protos × f32 coefficients → sign threshold → binary {0, 255}.
1375#[allow(clippy::too_many_arguments)]
1376fn fused_dot_sign_f32_slice(
1377    protos: &[f32],
1378    coeff: &[f32],
1379    _proto_h: usize,
1380    proto_w: usize,
1381    y0: usize,
1382    x0: usize,
1383    roi_h: usize,
1384    roi_w: usize,
1385    num_protos: usize,
1386) -> ndarray::Array3<u8> {
1387    let stride_y = proto_w * num_protos;
1388    let mut mask_buf = vec![0u8; roi_h * roi_w];
1389    for y in 0..roi_h {
1390        let row_base = (y0 + y) * stride_y + x0 * num_protos;
1391        let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1392        for (x, out_px) in out_row.iter_mut().enumerate() {
1393            let base = row_base + x * num_protos;
1394            let mut acc = 0.0_f32;
1395            let mut k = 0;
1396            let chunks = num_protos / 4;
1397            for _ in 0..chunks {
1398                acc += coeff[k] * protos[base + k]
1399                    + coeff[k + 1] * protos[base + k + 1]
1400                    + coeff[k + 2] * protos[base + k + 2]
1401                    + coeff[k + 3] * protos[base + k + 3];
1402                k += 4;
1403            }
1404            while k < num_protos {
1405                acc += coeff[k] * protos[base + k];
1406                k += 1;
1407            }
1408            if acc > 0.0 {
1409                *out_px = 255;
1410            }
1411        }
1412    }
1413    ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1414        .expect("mask_buf length matches roi_h * roi_w")
1415}
1416
1417/// f16 protos × f32 coefficients → sign threshold → binary {0, 255}.
1418///
1419/// Two code paths:
1420///
1421/// 1. **x86_64 + F16C + FMA** — explicit intrinsic kernel using
1422///    `_mm256_cvtph_ps` (8-lane f16→f32 widening) + `_mm256_fmadd_ps`.
1423///
1424/// 2. **Scalar fallback** — loop-unrolled by 4 with `half::f16::to_f32()`.
1425#[allow(clippy::too_many_arguments)]
1426fn fused_dot_sign_f16_slice(
1427    protos: &[half::f16],
1428    coeff: &[f32],
1429    _proto_h: usize,
1430    proto_w: usize,
1431    y0: usize,
1432    x0: usize,
1433    roi_h: usize,
1434    roi_w: usize,
1435    num_protos: usize,
1436) -> ndarray::Array3<u8> {
1437    #[cfg(all(
1438        target_arch = "x86_64",
1439        target_feature = "f16c",
1440        target_feature = "fma"
1441    ))]
1442    {
1443        // SAFETY: target-feature gates guarantee F16C + FMA support.
1444        unsafe {
1445            fused_dot_sign_f16_slice_f16c(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
1446        }
1447    }
1448    #[cfg(not(all(
1449        target_arch = "x86_64",
1450        target_feature = "f16c",
1451        target_feature = "fma"
1452    )))]
1453    {
1454        fused_dot_sign_f16_slice_scalar(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
1455    }
1456}
1457
1458/// Scalar f16 sign-threshold kernel — loop-unrolled by 4.
1459#[allow(clippy::too_many_arguments)]
1460fn fused_dot_sign_f16_slice_scalar(
1461    protos: &[half::f16],
1462    coeff: &[f32],
1463    proto_w: usize,
1464    y0: usize,
1465    x0: usize,
1466    roi_h: usize,
1467    roi_w: usize,
1468    num_protos: usize,
1469) -> ndarray::Array3<u8> {
1470    let stride_y = proto_w * num_protos;
1471    let mut mask_buf = vec![0u8; roi_h * roi_w];
1472    for y in 0..roi_h {
1473        let row_base = (y0 + y) * stride_y + x0 * num_protos;
1474        let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1475        for (x, out_px) in out_row.iter_mut().enumerate() {
1476            let base = row_base + x * num_protos;
1477            let mut acc = 0.0_f32;
1478            let mut k = 0;
1479            let chunks = num_protos / 4;
1480            for _ in 0..chunks {
1481                acc += coeff[k] * protos[base + k].to_f32()
1482                    + coeff[k + 1] * protos[base + k + 1].to_f32()
1483                    + coeff[k + 2] * protos[base + k + 2].to_f32()
1484                    + coeff[k + 3] * protos[base + k + 3].to_f32();
1485                k += 4;
1486            }
1487            while k < num_protos {
1488                acc += coeff[k] * protos[base + k].to_f32();
1489                k += 1;
1490            }
1491            if acc > 0.0 {
1492                *out_px = 255;
1493            }
1494        }
1495    }
1496    ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1497        .expect("mask_buf length matches roi_h * roi_w")
1498}
1499
1500/// x86_64 F16C + FMA intrinsic kernel for f16 sign-threshold.
1501///
1502/// Uses `_mm256_cvtph_ps` for hardware 8-lane f16→f32 widening and
1503/// `_mm256_fmadd_ps` for fused multiply-add. Only the sign of the
1504/// accumulated dot product is checked (no sigmoid needed).
1505///
1506/// # Safety
1507///
1508/// Caller must ensure the target CPU supports F16C + FMA.
1509#[cfg(all(
1510    target_arch = "x86_64",
1511    target_feature = "f16c",
1512    target_feature = "fma"
1513))]
1514#[allow(clippy::too_many_arguments)]
1515#[target_feature(enable = "f16c,fma,avx")]
1516unsafe fn fused_dot_sign_f16_slice_f16c(
1517    protos: &[half::f16],
1518    coeff: &[f32],
1519    proto_w: usize,
1520    y0: usize,
1521    x0: usize,
1522    roi_h: usize,
1523    roi_w: usize,
1524    num_protos: usize,
1525) -> ndarray::Array3<u8> {
1526    use core::arch::x86_64::{
1527        _mm256_castps256_ps128, _mm256_cvtph_ps, _mm256_extractf128_ps, _mm256_fmadd_ps,
1528        _mm256_loadu_ps, _mm256_setzero_ps, _mm_add_ps, _mm_cvtss_f32, _mm_hadd_ps,
1529        _mm_loadu_si128,
1530    };
1531
1532    let stride_y = proto_w * num_protos;
1533    let chunks8 = num_protos / 8;
1534    let mut mask_buf = vec![0u8; roi_h * roi_w];
1535
1536    for y in 0..roi_h {
1537        let row_base = (y0 + y) * stride_y + x0 * num_protos;
1538        let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1539        for (x, out_px) in out_row.iter_mut().enumerate() {
1540            let base = row_base + x * num_protos;
1541            let mut acc_v = _mm256_setzero_ps();
1542            let mut k = 0;
1543            for _ in 0..chunks8 {
1544                let p_ptr = protos
1545                    .as_ptr()
1546                    .add(base + k)
1547                    .cast::<core::arch::x86_64::__m128i>();
1548                let raw = _mm_loadu_si128(p_ptr);
1549                let widened = _mm256_cvtph_ps(raw);
1550                let coeffs_v = _mm256_loadu_ps(coeff.as_ptr().add(k));
1551                acc_v = _mm256_fmadd_ps(coeffs_v, widened, acc_v);
1552                k += 8;
1553            }
1554            // Horizontal reduce 8 → 1.
1555            let lo = _mm256_castps256_ps128(acc_v);
1556            let hi = _mm256_extractf128_ps::<1>(acc_v);
1557            let sum4 = _mm_add_ps(lo, hi);
1558            let sum2 = _mm_hadd_ps(sum4, sum4);
1559            let sum1 = _mm_hadd_ps(sum2, sum2);
1560            let mut acc = _mm_cvtss_f32(sum1);
1561
1562            // Scalar tail for num_protos % 8.
1563            while k < num_protos {
1564                acc += coeff[k] * protos[base + k].to_f32();
1565                k += 1;
1566            }
1567
1568            if acc > 0.0 {
1569                *out_px = 255;
1570            }
1571        }
1572    }
1573    ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1574        .expect("mask_buf length matches roi_h * roi_w")
1575}
1576
1577/// i8 protos (with quant) × f32 coefficients → sign threshold → binary {0, 255}.
1578/// Fallback for per-channel quant or mixed-dtype cases where the i8×i8 fast path
1579/// doesn't apply.
1580#[allow(clippy::too_many_arguments)]
1581fn fused_dequant_dot_sign_i8_slice(
1582    protos: &[i8],
1583    coeff: &[f32],
1584    quant: &edgefirst_tensor::Quantization,
1585    _proto_h: usize,
1586    proto_w: usize,
1587    y0: usize,
1588    x0: usize,
1589    roi_h: usize,
1590    roi_w: usize,
1591    num_protos: usize,
1592) -> crate::Result<ndarray::Array3<u8>> {
1593    use edgefirst_tensor::QuantMode;
1594    let stride_y = proto_w * num_protos;
1595
1596    // Precompute scaled coefficients + zp_offset (same as the old sigmoid kernel).
1597    let mut stack_scratch = [0.0_f32; 64];
1598    let mut heap_scratch: Vec<f32>;
1599    let scaled_coeff: &mut [f32] = if num_protos <= stack_scratch.len() {
1600        &mut stack_scratch[..num_protos]
1601    } else {
1602        heap_scratch = vec![0.0_f32; num_protos];
1603        heap_scratch.as_mut_slice()
1604    };
1605    let zp_offset: f32;
1606    match quant.mode() {
1607        QuantMode::PerTensorSymmetric { scale } => {
1608            for k in 0..num_protos {
1609                scaled_coeff[k] = coeff[k] * scale;
1610            }
1611            zp_offset = 0.0;
1612        }
1613        QuantMode::PerTensor { scale, zero_point } => {
1614            for k in 0..num_protos {
1615                scaled_coeff[k] = coeff[k] * scale;
1616            }
1617            zp_offset = zero_point as f32 * scaled_coeff.iter().take(num_protos).sum::<f32>();
1618        }
1619        QuantMode::PerChannelSymmetric { scales, axis } => {
1620            if axis != 2 {
1621                return Err(crate::Error::NotSupported(format!(
1622                    "per-channel quantization on axis {axis} not supported \
1623                     (only channel axis 2 is implemented on this kernel)"
1624                )));
1625            }
1626            for k in 0..num_protos {
1627                scaled_coeff[k] = coeff[k] * scales[k];
1628            }
1629            zp_offset = 0.0;
1630        }
1631        QuantMode::PerChannel {
1632            scales,
1633            zero_points,
1634            axis,
1635        } => {
1636            if axis != 2 {
1637                return Err(crate::Error::NotSupported(format!(
1638                    "per-channel quantization on axis {axis} not supported \
1639                     (only channel axis 2 is implemented on this kernel)"
1640                )));
1641            }
1642            for k in 0..num_protos {
1643                scaled_coeff[k] = coeff[k] * scales[k];
1644            }
1645            zp_offset = (0..num_protos)
1646                .map(|k| scaled_coeff[k] * zero_points[k] as f32)
1647                .sum();
1648        }
1649    }
1650
1651    let mut mask_buf = vec![0u8; roi_h * roi_w];
1652    for y in 0..roi_h {
1653        let row_base = (y0 + y) * stride_y + (x0) * num_protos;
1654        let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1655        for (x, out_px) in out_row.iter_mut().enumerate() {
1656            let base = row_base + x * num_protos;
1657            let mut acc = 0.0_f32;
1658            let mut k = 0;
1659            let chunks = num_protos / 4;
1660            for _ in 0..chunks {
1661                let p0 = protos[base + k] as f32;
1662                let p1 = protos[base + k + 1] as f32;
1663                let p2 = protos[base + k + 2] as f32;
1664                let p3 = protos[base + k + 3] as f32;
1665                acc += scaled_coeff[k] * p0
1666                    + scaled_coeff[k + 1] * p1
1667                    + scaled_coeff[k + 2] * p2
1668                    + scaled_coeff[k + 3] * p3;
1669                k += 4;
1670            }
1671            while k < num_protos {
1672                acc += scaled_coeff[k] * protos[base + k] as f32;
1673                k += 1;
1674            }
1675            if acc > zp_offset {
1676                *out_px = 255;
1677            }
1678        }
1679    }
1680    Ok(ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1681        .expect("mask_buf length matches roi_h * roi_w"))
1682}
1683
1684#[allow(clippy::too_many_arguments)]
1685fn scaled_segmentations_f32_slice(
1686    detect: &[crate::DetectBox],
1687    coeff_all: &[f32],
1688    protos: &[f32],
1689    proto_h: usize,
1690    proto_w: usize,
1691    num_protos: usize,
1692    letterbox: Option<[f32; 4]>,
1693    width: u32,
1694    height: u32,
1695) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1696    scaled_run(
1697        detect,
1698        coeff_all,
1699        protos,
1700        proto_h,
1701        proto_w,
1702        num_protos,
1703        letterbox,
1704        width,
1705        height,
1706        1.0,
1707        |p, _| *p,
1708    )
1709}
1710
1711#[allow(clippy::too_many_arguments)]
1712fn scaled_segmentations_f16_slice(
1713    detect: &[crate::DetectBox],
1714    coeff_all: &[f32],
1715    protos: &[half::f16],
1716    proto_h: usize,
1717    proto_w: usize,
1718    num_protos: usize,
1719    letterbox: Option<[f32; 4]>,
1720    width: u32,
1721    height: u32,
1722) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1723    scaled_run(
1724        detect,
1725        coeff_all,
1726        protos,
1727        proto_h,
1728        proto_w,
1729        num_protos,
1730        letterbox,
1731        width,
1732        height,
1733        1.0,
1734        |p: &half::f16, _| p.to_f32(),
1735    )
1736}
1737
1738#[allow(clippy::too_many_arguments)]
1739fn scaled_segmentations_i8_slice(
1740    detect: &[crate::DetectBox],
1741    coeff_all: &[f32],
1742    protos: &[i8],
1743    proto_h: usize,
1744    proto_w: usize,
1745    num_protos: usize,
1746    quant: &edgefirst_tensor::Quantization,
1747    letterbox: Option<[f32; 4]>,
1748    width: u32,
1749    height: u32,
1750) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1751    use edgefirst_tensor::QuantMode;
1752    // Only per-tensor quantization supported on the scaled-path CPU kernel
1753    // today. Per-channel fits naturally into a future extension (would need
1754    // per-channel scaled coefficients in scaled_run's dot-product precompute).
1755    let (scale, zp) = match quant.mode() {
1756        QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
1757        QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
1758        QuantMode::PerChannel { axis, .. } | QuantMode::PerChannelSymmetric { axis, .. } => {
1759            return Err(crate::Error::NotSupported(format!(
1760                "per-channel quantization (axis={axis}) on scaled seg path \
1761                 not yet supported"
1762            )));
1763        }
1764    };
1765    scaled_run(
1766        detect,
1767        coeff_all,
1768        protos,
1769        proto_h,
1770        proto_w,
1771        num_protos,
1772        letterbox,
1773        width,
1774        height,
1775        scale,
1776        move |p: &i8, _| *p as f32 - zp,
1777    )
1778}
1779
1780// =============================================================================
1781// Integer-domain kernel: i8 coefficients × i8 protos → i32 → sign threshold.
1782//
1783// Eliminates all f32 conversion by working directly with raw quantized values.
1784// The math:
1785//   sign(dot(dequant(coeff), dequant(proto)))
1786//   = sign(Σ (c_raw - zp_c) · (p_raw - zp_p))
1787//   = sign(Σ c_raw·p_raw - zp_c·Σp_raw - zp_p·Σc_raw + N·zp_c·zp_p)
1788//   = sign(sdot(c_raw, p_raw) - zp_c·proto_sum[pixel] - bias_per_det)
1789//
1790// where bias_per_det = zp_p·Σc_raw - N·zp_c·zp_p  (precomputed once per det).
1791// =============================================================================
1792
1793/// Compute i8×i8 dot product (32 elements) → i32.
1794/// Platform-agnostic scalar fallback.
1795#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
1796#[inline(always)]
1797fn dot_i8_scalar(coeff: &[i8], proto: &[i8], n: usize) -> i32 {
1798    let mut acc: i32 = 0;
1799    let chunks = n / 4;
1800    let mut k = 0;
1801    for _ in 0..chunks {
1802        acc += coeff[k] as i32 * proto[k] as i32
1803            + coeff[k + 1] as i32 * proto[k + 1] as i32
1804            + coeff[k + 2] as i32 * proto[k + 2] as i32
1805            + coeff[k + 3] as i32 * proto[k + 3] as i32;
1806        k += 4;
1807    }
1808    while k < n {
1809        acc += coeff[k] as i32 * proto[k] as i32;
1810        k += 1;
1811    }
1812    acc
1813}
1814
1815/// NEON i8×i8→i32 dot product using smull+sadalp (works on ALL aarch64, A53+).
1816#[cfg(target_arch = "aarch64")]
1817#[inline(always)]
1818unsafe fn dot_i8_neon_base(coeff: *const i8, proto: *const i8, n: usize) -> i32 {
1819    use std::arch::aarch64::*;
1820    let mut acc = vdupq_n_s32(0);
1821    let full_chunks = n / 16;
1822    let mut offset = 0usize;
1823    for _ in 0..full_chunks {
1824        let c = vld1q_s8(coeff.add(offset));
1825        let p = vld1q_s8(proto.add(offset));
1826        // Widening multiply + pairwise accumulate (all aarch64).
1827        let lo = vmull_s8(vget_low_s8(c), vget_low_s8(p));
1828        let hi = vmull_high_s8(c, p);
1829        acc = vpadalq_s16(acc, lo);
1830        acc = vpadalq_s16(acc, hi);
1831        offset += 16;
1832    }
1833    // Handle remaining elements (for num_protos=32, full_chunks=2, remainder=0)
1834    let remainder = n - offset;
1835    if remainder >= 8 {
1836        let c = vld1_s8(coeff.add(offset));
1837        let p = vld1_s8(proto.add(offset));
1838        let prod = vmull_s8(c, p);
1839        acc = vpadalq_s16(acc, prod);
1840        offset += 8;
1841    }
1842    let mut scalar_acc = vaddvq_s32(acc);
1843    while offset < n {
1844        scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1845        offset += 1;
1846    }
1847    scalar_acc
1848}
1849
1850/// NEON i8×i8→i32 dot product using sdot (ARMv8.2-A dotprod, A55+).
1851/// Each `sdot` processes 16 i8 lanes → 4 i32 partial sums in one instruction,
1852/// replacing the 3-instruction smull+smull2+sadalp sequence.
1853#[cfg(target_arch = "aarch64")]
1854#[inline(always)]
1855unsafe fn dot_i8_neon_dotprod(coeff: *const i8, proto: *const i8, n: usize) -> i32 {
1856    use std::arch::aarch64::*;
1857    let mut acc = vdupq_n_s32(0);
1858    let full_chunks = n / 16;
1859    let mut offset = 0usize;
1860    for _ in 0..full_chunks {
1861        let c = vld1q_s8(coeff.add(offset));
1862        let p = vld1q_s8(proto.add(offset));
1863        // Enable dotprod extension locally so the assembler accepts sdot
1864        // even when compiling for baseline aarch64 (A53). At runtime we only
1865        // reach this path when HWCAP confirms dotprod support.
1866        let result: int32x4_t;
1867        core::arch::asm!(
1868            ".arch_extension dotprod",
1869            "sdot {acc:v}.4s, {a:v}.16b, {b:v}.16b",
1870            acc = inout(vreg) acc => result,
1871            a = in(vreg) c,
1872            b = in(vreg) p,
1873            options(pure, nomem, nostack),
1874        );
1875        acc = result;
1876        offset += 16;
1877    }
1878    let mut scalar_acc = vaddvq_s32(acc);
1879    // Tail: handle remainder (unlikely for num_protos=32, but correct)
1880    while offset < n {
1881        scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1882        offset += 1;
1883    }
1884    scalar_acc
1885}
1886
1887/// Compute i16×i8 dot product → i32. Platform-agnostic scalar fallback.
1888#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
1889#[inline(always)]
1890fn dot_i16_i8_scalar(coeff: &[i16], proto: &[i8], n: usize) -> i32 {
1891    let mut acc: i32 = 0;
1892    let chunks = n / 4;
1893    let mut k = 0;
1894    for _ in 0..chunks {
1895        acc += coeff[k] as i32 * proto[k] as i32
1896            + coeff[k + 1] as i32 * proto[k + 1] as i32
1897            + coeff[k + 2] as i32 * proto[k + 2] as i32
1898            + coeff[k + 3] as i32 * proto[k + 3] as i32;
1899        k += 4;
1900    }
1901    while k < n {
1902        acc += coeff[k] as i32 * proto[k] as i32;
1903        k += 1;
1904    }
1905    acc
1906}
1907
1908/// NEON i16×i8→i32 dot product using widening multiply-accumulate.
1909/// Processes 8 elements per iteration (vs 16 for i8×i8 dotprod).
1910#[cfg(target_arch = "aarch64")]
1911#[inline(always)]
1912unsafe fn dot_i16_i8_neon(coeff: *const i16, proto: *const i8, n: usize) -> i32 {
1913    use std::arch::aarch64::*;
1914    let mut acc = vdupq_n_s32(0);
1915    let full_chunks = n / 8;
1916    let mut offset = 0usize;
1917    for _ in 0..full_chunks {
1918        let c = vld1q_s16(coeff.add(offset));
1919        let p_raw = vld1_s8(proto.add(offset));
1920        let p = vmovl_s8(p_raw);
1921        acc = vmlal_s16(acc, vget_low_s16(c), vget_low_s16(p));
1922        acc = vmlal_high_s16(acc, c, p);
1923        offset += 8;
1924    }
1925    let mut scalar_acc = vaddvq_s32(acc);
1926    while offset < n {
1927        scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1928        offset += 1;
1929    }
1930    scalar_acc
1931}
1932
1933/// Compute the logit grid using the dotprod (sdot) path.
1934/// Separated into its own function so the compiler inlines the sdot asm fully.
1935#[cfg(target_arch = "aarch64")]
1936#[inline(always)]
1937#[allow(clippy::too_many_arguments)]
1938fn compute_logits_dotprod(
1939    logits: &mut [i32],
1940    coeff: &[i8],
1941    protos: &[i8],
1942    proto_sums: &[i32],
1943    proto_w: usize,
1944    proto_x0: usize,
1945    proto_y0: usize,
1946    roi_w: usize,
1947    roi_h: usize,
1948    stride_y: usize,
1949    num_protos: usize,
1950    zp_c: i32,
1951    bias: i32,
1952) {
1953    for ly_idx in 0..roi_h {
1954        let py = proto_y0 + ly_idx;
1955        let row_base = py * stride_y + proto_x0 * num_protos;
1956        for lx_idx in 0..roi_w {
1957            let pix_base = row_base + lx_idx * num_protos;
1958            let proto_px = &protos[pix_base..pix_base + num_protos];
1959            let raw_dot =
1960                unsafe { dot_i8_neon_dotprod(coeff.as_ptr(), proto_px.as_ptr(), num_protos) };
1961            let correction = if zp_c != 0 {
1962                zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
1963            } else {
1964                0
1965            };
1966            logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
1967        }
1968    }
1969}
1970
1971/// Compute the logit grid using the base NEON path (smull+sadalp).
1972/// Separated into its own function so the compiler inlines the NEON code fully.
1973#[cfg(target_arch = "aarch64")]
1974#[inline(always)]
1975#[allow(clippy::too_many_arguments)]
1976fn compute_logits_base(
1977    logits: &mut [i32],
1978    coeff: &[i8],
1979    protos: &[i8],
1980    proto_sums: &[i32],
1981    proto_w: usize,
1982    proto_x0: usize,
1983    proto_y0: usize,
1984    roi_w: usize,
1985    roi_h: usize,
1986    stride_y: usize,
1987    num_protos: usize,
1988    zp_c: i32,
1989    bias: i32,
1990) {
1991    for ly_idx in 0..roi_h {
1992        let py = proto_y0 + ly_idx;
1993        let row_base = py * stride_y + proto_x0 * num_protos;
1994        for lx_idx in 0..roi_w {
1995            let pix_base = row_base + lx_idx * num_protos;
1996            let proto_px = &protos[pix_base..pix_base + num_protos];
1997            let raw_dot =
1998                unsafe { dot_i8_neon_base(coeff.as_ptr(), proto_px.as_ptr(), num_protos) };
1999            let correction = if zp_c != 0 {
2000                zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2001            } else {
2002                0
2003            };
2004            logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2005        }
2006    }
2007}
2008
2009#[allow(clippy::too_many_arguments)]
2010fn scaled_segmentations_i8_i8(
2011    detect: &[crate::DetectBox],
2012    coeff_all: &[i8],
2013    coeff_quant: &edgefirst_tensor::Quantization,
2014    protos: &[i8],
2015    proto_quant: &edgefirst_tensor::Quantization,
2016    proto_h: usize,
2017    proto_w: usize,
2018    num_protos: usize,
2019    letterbox: Option<[f32; 4]>,
2020    width: u32,
2021    height: u32,
2022    layout: edgefirst_decoder::ProtoLayout,
2023) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
2024    use edgefirst_tensor::QuantMode;
2025
2026    let _span = tracing::trace_span!(
2027        "image.materialize_masks.kernel_i8_scaled",
2028        n = detect.len(),
2029        proto_h,
2030        proto_w,
2031        num_protos,
2032        width,
2033        height,
2034        ?layout,
2035    )
2036    .entered();
2037
2038    let zp_c: i32 = match coeff_quant.mode() {
2039        QuantMode::PerTensor { zero_point, .. } => zero_point,
2040        QuantMode::PerTensorSymmetric { .. } => 0,
2041        _ => {
2042            return Err(crate::Error::NotSupported(
2043                "per-channel coeff quantization not supported".into(),
2044            ))
2045        }
2046    };
2047    let zp_p: i32 = match proto_quant.mode() {
2048        QuantMode::PerTensor { zero_point, .. } => zero_point,
2049        QuantMode::PerTensorSymmetric { .. } => 0,
2050        _ => {
2051            return Err(crate::Error::NotSupported(
2052                "per-channel proto quantization not supported".into(),
2053            ))
2054        }
2055    };
2056
2057    let (lx0, lw, ly0, lh) = match letterbox {
2058        Some([lx0, ly0, lx1, ly1]) => {
2059            let lw = (lx1 - lx0).max(f32::EPSILON);
2060            let lh = (ly1 - ly0).max(f32::EPSILON);
2061            (lx0, lw, ly0, lh)
2062        }
2063        None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
2064    };
2065    let out_w = width as usize;
2066    let out_h = height as usize;
2067    let hw = proto_h * proto_w;
2068
2069    // Precompute proto_sum for the entire proto tensor (zero-point correction).
2070    let proto_sums: Vec<i32> = if zp_c != 0 {
2071        match layout {
2072            edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
2073                .map(|px_idx| {
2074                    let base = px_idx * num_protos;
2075                    let mut s: i32 = 0;
2076                    for k in 0..num_protos {
2077                        s += protos[base + k] as i32;
2078                    }
2079                    s
2080                })
2081                .collect(),
2082            edgefirst_decoder::ProtoLayout::Nchw => {
2083                let mut sums = vec![0i32; hw];
2084                for c in 0..num_protos {
2085                    let plane = &protos[c * hw..];
2086                    for (px, s) in sums.iter_mut().enumerate() {
2087                        *s += plane[px] as i32;
2088                    }
2089                }
2090                sums
2091            }
2092        }
2093    } else {
2094        Vec::new()
2095    };
2096
2097    // Detect dotprod support once, outside the hot loop.
2098    #[cfg(target_arch = "aarch64")]
2099    let use_dotprod = std::arch::is_aarch64_feature_detected!("dotprod");
2100
2101    // For NHWC layout, stride for row navigation.
2102    let stride_y = proto_w * num_protos;
2103
2104    detect
2105        .par_iter()
2106        .enumerate()
2107        .map(|(i, det)| {
2108            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
2109            let bbox = det.bbox.to_canonical();
2110            let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
2111            let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
2112            let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
2113            let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
2114            let px0 = (xmin * out_w as f32).round() as usize;
2115            let py0 = (ymin * out_h as f32).round() as usize;
2116            let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
2117            let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
2118            let bbox_w = px1.saturating_sub(px0).max(1);
2119            let bbox_h = py1.saturating_sub(py0).max(1);
2120
2121            // Map output bbox → proto ROI.
2122            let sample_x_at = |px: f32| -> f32 {
2123                let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
2124                model_x_norm * proto_w as f32 - 0.5
2125            };
2126            let sample_y_at = |py: f32| -> f32 {
2127                let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
2128                model_y_norm * proto_h as f32 - 0.5
2129            };
2130            let s_x_min = sample_x_at(px0 as f32);
2131            let s_x_max = sample_x_at((px1 as f32) - 1.0);
2132            let s_y_min = sample_y_at(py0 as f32);
2133            let s_y_max = sample_y_at((py1 as f32) - 1.0);
2134            let proto_x0 = (s_x_min.floor() as isize)
2135                .max(0)
2136                .min(proto_w.saturating_sub(1) as isize) as usize;
2137            let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2138            let proto_y0 = (s_y_min.floor() as isize)
2139                .max(0)
2140                .min(proto_h.saturating_sub(1) as isize) as usize;
2141            let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2142            let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2143            let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2144
2145            // Per-detection bias.
2146            let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
2147            let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
2148
2149            // Step 2: Compute i32 logits at each proto-ROI pixel.
2150            let mut logits = vec![0_i32; roi_h * roi_w];
2151            match layout {
2152                edgefirst_decoder::ProtoLayout::Nhwc => {
2153                    #[cfg(target_arch = "aarch64")]
2154                    {
2155                        if use_dotprod {
2156                            compute_logits_dotprod(
2157                                &mut logits,
2158                                coeff,
2159                                protos,
2160                                &proto_sums,
2161                                proto_w,
2162                                proto_x0,
2163                                proto_y0,
2164                                roi_w,
2165                                roi_h,
2166                                stride_y,
2167                                num_protos,
2168                                zp_c,
2169                                bias,
2170                            );
2171                        } else {
2172                            compute_logits_base(
2173                                &mut logits,
2174                                coeff,
2175                                protos,
2176                                &proto_sums,
2177                                proto_w,
2178                                proto_x0,
2179                                proto_y0,
2180                                roi_w,
2181                                roi_h,
2182                                stride_y,
2183                                num_protos,
2184                                zp_c,
2185                                bias,
2186                            );
2187                        }
2188                    }
2189                    #[cfg(not(target_arch = "aarch64"))]
2190                    {
2191                        for ly_idx in 0..roi_h {
2192                            let py = proto_y0 + ly_idx;
2193                            let row_base = py * stride_y + proto_x0 * num_protos;
2194                            for lx_idx in 0..roi_w {
2195                                let pix_base = row_base + lx_idx * num_protos;
2196                                let proto_px = &protos[pix_base..pix_base + num_protos];
2197                                let raw_dot = dot_i8_scalar(coeff, proto_px, num_protos);
2198                                let correction = if zp_c != 0 {
2199                                    zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2200                                } else {
2201                                    0
2202                                };
2203                                logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2204                            }
2205                        }
2206                    }
2207                }
2208                edgefirst_decoder::ProtoLayout::Nchw => {
2209                    // Channel-major accumulation: contiguous reads per channel plane.
2210                    for c in 0..num_protos {
2211                        let plane = &protos[c * hw..];
2212                        let coeff_c = coeff[c] as i32;
2213                        for ly_idx in 0..roi_h {
2214                            let py = proto_y0 + ly_idx;
2215                            let row_start = py * proto_w + proto_x0;
2216                            let out_row_start = ly_idx * roi_w;
2217                            for lx_idx in 0..roi_w {
2218                                logits[out_row_start + lx_idx] +=
2219                                    coeff_c * plane[row_start + lx_idx] as i32;
2220                            }
2221                        }
2222                    }
2223                    // Apply zero-point correction and per-detection bias.
2224                    for ly_idx in 0..roi_h {
2225                        let py = proto_y0 + ly_idx;
2226                        for lx_idx in 0..roi_w {
2227                            let idx = ly_idx * roi_w + lx_idx;
2228                            let correction = if zp_c != 0 {
2229                                zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2230                            } else {
2231                                0
2232                            };
2233                            logits[idx] -= correction + bias;
2234                        }
2235                    }
2236                }
2237            }
2238
2239            // Step 3: Bilinear upsample i32 logits → binary mask with
2240            // sign-shortcut (skip interpolation when all 4 neighbors agree).
2241            let roi_last_x = roi_w.saturating_sub(1);
2242            let roi_last_y = roi_h.saturating_sub(1);
2243
2244            // X-coordinate LUT with fixed-point fraction (scale 1024).
2245            const FRAC_BITS: i32 = 10;
2246            const FRAC_SCALE: i32 = 1 << FRAC_BITS; // 1024
2247            let x_coords: Vec<(usize, usize, i32)> = (0..bbox_w)
2248                .map(|xi| {
2249                    let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2250                    let x_floor = sample_x.floor();
2251                    let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as usize;
2252                    let x_hi = (x_lo + 1).min(roi_w - 1);
2253                    let x_frac = ((sample_x - x_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2254                    (x_lo, x_hi, x_frac)
2255                })
2256                .collect();
2257
2258            let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2259            for yi in 0..bbox_h {
2260                let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2261                let y_floor = sample_y.floor();
2262                let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2263                let y_hi = (y_lo + 1).min(roi_h - 1);
2264                let y_frac = ((sample_y - y_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2265                let y_frac_inv = FRAC_SCALE - y_frac;
2266                let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2267                let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2268                let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2269
2270                for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2271                    let tl = row_lo[x_lo];
2272                    let tr = row_lo[x_hi];
2273                    let bl = row_hi[x_lo];
2274                    let br = row_hi[x_hi];
2275
2276                    // Sign-shortcut: if all 4 corners have the same sign,
2277                    // the bilinear interpolation (positive-weight combination)
2278                    // preserves that sign. Skip arithmetic for ~80% of pixels.
2279                    if (tl & tr & bl & br) < 0 {
2280                        // All negative → output 0 (already zero).
2281                        continue;
2282                    }
2283                    if tl > 0 && tr > 0 && bl > 0 && br > 0 {
2284                        // All strictly positive → output 255.
2285                        out_row[xi] = 255;
2286                        continue;
2287                    }
2288
2289                    // Boundary pixel: fixed-point bilinear in i64.
2290                    let x_frac_inv = FRAC_SCALE - x_frac;
2291                    let l0 = tl as i64 * x_frac_inv as i64 + tr as i64 * x_frac as i64;
2292                    let l1 = bl as i64 * x_frac_inv as i64 + br as i64 * x_frac as i64;
2293                    let logit = l0 * y_frac_inv as i64 + l1 * y_frac as i64;
2294                    out_row[xi] = if logit > 0 { 255 } else { 0 };
2295                }
2296            }
2297
2298            let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2299                .expect("tile_buf length matches bbox_h * bbox_w");
2300            Ok(edgefirst_decoder::Segmentation {
2301                xmin,
2302                ymin,
2303                xmax,
2304                ymax,
2305                segmentation: tile,
2306            })
2307        })
2308        .collect()
2309}
2310
2311#[allow(clippy::too_many_arguments)]
2312fn scaled_segmentations_i16_i8(
2313    detect: &[crate::DetectBox],
2314    coeff_all: &[i16],
2315    coeff_quant: &edgefirst_tensor::Quantization,
2316    protos: &[i8],
2317    proto_quant: &edgefirst_tensor::Quantization,
2318    proto_h: usize,
2319    proto_w: usize,
2320    num_protos: usize,
2321    letterbox: Option<[f32; 4]>,
2322    width: u32,
2323    height: u32,
2324    layout: edgefirst_decoder::ProtoLayout,
2325) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
2326    use edgefirst_tensor::QuantMode;
2327
2328    let _span = tracing::trace_span!(
2329        "image.materialize_masks.kernel_i16xi8_scaled",
2330        n = detect.len(),
2331        proto_h,
2332        proto_w,
2333        num_protos,
2334        width,
2335        height,
2336        ?layout,
2337    )
2338    .entered();
2339
2340    let zp_c: i32 = match coeff_quant.mode() {
2341        QuantMode::PerTensor { zero_point, .. } => zero_point,
2342        QuantMode::PerTensorSymmetric { .. } => 0,
2343        _ => {
2344            return Err(crate::Error::NotSupported(
2345                "per-channel coeff quantization not supported".into(),
2346            ))
2347        }
2348    };
2349    let zp_p: i32 = match proto_quant.mode() {
2350        QuantMode::PerTensor { zero_point, .. } => zero_point,
2351        QuantMode::PerTensorSymmetric { .. } => 0,
2352        _ => {
2353            return Err(crate::Error::NotSupported(
2354                "per-channel proto quantization not supported".into(),
2355            ))
2356        }
2357    };
2358
2359    let (lx0, lw, ly0, lh) = match letterbox {
2360        Some([lx0, ly0, lx1, ly1]) => {
2361            let lw = (lx1 - lx0).max(f32::EPSILON);
2362            let lh = (ly1 - ly0).max(f32::EPSILON);
2363            (lx0, lw, ly0, lh)
2364        }
2365        None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
2366    };
2367    let out_w = width as usize;
2368    let out_h = height as usize;
2369    let hw = proto_h * proto_w;
2370
2371    // Precompute proto_sum for the entire proto tensor (zero-point correction).
2372    let proto_sums: Vec<i32> = if zp_c != 0 {
2373        match layout {
2374            edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
2375                .map(|px_idx| {
2376                    let base = px_idx * num_protos;
2377                    let mut s: i32 = 0;
2378                    for k in 0..num_protos {
2379                        s += protos[base + k] as i32;
2380                    }
2381                    s
2382                })
2383                .collect(),
2384            edgefirst_decoder::ProtoLayout::Nchw => {
2385                let mut sums = vec![0i32; hw];
2386                for c in 0..num_protos {
2387                    let plane = &protos[c * hw..];
2388                    for (px, s) in sums.iter_mut().enumerate() {
2389                        *s += plane[px] as i32;
2390                    }
2391                }
2392                sums
2393            }
2394        }
2395    } else {
2396        Vec::new()
2397    };
2398
2399    // For NHWC layout, stride for row navigation.
2400    let stride_y = proto_w * num_protos;
2401
2402    detect
2403        .par_iter()
2404        .enumerate()
2405        .map(|(i, det)| {
2406            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
2407            let bbox = det.bbox.to_canonical();
2408            let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
2409            let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
2410            let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
2411            let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
2412            let px0 = (xmin * out_w as f32).round() as usize;
2413            let py0 = (ymin * out_h as f32).round() as usize;
2414            let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
2415            let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
2416            let bbox_w = px1.saturating_sub(px0).max(1);
2417            let bbox_h = py1.saturating_sub(py0).max(1);
2418
2419            // Map output bbox → proto ROI.
2420            let sample_x_at = |px: f32| -> f32 {
2421                let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
2422                model_x_norm * proto_w as f32 - 0.5
2423            };
2424            let sample_y_at = |py: f32| -> f32 {
2425                let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
2426                model_y_norm * proto_h as f32 - 0.5
2427            };
2428            let s_x_min = sample_x_at(px0 as f32);
2429            let s_x_max = sample_x_at((px1 as f32) - 1.0);
2430            let s_y_min = sample_y_at(py0 as f32);
2431            let s_y_max = sample_y_at((py1 as f32) - 1.0);
2432            let proto_x0 = (s_x_min.floor() as isize)
2433                .max(0)
2434                .min(proto_w.saturating_sub(1) as isize) as usize;
2435            let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2436            let proto_y0 = (s_y_min.floor() as isize)
2437                .max(0)
2438                .min(proto_h.saturating_sub(1) as isize) as usize;
2439            let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2440            let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2441            let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2442
2443            // Per-detection bias.
2444            let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
2445            let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
2446
2447            // Step 2: Compute i32 logits at each proto-ROI pixel.
2448            let mut logits = vec![0_i32; roi_h * roi_w];
2449            match layout {
2450                edgefirst_decoder::ProtoLayout::Nhwc => {
2451                    #[cfg(target_arch = "aarch64")]
2452                    {
2453                        for ly_idx in 0..roi_h {
2454                            let py = proto_y0 + ly_idx;
2455                            let row_base = py * stride_y + proto_x0 * num_protos;
2456                            for lx_idx in 0..roi_w {
2457                                let pix_base = row_base + lx_idx * num_protos;
2458                                let proto_px = &protos[pix_base..pix_base + num_protos];
2459                                let raw_dot = unsafe {
2460                                    dot_i16_i8_neon(coeff.as_ptr(), proto_px.as_ptr(), num_protos)
2461                                };
2462                                let correction = if zp_c != 0 {
2463                                    zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2464                                } else {
2465                                    0
2466                                };
2467                                logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2468                            }
2469                        }
2470                    }
2471                    #[cfg(not(target_arch = "aarch64"))]
2472                    {
2473                        for ly_idx in 0..roi_h {
2474                            let py = proto_y0 + ly_idx;
2475                            let row_base = py * stride_y + proto_x0 * num_protos;
2476                            for lx_idx in 0..roi_w {
2477                                let pix_base = row_base + lx_idx * num_protos;
2478                                let proto_px = &protos[pix_base..pix_base + num_protos];
2479                                let raw_dot = dot_i16_i8_scalar(coeff, proto_px, num_protos);
2480                                let correction = if zp_c != 0 {
2481                                    zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2482                                } else {
2483                                    0
2484                                };
2485                                logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2486                            }
2487                        }
2488                    }
2489                }
2490                edgefirst_decoder::ProtoLayout::Nchw => {
2491                    // Channel-major accumulation: contiguous reads per channel plane.
2492                    for c in 0..num_protos {
2493                        let plane = &protos[c * hw..];
2494                        let coeff_c = coeff[c] as i32;
2495                        for ly_idx in 0..roi_h {
2496                            let py = proto_y0 + ly_idx;
2497                            let row_start = py * proto_w + proto_x0;
2498                            let out_row_start = ly_idx * roi_w;
2499                            for lx_idx in 0..roi_w {
2500                                logits[out_row_start + lx_idx] +=
2501                                    coeff_c * plane[row_start + lx_idx] as i32;
2502                            }
2503                        }
2504                    }
2505                    // Apply zero-point correction and per-detection bias.
2506                    for ly_idx in 0..roi_h {
2507                        let py = proto_y0 + ly_idx;
2508                        for lx_idx in 0..roi_w {
2509                            let idx = ly_idx * roi_w + lx_idx;
2510                            let correction = if zp_c != 0 {
2511                                zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2512                            } else {
2513                                0
2514                            };
2515                            logits[idx] -= correction + bias;
2516                        }
2517                    }
2518                }
2519            }
2520
2521            // Step 3: Bilinear upsample i32 logits → binary mask with
2522            // sign-shortcut (skip interpolation when all 4 neighbors agree).
2523            let roi_last_x = roi_w.saturating_sub(1);
2524            let roi_last_y = roi_h.saturating_sub(1);
2525
2526            // X-coordinate LUT with fixed-point fraction (scale 1024).
2527            const FRAC_BITS: i32 = 10;
2528            const FRAC_SCALE: i32 = 1 << FRAC_BITS; // 1024
2529            let x_coords: Vec<(usize, usize, i32)> = (0..bbox_w)
2530                .map(|xi| {
2531                    let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2532                    let x_floor = sample_x.floor();
2533                    let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as usize;
2534                    let x_hi = (x_lo + 1).min(roi_w - 1);
2535                    let x_frac = ((sample_x - x_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2536                    (x_lo, x_hi, x_frac)
2537                })
2538                .collect();
2539
2540            let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2541            for yi in 0..bbox_h {
2542                let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2543                let y_floor = sample_y.floor();
2544                let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2545                let y_hi = (y_lo + 1).min(roi_h - 1);
2546                let y_frac = ((sample_y - y_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2547                let y_frac_inv = FRAC_SCALE - y_frac;
2548                let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2549                let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2550                let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2551
2552                for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2553                    let tl = row_lo[x_lo];
2554                    let tr = row_lo[x_hi];
2555                    let bl = row_hi[x_lo];
2556                    let br = row_hi[x_hi];
2557
2558                    // Sign-shortcut: if all 4 corners have the same sign,
2559                    // the bilinear interpolation (positive-weight combination)
2560                    // preserves that sign. Skip arithmetic for ~80% of pixels.
2561                    if (tl & tr & bl & br) < 0 {
2562                        // All negative → output 0 (already zero).
2563                        continue;
2564                    }
2565                    if tl > 0 && tr > 0 && bl > 0 && br > 0 {
2566                        // All strictly positive → output 255.
2567                        out_row[xi] = 255;
2568                        continue;
2569                    }
2570
2571                    // Boundary pixel: fixed-point bilinear in i64.
2572                    let x_frac_inv = FRAC_SCALE - x_frac;
2573                    let l0 = tl as i64 * x_frac_inv as i64 + tr as i64 * x_frac as i64;
2574                    let l1 = bl as i64 * x_frac_inv as i64 + br as i64 * x_frac as i64;
2575                    let logit = l0 * y_frac_inv as i64 + l1 * y_frac as i64;
2576                    out_row[xi] = if logit > 0 { 255 } else { 0 };
2577                }
2578            }
2579
2580            let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2581                .expect("tile_buf length matches bbox_h * bbox_w");
2582            Ok(edgefirst_decoder::Segmentation {
2583                xmin,
2584                ymin,
2585                xmax,
2586                ymax,
2587                segmentation: tile,
2588            })
2589        })
2590        .collect()
2591}
2592
2593#[allow(clippy::too_many_arguments)]
2594fn scaled_run<P: Copy + Sync>(
2595    detect: &[crate::DetectBox],
2596    coeff_all: &[f32],
2597    protos: &[P],
2598    proto_h: usize,
2599    proto_w: usize,
2600    num_protos: usize,
2601    letterbox: Option<[f32; 4]>,
2602    width: u32,
2603    height: u32,
2604    acc_scale: f32,
2605    load_f32: impl Fn(&P, f32) -> f32 + Copy + Sync,
2606) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
2607    let (lx0, lw, ly0, lh) = match letterbox {
2608        Some([lx0, ly0, lx1, ly1]) => {
2609            let lw = (lx1 - lx0).max(f32::EPSILON);
2610            let lh = (ly1 - ly0).max(f32::EPSILON);
2611            (lx0, lw, ly0, lh)
2612        }
2613        None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
2614    };
2615    let out_w = width as usize;
2616    let out_h = height as usize;
2617    let stride_y = proto_w * num_protos;
2618
2619    // Parallelise across detections. Each detection produces an
2620    // independent ndarray::Array3<u8> tile from a read-only proto slice +
2621    // its own coeff slice; no shared mutable state.
2622    //
2623    // Algorithm (restores the spirit of PR #54's batched-GEMM optimisation
2624    // that PR #51's f16 dispatch refactor inadvertently removed):
2625    //
2626    //   1. Map the output bbox back to a proto-plane ROI (with 1-px margin
2627    //      so the bilinear sampling at the output edges has neighbours).
2628    //   2. Precompute *f32 logits* at every proto pixel inside that ROI by
2629    //      doing a single K-wide dot product per proto pixel — once, not
2630    //      once per output pixel.
2631    //   3. For each output pixel, bilinear-interpolate the scalar f32 logit
2632    //      from the 4 surrounding proto-roi pixels, apply sigmoid, and
2633    //      threshold to {0, 255}.
2634    //
2635    // For typical YOLO-seg: proto_roi ~ 30×30 = 900 px × K=32 = 28.8K dot
2636    // ops vs the legacy "bilinear sample then dot at every output pixel"
2637    // which costs bbox_h × bbox_w × 4 × K = ~1.3M ops at 100×100 output
2638    // bbox. ~45× fewer FMAs at this size; the bilinear upsample of a
2639    // scalar plane (no inner K loop) is comparatively negligible.
2640    detect
2641        .par_iter()
2642        .enumerate()
2643        .map(|(i, det)| {
2644            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
2645            let bbox = det.bbox.to_canonical();
2646            let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
2647            let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
2648            let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
2649            let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
2650            let px0 = (xmin * out_w as f32).round() as usize;
2651            let py0 = (ymin * out_h as f32).round() as usize;
2652            let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
2653            let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
2654            let bbox_w = px1.saturating_sub(px0).max(1);
2655            let bbox_h = py1.saturating_sub(py0).max(1);
2656
2657            // Step 1 — proto-plane ROI for this detection's output bbox.
2658            // Map the four output bbox corners back to proto coords and
2659            // expand by 1 pixel in each direction so the bilinear sampler
2660            // at the bbox boundary has both neighbours.
2661            let sample_x_at = |px: f32| -> f32 {
2662                let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
2663                model_x_norm * proto_w as f32 - 0.5
2664            };
2665            let sample_y_at = |py: f32| -> f32 {
2666                let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
2667                model_y_norm * proto_h as f32 - 0.5
2668            };
2669            let s_x_min = sample_x_at(px0 as f32);
2670            let s_x_max = sample_x_at((px1 as f32) - 1.0);
2671            let s_y_min = sample_y_at(py0 as f32);
2672            let s_y_max = sample_y_at((py1 as f32) - 1.0);
2673            // Floor min, ceil max+1 to include both bilinear neighbours.
2674            // Start indices are used as direct bases into `protos`, so clamp
2675            // them to the last valid index, not to the exclusive upper bound.
2676            let proto_x0 = (s_x_min.floor() as isize)
2677                .max(0)
2678                .min(proto_w.saturating_sub(1) as isize) as usize;
2679            let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2680            let proto_y0 = (s_y_min.floor() as isize)
2681                .max(0)
2682                .min(proto_h.saturating_sub(1) as isize) as usize;
2683            let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2684            let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2685            let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2686
2687            // Step 2 — precompute f32 logits at every proto-roi pixel.
2688            // logits[(py - proto_y0) * roi_w + (px - proto_x0)] = dot(coeff, proto[py, px, :])
2689            //
2690            // Since the final threshold is `logit > 0` (O1) and bilinear
2691            // interpolation is a positive-weight linear combination,
2692            // `acc_scale * interp(logits) > 0 ⟺ interp(logits) > 0` when
2693            // acc_scale > 0. We therefore skip the per-pixel `acc_scale *`
2694            // multiply entirely, storing raw dot products.
2695            if !acc_scale.is_finite() || acc_scale <= 0.0 {
2696                return Err(crate::Error::NotSupported(format!(
2697                    "acc_scale must be finite and positive for sign-threshold optimization (got {acc_scale})"
2698                )));
2699            }
2700            let _ = acc_scale; // Scale-invariant: only sign matters.
2701            let mut logits = vec![0.0_f32; roi_h * roi_w];
2702            for ly_idx in 0..roi_h {
2703                let py = proto_y0 + ly_idx;
2704                let row_base = py * stride_y + proto_x0 * num_protos;
2705                for lx_idx in 0..roi_w {
2706                    let pix_base = row_base + lx_idx * num_protos;
2707                    let mut acc = 0.0_f32;
2708                    // 4-wide unroll to help auto-vectorization.
2709                    let mut k = 0;
2710                    let chunks = num_protos / 4;
2711                    for _ in 0..chunks {
2712                        acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0)
2713                            + coeff[k + 1] * load_f32(&protos[pix_base + k + 1], 0.0)
2714                            + coeff[k + 2] * load_f32(&protos[pix_base + k + 2], 0.0)
2715                            + coeff[k + 3] * load_f32(&protos[pix_base + k + 3], 0.0);
2716                        k += 4;
2717                    }
2718                    while k < num_protos {
2719                        acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0);
2720                        k += 1;
2721                    }
2722                    logits[ly_idx * roi_w + lx_idx] = acc;
2723                }
2724            }
2725
2726            // Step 3 — bilinear upsample logits → binary mask.
2727            //
2728            // O1: sigmoid(x) > 0.5 ⟺ x > 0 (sigmoid is strictly monotonic,
2729            // and acc_scale > 0 preserves sign). The sign threshold replaces
2730            // the old fast_sigmoid approximation, saving ~15 cycles/pixel.
2731            //
2732            // O5: Pre-compute bilinear sample coordinates. sample_x_at /
2733            // sample_y_at depend only on pixel index, not on logit values.
2734            // Building lookup tables avoids redundant float ops in the inner
2735            // loop (floor, clamp, isize cast per pixel).
2736            let roi_last_x = roi_w.saturating_sub(1);
2737            let roi_last_y = roi_h.saturating_sub(1);
2738
2739            // X-coordinate LUT (shared across all rows).
2740            let x_coords: Vec<(u32, u32, f32)> = (0..bbox_w)
2741                .map(|xi| {
2742                    let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2743                    let x_floor = sample_x.floor();
2744                    let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as u32;
2745                    let x_hi = (x_lo as usize + 1).min(roi_w - 1) as u32;
2746                    let x_frac = (sample_x - x_floor).clamp(0.0, 1.0);
2747                    (x_lo, x_hi, x_frac)
2748                })
2749                .collect();
2750
2751            // Write the output tile through a contiguous slice to avoid
2752            // ndarray's per-element bounds checks + stride arithmetic.
2753            let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2754            for yi in 0..bbox_h {
2755                let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2756                let y_floor = sample_y.floor();
2757                let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2758                let y_hi = (y_lo + 1).min(roi_h - 1);
2759                let y_frac = (sample_y - y_floor).clamp(0.0, 1.0);
2760                let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2761                let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2762                let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2763                for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2764                    let (xl, xh) = (x_lo as usize, x_hi as usize);
2765                    let l0 = row_lo[xl] + (row_lo[xh] - row_lo[xl]) * x_frac;
2766                    let l1 = row_hi[xl] + (row_hi[xh] - row_hi[xl]) * x_frac;
2767                    let logit = l0 + (l1 - l0) * y_frac;
2768                    out_row[xi] = if logit > 0.0 { 255 } else { 0 };
2769                }
2770            }
2771            // Wrap into the expected Array3<u8> shape [bbox_h, bbox_w, 1].
2772            let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2773                .expect("tile_buf length matches bbox_h * bbox_w");
2774            Ok(edgefirst_decoder::Segmentation {
2775                xmin,
2776                ymin,
2777                xmax,
2778                ymax,
2779                segmentation: tile,
2780            })
2781        })
2782        .collect()
2783}
2784
2785#[cfg(test)]
2786mod tests {
2787    use super::CPUProcessor;
2788    use edgefirst_decoder::{BoundingBox, DetectBox, ProtoData, ProtoLayout};
2789    use edgefirst_tensor::{Quantization, Tensor, TensorDyn};
2790
2791    const PROTO_H: usize = 4;
2792    const PROTO_W: usize = 4;
2793    const NUM_PROTOS: usize = 8;
2794
2795    fn det(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> DetectBox {
2796        DetectBox {
2797            bbox: BoundingBox {
2798                xmin,
2799                ymin,
2800                xmax,
2801                ymax,
2802            },
2803            score: 0.9,
2804            label: 0,
2805        }
2806    }
2807
2808    fn make_i8_quant(shape: &[usize], data: &[i8], scale: f32, zp: i32) -> TensorDyn {
2809        let t = Tensor::<i8>::from_slice(data, shape).unwrap();
2810        let t = t
2811            .with_quantization(Quantization::per_tensor(scale, zp))
2812            .unwrap();
2813        TensorDyn::I8(t)
2814    }
2815
2816    fn make_i16_quant(shape: &[usize], data: &[i16], scale: f32, zp: i32) -> TensorDyn {
2817        let t = Tensor::<i16>::from_slice(data, shape).unwrap();
2818        let t = t
2819            .with_quantization(Quantization::per_tensor(scale, zp))
2820            .unwrap();
2821        TensorDyn::I16(t)
2822    }
2823
2824    fn make_i16_raw(shape: &[usize], data: &[i16]) -> TensorDyn {
2825        let t = Tensor::<i16>::from_slice(data, shape).unwrap();
2826        TensorDyn::I16(t)
2827    }
2828
2829    fn make_f32(shape: &[usize], data: &[f32]) -> TensorDyn {
2830        let t = Tensor::<f32>::from_slice(data, shape).unwrap();
2831        TensorDyn::F32(t)
2832    }
2833
2834    fn gen_protos_i8(h: usize, w: usize, k: usize) -> Vec<i8> {
2835        (0..h * w * k).map(|i| (i % 127) as i8).collect()
2836    }
2837
2838    fn gen_coeffs_i16(n: usize, k: usize) -> Vec<i16> {
2839        (0..n * k)
2840            .map(|i| ((i as i32 % 201) - 100) as i16)
2841            .collect()
2842    }
2843
2844    fn gen_coeffs_i8(n: usize, k: usize) -> Vec<i8> {
2845        (0..n * k).map(|i| ((i as i32 % 201) - 100) as i8).collect()
2846    }
2847
2848    // ── Proto-resolution: i16×i8 fast path (quantized) ─────────────
2849
2850    #[test]
2851    fn materialize_proto_i16_i8_quant_produces_masks() {
2852        let cpu = CPUProcessor::new();
2853        let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
2854        let protos = make_i8_quant(
2855            &[PROTO_H, PROTO_W, NUM_PROTOS],
2856            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2857            0.02,
2858            0,
2859        );
2860        let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 0);
2861        let proto_data = ProtoData {
2862            mask_coefficients: coeffs,
2863            protos,
2864            layout: ProtoLayout::Nhwc,
2865        };
2866        let result = cpu.materialize_segmentations(&detect, &proto_data, None);
2867        assert!(result.is_ok(), "materialize failed: {:?}", result.err());
2868        let segs = result.unwrap();
2869        assert_eq!(segs.len(), 1);
2870        let seg = &segs[0];
2871        assert!(seg.segmentation.shape()[0] > 0);
2872        assert!(seg.segmentation.shape()[1] > 0);
2873    }
2874
2875    // ── Proto-resolution: i16 missing quant → f32 fallback ─────────
2876
2877    #[test]
2878    fn materialize_proto_i16_no_quant_falls_back_to_f32() {
2879        let cpu = CPUProcessor::new();
2880        let detect = vec![det(0.2, 0.2, 0.8, 0.8)];
2881        let protos = make_i8_quant(
2882            &[PROTO_H, PROTO_W, NUM_PROTOS],
2883            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2884            0.02,
2885            0,
2886        );
2887        // I16 coefficients WITHOUT quantization — fast path should be
2888        // skipped, f32 fallback should widen raw i16 values.
2889        let coeffs = make_i16_raw(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS));
2890        let proto_data = ProtoData {
2891            mask_coefficients: coeffs,
2892            protos,
2893            layout: ProtoLayout::Nhwc,
2894        };
2895        let result = cpu.materialize_segmentations(&detect, &proto_data, None);
2896        assert!(
2897            result.is_ok(),
2898            "missing coeff quant should fall back to f32 path, got: {:?}",
2899            result.err()
2900        );
2901        assert_eq!(result.unwrap().len(), 1);
2902    }
2903
2904    // ── Scaled: i16×i8 fast path (quantized) ───────────────────────
2905
2906    #[test]
2907    fn materialize_scaled_i16_i8_quant_produces_masks() {
2908        let cpu = CPUProcessor::new();
2909        let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
2910        let protos = make_i8_quant(
2911            &[PROTO_H, PROTO_W, NUM_PROTOS],
2912            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2913            0.02,
2914            0,
2915        );
2916        let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 0);
2917        let proto_data = ProtoData {
2918            mask_coefficients: coeffs,
2919            protos,
2920            layout: ProtoLayout::Nhwc,
2921        };
2922        let result = cpu.materialize_scaled_segmentations(&detect, &proto_data, None, 64, 64);
2923        assert!(
2924            result.is_ok(),
2925            "materialize_scaled failed: {:?}",
2926            result.err()
2927        );
2928        let segs = result.unwrap();
2929        assert_eq!(segs.len(), 1);
2930        let seg = &segs[0];
2931        assert!(seg.segmentation.shape()[0] > 0);
2932        assert!(seg.segmentation.shape()[1] > 0);
2933    }
2934
2935    // ── Scaled: i16 missing quant → f32 fallback ───────────────────
2936
2937    #[test]
2938    fn materialize_scaled_i16_no_quant_falls_back_to_f32() {
2939        let cpu = CPUProcessor::new();
2940        let detect = vec![det(0.2, 0.2, 0.8, 0.8)];
2941        let protos = make_i8_quant(
2942            &[PROTO_H, PROTO_W, NUM_PROTOS],
2943            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2944            0.02,
2945            0,
2946        );
2947        let coeffs = make_i16_raw(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS));
2948        let proto_data = ProtoData {
2949            mask_coefficients: coeffs,
2950            protos,
2951            layout: ProtoLayout::Nhwc,
2952        };
2953        let result = cpu.materialize_scaled_segmentations(&detect, &proto_data, None, 64, 64);
2954        assert!(
2955            result.is_ok(),
2956            "missing coeff quant should fall back to f32 path, got: {:?}",
2957            result.err()
2958        );
2959        assert_eq!(result.unwrap().len(), 1);
2960    }
2961
2962    // ── i16×i8 parity with f32 reference ───────────────────────────
2963
2964    #[test]
2965    fn materialize_proto_i16_i8_matches_f32_reference() {
2966        let cpu = CPUProcessor::new();
2967        let detect = vec![det(0.1, 0.1, 0.9, 0.9), det(0.3, 0.3, 0.7, 0.7)];
2968        let n_det = detect.len();
2969        let scale_c = 0.01_f32;
2970        let scale_p = 0.02_f32;
2971        let raw_protos = gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS);
2972        let raw_coeffs = gen_coeffs_i16(n_det, NUM_PROTOS);
2973
2974        // Build f32 reference (dequantized manually).
2975        let protos_f32: Vec<f32> = raw_protos.iter().map(|&v| v as f32 * scale_p).collect();
2976        let coeffs_f32: Vec<f32> = raw_coeffs.iter().map(|&v| v as f32 * scale_c).collect();
2977        let proto_data_f32 = ProtoData {
2978            mask_coefficients: make_f32(&[n_det, NUM_PROTOS], &coeffs_f32),
2979            protos: make_f32(&[PROTO_H, PROTO_W, NUM_PROTOS], &protos_f32),
2980            layout: ProtoLayout::Nhwc,
2981        };
2982
2983        let proto_data_int = ProtoData {
2984            mask_coefficients: make_i16_quant(&[n_det, NUM_PROTOS], &raw_coeffs, scale_c, 0),
2985            protos: make_i8_quant(&[PROTO_H, PROTO_W, NUM_PROTOS], &raw_protos, scale_p, 0),
2986            layout: ProtoLayout::Nhwc,
2987        };
2988
2989        let segs_f32 = cpu
2990            .materialize_segmentations(&detect, &proto_data_f32, None)
2991            .unwrap();
2992        let segs_int = cpu
2993            .materialize_segmentations(&detect, &proto_data_int, None)
2994            .unwrap();
2995
2996        assert_eq!(segs_f32.len(), segs_int.len());
2997        for (sf, si) in segs_f32.iter().zip(segs_int.iter()) {
2998            assert_eq!(sf.segmentation.shape(), si.segmentation.shape());
2999            let total = sf.segmentation.len();
3000            let mismatches = sf
3001                .segmentation
3002                .iter()
3003                .zip(si.segmentation.iter())
3004                .filter(|(a, b)| a != b)
3005                .count();
3006            let pct = mismatches as f64 / total as f64 * 100.0;
3007            assert!(
3008                pct < 5.0,
3009                "mask mismatch {mismatches}/{total} ({pct:.1}%) exceeds 5% threshold"
3010            );
3011        }
3012    }
3013
3014    // ── Multiple detections ────────────────────────────────────────
3015
3016    #[test]
3017    fn materialize_proto_i16_multiple_detections() {
3018        let cpu = CPUProcessor::new();
3019        let detect = vec![
3020            det(0.0, 0.0, 0.5, 0.5),
3021            det(0.5, 0.5, 1.0, 1.0),
3022            det(0.1, 0.1, 0.3, 0.3),
3023        ];
3024        let protos = make_i8_quant(
3025            &[PROTO_H, PROTO_W, NUM_PROTOS],
3026            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3027            0.02,
3028            0,
3029        );
3030        let coeffs = make_i16_quant(&[3, NUM_PROTOS], &gen_coeffs_i16(3, NUM_PROTOS), 0.01, 0);
3031        let proto_data = ProtoData {
3032            mask_coefficients: coeffs,
3033            protos,
3034            layout: ProtoLayout::Nhwc,
3035        };
3036        let segs = cpu
3037            .materialize_segmentations(&detect, &proto_data, None)
3038            .unwrap();
3039        assert_eq!(segs.len(), 3);
3040    }
3041
3042    // ── Empty detections ───────────────────────────────────────────
3043
3044    #[test]
3045    fn materialize_proto_i16_empty_detections() {
3046        let cpu = CPUProcessor::new();
3047        let detect: Vec<DetectBox> = vec![];
3048        let protos = make_i8_quant(
3049            &[PROTO_H, PROTO_W, NUM_PROTOS],
3050            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3051            0.02,
3052            0,
3053        );
3054        let coeffs = make_i16_quant(&[0, NUM_PROTOS], &[], 0.01, 0);
3055        let proto_data = ProtoData {
3056            mask_coefficients: coeffs,
3057            protos,
3058            layout: ProtoLayout::Nhwc,
3059        };
3060        let segs = cpu
3061            .materialize_segmentations(&detect, &proto_data, None)
3062            .unwrap();
3063        assert!(segs.is_empty());
3064    }
3065
3066    // ── Scaled parity ──────────────────────────────────────────────
3067
3068    #[test]
3069    fn materialize_scaled_i16_i8_matches_f32_reference() {
3070        let cpu = CPUProcessor::new();
3071        let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3072        let scale_c = 0.01_f32;
3073        let scale_p = 0.02_f32;
3074        let raw_protos = gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS);
3075        let raw_coeffs = gen_coeffs_i16(1, NUM_PROTOS);
3076
3077        let protos_f32: Vec<f32> = raw_protos.iter().map(|&v| v as f32 * scale_p).collect();
3078        let coeffs_f32: Vec<f32> = raw_coeffs.iter().map(|&v| v as f32 * scale_c).collect();
3079        let proto_data_f32 = ProtoData {
3080            mask_coefficients: make_f32(&[1, NUM_PROTOS], &coeffs_f32),
3081            protos: make_f32(&[PROTO_H, PROTO_W, NUM_PROTOS], &protos_f32),
3082            layout: ProtoLayout::Nhwc,
3083        };
3084        let proto_data_int = ProtoData {
3085            mask_coefficients: make_i16_quant(&[1, NUM_PROTOS], &raw_coeffs, scale_c, 0),
3086            protos: make_i8_quant(&[PROTO_H, PROTO_W, NUM_PROTOS], &raw_protos, scale_p, 0),
3087            layout: ProtoLayout::Nhwc,
3088        };
3089
3090        let (w, h) = (64_u32, 64_u32);
3091        let segs_f32 = cpu
3092            .materialize_scaled_segmentations(&detect, &proto_data_f32, None, w, h)
3093            .unwrap();
3094        let segs_int = cpu
3095            .materialize_scaled_segmentations(&detect, &proto_data_int, None, w, h)
3096            .unwrap();
3097
3098        assert_eq!(segs_f32.len(), segs_int.len());
3099        for (sf, si) in segs_f32.iter().zip(segs_int.iter()) {
3100            assert_eq!(sf.segmentation.shape(), si.segmentation.shape());
3101            let total = sf.segmentation.len();
3102            let mismatches = sf
3103                .segmentation
3104                .iter()
3105                .zip(si.segmentation.iter())
3106                .filter(|(a, b)| a != b)
3107                .count();
3108            let pct = mismatches as f64 / total as f64 * 100.0;
3109            assert!(
3110                pct < 5.0,
3111                "scaled mask mismatch {mismatches}/{total} ({pct:.1}%) exceeds 5% threshold"
3112            );
3113        }
3114    }
3115
3116    // ── i8×i8 existing path still works (regression) ───────────────
3117
3118    #[test]
3119    fn materialize_proto_i8_i8_regression() {
3120        let cpu = CPUProcessor::new();
3121        let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3122        let protos = make_i8_quant(
3123            &[PROTO_H, PROTO_W, NUM_PROTOS],
3124            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3125            0.02,
3126            0,
3127        );
3128        let coeffs = make_i8_quant(&[1, NUM_PROTOS], &gen_coeffs_i8(1, NUM_PROTOS), 0.01, 0);
3129        let proto_data = ProtoData {
3130            mask_coefficients: coeffs,
3131            protos,
3132            layout: ProtoLayout::Nhwc,
3133        };
3134        let result = cpu.materialize_segmentations(&detect, &proto_data, None);
3135        assert!(result.is_ok(), "i8×i8 regression: {:?}", result.err());
3136        assert_eq!(result.unwrap().len(), 1);
3137    }
3138
3139    // ── Non-zero zero_point ────────────────────────────────────────
3140
3141    #[test]
3142    fn materialize_proto_i16_nonzero_zp() {
3143        let cpu = CPUProcessor::new();
3144        let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3145        let protos = make_i8_quant(
3146            &[PROTO_H, PROTO_W, NUM_PROTOS],
3147            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3148            0.02,
3149            -10,
3150        );
3151        let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 5);
3152        let proto_data = ProtoData {
3153            mask_coefficients: coeffs,
3154            protos,
3155            layout: ProtoLayout::Nhwc,
3156        };
3157        let result = cpu.materialize_segmentations(&detect, &proto_data, None);
3158        assert!(result.is_ok(), "nonzero zp failed: {:?}", result.err());
3159        assert_eq!(result.unwrap().len(), 1);
3160    }
3161
3162    // ── Scaled: non-zero zero_point ────────────────────────────────
3163
3164    #[test]
3165    fn materialize_scaled_i16_nonzero_zp() {
3166        let cpu = CPUProcessor::new();
3167        let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3168        let protos = make_i8_quant(
3169            &[PROTO_H, PROTO_W, NUM_PROTOS],
3170            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3171            0.02,
3172            -10,
3173        );
3174        let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 5);
3175        let proto_data = ProtoData {
3176            mask_coefficients: coeffs,
3177            protos,
3178            layout: ProtoLayout::Nhwc,
3179        };
3180        let result = cpu.materialize_scaled_segmentations(&detect, &proto_data, None, 64, 64);
3181        assert!(
3182            result.is_ok(),
3183            "scaled nonzero zp failed: {:?}",
3184            result.err()
3185        );
3186        assert_eq!(result.unwrap().len(), 1);
3187    }
3188}