Skip to main content

edgefirst_image/cpu/
masks.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use super::CPUProcessor;
5use crate::Result;
6use edgefirst_decoder::{DetectBox, Segmentation};
7use ndarray::Axis;
8use rayon::prelude::*;
9
10impl CPUProcessor {
11    #[allow(clippy::too_many_arguments)]
12    pub(super) fn render_modelpack_segmentation(
13        &mut self,
14        dst_w: usize,
15        dst_h: usize,
16        dst_rs: usize,
17        dst_c: usize,
18        dst_slice: &mut [u8],
19        segmentation: &Segmentation,
20        opacity: f32,
21    ) -> Result<()> {
22        use ndarray_stats::QuantileExt;
23
24        let seg = &segmentation.segmentation;
25        let [seg_height, seg_width, seg_classes] = *seg.shape() else {
26            unreachable!("Array3 did not have [usize; 3] as shape");
27        };
28        let start_y = (dst_h as f32 * segmentation.ymin).round();
29        let end_y = (dst_h as f32 * segmentation.ymax).round();
30        let start_x = (dst_w as f32 * segmentation.xmin).round();
31        let end_x = (dst_w as f32 * segmentation.xmax).round();
32
33        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
34        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
35
36        let start_x_u = (start_x as usize).min(dst_w);
37        let start_y_u = (start_y as usize).min(dst_h);
38        let end_x_u = (end_x as usize).min(dst_w);
39        let end_y_u = (end_y as usize).min(dst_h);
40
41        let argmax = seg.map_axis(Axis(2), |r| r.argmax().unwrap());
42        let get_value_at_nearest = |x: f32, y: f32| -> usize {
43            let x = x.round() as usize;
44            let y = y.round() as usize;
45            argmax
46                .get([y.min(seg_height - 1), x.min(seg_width - 1)])
47                .copied()
48                .unwrap_or(0)
49        };
50
51        for y in start_y_u..end_y_u {
52            for x in start_x_u..end_x_u {
53                let seg_x = (x as f32 - start_x) * scale_x;
54                let seg_y = (y as f32 - start_y) * scale_y;
55                let label = get_value_at_nearest(seg_x, seg_y);
56
57                if label == seg_classes - 1 {
58                    continue;
59                }
60
61                let color = self.colors[label % self.colors.len()];
62
63                let alpha = if opacity == 1.0 {
64                    color[3] as u16
65                } else {
66                    (color[3] as f32 * opacity).round() as u16
67                };
68
69                let dst_index = (y * dst_rs) + (x * dst_c);
70                for c in 0..3 {
71                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
72                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
73                        / 255) as u8;
74                }
75            }
76        }
77
78        Ok(())
79    }
80
81    #[allow(clippy::too_many_arguments)]
82    pub(super) fn render_yolo_segmentation(
83        &mut self,
84        dst_w: usize,
85        dst_h: usize,
86        dst_rs: usize,
87        dst_c: usize,
88        dst_slice: &mut [u8],
89        segmentation: &Segmentation,
90        class: usize,
91        opacity: f32,
92    ) -> Result<()> {
93        let seg = &segmentation.segmentation;
94        let [seg_height, seg_width, classes] = *seg.shape() else {
95            unreachable!("Array3 did not have [usize;3] as shape");
96        };
97        debug_assert_eq!(classes, 1);
98
99        let start_y = (dst_h as f32 * segmentation.ymin).round();
100        let end_y = (dst_h as f32 * segmentation.ymax).round();
101        let start_x = (dst_w as f32 * segmentation.xmin).round();
102        let end_x = (dst_w as f32 * segmentation.xmax).round();
103
104        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
105        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
106
107        let start_x_u = (start_x as usize).min(dst_w);
108        let start_y_u = (start_y as usize).min(dst_h);
109        let end_x_u = (end_x as usize).min(dst_w);
110        let end_y_u = (end_y as usize).min(dst_h);
111
112        for y in start_y_u..end_y_u {
113            for x in start_x_u..end_x_u {
114                let seg_x = ((x as f32 - start_x) * scale_x) as usize;
115                let seg_y = ((y as f32 - start_y) * scale_y) as usize;
116                let val = *seg.get([seg_y, seg_x, 0]).unwrap_or(&0);
117
118                if val < 127 {
119                    continue;
120                }
121
122                let color = self.colors[class % self.colors.len()];
123
124                let alpha = if opacity == 1.0 {
125                    color[3] as u16
126                } else {
127                    (color[3] as f32 * opacity).round() as u16
128                };
129
130                let dst_index = (y * dst_rs) + (x * dst_c);
131                for c in 0..3 {
132                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
133                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
134                        / 255) as u8;
135                }
136            }
137        }
138
139        Ok(())
140    }
141
142    #[allow(clippy::too_many_arguments)]
143    pub(super) fn render_box(
144        &mut self,
145        dst_w: usize,
146        dst_h: usize,
147        dst_rs: usize,
148        dst_c: usize,
149        dst_slice: &mut [u8],
150        detect: &[DetectBox],
151        color_mode: crate::ColorMode,
152    ) -> Result<()> {
153        const LINE_THICKNESS: usize = 3;
154
155        for (idx, d) in detect.iter().enumerate() {
156            use edgefirst_decoder::BoundingBox;
157
158            let color_index = color_mode.index(idx, d.label);
159            let [r, g, b, _] = self.colors[color_index % self.colors.len()];
160            let bbox = d.bbox.to_canonical();
161            let bbox = BoundingBox {
162                xmin: bbox.xmin.clamp(0.0, 1.0),
163                ymin: bbox.ymin.clamp(0.0, 1.0),
164                xmax: bbox.xmax.clamp(0.0, 1.0),
165                ymax: bbox.ymax.clamp(0.0, 1.0),
166            };
167            let inner = [
168                ((dst_w - 1) as f32 * bbox.xmin - 0.5).round() as usize,
169                ((dst_h - 1) as f32 * bbox.ymin - 0.5).round() as usize,
170                ((dst_w - 1) as f32 * bbox.xmax + 0.5).round() as usize,
171                ((dst_h - 1) as f32 * bbox.ymax + 0.5).round() as usize,
172            ];
173
174            let outer = [
175                inner[0].saturating_sub(LINE_THICKNESS),
176                inner[1].saturating_sub(LINE_THICKNESS),
177                (inner[2] + LINE_THICKNESS).min(dst_w),
178                (inner[3] + LINE_THICKNESS).min(dst_h),
179            ];
180
181            // top line
182            for y in outer[1] + 1..=inner[1] {
183                for x in outer[0] + 1..outer[2] {
184                    let index = (y * dst_rs) + (x * dst_c);
185                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
186                }
187            }
188
189            // left and right lines
190            for y in inner[1]..inner[3] {
191                for x in outer[0] + 1..=inner[0] {
192                    let index = (y * dst_rs) + (x * dst_c);
193                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
194                }
195
196                for x in inner[2]..outer[2] {
197                    let index = (y * dst_rs) + (x * dst_c);
198                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
199                }
200            }
201
202            // bottom line
203            for y in inner[3]..outer[3] {
204                for x in outer[0] + 1..outer[2] {
205                    let index = (y * dst_rs) + (x * dst_c);
206                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
207                }
208            }
209        }
210        Ok(())
211    }
212
213    /// Materialize segmentation masks from proto data into `Vec<Segmentation>`.
214    ///
215    /// This is the CPU-side decode step of the hybrid mask rendering path:
216    /// call this to get pre-decoded masks, then pass them to
217    /// [`draw_decoded_masks`](crate::ImageProcessorTrait::draw_decoded_masks) for GPU overlay.
218    /// Benchmarks show this hybrid path (CPU decode + GL overlay) is faster
219    /// than the fused GPU `draw_proto_masks` on all tested platforms.
220    ///
221    /// Optimized: fused dequantization + dot product avoids a 3.1MB f32
222    /// allocation for the full proto tensor. Uses fast sigmoid approximation.
223    pub fn materialize_segmentations(
224        &self,
225        detect: &[crate::DetectBox],
226        proto_data: &crate::ProtoData,
227        letterbox: Option<[f32; 4]>,
228    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
229        use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
230
231        let _span = tracing::trace_span!(
232            "materialize_masks",
233            mode = "proto",
234            n_detections = detect.len(),
235        )
236        .entered();
237
238        if detect.is_empty() {
239            return Ok(Vec::new());
240        }
241        let proto_shape = proto_data.protos.shape();
242        if proto_shape.len() != 3 {
243            return Err(crate::Error::InvalidShape(format!(
244                "protos tensor must be rank-3, got {proto_shape:?}"
245            )));
246        }
247        // Interpret shape based on physical layout.
248        let (proto_h, proto_w, num_protos) = match proto_data.layout {
249            edgefirst_decoder::ProtoLayout::Nhwc => {
250                (proto_shape[0], proto_shape[1], proto_shape[2])
251            }
252            edgefirst_decoder::ProtoLayout::Nchw => {
253                (proto_shape[1], proto_shape[2], proto_shape[0])
254            }
255        };
256        let coeff_shape = proto_data.mask_coefficients.shape();
257        if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
258            return Err(crate::Error::InvalidShape(format!(
259                "mask_coefficients shape {coeff_shape:?} incompatible with protos \
260                 {proto_shape:?} (expected [N, {num_protos}])"
261            )));
262        }
263        if coeff_shape[0] == 0 {
264            return Ok(Vec::new());
265        }
266        if coeff_shape[0] != detect.len() {
267            return Err(crate::Error::Internal(format!(
268                "mask_coefficients rows {} != detection count {}",
269                coeff_shape[0],
270                detect.len()
271            )));
272        }
273
274        // Precompute inverse letterbox scale for output-coord conversion.
275        let (lx0, inv_lw, ly0, inv_lh) = match letterbox {
276            Some([lx0, ly0, lx1, ly1]) => {
277                let lw = lx1 - lx0;
278                let lh = ly1 - ly0;
279                (
280                    lx0,
281                    if lw > 0.0 { 1.0 / lw } else { 1.0 },
282                    ly0,
283                    if lh > 0.0 { 1.0 / lh } else { 1.0 },
284                )
285            }
286            None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
287        };
288
289        // Fast integer path: when both coefficients and protos are I8 with
290        // per-tensor quantization, use the all-integer kernel (same dot
291        // product infrastructure as the scaled path, but at proto resolution
292        // without bilinear upsampling). Output is binary {0, 255}.
293        // Falls through to the general f32 dequant path for per-channel
294        // quantization or other unsupported modes.
295        if proto_data.mask_coefficients.dtype() == DType::I8
296            && proto_data.protos.dtype() == DType::I8
297        {
298            let coeff_t = proto_data
299                .mask_coefficients
300                .as_i8()
301                .expect("I8 coefficients");
302            let coeff_m = coeff_t.map()?;
303            let coeff_quant = coeff_t.quantization().ok_or_else(|| {
304                crate::Error::InvalidShape(
305                    "I8 mask_coefficients require quantization metadata".into(),
306                )
307            })?;
308            let proto_t = proto_data.protos.as_i8().expect("I8 protos");
309            let proto_m = proto_t.map()?;
310            let proto_quant = proto_t.quantization().ok_or_else(|| {
311                crate::Error::InvalidShape("I8 protos require quantization metadata".into())
312            })?;
313            match proto_segmentations_i8_i8(
314                detect,
315                coeff_m.as_slice(),
316                coeff_quant,
317                proto_m.as_slice(),
318                proto_quant,
319                proto_h,
320                proto_w,
321                num_protos,
322                lx0,
323                inv_lw,
324                ly0,
325                inv_lh,
326                proto_data.layout,
327            ) {
328                Ok(result) => return Ok(result),
329                Err(crate::Error::NotSupported(_)) => {
330                    // Fall through to the general f32 dequant path below for
331                    // per-channel quantization and other unsupported modes.
332                }
333                Err(e) => return Err(e),
334            }
335        }
336
337        // Fast i16×i8 integer path: i16 coefficients × i8 protos → i32 dot.
338        if proto_data.mask_coefficients.dtype() == DType::I16
339            && proto_data.protos.dtype() == DType::I8
340        {
341            let coeff_t = proto_data
342                .mask_coefficients
343                .as_i16()
344                .expect("I16 coefficients");
345            let coeff_m = coeff_t.map()?;
346            // Skip the integer fast path when coefficient quantization is
347            // absent — the f32 fallback below handles raw i16 by widening.
348            if let Some(coeff_quant) = coeff_t.quantization() {
349                let proto_t = proto_data.protos.as_i8().expect("I8 protos");
350                let proto_m = proto_t.map()?;
351                let proto_quant = proto_t.quantization().ok_or_else(|| {
352                    crate::Error::InvalidShape("I8 protos require quantization metadata".into())
353                })?;
354                match proto_segmentations_i16_i8(
355                    detect,
356                    coeff_m.as_slice(),
357                    coeff_quant,
358                    proto_m.as_slice(),
359                    proto_quant,
360                    proto_h,
361                    proto_w,
362                    num_protos,
363                    lx0,
364                    inv_lw,
365                    ly0,
366                    inv_lh,
367                    proto_data.layout,
368                ) {
369                    Ok(result) => return Ok(result),
370                    Err(crate::Error::NotSupported(_)) => {
371                        // Fall through to the general f32 dequant path.
372                    }
373                    Err(e) => return Err(e),
374                }
375            }
376        }
377
378        // Coefficients may be F32 (from f32 models), F16 (from fp16 models),
379        // I8 (from quantized models — kept raw with quantization), or I16.
380        // For the mask kernel we always need an f32 view (the multiply-accumulate
381        // is done in f32 for precision). Map once and widen once outside the loop.
382        if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw
383            && proto_data.protos.dtype() != DType::I8
384        {
385            return Err(crate::Error::NotSupported(
386                "NCHW proto layout with non-I8 protos is not supported in the f32 fallback path"
387                    .into(),
388            ));
389        }
390        let coeff_f32_storage: Vec<f32>;
391        let coeff_f32_slice: &[f32] = match proto_data.mask_coefficients.dtype() {
392            DType::F32 => {
393                let t = proto_data
394                    .mask_coefficients
395                    .as_f32()
396                    .expect("dtype matched F32");
397                let m = t.map()?;
398                coeff_f32_storage = m.as_slice().to_vec();
399                &coeff_f32_storage[..]
400            }
401            DType::F16 => {
402                let t = proto_data
403                    .mask_coefficients
404                    .as_f16()
405                    .expect("dtype matched F16");
406                let m = t.map()?;
407                coeff_f32_storage = m.as_slice().iter().map(|v| v.to_f32()).collect();
408                &coeff_f32_storage[..]
409            }
410            DType::I8 => {
411                let t = proto_data
412                    .mask_coefficients
413                    .as_i8()
414                    .expect("dtype matched I8");
415                let m = t.map()?;
416                coeff_f32_storage = if let Some(q) = t.quantization() {
417                    use edgefirst_tensor::QuantMode;
418                    let (scale, zp) = match q.mode() {
419                        QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
420                        QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
421                        other => {
422                            return Err(crate::Error::NotSupported(format!(
423                                "I8 mask_coefficients quantization mode {other:?} not supported"
424                            )));
425                        }
426                    };
427                    m.as_slice()
428                        .iter()
429                        .map(|&v| (v as f32 - zp) * scale)
430                        .collect()
431                } else {
432                    m.as_slice().iter().map(|&v| v as f32).collect()
433                };
434                &coeff_f32_storage[..]
435            }
436            DType::I16 => {
437                let t = proto_data
438                    .mask_coefficients
439                    .as_i16()
440                    .expect("dtype matched I16");
441                let m = t.map()?;
442                coeff_f32_storage = if let Some(q) = t.quantization() {
443                    use edgefirst_tensor::QuantMode;
444                    let (scale, zp) = match q.mode() {
445                        QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
446                        QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
447                        other => {
448                            return Err(crate::Error::NotSupported(format!(
449                                "I16 mask_coefficients quantization mode {other:?} not supported"
450                            )));
451                        }
452                    };
453                    m.as_slice()
454                        .iter()
455                        .map(|&v| (v as f32 - zp) * scale)
456                        .collect()
457                } else {
458                    m.as_slice().iter().map(|&v| v as f32).collect()
459                };
460                &coeff_f32_storage[..]
461            }
462            other => {
463                return Err(crate::Error::InvalidShape(format!(
464                    "mask_coefficients dtype {other:?} not supported; expected F32, F16, I8, or I16"
465                )));
466            }
467        };
468
469        // Hoist the proto tensor map() out of the per-detection loop so the
470        // map-guard is acquired once. Then dispatch per-dtype via a helper
471        // that runs the per-detection kernels in parallel across detections
472        // via rayon. This restores the parallelism that PR #54 added and
473        // PR #51 (EDGEAI-1244 f16 refactor) inadvertently removed.
474        match proto_data.protos.dtype() {
475            DType::I8 => {
476                let t = proto_data.protos.as_i8().expect("dtype matched I8");
477                let quant = t.quantization().ok_or_else(|| {
478                    crate::Error::InvalidShape("I8 protos require quantization metadata".into())
479                })?;
480                let m = t.map()?;
481                let src_slice = m.as_slice();
482                let transposed_storage =
483                    if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw {
484                        let hw = proto_h * proto_w;
485                        let mut nhwc = vec![0i8; hw * num_protos];
486                        for c in 0..num_protos {
487                            let plane = &src_slice[c * hw..(c + 1) * hw];
488                            for px in 0..hw {
489                                nhwc[px * num_protos + c] = plane[px];
490                            }
491                        }
492                        Some(nhwc)
493                    } else {
494                        None
495                    };
496                let protos_slice = transposed_storage.as_deref().unwrap_or(src_slice);
497                detect
498                    .par_iter()
499                    .enumerate()
500                    .map(|(i, det)| {
501                        let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
502                        let (x0, y0, x1, y1, roi_w, roi_h) =
503                            bbox_to_proto_roi(det, proto_w, proto_h);
504                        let mask = fused_dequant_dot_sign_i8_slice(
505                            protos_slice,
506                            coeff,
507                            quant,
508                            proto_h,
509                            proto_w,
510                            y0,
511                            x0,
512                            roi_h,
513                            roi_w,
514                            num_protos,
515                        )?;
516                        Ok(seg_from_roi(
517                            mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
518                        ))
519                    })
520                    .collect()
521            }
522            DType::F32 => {
523                let t = proto_data.protos.as_f32().expect("dtype matched F32");
524                let m = t.map()?;
525                let protos_slice = m.as_slice();
526                detect
527                    .par_iter()
528                    .enumerate()
529                    .map(|(i, det)| {
530                        let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
531                        let (x0, y0, x1, y1, roi_w, roi_h) =
532                            bbox_to_proto_roi(det, proto_w, proto_h);
533                        let mask = fused_dot_sign_f32_slice(
534                            protos_slice,
535                            coeff,
536                            proto_h,
537                            proto_w,
538                            y0,
539                            x0,
540                            roi_h,
541                            roi_w,
542                            num_protos,
543                        );
544                        Ok(seg_from_roi(
545                            mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
546                        ))
547                    })
548                    .collect()
549            }
550            DType::F16 => {
551                let t = proto_data.protos.as_f16().expect("dtype matched F16");
552                let m = t.map()?;
553                let protos_slice = m.as_slice();
554                detect
555                    .par_iter()
556                    .enumerate()
557                    .map(|(i, det)| {
558                        let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
559                        let (x0, y0, x1, y1, roi_w, roi_h) =
560                            bbox_to_proto_roi(det, proto_w, proto_h);
561                        let mask = fused_dot_sign_f16_slice(
562                            protos_slice,
563                            coeff,
564                            proto_h,
565                            proto_w,
566                            y0,
567                            x0,
568                            roi_h,
569                            roi_w,
570                            num_protos,
571                        );
572                        Ok(seg_from_roi(
573                            mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
574                        ))
575                    })
576                    .collect()
577            }
578            other => Err(crate::Error::InvalidShape(format!(
579                "proto tensor dtype {other:?} not supported"
580            ))),
581        }
582    }
583
584    /// Produce per-detection masks at `(width, height)` pixel resolution by
585    /// upsampling the full proto plane once then cropping per bbox. Each
586    /// `det.bbox` is assumed to be in model-input normalized coordinates
587    /// (the convention used by the decoder output); when `letterbox` is
588    /// `Some`, `(width, height)` are original-content pixel dims and the
589    /// inverse letterbox transform is applied to both the bbox (for the
590    /// crop region and returned `Segmentation` metadata) and each output
591    /// pixel (for proto-plane sampling). Mask values are binary
592    /// `uint8 {0, 255}` after thresholding sigmoid > 0.5.
593    ///
594    /// Used by [`ImageProcessor::materialize_masks`] when the caller selects
595    /// [`MaskResolution::Scaled`](crate::MaskResolution::Scaled).
596    pub fn materialize_scaled_segmentations(
597        &self,
598        detect: &[crate::DetectBox],
599        proto_data: &crate::ProtoData,
600        letterbox: Option<[f32; 4]>,
601        width: u32,
602        height: u32,
603    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
604        use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
605
606        let _span = tracing::trace_span!(
607            "materialize_masks",
608            mode = "scaled",
609            n_detections = detect.len(),
610            width,
611            height,
612        )
613        .entered();
614
615        if detect.is_empty() {
616            return Ok(Vec::new());
617        }
618        if width == 0 || height == 0 {
619            return Err(crate::Error::InvalidShape(
620                "Scaled mask width/height must be positive".into(),
621            ));
622        }
623        let proto_shape = proto_data.protos.shape();
624        if proto_shape.len() != 3 {
625            return Err(crate::Error::InvalidShape(format!(
626                "protos tensor must be rank-3, got {proto_shape:?}"
627            )));
628        }
629        // Interpret shape based on physical layout.
630        let (proto_h, proto_w, num_protos) = match proto_data.layout {
631            edgefirst_decoder::ProtoLayout::Nhwc => {
632                (proto_shape[0], proto_shape[1], proto_shape[2])
633            }
634            edgefirst_decoder::ProtoLayout::Nchw => {
635                (proto_shape[1], proto_shape[2], proto_shape[0])
636            }
637        };
638        let coeff_shape = proto_data.mask_coefficients.shape();
639        if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
640            return Err(crate::Error::InvalidShape(format!(
641                "mask_coefficients shape {coeff_shape:?} incompatible with protos \
642                 {proto_shape:?}"
643            )));
644        }
645        if coeff_shape[0] == 0 {
646            return Ok(Vec::new());
647        }
648        if coeff_shape[0] != detect.len() {
649            return Err(crate::Error::Internal(format!(
650                "mask_coefficients rows {} != detection count {}",
651                coeff_shape[0],
652                detect.len()
653            )));
654        }
655
656        // Fast integer path: when both coefficients and protos are I8, use
657        // the all-integer kernel (i8×i8→i32 dot product, sign-shortcut
658        // bilinear). No floating-point conversion at all.
659        // Falls through to the general f32 dequant path for per-channel
660        // quantization or other unsupported modes.
661        if proto_data.mask_coefficients.dtype() == DType::I8
662            && proto_data.protos.dtype() == DType::I8
663        {
664            let coeff_t = proto_data
665                .mask_coefficients
666                .as_i8()
667                .expect("I8 coefficients");
668            let coeff_m = coeff_t.map()?;
669            let coeff_quant = coeff_t.quantization().ok_or_else(|| {
670                crate::Error::InvalidShape(
671                    "I8 mask_coefficients require quantization metadata".into(),
672                )
673            })?;
674            let proto_t = proto_data.protos.as_i8().expect("I8 protos");
675            let proto_m = proto_t.map()?;
676            let proto_quant = proto_t.quantization().ok_or_else(|| {
677                crate::Error::InvalidShape("I8 protos require quantization metadata".into())
678            })?;
679            match scaled_segmentations_i8_i8(
680                detect,
681                coeff_m.as_slice(),
682                coeff_quant,
683                proto_m.as_slice(),
684                proto_quant,
685                proto_h,
686                proto_w,
687                num_protos,
688                letterbox,
689                width,
690                height,
691                proto_data.layout,
692            ) {
693                Ok(result) => return Ok(result),
694                Err(crate::Error::NotSupported(_)) => {
695                    // Fall through to the general f32 dequant path below for
696                    // per-channel quantization and other unsupported modes.
697                }
698                Err(e) => return Err(e),
699            }
700        }
701
702        // Fast i16×i8 integer path: i16 coefficients × i8 protos.
703        if proto_data.mask_coefficients.dtype() == DType::I16
704            && proto_data.protos.dtype() == DType::I8
705        {
706            let coeff_t = proto_data
707                .mask_coefficients
708                .as_i16()
709                .expect("I16 coefficients");
710            let coeff_m = coeff_t.map()?;
711            // Skip the integer fast path when coefficient quantization is
712            // absent — the f32 fallback below handles raw i16 by widening.
713            if let Some(coeff_quant) = coeff_t.quantization() {
714                let proto_t = proto_data.protos.as_i8().expect("I8 protos");
715                let proto_m = proto_t.map()?;
716                let proto_quant = proto_t.quantization().ok_or_else(|| {
717                    crate::Error::InvalidShape("I8 protos require quantization metadata".into())
718                })?;
719                match scaled_segmentations_i16_i8(
720                    detect,
721                    coeff_m.as_slice(),
722                    coeff_quant,
723                    proto_m.as_slice(),
724                    proto_quant,
725                    proto_h,
726                    proto_w,
727                    num_protos,
728                    letterbox,
729                    width,
730                    height,
731                    proto_data.layout,
732                ) {
733                    Ok(result) => return Ok(result),
734                    Err(crate::Error::NotSupported(_)) => {}
735                    Err(e) => return Err(e),
736                }
737            }
738        }
739
740        // Fallback: widen coefficients to f32 for the float-path kernels.
741        if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw
742            && proto_data.protos.dtype() != DType::I8
743        {
744            return Err(crate::Error::NotSupported(
745                "NCHW proto layout with non-I8 protos is not supported in the f32 fallback path"
746                    .into(),
747            ));
748        }
749        let coeff_f32: Vec<f32> = match proto_data.mask_coefficients.dtype() {
750            DType::F32 => {
751                let t = proto_data.mask_coefficients.as_f32().expect("F32");
752                let m = t.map()?;
753                m.as_slice().to_vec()
754            }
755            DType::F16 => {
756                let t = proto_data.mask_coefficients.as_f16().expect("F16");
757                let m = t.map()?;
758                m.as_slice().iter().map(|v| v.to_f32()).collect()
759            }
760            DType::I8 => {
761                // Dequantize I8 coefficients to f32 for the float proto path.
762                let t = proto_data.mask_coefficients.as_i8().expect("I8");
763                let m = t.map()?;
764                let q = t.quantization().ok_or_else(|| {
765                    crate::Error::InvalidShape(
766                        "I8 mask_coefficients require quantization metadata".into(),
767                    )
768                })?;
769                use edgefirst_tensor::QuantMode;
770                let (scale, zp) = match q.mode() {
771                    QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
772                    QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
773                    _ => {
774                        return Err(crate::Error::NotSupported(
775                            "per-channel mask_coefficients not supported".into(),
776                        ))
777                    }
778                };
779                m.as_slice()
780                    .iter()
781                    .map(|&v| (v as f32 - zp) * scale)
782                    .collect()
783            }
784            DType::I16 => {
785                let t = proto_data.mask_coefficients.as_i16().expect("I16");
786                let m = t.map()?;
787                if let Some(q) = t.quantization() {
788                    use edgefirst_tensor::QuantMode;
789                    let (scale, zp) = match q.mode() {
790                        QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
791                        QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
792                        other => {
793                            return Err(crate::Error::NotSupported(format!(
794                                "I16 mask_coefficients quantization mode {other:?} not supported"
795                            )))
796                        }
797                    };
798                    m.as_slice()
799                        .iter()
800                        .map(|&v| (v as f32 - zp) * scale)
801                        .collect()
802                } else {
803                    m.as_slice().iter().map(|&v| v as f32).collect()
804                }
805            }
806            other => {
807                return Err(crate::Error::InvalidShape(format!(
808                    "mask_coefficients dtype {other:?} not supported"
809                )));
810            }
811        };
812
813        match proto_data.protos.dtype() {
814            DType::F32 => {
815                let t = proto_data.protos.as_f32().expect("F32");
816                let m = t.map()?;
817                scaled_segmentations_f32_slice(
818                    detect,
819                    &coeff_f32,
820                    m.as_slice(),
821                    proto_h,
822                    proto_w,
823                    num_protos,
824                    letterbox,
825                    width,
826                    height,
827                )
828            }
829            DType::F16 => {
830                let t = proto_data.protos.as_f16().expect("F16");
831                let m = t.map()?;
832                scaled_segmentations_f16_slice(
833                    detect,
834                    &coeff_f32,
835                    m.as_slice(),
836                    proto_h,
837                    proto_w,
838                    num_protos,
839                    letterbox,
840                    width,
841                    height,
842                )
843            }
844            DType::I8 => {
845                let t = proto_data.protos.as_i8().expect("I8");
846                let m = t.map()?;
847                let quant = t.quantization().ok_or_else(|| {
848                    crate::Error::InvalidShape("I8 protos require quantization metadata".into())
849                })?;
850                let src_slice = m.as_slice();
851                let transposed_storage =
852                    if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw {
853                        let hw = proto_h * proto_w;
854                        let mut nhwc = vec![0i8; hw * num_protos];
855                        for c in 0..num_protos {
856                            let plane = &src_slice[c * hw..(c + 1) * hw];
857                            for px in 0..hw {
858                                nhwc[px * num_protos + c] = plane[px];
859                            }
860                        }
861                        Some(nhwc)
862                    } else {
863                        None
864                    };
865                let protos_slice = transposed_storage.as_deref().unwrap_or(src_slice);
866                scaled_segmentations_i8_slice(
867                    detect,
868                    &coeff_f32,
869                    protos_slice,
870                    proto_h,
871                    proto_w,
872                    num_protos,
873                    quant,
874                    letterbox,
875                    width,
876                    height,
877                )
878            }
879            other => Err(crate::Error::InvalidShape(format!(
880                "proto tensor dtype {other:?} not supported"
881            ))),
882        }
883    }
884}
885
886// =============================================================================
887// Slice-native fused kernels.
888//
889// All kernels take row-major `[H, W, num_protos]` proto slices + `&[f32]`
890// coefficients (widened once from the source dtype at the materialize entry
891// point). Per-dtype variants exist for i8 (with on-the-fly dequant using a
892// tensor-level `Quantization`), f32, and f16; f16 widens to f32 per-element
893// at the FMA site via `half::f16::to_f32()`.
894//
895// On ARMv8.2-FP16 this compiles to `fcvt`; on Cortex-A53 and non-F16C x86 it
896// becomes a soft-float helper. Stage 8 adds explicit intrinsic kernels
897// gated by `#[cfg(target_feature = "fp16")]` / `+f16c`.
898// =============================================================================
899
900/// Map a detection bbox in normalised letterboxed coords to its ROI in
901/// the proto plane (floor xmin/ymin, ceil xmax/ymax, clamp to plane bounds).
902/// Returns `(x0, y0, x1, y1, roi_w, roi_h)` where roi_w/h are guaranteed ≥ 1.
903fn bbox_to_proto_roi(
904    det: &DetectBox,
905    proto_w: usize,
906    proto_h: usize,
907) -> (usize, usize, usize, usize, usize, usize) {
908    let bbox = det.bbox.to_canonical();
909    let xmin = bbox.xmin.clamp(0.0, 1.0);
910    let ymin = bbox.ymin.clamp(0.0, 1.0);
911    let xmax = bbox.xmax.clamp(0.0, 1.0);
912    let ymax = bbox.ymax.clamp(0.0, 1.0);
913    let x0 = ((xmin * proto_w as f32) as usize).min(proto_w.saturating_sub(1));
914    let y0 = ((ymin * proto_h as f32) as usize).min(proto_h.saturating_sub(1));
915    let x1 = ((xmax * proto_w as f32).ceil() as usize).min(proto_w);
916    let y1 = ((ymax * proto_h as f32).ceil() as usize).min(proto_h);
917    let roi_w = x1.saturating_sub(x0).max(1);
918    let roi_h = y1.saturating_sub(y0).max(1);
919    (x0, y0, x1, y1, roi_w, roi_h)
920}
921
922/// Build a `Segmentation` from a per-detection mask + the ROI bounds in
923/// proto coords. Applies the inverse letterbox transform to express the
924/// segmentation bbox in original-image-content normalised space.
925#[allow(clippy::too_many_arguments)]
926fn seg_from_roi(
927    mask: ndarray::Array3<u8>,
928    x0: usize,
929    y0: usize,
930    x1: usize,
931    y1: usize,
932    proto_w: usize,
933    proto_h: usize,
934    lx0: f32,
935    inv_lw: f32,
936    ly0: f32,
937    inv_lh: f32,
938) -> edgefirst_decoder::Segmentation {
939    let seg_xmin = ((x0 as f32 / proto_w as f32) - lx0) * inv_lw;
940    let seg_ymin = ((y0 as f32 / proto_h as f32) - ly0) * inv_lh;
941    let seg_xmax = ((x1 as f32 / proto_w as f32) - lx0) * inv_lw;
942    let seg_ymax = ((y1 as f32 / proto_h as f32) - ly0) * inv_lh;
943    edgefirst_decoder::Segmentation {
944        xmin: seg_xmin.clamp(0.0, 1.0),
945        ymin: seg_ymin.clamp(0.0, 1.0),
946        xmax: seg_xmax.clamp(0.0, 1.0),
947        ymax: seg_ymax.clamp(0.0, 1.0),
948        segmentation: mask,
949    }
950}
951
952// =============================================================================
953// Integer-domain proto-resolution kernel: i8 coefficients × i8 protos → i32
954// → sign threshold → binary {0, 255}.
955//
956// Reuses the same dot product infrastructure as the scaled path (NEON sdot on
957// A55+, smull+sadalp on A53, scalar fallback on x86). Since proto-resolution
958// produces masks at the native proto grid (~30×30 per ROI), there is no
959// bilinear upsampling — just a direct sign threshold per pixel.
960// =============================================================================
961
962/// Proto-resolution mask materialization using integer-domain math.
963///
964/// For each detection, computes the i8×i8 dot product at every proto-ROI pixel,
965/// applies the zero-point correction, and thresholds at sign(logit) → {0, 255}.
966/// Supports both NHWC and NCHW proto layouts.
967#[allow(clippy::too_many_arguments)]
968fn proto_segmentations_i8_i8(
969    detect: &[crate::DetectBox],
970    coeff_all: &[i8],
971    coeff_quant: &edgefirst_tensor::Quantization,
972    protos: &[i8],
973    proto_quant: &edgefirst_tensor::Quantization,
974    proto_h: usize,
975    proto_w: usize,
976    num_protos: usize,
977    lx0: f32,
978    inv_lw: f32,
979    ly0: f32,
980    inv_lh: f32,
981    layout: edgefirst_decoder::ProtoLayout,
982) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
983    use edgefirst_tensor::QuantMode;
984
985    let _span = tracing::trace_span!(
986        "mask_i8_fastpath",
987        n = detect.len(),
988        proto_h,
989        proto_w,
990        num_protos,
991        ?layout,
992    )
993    .entered();
994
995    let zp_c: i32 = match coeff_quant.mode() {
996        QuantMode::PerTensor { zero_point, .. } => zero_point,
997        QuantMode::PerTensorSymmetric { .. } => 0,
998        _ => {
999            return Err(crate::Error::NotSupported(
1000                "per-channel coeff quantization not supported on proto-res i8 path".into(),
1001            ))
1002        }
1003    };
1004    let zp_p: i32 = match proto_quant.mode() {
1005        QuantMode::PerTensor { zero_point, .. } => zero_point,
1006        QuantMode::PerTensorSymmetric { .. } => 0,
1007        _ => {
1008            return Err(crate::Error::NotSupported(
1009                "per-channel proto quantization not supported on proto-res i8 path".into(),
1010            ))
1011        }
1012    };
1013
1014    let hw = proto_h * proto_w;
1015
1016    // Precompute per-pixel proto sums for zero-point correction.
1017    let proto_sums: Vec<i32> = if zp_c != 0 {
1018        match layout {
1019            edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
1020                .map(|px_idx| {
1021                    let base = px_idx * num_protos;
1022                    protos[base..base + num_protos]
1023                        .iter()
1024                        .map(|&v| v as i32)
1025                        .sum()
1026                })
1027                .collect(),
1028            edgefirst_decoder::ProtoLayout::Nchw => {
1029                let mut sums = vec![0i32; hw];
1030                for c in 0..num_protos {
1031                    let plane = &protos[c * hw..];
1032                    for (px, s) in sums.iter_mut().enumerate() {
1033                        *s += plane[px] as i32;
1034                    }
1035                }
1036                sums
1037            }
1038        }
1039    } else {
1040        Vec::new()
1041    };
1042
1043    #[cfg(target_arch = "aarch64")]
1044    let use_dotprod = std::arch::is_aarch64_feature_detected!("dotprod");
1045
1046    detect
1047        .par_iter()
1048        .enumerate()
1049        .map(|(i, det)| {
1050            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1051            let (x0, y0, x1, y1, roi_w, roi_h) = bbox_to_proto_roi(det, proto_w, proto_h);
1052
1053            // Per-detection bias: zp_p·Σc_raw - N·zp_c·zp_p
1054            let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
1055            let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
1056
1057            let mut mask_buf = vec![0u8; roi_h * roi_w];
1058
1059            match layout {
1060                edgefirst_decoder::ProtoLayout::Nhwc => {
1061                    let stride_y = proto_w * num_protos;
1062                    #[cfg(target_arch = "aarch64")]
1063                    {
1064                        if use_dotprod {
1065                            for ly in 0..roi_h {
1066                                let py = y0 + ly;
1067                                let row_base = py * stride_y + x0 * num_protos;
1068                                for lx in 0..roi_w {
1069                                    let pix_base = row_base + lx * num_protos;
1070                                    let proto_px = &protos[pix_base..pix_base + num_protos];
1071                                    let raw_dot = unsafe {
1072                                        dot_i8_neon_dotprod(
1073                                            coeff.as_ptr(),
1074                                            proto_px.as_ptr(),
1075                                            num_protos,
1076                                        )
1077                                    };
1078                                    let correction = if zp_c != 0 {
1079                                        zp_c * proto_sums[py * proto_w + x0 + lx]
1080                                    } else {
1081                                        0
1082                                    };
1083                                    let logit = raw_dot - correction - bias;
1084                                    if logit > 0 {
1085                                        mask_buf[ly * roi_w + lx] = 255;
1086                                    }
1087                                }
1088                            }
1089                        } else {
1090                            for ly in 0..roi_h {
1091                                let py = y0 + ly;
1092                                let row_base = py * stride_y + x0 * num_protos;
1093                                for lx in 0..roi_w {
1094                                    let pix_base = row_base + lx * num_protos;
1095                                    let proto_px = &protos[pix_base..pix_base + num_protos];
1096                                    let raw_dot = unsafe {
1097                                        dot_i8_neon_base(
1098                                            coeff.as_ptr(),
1099                                            proto_px.as_ptr(),
1100                                            num_protos,
1101                                        )
1102                                    };
1103                                    let correction = if zp_c != 0 {
1104                                        zp_c * proto_sums[py * proto_w + x0 + lx]
1105                                    } else {
1106                                        0
1107                                    };
1108                                    let logit = raw_dot - correction - bias;
1109                                    if logit > 0 {
1110                                        mask_buf[ly * roi_w + lx] = 255;
1111                                    }
1112                                }
1113                            }
1114                        }
1115                    }
1116                    #[cfg(not(target_arch = "aarch64"))]
1117                    {
1118                        for ly in 0..roi_h {
1119                            let py = y0 + ly;
1120                            let row_base = py * stride_y + x0 * num_protos;
1121                            for lx in 0..roi_w {
1122                                let pix_base = row_base + lx * num_protos;
1123                                let proto_px = &protos[pix_base..pix_base + num_protos];
1124                                let raw_dot = dot_i8_scalar(coeff, proto_px, num_protos);
1125                                let correction = if zp_c != 0 {
1126                                    zp_c * proto_sums[py * proto_w + x0 + lx]
1127                                } else {
1128                                    0
1129                                };
1130                                let logit = raw_dot - correction - bias;
1131                                if logit > 0 {
1132                                    mask_buf[ly * roi_w + lx] = 255;
1133                                }
1134                            }
1135                        }
1136                    }
1137                }
1138                edgefirst_decoder::ProtoLayout::Nchw => {
1139                    // Channel-major accumulation: for each channel, accumulate
1140                    // coeff[c] * proto[c, py, px] across the ROI. Each channel
1141                    // plane is contiguous, giving excellent sequential read access.
1142                    let mut accum = vec![0i32; roi_h * roi_w];
1143                    for c in 0..num_protos {
1144                        let plane = &protos[c * hw..];
1145                        let coeff_c = coeff[c] as i32;
1146                        for ly in 0..roi_h {
1147                            let py = y0 + ly;
1148                            let row_start = py * proto_w + x0;
1149                            let out_row_start = ly * roi_w;
1150                            for lx in 0..roi_w {
1151                                accum[out_row_start + lx] += coeff_c * plane[row_start + lx] as i32;
1152                            }
1153                        }
1154                    }
1155                    // Apply zero-point correction and threshold.
1156                    for ly in 0..roi_h {
1157                        let py = y0 + ly;
1158                        for lx in 0..roi_w {
1159                            let idx = ly * roi_w + lx;
1160                            let correction = if zp_c != 0 {
1161                                zp_c * proto_sums[py * proto_w + x0 + lx]
1162                            } else {
1163                                0
1164                            };
1165                            let logit = accum[idx] - correction - bias;
1166                            if logit > 0 {
1167                                mask_buf[idx] = 255;
1168                            }
1169                        }
1170                    }
1171                }
1172            }
1173
1174            let mask = ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1175                .expect("mask_buf length matches roi_h * roi_w");
1176            Ok(seg_from_roi(
1177                mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
1178            ))
1179        })
1180        .collect()
1181}
1182
1183// =============================================================================
1184#[allow(clippy::too_many_arguments)]
1185fn proto_segmentations_i16_i8(
1186    detect: &[crate::DetectBox],
1187    coeff_all: &[i16],
1188    coeff_quant: &edgefirst_tensor::Quantization,
1189    protos: &[i8],
1190    proto_quant: &edgefirst_tensor::Quantization,
1191    proto_h: usize,
1192    proto_w: usize,
1193    num_protos: usize,
1194    lx0: f32,
1195    inv_lw: f32,
1196    ly0: f32,
1197    inv_lh: f32,
1198    layout: edgefirst_decoder::ProtoLayout,
1199) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1200    use edgefirst_tensor::QuantMode;
1201
1202    let _span = tracing::trace_span!(
1203        "mask_i16_i8_fastpath",
1204        n = detect.len(),
1205        proto_h,
1206        proto_w,
1207        num_protos,
1208        ?layout,
1209    )
1210    .entered();
1211
1212    let zp_c: i32 = match coeff_quant.mode() {
1213        QuantMode::PerTensor { zero_point, .. } => zero_point,
1214        QuantMode::PerTensorSymmetric { .. } => 0,
1215        _ => {
1216            return Err(crate::Error::NotSupported(
1217                "per-channel coeff quantization not supported on proto-res i16 path".into(),
1218            ))
1219        }
1220    };
1221    let zp_p: i32 = match proto_quant.mode() {
1222        QuantMode::PerTensor { zero_point, .. } => zero_point,
1223        QuantMode::PerTensorSymmetric { .. } => 0,
1224        _ => {
1225            return Err(crate::Error::NotSupported(
1226                "per-channel proto quantization not supported on proto-res i8 path".into(),
1227            ))
1228        }
1229    };
1230
1231    let hw = proto_h * proto_w;
1232
1233    // Precompute per-pixel proto sums for zero-point correction.
1234    let proto_sums: Vec<i32> = if zp_c != 0 {
1235        match layout {
1236            edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
1237                .map(|px_idx| {
1238                    let base = px_idx * num_protos;
1239                    protos[base..base + num_protos]
1240                        .iter()
1241                        .map(|&v| v as i32)
1242                        .sum()
1243                })
1244                .collect(),
1245            edgefirst_decoder::ProtoLayout::Nchw => {
1246                let mut sums = vec![0i32; hw];
1247                for c in 0..num_protos {
1248                    let plane = &protos[c * hw..];
1249                    for (px, s) in sums.iter_mut().enumerate() {
1250                        *s += plane[px] as i32;
1251                    }
1252                }
1253                sums
1254            }
1255        }
1256    } else {
1257        Vec::new()
1258    };
1259
1260    detect
1261        .par_iter()
1262        .enumerate()
1263        .map(|(i, det)| {
1264            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1265            let (x0, y0, x1, y1, roi_w, roi_h) = bbox_to_proto_roi(det, proto_w, proto_h);
1266
1267            // Per-detection bias: zp_p·Σc_raw - N·zp_c·zp_p
1268            let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
1269            let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
1270
1271            let mut mask_buf = vec![0u8; roi_h * roi_w];
1272
1273            match layout {
1274                edgefirst_decoder::ProtoLayout::Nhwc => {
1275                    let stride_y = proto_w * num_protos;
1276                    #[cfg(target_arch = "aarch64")]
1277                    {
1278                        for ly in 0..roi_h {
1279                            let py = y0 + ly;
1280                            let row_base = py * stride_y + x0 * num_protos;
1281                            for lx in 0..roi_w {
1282                                let pix_base = row_base + lx * num_protos;
1283                                let proto_px = &protos[pix_base..pix_base + num_protos];
1284                                let raw_dot = unsafe {
1285                                    dot_i16_i8_neon(coeff.as_ptr(), proto_px.as_ptr(), num_protos)
1286                                };
1287                                let correction = if zp_c != 0 {
1288                                    zp_c * proto_sums[py * proto_w + x0 + lx]
1289                                } else {
1290                                    0
1291                                };
1292                                let logit = raw_dot - correction - bias;
1293                                if logit > 0 {
1294                                    mask_buf[ly * roi_w + lx] = 255;
1295                                }
1296                            }
1297                        }
1298                    }
1299                    #[cfg(not(target_arch = "aarch64"))]
1300                    {
1301                        for ly in 0..roi_h {
1302                            let py = y0 + ly;
1303                            let row_base = py * stride_y + x0 * num_protos;
1304                            for lx in 0..roi_w {
1305                                let pix_base = row_base + lx * num_protos;
1306                                let proto_px = &protos[pix_base..pix_base + num_protos];
1307                                let raw_dot = dot_i16_i8_scalar(coeff, proto_px, num_protos);
1308                                let correction = if zp_c != 0 {
1309                                    zp_c * proto_sums[py * proto_w + x0 + lx]
1310                                } else {
1311                                    0
1312                                };
1313                                let logit = raw_dot - correction - bias;
1314                                if logit > 0 {
1315                                    mask_buf[ly * roi_w + lx] = 255;
1316                                }
1317                            }
1318                        }
1319                    }
1320                }
1321                edgefirst_decoder::ProtoLayout::Nchw => {
1322                    // Channel-major accumulation: for each channel, accumulate
1323                    // coeff[c] * proto[c, py, px] across the ROI. Each channel
1324                    // plane is contiguous, giving excellent sequential read access.
1325                    let mut accum = vec![0i32; roi_h * roi_w];
1326                    for c in 0..num_protos {
1327                        let plane = &protos[c * hw..];
1328                        let coeff_c = coeff[c] as i32;
1329                        for ly in 0..roi_h {
1330                            let py = y0 + ly;
1331                            let row_start = py * proto_w + x0;
1332                            let out_row_start = ly * roi_w;
1333                            for lx in 0..roi_w {
1334                                accum[out_row_start + lx] += coeff_c * plane[row_start + lx] as i32;
1335                            }
1336                        }
1337                    }
1338                    // Apply zero-point correction and threshold.
1339                    for ly in 0..roi_h {
1340                        let py = y0 + ly;
1341                        for lx in 0..roi_w {
1342                            let idx = ly * roi_w + lx;
1343                            let correction = if zp_c != 0 {
1344                                zp_c * proto_sums[py * proto_w + x0 + lx]
1345                            } else {
1346                                0
1347                            };
1348                            let logit = accum[idx] - correction - bias;
1349                            if logit > 0 {
1350                                mask_buf[idx] = 255;
1351                            }
1352                        }
1353                    }
1354                }
1355            }
1356
1357            let mask = ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1358                .expect("mask_buf length matches roi_h * roi_w");
1359            Ok(seg_from_roi(
1360                mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
1361            ))
1362        })
1363        .collect()
1364}
1365
1366// =============================================================================
1367
1368// Sign-threshold proto-resolution kernels (f32/f16/i8 protos with f32 coeffs).
1369//
1370// These replace the sigmoid-computing kernels for the non-i8×i8 fallback paths.
1371// Since downstream always thresholds at > 127, computing sigmoid is wasteful;
1372// sign(dot) > 0 ⟺ sigmoid(dot) > 0.5 gives the same binary result.
1373// =============================================================================
1374
1375/// f32 protos × f32 coefficients → sign threshold → binary {0, 255}.
1376#[allow(clippy::too_many_arguments)]
1377fn fused_dot_sign_f32_slice(
1378    protos: &[f32],
1379    coeff: &[f32],
1380    _proto_h: usize,
1381    proto_w: usize,
1382    y0: usize,
1383    x0: usize,
1384    roi_h: usize,
1385    roi_w: usize,
1386    num_protos: usize,
1387) -> ndarray::Array3<u8> {
1388    let stride_y = proto_w * num_protos;
1389    let mut mask_buf = vec![0u8; roi_h * roi_w];
1390    for y in 0..roi_h {
1391        let row_base = (y0 + y) * stride_y + x0 * num_protos;
1392        let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1393        for (x, out_px) in out_row.iter_mut().enumerate() {
1394            let base = row_base + x * num_protos;
1395            let mut acc = 0.0_f32;
1396            let mut k = 0;
1397            let chunks = num_protos / 4;
1398            for _ in 0..chunks {
1399                acc += coeff[k] * protos[base + k]
1400                    + coeff[k + 1] * protos[base + k + 1]
1401                    + coeff[k + 2] * protos[base + k + 2]
1402                    + coeff[k + 3] * protos[base + k + 3];
1403                k += 4;
1404            }
1405            while k < num_protos {
1406                acc += coeff[k] * protos[base + k];
1407                k += 1;
1408            }
1409            if acc > 0.0 {
1410                *out_px = 255;
1411            }
1412        }
1413    }
1414    ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1415        .expect("mask_buf length matches roi_h * roi_w")
1416}
1417
1418/// f16 protos × f32 coefficients → sign threshold → binary {0, 255}.
1419///
1420/// Two code paths:
1421///
1422/// 1. **x86_64 + F16C + FMA** — explicit intrinsic kernel using
1423///    `_mm256_cvtph_ps` (8-lane f16→f32 widening) + `_mm256_fmadd_ps`.
1424///
1425/// 2. **Scalar fallback** — loop-unrolled by 4 with `half::f16::to_f32()`.
1426#[allow(clippy::too_many_arguments)]
1427fn fused_dot_sign_f16_slice(
1428    protos: &[half::f16],
1429    coeff: &[f32],
1430    _proto_h: usize,
1431    proto_w: usize,
1432    y0: usize,
1433    x0: usize,
1434    roi_h: usize,
1435    roi_w: usize,
1436    num_protos: usize,
1437) -> ndarray::Array3<u8> {
1438    #[cfg(all(
1439        target_arch = "x86_64",
1440        target_feature = "f16c",
1441        target_feature = "fma"
1442    ))]
1443    {
1444        // SAFETY: target-feature gates guarantee F16C + FMA support.
1445        unsafe {
1446            fused_dot_sign_f16_slice_f16c(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
1447        }
1448    }
1449    #[cfg(not(all(
1450        target_arch = "x86_64",
1451        target_feature = "f16c",
1452        target_feature = "fma"
1453    )))]
1454    {
1455        fused_dot_sign_f16_slice_scalar(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
1456    }
1457}
1458
1459/// Scalar f16 sign-threshold kernel — loop-unrolled by 4.
1460#[allow(clippy::too_many_arguments)]
1461fn fused_dot_sign_f16_slice_scalar(
1462    protos: &[half::f16],
1463    coeff: &[f32],
1464    proto_w: usize,
1465    y0: usize,
1466    x0: usize,
1467    roi_h: usize,
1468    roi_w: usize,
1469    num_protos: usize,
1470) -> ndarray::Array3<u8> {
1471    let stride_y = proto_w * num_protos;
1472    let mut mask_buf = vec![0u8; roi_h * roi_w];
1473    for y in 0..roi_h {
1474        let row_base = (y0 + y) * stride_y + x0 * num_protos;
1475        let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1476        for (x, out_px) in out_row.iter_mut().enumerate() {
1477            let base = row_base + x * num_protos;
1478            let mut acc = 0.0_f32;
1479            let mut k = 0;
1480            let chunks = num_protos / 4;
1481            for _ in 0..chunks {
1482                acc += coeff[k] * protos[base + k].to_f32()
1483                    + coeff[k + 1] * protos[base + k + 1].to_f32()
1484                    + coeff[k + 2] * protos[base + k + 2].to_f32()
1485                    + coeff[k + 3] * protos[base + k + 3].to_f32();
1486                k += 4;
1487            }
1488            while k < num_protos {
1489                acc += coeff[k] * protos[base + k].to_f32();
1490                k += 1;
1491            }
1492            if acc > 0.0 {
1493                *out_px = 255;
1494            }
1495        }
1496    }
1497    ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1498        .expect("mask_buf length matches roi_h * roi_w")
1499}
1500
1501/// x86_64 F16C + FMA intrinsic kernel for f16 sign-threshold.
1502///
1503/// Uses `_mm256_cvtph_ps` for hardware 8-lane f16→f32 widening and
1504/// `_mm256_fmadd_ps` for fused multiply-add. Only the sign of the
1505/// accumulated dot product is checked (no sigmoid needed).
1506///
1507/// # Safety
1508///
1509/// Caller must ensure the target CPU supports F16C + FMA.
1510#[cfg(all(
1511    target_arch = "x86_64",
1512    target_feature = "f16c",
1513    target_feature = "fma"
1514))]
1515#[allow(clippy::too_many_arguments)]
1516#[target_feature(enable = "f16c,fma,avx")]
1517unsafe fn fused_dot_sign_f16_slice_f16c(
1518    protos: &[half::f16],
1519    coeff: &[f32],
1520    proto_w: usize,
1521    y0: usize,
1522    x0: usize,
1523    roi_h: usize,
1524    roi_w: usize,
1525    num_protos: usize,
1526) -> ndarray::Array3<u8> {
1527    use core::arch::x86_64::{
1528        _mm256_castps256_ps128, _mm256_cvtph_ps, _mm256_extractf128_ps, _mm256_fmadd_ps,
1529        _mm256_loadu_ps, _mm256_setzero_ps, _mm_add_ps, _mm_cvtss_f32, _mm_hadd_ps,
1530        _mm_loadu_si128,
1531    };
1532
1533    let stride_y = proto_w * num_protos;
1534    let chunks8 = num_protos / 8;
1535    let mut mask_buf = vec![0u8; roi_h * roi_w];
1536
1537    for y in 0..roi_h {
1538        let row_base = (y0 + y) * stride_y + x0 * num_protos;
1539        let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1540        for (x, out_px) in out_row.iter_mut().enumerate() {
1541            let base = row_base + x * num_protos;
1542            let mut acc_v = _mm256_setzero_ps();
1543            let mut k = 0;
1544            for _ in 0..chunks8 {
1545                let p_ptr = protos
1546                    .as_ptr()
1547                    .add(base + k)
1548                    .cast::<core::arch::x86_64::__m128i>();
1549                let raw = _mm_loadu_si128(p_ptr);
1550                let widened = _mm256_cvtph_ps(raw);
1551                let coeffs_v = _mm256_loadu_ps(coeff.as_ptr().add(k));
1552                acc_v = _mm256_fmadd_ps(coeffs_v, widened, acc_v);
1553                k += 8;
1554            }
1555            // Horizontal reduce 8 → 1.
1556            let lo = _mm256_castps256_ps128(acc_v);
1557            let hi = _mm256_extractf128_ps::<1>(acc_v);
1558            let sum4 = _mm_add_ps(lo, hi);
1559            let sum2 = _mm_hadd_ps(sum4, sum4);
1560            let sum1 = _mm_hadd_ps(sum2, sum2);
1561            let mut acc = _mm_cvtss_f32(sum1);
1562
1563            // Scalar tail for num_protos % 8.
1564            while k < num_protos {
1565                acc += coeff[k] * protos[base + k].to_f32();
1566                k += 1;
1567            }
1568
1569            if acc > 0.0 {
1570                *out_px = 255;
1571            }
1572        }
1573    }
1574    ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1575        .expect("mask_buf length matches roi_h * roi_w")
1576}
1577
1578/// i8 protos (with quant) × f32 coefficients → sign threshold → binary {0, 255}.
1579/// Fallback for per-channel quant or mixed-dtype cases where the i8×i8 fast path
1580/// doesn't apply.
1581#[allow(clippy::too_many_arguments)]
1582fn fused_dequant_dot_sign_i8_slice(
1583    protos: &[i8],
1584    coeff: &[f32],
1585    quant: &edgefirst_tensor::Quantization,
1586    _proto_h: usize,
1587    proto_w: usize,
1588    y0: usize,
1589    x0: usize,
1590    roi_h: usize,
1591    roi_w: usize,
1592    num_protos: usize,
1593) -> crate::Result<ndarray::Array3<u8>> {
1594    use edgefirst_tensor::QuantMode;
1595    let stride_y = proto_w * num_protos;
1596
1597    // Precompute scaled coefficients + zp_offset (same as the old sigmoid kernel).
1598    let mut stack_scratch = [0.0_f32; 64];
1599    let mut heap_scratch: Vec<f32>;
1600    let scaled_coeff: &mut [f32] = if num_protos <= stack_scratch.len() {
1601        &mut stack_scratch[..num_protos]
1602    } else {
1603        heap_scratch = vec![0.0_f32; num_protos];
1604        heap_scratch.as_mut_slice()
1605    };
1606    let zp_offset: f32;
1607    match quant.mode() {
1608        QuantMode::PerTensorSymmetric { scale } => {
1609            for k in 0..num_protos {
1610                scaled_coeff[k] = coeff[k] * scale;
1611            }
1612            zp_offset = 0.0;
1613        }
1614        QuantMode::PerTensor { scale, zero_point } => {
1615            for k in 0..num_protos {
1616                scaled_coeff[k] = coeff[k] * scale;
1617            }
1618            zp_offset = zero_point as f32 * scaled_coeff.iter().take(num_protos).sum::<f32>();
1619        }
1620        QuantMode::PerChannelSymmetric { scales, axis } => {
1621            if axis != 2 {
1622                return Err(crate::Error::NotSupported(format!(
1623                    "per-channel quantization on axis {axis} not supported \
1624                     (only channel axis 2 is implemented on this kernel)"
1625                )));
1626            }
1627            for k in 0..num_protos {
1628                scaled_coeff[k] = coeff[k] * scales[k];
1629            }
1630            zp_offset = 0.0;
1631        }
1632        QuantMode::PerChannel {
1633            scales,
1634            zero_points,
1635            axis,
1636        } => {
1637            if axis != 2 {
1638                return Err(crate::Error::NotSupported(format!(
1639                    "per-channel quantization on axis {axis} not supported \
1640                     (only channel axis 2 is implemented on this kernel)"
1641                )));
1642            }
1643            for k in 0..num_protos {
1644                scaled_coeff[k] = coeff[k] * scales[k];
1645            }
1646            zp_offset = (0..num_protos)
1647                .map(|k| scaled_coeff[k] * zero_points[k] as f32)
1648                .sum();
1649        }
1650    }
1651
1652    let mut mask_buf = vec![0u8; roi_h * roi_w];
1653    for y in 0..roi_h {
1654        let row_base = (y0 + y) * stride_y + (x0) * num_protos;
1655        let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1656        for (x, out_px) in out_row.iter_mut().enumerate() {
1657            let base = row_base + x * num_protos;
1658            let mut acc = 0.0_f32;
1659            let mut k = 0;
1660            let chunks = num_protos / 4;
1661            for _ in 0..chunks {
1662                let p0 = protos[base + k] as f32;
1663                let p1 = protos[base + k + 1] as f32;
1664                let p2 = protos[base + k + 2] as f32;
1665                let p3 = protos[base + k + 3] as f32;
1666                acc += scaled_coeff[k] * p0
1667                    + scaled_coeff[k + 1] * p1
1668                    + scaled_coeff[k + 2] * p2
1669                    + scaled_coeff[k + 3] * p3;
1670                k += 4;
1671            }
1672            while k < num_protos {
1673                acc += scaled_coeff[k] * protos[base + k] as f32;
1674                k += 1;
1675            }
1676            if acc > zp_offset {
1677                *out_px = 255;
1678            }
1679        }
1680    }
1681    Ok(ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1682        .expect("mask_buf length matches roi_h * roi_w"))
1683}
1684
1685#[allow(clippy::too_many_arguments)]
1686fn scaled_segmentations_f32_slice(
1687    detect: &[crate::DetectBox],
1688    coeff_all: &[f32],
1689    protos: &[f32],
1690    proto_h: usize,
1691    proto_w: usize,
1692    num_protos: usize,
1693    letterbox: Option<[f32; 4]>,
1694    width: u32,
1695    height: u32,
1696) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1697    scaled_run(
1698        detect,
1699        coeff_all,
1700        protos,
1701        proto_h,
1702        proto_w,
1703        num_protos,
1704        letterbox,
1705        width,
1706        height,
1707        1.0,
1708        |p, _| *p,
1709    )
1710}
1711
1712#[allow(clippy::too_many_arguments)]
1713fn scaled_segmentations_f16_slice(
1714    detect: &[crate::DetectBox],
1715    coeff_all: &[f32],
1716    protos: &[half::f16],
1717    proto_h: usize,
1718    proto_w: usize,
1719    num_protos: usize,
1720    letterbox: Option<[f32; 4]>,
1721    width: u32,
1722    height: u32,
1723) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1724    scaled_run(
1725        detect,
1726        coeff_all,
1727        protos,
1728        proto_h,
1729        proto_w,
1730        num_protos,
1731        letterbox,
1732        width,
1733        height,
1734        1.0,
1735        |p: &half::f16, _| p.to_f32(),
1736    )
1737}
1738
1739#[allow(clippy::too_many_arguments)]
1740fn scaled_segmentations_i8_slice(
1741    detect: &[crate::DetectBox],
1742    coeff_all: &[f32],
1743    protos: &[i8],
1744    proto_h: usize,
1745    proto_w: usize,
1746    num_protos: usize,
1747    quant: &edgefirst_tensor::Quantization,
1748    letterbox: Option<[f32; 4]>,
1749    width: u32,
1750    height: u32,
1751) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1752    use edgefirst_tensor::QuantMode;
1753    // Only per-tensor quantization supported on the scaled-path CPU kernel
1754    // today. Per-channel fits naturally into a future extension (would need
1755    // per-channel scaled coefficients in scaled_run's dot-product precompute).
1756    let (scale, zp) = match quant.mode() {
1757        QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
1758        QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
1759        QuantMode::PerChannel { axis, .. } | QuantMode::PerChannelSymmetric { axis, .. } => {
1760            return Err(crate::Error::NotSupported(format!(
1761                "per-channel quantization (axis={axis}) on scaled seg path \
1762                 not yet supported"
1763            )));
1764        }
1765    };
1766    scaled_run(
1767        detect,
1768        coeff_all,
1769        protos,
1770        proto_h,
1771        proto_w,
1772        num_protos,
1773        letterbox,
1774        width,
1775        height,
1776        scale,
1777        move |p: &i8, _| *p as f32 - zp,
1778    )
1779}
1780
1781// =============================================================================
1782// Integer-domain kernel: i8 coefficients × i8 protos → i32 → sign threshold.
1783//
1784// Eliminates all f32 conversion by working directly with raw quantized values.
1785// The math:
1786//   sign(dot(dequant(coeff), dequant(proto)))
1787//   = sign(Σ (c_raw - zp_c) · (p_raw - zp_p))
1788//   = sign(Σ c_raw·p_raw - zp_c·Σp_raw - zp_p·Σc_raw + N·zp_c·zp_p)
1789//   = sign(sdot(c_raw, p_raw) - zp_c·proto_sum[pixel] - bias_per_det)
1790//
1791// where bias_per_det = zp_p·Σc_raw - N·zp_c·zp_p  (precomputed once per det).
1792// =============================================================================
1793
1794/// Compute i8×i8 dot product (32 elements) → i32.
1795/// Platform-agnostic scalar fallback.
1796#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
1797#[inline(always)]
1798fn dot_i8_scalar(coeff: &[i8], proto: &[i8], n: usize) -> i32 {
1799    let mut acc: i32 = 0;
1800    let chunks = n / 4;
1801    let mut k = 0;
1802    for _ in 0..chunks {
1803        acc += coeff[k] as i32 * proto[k] as i32
1804            + coeff[k + 1] as i32 * proto[k + 1] as i32
1805            + coeff[k + 2] as i32 * proto[k + 2] as i32
1806            + coeff[k + 3] as i32 * proto[k + 3] as i32;
1807        k += 4;
1808    }
1809    while k < n {
1810        acc += coeff[k] as i32 * proto[k] as i32;
1811        k += 1;
1812    }
1813    acc
1814}
1815
1816/// NEON i8×i8→i32 dot product using smull+sadalp (works on ALL aarch64, A53+).
1817#[cfg(target_arch = "aarch64")]
1818#[inline(always)]
1819unsafe fn dot_i8_neon_base(coeff: *const i8, proto: *const i8, n: usize) -> i32 {
1820    use std::arch::aarch64::*;
1821    let mut acc = vdupq_n_s32(0);
1822    let full_chunks = n / 16;
1823    let mut offset = 0usize;
1824    for _ in 0..full_chunks {
1825        let c = vld1q_s8(coeff.add(offset));
1826        let p = vld1q_s8(proto.add(offset));
1827        // Widening multiply + pairwise accumulate (all aarch64).
1828        let lo = vmull_s8(vget_low_s8(c), vget_low_s8(p));
1829        let hi = vmull_high_s8(c, p);
1830        acc = vpadalq_s16(acc, lo);
1831        acc = vpadalq_s16(acc, hi);
1832        offset += 16;
1833    }
1834    // Handle remaining elements (for num_protos=32, full_chunks=2, remainder=0)
1835    let remainder = n - offset;
1836    if remainder >= 8 {
1837        let c = vld1_s8(coeff.add(offset));
1838        let p = vld1_s8(proto.add(offset));
1839        let prod = vmull_s8(c, p);
1840        acc = vpadalq_s16(acc, prod);
1841        offset += 8;
1842    }
1843    let mut scalar_acc = vaddvq_s32(acc);
1844    while offset < n {
1845        scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1846        offset += 1;
1847    }
1848    scalar_acc
1849}
1850
1851/// NEON i8×i8→i32 dot product using sdot (ARMv8.2-A dotprod, A55+).
1852/// Each `sdot` processes 16 i8 lanes → 4 i32 partial sums in one instruction,
1853/// replacing the 3-instruction smull+smull2+sadalp sequence.
1854#[cfg(target_arch = "aarch64")]
1855#[inline(always)]
1856unsafe fn dot_i8_neon_dotprod(coeff: *const i8, proto: *const i8, n: usize) -> i32 {
1857    use std::arch::aarch64::*;
1858    let mut acc = vdupq_n_s32(0);
1859    let full_chunks = n / 16;
1860    let mut offset = 0usize;
1861    for _ in 0..full_chunks {
1862        let c = vld1q_s8(coeff.add(offset));
1863        let p = vld1q_s8(proto.add(offset));
1864        // Enable dotprod extension locally so the assembler accepts sdot
1865        // even when compiling for baseline aarch64 (A53). At runtime we only
1866        // reach this path when HWCAP confirms dotprod support.
1867        let result: int32x4_t;
1868        core::arch::asm!(
1869            ".arch_extension dotprod",
1870            "sdot {acc:v}.4s, {a:v}.16b, {b:v}.16b",
1871            acc = inout(vreg) acc => result,
1872            a = in(vreg) c,
1873            b = in(vreg) p,
1874            options(pure, nomem, nostack),
1875        );
1876        acc = result;
1877        offset += 16;
1878    }
1879    let mut scalar_acc = vaddvq_s32(acc);
1880    // Tail: handle remainder (unlikely for num_protos=32, but correct)
1881    while offset < n {
1882        scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1883        offset += 1;
1884    }
1885    scalar_acc
1886}
1887
1888/// Compute i16×i8 dot product → i32. Platform-agnostic scalar fallback.
1889#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
1890#[inline(always)]
1891fn dot_i16_i8_scalar(coeff: &[i16], proto: &[i8], n: usize) -> i32 {
1892    let mut acc: i32 = 0;
1893    let chunks = n / 4;
1894    let mut k = 0;
1895    for _ in 0..chunks {
1896        acc += coeff[k] as i32 * proto[k] as i32
1897            + coeff[k + 1] as i32 * proto[k + 1] as i32
1898            + coeff[k + 2] as i32 * proto[k + 2] as i32
1899            + coeff[k + 3] as i32 * proto[k + 3] as i32;
1900        k += 4;
1901    }
1902    while k < n {
1903        acc += coeff[k] as i32 * proto[k] as i32;
1904        k += 1;
1905    }
1906    acc
1907}
1908
1909/// NEON i16×i8→i32 dot product using widening multiply-accumulate.
1910/// Processes 8 elements per iteration (vs 16 for i8×i8 dotprod).
1911#[cfg(target_arch = "aarch64")]
1912#[inline(always)]
1913unsafe fn dot_i16_i8_neon(coeff: *const i16, proto: *const i8, n: usize) -> i32 {
1914    use std::arch::aarch64::*;
1915    let mut acc = vdupq_n_s32(0);
1916    let full_chunks = n / 8;
1917    let mut offset = 0usize;
1918    for _ in 0..full_chunks {
1919        let c = vld1q_s16(coeff.add(offset));
1920        let p_raw = vld1_s8(proto.add(offset));
1921        let p = vmovl_s8(p_raw);
1922        acc = vmlal_s16(acc, vget_low_s16(c), vget_low_s16(p));
1923        acc = vmlal_high_s16(acc, c, p);
1924        offset += 8;
1925    }
1926    let mut scalar_acc = vaddvq_s32(acc);
1927    while offset < n {
1928        scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1929        offset += 1;
1930    }
1931    scalar_acc
1932}
1933
1934/// Compute the logit grid using the dotprod (sdot) path.
1935/// Separated into its own function so the compiler inlines the sdot asm fully.
1936#[cfg(target_arch = "aarch64")]
1937#[inline(always)]
1938#[allow(clippy::too_many_arguments)]
1939fn compute_logits_dotprod(
1940    logits: &mut [i32],
1941    coeff: &[i8],
1942    protos: &[i8],
1943    proto_sums: &[i32],
1944    proto_w: usize,
1945    proto_x0: usize,
1946    proto_y0: usize,
1947    roi_w: usize,
1948    roi_h: usize,
1949    stride_y: usize,
1950    num_protos: usize,
1951    zp_c: i32,
1952    bias: i32,
1953) {
1954    for ly_idx in 0..roi_h {
1955        let py = proto_y0 + ly_idx;
1956        let row_base = py * stride_y + proto_x0 * num_protos;
1957        for lx_idx in 0..roi_w {
1958            let pix_base = row_base + lx_idx * num_protos;
1959            let proto_px = &protos[pix_base..pix_base + num_protos];
1960            let raw_dot =
1961                unsafe { dot_i8_neon_dotprod(coeff.as_ptr(), proto_px.as_ptr(), num_protos) };
1962            let correction = if zp_c != 0 {
1963                zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
1964            } else {
1965                0
1966            };
1967            logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
1968        }
1969    }
1970}
1971
1972/// Compute the logit grid using the base NEON path (smull+sadalp).
1973/// Separated into its own function so the compiler inlines the NEON code fully.
1974#[cfg(target_arch = "aarch64")]
1975#[inline(always)]
1976#[allow(clippy::too_many_arguments)]
1977fn compute_logits_base(
1978    logits: &mut [i32],
1979    coeff: &[i8],
1980    protos: &[i8],
1981    proto_sums: &[i32],
1982    proto_w: usize,
1983    proto_x0: usize,
1984    proto_y0: usize,
1985    roi_w: usize,
1986    roi_h: usize,
1987    stride_y: usize,
1988    num_protos: usize,
1989    zp_c: i32,
1990    bias: i32,
1991) {
1992    for ly_idx in 0..roi_h {
1993        let py = proto_y0 + ly_idx;
1994        let row_base = py * stride_y + proto_x0 * num_protos;
1995        for lx_idx in 0..roi_w {
1996            let pix_base = row_base + lx_idx * num_protos;
1997            let proto_px = &protos[pix_base..pix_base + num_protos];
1998            let raw_dot =
1999                unsafe { dot_i8_neon_base(coeff.as_ptr(), proto_px.as_ptr(), num_protos) };
2000            let correction = if zp_c != 0 {
2001                zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2002            } else {
2003                0
2004            };
2005            logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2006        }
2007    }
2008}
2009
2010#[allow(clippy::too_many_arguments)]
2011fn scaled_segmentations_i8_i8(
2012    detect: &[crate::DetectBox],
2013    coeff_all: &[i8],
2014    coeff_quant: &edgefirst_tensor::Quantization,
2015    protos: &[i8],
2016    proto_quant: &edgefirst_tensor::Quantization,
2017    proto_h: usize,
2018    proto_w: usize,
2019    num_protos: usize,
2020    letterbox: Option<[f32; 4]>,
2021    width: u32,
2022    height: u32,
2023    layout: edgefirst_decoder::ProtoLayout,
2024) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
2025    use edgefirst_tensor::QuantMode;
2026
2027    let _span = tracing::trace_span!(
2028        "mask_i8_fastpath",
2029        n = detect.len(),
2030        proto_h,
2031        proto_w,
2032        num_protos,
2033        width,
2034        height,
2035        ?layout,
2036    )
2037    .entered();
2038
2039    let zp_c: i32 = match coeff_quant.mode() {
2040        QuantMode::PerTensor { zero_point, .. } => zero_point,
2041        QuantMode::PerTensorSymmetric { .. } => 0,
2042        _ => {
2043            return Err(crate::Error::NotSupported(
2044                "per-channel coeff quantization not supported".into(),
2045            ))
2046        }
2047    };
2048    let zp_p: i32 = match proto_quant.mode() {
2049        QuantMode::PerTensor { zero_point, .. } => zero_point,
2050        QuantMode::PerTensorSymmetric { .. } => 0,
2051        _ => {
2052            return Err(crate::Error::NotSupported(
2053                "per-channel proto quantization not supported".into(),
2054            ))
2055        }
2056    };
2057
2058    let (lx0, lw, ly0, lh) = match letterbox {
2059        Some([lx0, ly0, lx1, ly1]) => {
2060            let lw = (lx1 - lx0).max(f32::EPSILON);
2061            let lh = (ly1 - ly0).max(f32::EPSILON);
2062            (lx0, lw, ly0, lh)
2063        }
2064        None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
2065    };
2066    let out_w = width as usize;
2067    let out_h = height as usize;
2068    let hw = proto_h * proto_w;
2069
2070    // Precompute proto_sum for the entire proto tensor (zero-point correction).
2071    let proto_sums: Vec<i32> = if zp_c != 0 {
2072        match layout {
2073            edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
2074                .map(|px_idx| {
2075                    let base = px_idx * num_protos;
2076                    let mut s: i32 = 0;
2077                    for k in 0..num_protos {
2078                        s += protos[base + k] as i32;
2079                    }
2080                    s
2081                })
2082                .collect(),
2083            edgefirst_decoder::ProtoLayout::Nchw => {
2084                let mut sums = vec![0i32; hw];
2085                for c in 0..num_protos {
2086                    let plane = &protos[c * hw..];
2087                    for (px, s) in sums.iter_mut().enumerate() {
2088                        *s += plane[px] as i32;
2089                    }
2090                }
2091                sums
2092            }
2093        }
2094    } else {
2095        Vec::new()
2096    };
2097
2098    // Detect dotprod support once, outside the hot loop.
2099    #[cfg(target_arch = "aarch64")]
2100    let use_dotprod = std::arch::is_aarch64_feature_detected!("dotprod");
2101
2102    // For NHWC layout, stride for row navigation.
2103    let stride_y = proto_w * num_protos;
2104
2105    detect
2106        .par_iter()
2107        .enumerate()
2108        .map(|(i, det)| {
2109            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
2110            let bbox = det.bbox.to_canonical();
2111            let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
2112            let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
2113            let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
2114            let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
2115            let px0 = (xmin * out_w as f32).round() as usize;
2116            let py0 = (ymin * out_h as f32).round() as usize;
2117            let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
2118            let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
2119            let bbox_w = px1.saturating_sub(px0).max(1);
2120            let bbox_h = py1.saturating_sub(py0).max(1);
2121
2122            // Map output bbox → proto ROI.
2123            let sample_x_at = |px: f32| -> f32 {
2124                let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
2125                model_x_norm * proto_w as f32 - 0.5
2126            };
2127            let sample_y_at = |py: f32| -> f32 {
2128                let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
2129                model_y_norm * proto_h as f32 - 0.5
2130            };
2131            let s_x_min = sample_x_at(px0 as f32);
2132            let s_x_max = sample_x_at((px1 as f32) - 1.0);
2133            let s_y_min = sample_y_at(py0 as f32);
2134            let s_y_max = sample_y_at((py1 as f32) - 1.0);
2135            let proto_x0 = (s_x_min.floor() as isize)
2136                .max(0)
2137                .min(proto_w.saturating_sub(1) as isize) as usize;
2138            let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2139            let proto_y0 = (s_y_min.floor() as isize)
2140                .max(0)
2141                .min(proto_h.saturating_sub(1) as isize) as usize;
2142            let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2143            let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2144            let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2145
2146            // Per-detection bias.
2147            let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
2148            let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
2149
2150            // Step 2: Compute i32 logits at each proto-ROI pixel.
2151            let mut logits = vec![0_i32; roi_h * roi_w];
2152            match layout {
2153                edgefirst_decoder::ProtoLayout::Nhwc => {
2154                    #[cfg(target_arch = "aarch64")]
2155                    {
2156                        if use_dotprod {
2157                            compute_logits_dotprod(
2158                                &mut logits,
2159                                coeff,
2160                                protos,
2161                                &proto_sums,
2162                                proto_w,
2163                                proto_x0,
2164                                proto_y0,
2165                                roi_w,
2166                                roi_h,
2167                                stride_y,
2168                                num_protos,
2169                                zp_c,
2170                                bias,
2171                            );
2172                        } else {
2173                            compute_logits_base(
2174                                &mut logits,
2175                                coeff,
2176                                protos,
2177                                &proto_sums,
2178                                proto_w,
2179                                proto_x0,
2180                                proto_y0,
2181                                roi_w,
2182                                roi_h,
2183                                stride_y,
2184                                num_protos,
2185                                zp_c,
2186                                bias,
2187                            );
2188                        }
2189                    }
2190                    #[cfg(not(target_arch = "aarch64"))]
2191                    {
2192                        for ly_idx in 0..roi_h {
2193                            let py = proto_y0 + ly_idx;
2194                            let row_base = py * stride_y + proto_x0 * num_protos;
2195                            for lx_idx in 0..roi_w {
2196                                let pix_base = row_base + lx_idx * num_protos;
2197                                let proto_px = &protos[pix_base..pix_base + num_protos];
2198                                let raw_dot = dot_i8_scalar(coeff, proto_px, num_protos);
2199                                let correction = if zp_c != 0 {
2200                                    zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2201                                } else {
2202                                    0
2203                                };
2204                                logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2205                            }
2206                        }
2207                    }
2208                }
2209                edgefirst_decoder::ProtoLayout::Nchw => {
2210                    // Channel-major accumulation: contiguous reads per channel plane.
2211                    for c in 0..num_protos {
2212                        let plane = &protos[c * hw..];
2213                        let coeff_c = coeff[c] as i32;
2214                        for ly_idx in 0..roi_h {
2215                            let py = proto_y0 + ly_idx;
2216                            let row_start = py * proto_w + proto_x0;
2217                            let out_row_start = ly_idx * roi_w;
2218                            for lx_idx in 0..roi_w {
2219                                logits[out_row_start + lx_idx] +=
2220                                    coeff_c * plane[row_start + lx_idx] as i32;
2221                            }
2222                        }
2223                    }
2224                    // Apply zero-point correction and per-detection bias.
2225                    for ly_idx in 0..roi_h {
2226                        let py = proto_y0 + ly_idx;
2227                        for lx_idx in 0..roi_w {
2228                            let idx = ly_idx * roi_w + lx_idx;
2229                            let correction = if zp_c != 0 {
2230                                zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2231                            } else {
2232                                0
2233                            };
2234                            logits[idx] -= correction + bias;
2235                        }
2236                    }
2237                }
2238            }
2239
2240            // Step 3: Bilinear upsample i32 logits → binary mask with
2241            // sign-shortcut (skip interpolation when all 4 neighbors agree).
2242            let roi_last_x = roi_w.saturating_sub(1);
2243            let roi_last_y = roi_h.saturating_sub(1);
2244
2245            // X-coordinate LUT with fixed-point fraction (scale 1024).
2246            const FRAC_BITS: i32 = 10;
2247            const FRAC_SCALE: i32 = 1 << FRAC_BITS; // 1024
2248            let x_coords: Vec<(usize, usize, i32)> = (0..bbox_w)
2249                .map(|xi| {
2250                    let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2251                    let x_floor = sample_x.floor();
2252                    let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as usize;
2253                    let x_hi = (x_lo + 1).min(roi_w - 1);
2254                    let x_frac = ((sample_x - x_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2255                    (x_lo, x_hi, x_frac)
2256                })
2257                .collect();
2258
2259            let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2260            for yi in 0..bbox_h {
2261                let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2262                let y_floor = sample_y.floor();
2263                let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2264                let y_hi = (y_lo + 1).min(roi_h - 1);
2265                let y_frac = ((sample_y - y_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2266                let y_frac_inv = FRAC_SCALE - y_frac;
2267                let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2268                let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2269                let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2270
2271                for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2272                    let tl = row_lo[x_lo];
2273                    let tr = row_lo[x_hi];
2274                    let bl = row_hi[x_lo];
2275                    let br = row_hi[x_hi];
2276
2277                    // Sign-shortcut: if all 4 corners have the same sign,
2278                    // the bilinear interpolation (positive-weight combination)
2279                    // preserves that sign. Skip arithmetic for ~80% of pixels.
2280                    if (tl & tr & bl & br) < 0 {
2281                        // All negative → output 0 (already zero).
2282                        continue;
2283                    }
2284                    if tl > 0 && tr > 0 && bl > 0 && br > 0 {
2285                        // All strictly positive → output 255.
2286                        out_row[xi] = 255;
2287                        continue;
2288                    }
2289
2290                    // Boundary pixel: fixed-point bilinear in i64.
2291                    let x_frac_inv = FRAC_SCALE - x_frac;
2292                    let l0 = tl as i64 * x_frac_inv as i64 + tr as i64 * x_frac as i64;
2293                    let l1 = bl as i64 * x_frac_inv as i64 + br as i64 * x_frac as i64;
2294                    let logit = l0 * y_frac_inv as i64 + l1 * y_frac as i64;
2295                    out_row[xi] = if logit > 0 { 255 } else { 0 };
2296                }
2297            }
2298
2299            let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2300                .expect("tile_buf length matches bbox_h * bbox_w");
2301            Ok(edgefirst_decoder::Segmentation {
2302                xmin,
2303                ymin,
2304                xmax,
2305                ymax,
2306                segmentation: tile,
2307            })
2308        })
2309        .collect()
2310}
2311
2312#[allow(clippy::too_many_arguments)]
2313fn scaled_segmentations_i16_i8(
2314    detect: &[crate::DetectBox],
2315    coeff_all: &[i16],
2316    coeff_quant: &edgefirst_tensor::Quantization,
2317    protos: &[i8],
2318    proto_quant: &edgefirst_tensor::Quantization,
2319    proto_h: usize,
2320    proto_w: usize,
2321    num_protos: usize,
2322    letterbox: Option<[f32; 4]>,
2323    width: u32,
2324    height: u32,
2325    layout: edgefirst_decoder::ProtoLayout,
2326) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
2327    use edgefirst_tensor::QuantMode;
2328
2329    let _span = tracing::trace_span!(
2330        "mask_i16_i8_fastpath",
2331        n = detect.len(),
2332        proto_h,
2333        proto_w,
2334        num_protos,
2335        width,
2336        height,
2337        ?layout,
2338    )
2339    .entered();
2340
2341    let zp_c: i32 = match coeff_quant.mode() {
2342        QuantMode::PerTensor { zero_point, .. } => zero_point,
2343        QuantMode::PerTensorSymmetric { .. } => 0,
2344        _ => {
2345            return Err(crate::Error::NotSupported(
2346                "per-channel coeff quantization not supported".into(),
2347            ))
2348        }
2349    };
2350    let zp_p: i32 = match proto_quant.mode() {
2351        QuantMode::PerTensor { zero_point, .. } => zero_point,
2352        QuantMode::PerTensorSymmetric { .. } => 0,
2353        _ => {
2354            return Err(crate::Error::NotSupported(
2355                "per-channel proto quantization not supported".into(),
2356            ))
2357        }
2358    };
2359
2360    let (lx0, lw, ly0, lh) = match letterbox {
2361        Some([lx0, ly0, lx1, ly1]) => {
2362            let lw = (lx1 - lx0).max(f32::EPSILON);
2363            let lh = (ly1 - ly0).max(f32::EPSILON);
2364            (lx0, lw, ly0, lh)
2365        }
2366        None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
2367    };
2368    let out_w = width as usize;
2369    let out_h = height as usize;
2370    let hw = proto_h * proto_w;
2371
2372    // Precompute proto_sum for the entire proto tensor (zero-point correction).
2373    let proto_sums: Vec<i32> = if zp_c != 0 {
2374        match layout {
2375            edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
2376                .map(|px_idx| {
2377                    let base = px_idx * num_protos;
2378                    let mut s: i32 = 0;
2379                    for k in 0..num_protos {
2380                        s += protos[base + k] as i32;
2381                    }
2382                    s
2383                })
2384                .collect(),
2385            edgefirst_decoder::ProtoLayout::Nchw => {
2386                let mut sums = vec![0i32; hw];
2387                for c in 0..num_protos {
2388                    let plane = &protos[c * hw..];
2389                    for (px, s) in sums.iter_mut().enumerate() {
2390                        *s += plane[px] as i32;
2391                    }
2392                }
2393                sums
2394            }
2395        }
2396    } else {
2397        Vec::new()
2398    };
2399
2400    // For NHWC layout, stride for row navigation.
2401    let stride_y = proto_w * num_protos;
2402
2403    detect
2404        .par_iter()
2405        .enumerate()
2406        .map(|(i, det)| {
2407            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
2408            let bbox = det.bbox.to_canonical();
2409            let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
2410            let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
2411            let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
2412            let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
2413            let px0 = (xmin * out_w as f32).round() as usize;
2414            let py0 = (ymin * out_h as f32).round() as usize;
2415            let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
2416            let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
2417            let bbox_w = px1.saturating_sub(px0).max(1);
2418            let bbox_h = py1.saturating_sub(py0).max(1);
2419
2420            // Map output bbox → proto ROI.
2421            let sample_x_at = |px: f32| -> f32 {
2422                let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
2423                model_x_norm * proto_w as f32 - 0.5
2424            };
2425            let sample_y_at = |py: f32| -> f32 {
2426                let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
2427                model_y_norm * proto_h as f32 - 0.5
2428            };
2429            let s_x_min = sample_x_at(px0 as f32);
2430            let s_x_max = sample_x_at((px1 as f32) - 1.0);
2431            let s_y_min = sample_y_at(py0 as f32);
2432            let s_y_max = sample_y_at((py1 as f32) - 1.0);
2433            let proto_x0 = (s_x_min.floor() as isize)
2434                .max(0)
2435                .min(proto_w.saturating_sub(1) as isize) as usize;
2436            let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2437            let proto_y0 = (s_y_min.floor() as isize)
2438                .max(0)
2439                .min(proto_h.saturating_sub(1) as isize) as usize;
2440            let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2441            let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2442            let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2443
2444            // Per-detection bias.
2445            let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
2446            let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
2447
2448            // Step 2: Compute i32 logits at each proto-ROI pixel.
2449            let mut logits = vec![0_i32; roi_h * roi_w];
2450            match layout {
2451                edgefirst_decoder::ProtoLayout::Nhwc => {
2452                    #[cfg(target_arch = "aarch64")]
2453                    {
2454                        for ly_idx in 0..roi_h {
2455                            let py = proto_y0 + ly_idx;
2456                            let row_base = py * stride_y + proto_x0 * num_protos;
2457                            for lx_idx in 0..roi_w {
2458                                let pix_base = row_base + lx_idx * num_protos;
2459                                let proto_px = &protos[pix_base..pix_base + num_protos];
2460                                let raw_dot = unsafe {
2461                                    dot_i16_i8_neon(coeff.as_ptr(), proto_px.as_ptr(), num_protos)
2462                                };
2463                                let correction = if zp_c != 0 {
2464                                    zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2465                                } else {
2466                                    0
2467                                };
2468                                logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2469                            }
2470                        }
2471                    }
2472                    #[cfg(not(target_arch = "aarch64"))]
2473                    {
2474                        for ly_idx in 0..roi_h {
2475                            let py = proto_y0 + ly_idx;
2476                            let row_base = py * stride_y + proto_x0 * num_protos;
2477                            for lx_idx in 0..roi_w {
2478                                let pix_base = row_base + lx_idx * num_protos;
2479                                let proto_px = &protos[pix_base..pix_base + num_protos];
2480                                let raw_dot = dot_i16_i8_scalar(coeff, proto_px, num_protos);
2481                                let correction = if zp_c != 0 {
2482                                    zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2483                                } else {
2484                                    0
2485                                };
2486                                logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2487                            }
2488                        }
2489                    }
2490                }
2491                edgefirst_decoder::ProtoLayout::Nchw => {
2492                    // Channel-major accumulation: contiguous reads per channel plane.
2493                    for c in 0..num_protos {
2494                        let plane = &protos[c * hw..];
2495                        let coeff_c = coeff[c] as i32;
2496                        for ly_idx in 0..roi_h {
2497                            let py = proto_y0 + ly_idx;
2498                            let row_start = py * proto_w + proto_x0;
2499                            let out_row_start = ly_idx * roi_w;
2500                            for lx_idx in 0..roi_w {
2501                                logits[out_row_start + lx_idx] +=
2502                                    coeff_c * plane[row_start + lx_idx] as i32;
2503                            }
2504                        }
2505                    }
2506                    // Apply zero-point correction and per-detection bias.
2507                    for ly_idx in 0..roi_h {
2508                        let py = proto_y0 + ly_idx;
2509                        for lx_idx in 0..roi_w {
2510                            let idx = ly_idx * roi_w + lx_idx;
2511                            let correction = if zp_c != 0 {
2512                                zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2513                            } else {
2514                                0
2515                            };
2516                            logits[idx] -= correction + bias;
2517                        }
2518                    }
2519                }
2520            }
2521
2522            // Step 3: Bilinear upsample i32 logits → binary mask with
2523            // sign-shortcut (skip interpolation when all 4 neighbors agree).
2524            let roi_last_x = roi_w.saturating_sub(1);
2525            let roi_last_y = roi_h.saturating_sub(1);
2526
2527            // X-coordinate LUT with fixed-point fraction (scale 1024).
2528            const FRAC_BITS: i32 = 10;
2529            const FRAC_SCALE: i32 = 1 << FRAC_BITS; // 1024
2530            let x_coords: Vec<(usize, usize, i32)> = (0..bbox_w)
2531                .map(|xi| {
2532                    let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2533                    let x_floor = sample_x.floor();
2534                    let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as usize;
2535                    let x_hi = (x_lo + 1).min(roi_w - 1);
2536                    let x_frac = ((sample_x - x_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2537                    (x_lo, x_hi, x_frac)
2538                })
2539                .collect();
2540
2541            let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2542            for yi in 0..bbox_h {
2543                let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2544                let y_floor = sample_y.floor();
2545                let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2546                let y_hi = (y_lo + 1).min(roi_h - 1);
2547                let y_frac = ((sample_y - y_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2548                let y_frac_inv = FRAC_SCALE - y_frac;
2549                let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2550                let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2551                let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2552
2553                for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2554                    let tl = row_lo[x_lo];
2555                    let tr = row_lo[x_hi];
2556                    let bl = row_hi[x_lo];
2557                    let br = row_hi[x_hi];
2558
2559                    // Sign-shortcut: if all 4 corners have the same sign,
2560                    // the bilinear interpolation (positive-weight combination)
2561                    // preserves that sign. Skip arithmetic for ~80% of pixels.
2562                    if (tl & tr & bl & br) < 0 {
2563                        // All negative → output 0 (already zero).
2564                        continue;
2565                    }
2566                    if tl > 0 && tr > 0 && bl > 0 && br > 0 {
2567                        // All strictly positive → output 255.
2568                        out_row[xi] = 255;
2569                        continue;
2570                    }
2571
2572                    // Boundary pixel: fixed-point bilinear in i64.
2573                    let x_frac_inv = FRAC_SCALE - x_frac;
2574                    let l0 = tl as i64 * x_frac_inv as i64 + tr as i64 * x_frac as i64;
2575                    let l1 = bl as i64 * x_frac_inv as i64 + br as i64 * x_frac as i64;
2576                    let logit = l0 * y_frac_inv as i64 + l1 * y_frac as i64;
2577                    out_row[xi] = if logit > 0 { 255 } else { 0 };
2578                }
2579            }
2580
2581            let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2582                .expect("tile_buf length matches bbox_h * bbox_w");
2583            Ok(edgefirst_decoder::Segmentation {
2584                xmin,
2585                ymin,
2586                xmax,
2587                ymax,
2588                segmentation: tile,
2589            })
2590        })
2591        .collect()
2592}
2593
2594#[allow(clippy::too_many_arguments)]
2595fn scaled_run<P: Copy + Sync>(
2596    detect: &[crate::DetectBox],
2597    coeff_all: &[f32],
2598    protos: &[P],
2599    proto_h: usize,
2600    proto_w: usize,
2601    num_protos: usize,
2602    letterbox: Option<[f32; 4]>,
2603    width: u32,
2604    height: u32,
2605    acc_scale: f32,
2606    load_f32: impl Fn(&P, f32) -> f32 + Copy + Sync,
2607) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
2608    let (lx0, lw, ly0, lh) = match letterbox {
2609        Some([lx0, ly0, lx1, ly1]) => {
2610            let lw = (lx1 - lx0).max(f32::EPSILON);
2611            let lh = (ly1 - ly0).max(f32::EPSILON);
2612            (lx0, lw, ly0, lh)
2613        }
2614        None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
2615    };
2616    let out_w = width as usize;
2617    let out_h = height as usize;
2618    let stride_y = proto_w * num_protos;
2619
2620    // Parallelise across detections. Each detection produces an
2621    // independent ndarray::Array3<u8> tile from a read-only proto slice +
2622    // its own coeff slice; no shared mutable state.
2623    //
2624    // Algorithm (restores the spirit of PR #54's batched-GEMM optimisation
2625    // that PR #51's f16 dispatch refactor inadvertently removed):
2626    //
2627    //   1. Map the output bbox back to a proto-plane ROI (with 1-px margin
2628    //      so the bilinear sampling at the output edges has neighbours).
2629    //   2. Precompute *f32 logits* at every proto pixel inside that ROI by
2630    //      doing a single K-wide dot product per proto pixel — once, not
2631    //      once per output pixel.
2632    //   3. For each output pixel, bilinear-interpolate the scalar f32 logit
2633    //      from the 4 surrounding proto-roi pixels, apply sigmoid, and
2634    //      threshold to {0, 255}.
2635    //
2636    // For typical YOLO-seg: proto_roi ~ 30×30 = 900 px × K=32 = 28.8K dot
2637    // ops vs the legacy "bilinear sample then dot at every output pixel"
2638    // which costs bbox_h × bbox_w × 4 × K = ~1.3M ops at 100×100 output
2639    // bbox. ~45× fewer FMAs at this size; the bilinear upsample of a
2640    // scalar plane (no inner K loop) is comparatively negligible.
2641    detect
2642        .par_iter()
2643        .enumerate()
2644        .map(|(i, det)| {
2645            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
2646            let bbox = det.bbox.to_canonical();
2647            let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
2648            let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
2649            let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
2650            let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
2651            let px0 = (xmin * out_w as f32).round() as usize;
2652            let py0 = (ymin * out_h as f32).round() as usize;
2653            let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
2654            let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
2655            let bbox_w = px1.saturating_sub(px0).max(1);
2656            let bbox_h = py1.saturating_sub(py0).max(1);
2657
2658            // Step 1 — proto-plane ROI for this detection's output bbox.
2659            // Map the four output bbox corners back to proto coords and
2660            // expand by 1 pixel in each direction so the bilinear sampler
2661            // at the bbox boundary has both neighbours.
2662            let sample_x_at = |px: f32| -> f32 {
2663                let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
2664                model_x_norm * proto_w as f32 - 0.5
2665            };
2666            let sample_y_at = |py: f32| -> f32 {
2667                let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
2668                model_y_norm * proto_h as f32 - 0.5
2669            };
2670            let s_x_min = sample_x_at(px0 as f32);
2671            let s_x_max = sample_x_at((px1 as f32) - 1.0);
2672            let s_y_min = sample_y_at(py0 as f32);
2673            let s_y_max = sample_y_at((py1 as f32) - 1.0);
2674            // Floor min, ceil max+1 to include both bilinear neighbours.
2675            // Start indices are used as direct bases into `protos`, so clamp
2676            // them to the last valid index, not to the exclusive upper bound.
2677            let proto_x0 = (s_x_min.floor() as isize)
2678                .max(0)
2679                .min(proto_w.saturating_sub(1) as isize) as usize;
2680            let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2681            let proto_y0 = (s_y_min.floor() as isize)
2682                .max(0)
2683                .min(proto_h.saturating_sub(1) as isize) as usize;
2684            let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2685            let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2686            let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2687
2688            // Step 2 — precompute f32 logits at every proto-roi pixel.
2689            // logits[(py - proto_y0) * roi_w + (px - proto_x0)] = dot(coeff, proto[py, px, :])
2690            //
2691            // Since the final threshold is `logit > 0` (O1) and bilinear
2692            // interpolation is a positive-weight linear combination,
2693            // `acc_scale * interp(logits) > 0 ⟺ interp(logits) > 0` when
2694            // acc_scale > 0. We therefore skip the per-pixel `acc_scale *`
2695            // multiply entirely, storing raw dot products.
2696            if !acc_scale.is_finite() || acc_scale <= 0.0 {
2697                return Err(crate::Error::NotSupported(format!(
2698                    "acc_scale must be finite and positive for sign-threshold optimization (got {acc_scale})"
2699                )));
2700            }
2701            let _ = acc_scale; // Scale-invariant: only sign matters.
2702            let mut logits = vec![0.0_f32; roi_h * roi_w];
2703            for ly_idx in 0..roi_h {
2704                let py = proto_y0 + ly_idx;
2705                let row_base = py * stride_y + proto_x0 * num_protos;
2706                for lx_idx in 0..roi_w {
2707                    let pix_base = row_base + lx_idx * num_protos;
2708                    let mut acc = 0.0_f32;
2709                    // 4-wide unroll to help auto-vectorization.
2710                    let mut k = 0;
2711                    let chunks = num_protos / 4;
2712                    for _ in 0..chunks {
2713                        acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0)
2714                            + coeff[k + 1] * load_f32(&protos[pix_base + k + 1], 0.0)
2715                            + coeff[k + 2] * load_f32(&protos[pix_base + k + 2], 0.0)
2716                            + coeff[k + 3] * load_f32(&protos[pix_base + k + 3], 0.0);
2717                        k += 4;
2718                    }
2719                    while k < num_protos {
2720                        acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0);
2721                        k += 1;
2722                    }
2723                    logits[ly_idx * roi_w + lx_idx] = acc;
2724                }
2725            }
2726
2727            // Step 3 — bilinear upsample logits → binary mask.
2728            //
2729            // O1: sigmoid(x) > 0.5 ⟺ x > 0 (sigmoid is strictly monotonic,
2730            // and acc_scale > 0 preserves sign). The sign threshold replaces
2731            // the old fast_sigmoid approximation, saving ~15 cycles/pixel.
2732            //
2733            // O5: Pre-compute bilinear sample coordinates. sample_x_at /
2734            // sample_y_at depend only on pixel index, not on logit values.
2735            // Building lookup tables avoids redundant float ops in the inner
2736            // loop (floor, clamp, isize cast per pixel).
2737            let roi_last_x = roi_w.saturating_sub(1);
2738            let roi_last_y = roi_h.saturating_sub(1);
2739
2740            // X-coordinate LUT (shared across all rows).
2741            let x_coords: Vec<(u32, u32, f32)> = (0..bbox_w)
2742                .map(|xi| {
2743                    let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2744                    let x_floor = sample_x.floor();
2745                    let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as u32;
2746                    let x_hi = (x_lo as usize + 1).min(roi_w - 1) as u32;
2747                    let x_frac = (sample_x - x_floor).clamp(0.0, 1.0);
2748                    (x_lo, x_hi, x_frac)
2749                })
2750                .collect();
2751
2752            // Write the output tile through a contiguous slice to avoid
2753            // ndarray's per-element bounds checks + stride arithmetic.
2754            let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2755            for yi in 0..bbox_h {
2756                let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2757                let y_floor = sample_y.floor();
2758                let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2759                let y_hi = (y_lo + 1).min(roi_h - 1);
2760                let y_frac = (sample_y - y_floor).clamp(0.0, 1.0);
2761                let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2762                let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2763                let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2764                for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2765                    let (xl, xh) = (x_lo as usize, x_hi as usize);
2766                    let l0 = row_lo[xl] + (row_lo[xh] - row_lo[xl]) * x_frac;
2767                    let l1 = row_hi[xl] + (row_hi[xh] - row_hi[xl]) * x_frac;
2768                    let logit = l0 + (l1 - l0) * y_frac;
2769                    out_row[xi] = if logit > 0.0 { 255 } else { 0 };
2770                }
2771            }
2772            // Wrap into the expected Array3<u8> shape [bbox_h, bbox_w, 1].
2773            let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2774                .expect("tile_buf length matches bbox_h * bbox_w");
2775            Ok(edgefirst_decoder::Segmentation {
2776                xmin,
2777                ymin,
2778                xmax,
2779                ymax,
2780                segmentation: tile,
2781            })
2782        })
2783        .collect()
2784}
2785
2786#[cfg(test)]
2787mod tests {
2788    use super::CPUProcessor;
2789    use edgefirst_decoder::{BoundingBox, DetectBox, ProtoData, ProtoLayout};
2790    use edgefirst_tensor::{Quantization, Tensor, TensorDyn};
2791
2792    const PROTO_H: usize = 4;
2793    const PROTO_W: usize = 4;
2794    const NUM_PROTOS: usize = 8;
2795
2796    fn det(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> DetectBox {
2797        DetectBox {
2798            bbox: BoundingBox {
2799                xmin,
2800                ymin,
2801                xmax,
2802                ymax,
2803            },
2804            score: 0.9,
2805            label: 0,
2806        }
2807    }
2808
2809    fn make_i8_quant(shape: &[usize], data: &[i8], scale: f32, zp: i32) -> TensorDyn {
2810        let t = Tensor::<i8>::from_slice(data, shape).unwrap();
2811        let t = t
2812            .with_quantization(Quantization::per_tensor(scale, zp))
2813            .unwrap();
2814        TensorDyn::I8(t)
2815    }
2816
2817    fn make_i16_quant(shape: &[usize], data: &[i16], scale: f32, zp: i32) -> TensorDyn {
2818        let t = Tensor::<i16>::from_slice(data, shape).unwrap();
2819        let t = t
2820            .with_quantization(Quantization::per_tensor(scale, zp))
2821            .unwrap();
2822        TensorDyn::I16(t)
2823    }
2824
2825    fn make_i16_raw(shape: &[usize], data: &[i16]) -> TensorDyn {
2826        let t = Tensor::<i16>::from_slice(data, shape).unwrap();
2827        TensorDyn::I16(t)
2828    }
2829
2830    fn make_f32(shape: &[usize], data: &[f32]) -> TensorDyn {
2831        let t = Tensor::<f32>::from_slice(data, shape).unwrap();
2832        TensorDyn::F32(t)
2833    }
2834
2835    fn gen_protos_i8(h: usize, w: usize, k: usize) -> Vec<i8> {
2836        (0..h * w * k).map(|i| (i % 127) as i8).collect()
2837    }
2838
2839    fn gen_coeffs_i16(n: usize, k: usize) -> Vec<i16> {
2840        (0..n * k)
2841            .map(|i| ((i as i32 % 201) - 100) as i16)
2842            .collect()
2843    }
2844
2845    fn gen_coeffs_i8(n: usize, k: usize) -> Vec<i8> {
2846        (0..n * k).map(|i| ((i as i32 % 201) - 100) as i8).collect()
2847    }
2848
2849    // ── Proto-resolution: i16×i8 fast path (quantized) ─────────────
2850
2851    #[test]
2852    fn materialize_proto_i16_i8_quant_produces_masks() {
2853        let cpu = CPUProcessor::new();
2854        let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
2855        let protos = make_i8_quant(
2856            &[PROTO_H, PROTO_W, NUM_PROTOS],
2857            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2858            0.02,
2859            0,
2860        );
2861        let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 0);
2862        let proto_data = ProtoData {
2863            mask_coefficients: coeffs,
2864            protos,
2865            layout: ProtoLayout::Nhwc,
2866        };
2867        let result = cpu.materialize_segmentations(&detect, &proto_data, None);
2868        assert!(result.is_ok(), "materialize failed: {:?}", result.err());
2869        let segs = result.unwrap();
2870        assert_eq!(segs.len(), 1);
2871        let seg = &segs[0];
2872        assert!(seg.segmentation.shape()[0] > 0);
2873        assert!(seg.segmentation.shape()[1] > 0);
2874    }
2875
2876    // ── Proto-resolution: i16 missing quant → f32 fallback ─────────
2877
2878    #[test]
2879    fn materialize_proto_i16_no_quant_falls_back_to_f32() {
2880        let cpu = CPUProcessor::new();
2881        let detect = vec![det(0.2, 0.2, 0.8, 0.8)];
2882        let protos = make_i8_quant(
2883            &[PROTO_H, PROTO_W, NUM_PROTOS],
2884            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2885            0.02,
2886            0,
2887        );
2888        // I16 coefficients WITHOUT quantization — fast path should be
2889        // skipped, f32 fallback should widen raw i16 values.
2890        let coeffs = make_i16_raw(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS));
2891        let proto_data = ProtoData {
2892            mask_coefficients: coeffs,
2893            protos,
2894            layout: ProtoLayout::Nhwc,
2895        };
2896        let result = cpu.materialize_segmentations(&detect, &proto_data, None);
2897        assert!(
2898            result.is_ok(),
2899            "missing coeff quant should fall back to f32 path, got: {:?}",
2900            result.err()
2901        );
2902        assert_eq!(result.unwrap().len(), 1);
2903    }
2904
2905    // ── Scaled: i16×i8 fast path (quantized) ───────────────────────
2906
2907    #[test]
2908    fn materialize_scaled_i16_i8_quant_produces_masks() {
2909        let cpu = CPUProcessor::new();
2910        let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
2911        let protos = make_i8_quant(
2912            &[PROTO_H, PROTO_W, NUM_PROTOS],
2913            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2914            0.02,
2915            0,
2916        );
2917        let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 0);
2918        let proto_data = ProtoData {
2919            mask_coefficients: coeffs,
2920            protos,
2921            layout: ProtoLayout::Nhwc,
2922        };
2923        let result = cpu.materialize_scaled_segmentations(&detect, &proto_data, None, 64, 64);
2924        assert!(
2925            result.is_ok(),
2926            "materialize_scaled failed: {:?}",
2927            result.err()
2928        );
2929        let segs = result.unwrap();
2930        assert_eq!(segs.len(), 1);
2931        let seg = &segs[0];
2932        assert!(seg.segmentation.shape()[0] > 0);
2933        assert!(seg.segmentation.shape()[1] > 0);
2934    }
2935
2936    // ── Scaled: i16 missing quant → f32 fallback ───────────────────
2937
2938    #[test]
2939    fn materialize_scaled_i16_no_quant_falls_back_to_f32() {
2940        let cpu = CPUProcessor::new();
2941        let detect = vec![det(0.2, 0.2, 0.8, 0.8)];
2942        let protos = make_i8_quant(
2943            &[PROTO_H, PROTO_W, NUM_PROTOS],
2944            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2945            0.02,
2946            0,
2947        );
2948        let coeffs = make_i16_raw(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS));
2949        let proto_data = ProtoData {
2950            mask_coefficients: coeffs,
2951            protos,
2952            layout: ProtoLayout::Nhwc,
2953        };
2954        let result = cpu.materialize_scaled_segmentations(&detect, &proto_data, None, 64, 64);
2955        assert!(
2956            result.is_ok(),
2957            "missing coeff quant should fall back to f32 path, got: {:?}",
2958            result.err()
2959        );
2960        assert_eq!(result.unwrap().len(), 1);
2961    }
2962
2963    // ── i16×i8 parity with f32 reference ───────────────────────────
2964
2965    #[test]
2966    fn materialize_proto_i16_i8_matches_f32_reference() {
2967        let cpu = CPUProcessor::new();
2968        let detect = vec![det(0.1, 0.1, 0.9, 0.9), det(0.3, 0.3, 0.7, 0.7)];
2969        let n_det = detect.len();
2970        let scale_c = 0.01_f32;
2971        let scale_p = 0.02_f32;
2972        let raw_protos = gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS);
2973        let raw_coeffs = gen_coeffs_i16(n_det, NUM_PROTOS);
2974
2975        // Build f32 reference (dequantized manually).
2976        let protos_f32: Vec<f32> = raw_protos.iter().map(|&v| v as f32 * scale_p).collect();
2977        let coeffs_f32: Vec<f32> = raw_coeffs.iter().map(|&v| v as f32 * scale_c).collect();
2978        let proto_data_f32 = ProtoData {
2979            mask_coefficients: make_f32(&[n_det, NUM_PROTOS], &coeffs_f32),
2980            protos: make_f32(&[PROTO_H, PROTO_W, NUM_PROTOS], &protos_f32),
2981            layout: ProtoLayout::Nhwc,
2982        };
2983
2984        let proto_data_int = ProtoData {
2985            mask_coefficients: make_i16_quant(&[n_det, NUM_PROTOS], &raw_coeffs, scale_c, 0),
2986            protos: make_i8_quant(&[PROTO_H, PROTO_W, NUM_PROTOS], &raw_protos, scale_p, 0),
2987            layout: ProtoLayout::Nhwc,
2988        };
2989
2990        let segs_f32 = cpu
2991            .materialize_segmentations(&detect, &proto_data_f32, None)
2992            .unwrap();
2993        let segs_int = cpu
2994            .materialize_segmentations(&detect, &proto_data_int, None)
2995            .unwrap();
2996
2997        assert_eq!(segs_f32.len(), segs_int.len());
2998        for (sf, si) in segs_f32.iter().zip(segs_int.iter()) {
2999            assert_eq!(sf.segmentation.shape(), si.segmentation.shape());
3000            let total = sf.segmentation.len();
3001            let mismatches = sf
3002                .segmentation
3003                .iter()
3004                .zip(si.segmentation.iter())
3005                .filter(|(a, b)| a != b)
3006                .count();
3007            let pct = mismatches as f64 / total as f64 * 100.0;
3008            assert!(
3009                pct < 5.0,
3010                "mask mismatch {mismatches}/{total} ({pct:.1}%) exceeds 5% threshold"
3011            );
3012        }
3013    }
3014
3015    // ── Multiple detections ────────────────────────────────────────
3016
3017    #[test]
3018    fn materialize_proto_i16_multiple_detections() {
3019        let cpu = CPUProcessor::new();
3020        let detect = vec![
3021            det(0.0, 0.0, 0.5, 0.5),
3022            det(0.5, 0.5, 1.0, 1.0),
3023            det(0.1, 0.1, 0.3, 0.3),
3024        ];
3025        let protos = make_i8_quant(
3026            &[PROTO_H, PROTO_W, NUM_PROTOS],
3027            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3028            0.02,
3029            0,
3030        );
3031        let coeffs = make_i16_quant(&[3, NUM_PROTOS], &gen_coeffs_i16(3, NUM_PROTOS), 0.01, 0);
3032        let proto_data = ProtoData {
3033            mask_coefficients: coeffs,
3034            protos,
3035            layout: ProtoLayout::Nhwc,
3036        };
3037        let segs = cpu
3038            .materialize_segmentations(&detect, &proto_data, None)
3039            .unwrap();
3040        assert_eq!(segs.len(), 3);
3041    }
3042
3043    // ── Empty detections ───────────────────────────────────────────
3044
3045    #[test]
3046    fn materialize_proto_i16_empty_detections() {
3047        let cpu = CPUProcessor::new();
3048        let detect: Vec<DetectBox> = vec![];
3049        let protos = make_i8_quant(
3050            &[PROTO_H, PROTO_W, NUM_PROTOS],
3051            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3052            0.02,
3053            0,
3054        );
3055        let coeffs = make_i16_quant(&[0, NUM_PROTOS], &[], 0.01, 0);
3056        let proto_data = ProtoData {
3057            mask_coefficients: coeffs,
3058            protos,
3059            layout: ProtoLayout::Nhwc,
3060        };
3061        let segs = cpu
3062            .materialize_segmentations(&detect, &proto_data, None)
3063            .unwrap();
3064        assert!(segs.is_empty());
3065    }
3066
3067    // ── Scaled parity ──────────────────────────────────────────────
3068
3069    #[test]
3070    fn materialize_scaled_i16_i8_matches_f32_reference() {
3071        let cpu = CPUProcessor::new();
3072        let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3073        let scale_c = 0.01_f32;
3074        let scale_p = 0.02_f32;
3075        let raw_protos = gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS);
3076        let raw_coeffs = gen_coeffs_i16(1, NUM_PROTOS);
3077
3078        let protos_f32: Vec<f32> = raw_protos.iter().map(|&v| v as f32 * scale_p).collect();
3079        let coeffs_f32: Vec<f32> = raw_coeffs.iter().map(|&v| v as f32 * scale_c).collect();
3080        let proto_data_f32 = ProtoData {
3081            mask_coefficients: make_f32(&[1, NUM_PROTOS], &coeffs_f32),
3082            protos: make_f32(&[PROTO_H, PROTO_W, NUM_PROTOS], &protos_f32),
3083            layout: ProtoLayout::Nhwc,
3084        };
3085        let proto_data_int = ProtoData {
3086            mask_coefficients: make_i16_quant(&[1, NUM_PROTOS], &raw_coeffs, scale_c, 0),
3087            protos: make_i8_quant(&[PROTO_H, PROTO_W, NUM_PROTOS], &raw_protos, scale_p, 0),
3088            layout: ProtoLayout::Nhwc,
3089        };
3090
3091        let (w, h) = (64_u32, 64_u32);
3092        let segs_f32 = cpu
3093            .materialize_scaled_segmentations(&detect, &proto_data_f32, None, w, h)
3094            .unwrap();
3095        let segs_int = cpu
3096            .materialize_scaled_segmentations(&detect, &proto_data_int, None, w, h)
3097            .unwrap();
3098
3099        assert_eq!(segs_f32.len(), segs_int.len());
3100        for (sf, si) in segs_f32.iter().zip(segs_int.iter()) {
3101            assert_eq!(sf.segmentation.shape(), si.segmentation.shape());
3102            let total = sf.segmentation.len();
3103            let mismatches = sf
3104                .segmentation
3105                .iter()
3106                .zip(si.segmentation.iter())
3107                .filter(|(a, b)| a != b)
3108                .count();
3109            let pct = mismatches as f64 / total as f64 * 100.0;
3110            assert!(
3111                pct < 5.0,
3112                "scaled mask mismatch {mismatches}/{total} ({pct:.1}%) exceeds 5% threshold"
3113            );
3114        }
3115    }
3116
3117    // ── i8×i8 existing path still works (regression) ───────────────
3118
3119    #[test]
3120    fn materialize_proto_i8_i8_regression() {
3121        let cpu = CPUProcessor::new();
3122        let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3123        let protos = make_i8_quant(
3124            &[PROTO_H, PROTO_W, NUM_PROTOS],
3125            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3126            0.02,
3127            0,
3128        );
3129        let coeffs = make_i8_quant(&[1, NUM_PROTOS], &gen_coeffs_i8(1, NUM_PROTOS), 0.01, 0);
3130        let proto_data = ProtoData {
3131            mask_coefficients: coeffs,
3132            protos,
3133            layout: ProtoLayout::Nhwc,
3134        };
3135        let result = cpu.materialize_segmentations(&detect, &proto_data, None);
3136        assert!(result.is_ok(), "i8×i8 regression: {:?}", result.err());
3137        assert_eq!(result.unwrap().len(), 1);
3138    }
3139
3140    // ── Non-zero zero_point ────────────────────────────────────────
3141
3142    #[test]
3143    fn materialize_proto_i16_nonzero_zp() {
3144        let cpu = CPUProcessor::new();
3145        let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3146        let protos = make_i8_quant(
3147            &[PROTO_H, PROTO_W, NUM_PROTOS],
3148            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3149            0.02,
3150            -10,
3151        );
3152        let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 5);
3153        let proto_data = ProtoData {
3154            mask_coefficients: coeffs,
3155            protos,
3156            layout: ProtoLayout::Nhwc,
3157        };
3158        let result = cpu.materialize_segmentations(&detect, &proto_data, None);
3159        assert!(result.is_ok(), "nonzero zp failed: {:?}", result.err());
3160        assert_eq!(result.unwrap().len(), 1);
3161    }
3162
3163    // ── Scaled: non-zero zero_point ────────────────────────────────
3164
3165    #[test]
3166    fn materialize_scaled_i16_nonzero_zp() {
3167        let cpu = CPUProcessor::new();
3168        let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3169        let protos = make_i8_quant(
3170            &[PROTO_H, PROTO_W, NUM_PROTOS],
3171            &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3172            0.02,
3173            -10,
3174        );
3175        let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 5);
3176        let proto_data = ProtoData {
3177            mask_coefficients: coeffs,
3178            protos,
3179            layout: ProtoLayout::Nhwc,
3180        };
3181        let result = cpu.materialize_scaled_segmentations(&detect, &proto_data, None, 64, 64);
3182        assert!(
3183            result.is_ok(),
3184            "scaled nonzero zp failed: {:?}",
3185            result.err()
3186        );
3187        assert_eq!(result.unwrap().len(), 1);
3188    }
3189}