Skip to main content

edgefirst_image/cpu/
masks.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use super::CPUProcessor;
5use crate::Result;
6use edgefirst_decoder::{DetectBox, Segmentation};
7use ndarray::Axis;
8use rayon::prelude::*;
9
10impl CPUProcessor {
11    #[allow(clippy::too_many_arguments)]
12    pub(super) fn render_modelpack_segmentation(
13        &mut self,
14        dst_w: usize,
15        dst_h: usize,
16        dst_rs: usize,
17        dst_c: usize,
18        dst_slice: &mut [u8],
19        segmentation: &Segmentation,
20        opacity: f32,
21    ) -> Result<()> {
22        use ndarray_stats::QuantileExt;
23
24        let seg = &segmentation.segmentation;
25        let [seg_height, seg_width, seg_classes] = *seg.shape() else {
26            unreachable!("Array3 did not have [usize; 3] as shape");
27        };
28        let start_y = (dst_h as f32 * segmentation.ymin).round();
29        let end_y = (dst_h as f32 * segmentation.ymax).round();
30        let start_x = (dst_w as f32 * segmentation.xmin).round();
31        let end_x = (dst_w as f32 * segmentation.xmax).round();
32
33        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
34        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
35
36        let start_x_u = (start_x as usize).min(dst_w);
37        let start_y_u = (start_y as usize).min(dst_h);
38        let end_x_u = (end_x as usize).min(dst_w);
39        let end_y_u = (end_y as usize).min(dst_h);
40
41        let argmax = seg.map_axis(Axis(2), |r| r.argmax().unwrap());
42        let get_value_at_nearest = |x: f32, y: f32| -> usize {
43            let x = x.round() as usize;
44            let y = y.round() as usize;
45            argmax
46                .get([y.min(seg_height - 1), x.min(seg_width - 1)])
47                .copied()
48                .unwrap_or(0)
49        };
50
51        for y in start_y_u..end_y_u {
52            for x in start_x_u..end_x_u {
53                let seg_x = (x as f32 - start_x) * scale_x;
54                let seg_y = (y as f32 - start_y) * scale_y;
55                let label = get_value_at_nearest(seg_x, seg_y);
56
57                if label == seg_classes - 1 {
58                    continue;
59                }
60
61                let color = self.colors[label % self.colors.len()];
62
63                let alpha = if opacity == 1.0 {
64                    color[3] as u16
65                } else {
66                    (color[3] as f32 * opacity).round() as u16
67                };
68
69                let dst_index = (y * dst_rs) + (x * dst_c);
70                for c in 0..3 {
71                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
72                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
73                        / 255) as u8;
74                }
75            }
76        }
77
78        Ok(())
79    }
80
81    #[allow(clippy::too_many_arguments)]
82    pub(super) fn render_yolo_segmentation(
83        &mut self,
84        dst_w: usize,
85        dst_h: usize,
86        dst_rs: usize,
87        dst_c: usize,
88        dst_slice: &mut [u8],
89        segmentation: &Segmentation,
90        class: usize,
91        opacity: f32,
92    ) -> Result<()> {
93        let seg = &segmentation.segmentation;
94        let [seg_height, seg_width, classes] = *seg.shape() else {
95            unreachable!("Array3 did not have [usize;3] as shape");
96        };
97        debug_assert_eq!(classes, 1);
98
99        let start_y = (dst_h as f32 * segmentation.ymin).round();
100        let end_y = (dst_h as f32 * segmentation.ymax).round();
101        let start_x = (dst_w as f32 * segmentation.xmin).round();
102        let end_x = (dst_w as f32 * segmentation.xmax).round();
103
104        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
105        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
106
107        let start_x_u = (start_x as usize).min(dst_w);
108        let start_y_u = (start_y as usize).min(dst_h);
109        let end_x_u = (end_x as usize).min(dst_w);
110        let end_y_u = (end_y as usize).min(dst_h);
111
112        for y in start_y_u..end_y_u {
113            for x in start_x_u..end_x_u {
114                let seg_x = ((x as f32 - start_x) * scale_x) as usize;
115                let seg_y = ((y as f32 - start_y) * scale_y) as usize;
116                let val = *seg.get([seg_y, seg_x, 0]).unwrap_or(&0);
117
118                if val < 127 {
119                    continue;
120                }
121
122                let color = self.colors[class % self.colors.len()];
123
124                let alpha = if opacity == 1.0 {
125                    color[3] as u16
126                } else {
127                    (color[3] as f32 * opacity).round() as u16
128                };
129
130                let dst_index = (y * dst_rs) + (x * dst_c);
131                for c in 0..3 {
132                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
133                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
134                        / 255) as u8;
135                }
136            }
137        }
138
139        Ok(())
140    }
141
142    #[allow(clippy::too_many_arguments)]
143    pub(super) fn render_box(
144        &mut self,
145        dst_w: usize,
146        dst_h: usize,
147        dst_rs: usize,
148        dst_c: usize,
149        dst_slice: &mut [u8],
150        detect: &[DetectBox],
151        color_mode: crate::ColorMode,
152    ) -> Result<()> {
153        const LINE_THICKNESS: usize = 3;
154
155        for (idx, d) in detect.iter().enumerate() {
156            use edgefirst_decoder::BoundingBox;
157
158            let color_index = color_mode.index(idx, d.label);
159            let [r, g, b, _] = self.colors[color_index % self.colors.len()];
160            let bbox = d.bbox.to_canonical();
161            let bbox = BoundingBox {
162                xmin: bbox.xmin.clamp(0.0, 1.0),
163                ymin: bbox.ymin.clamp(0.0, 1.0),
164                xmax: bbox.xmax.clamp(0.0, 1.0),
165                ymax: bbox.ymax.clamp(0.0, 1.0),
166            };
167            let inner = [
168                ((dst_w - 1) as f32 * bbox.xmin - 0.5).round() as usize,
169                ((dst_h - 1) as f32 * bbox.ymin - 0.5).round() as usize,
170                ((dst_w - 1) as f32 * bbox.xmax + 0.5).round() as usize,
171                ((dst_h - 1) as f32 * bbox.ymax + 0.5).round() as usize,
172            ];
173
174            let outer = [
175                inner[0].saturating_sub(LINE_THICKNESS),
176                inner[1].saturating_sub(LINE_THICKNESS),
177                (inner[2] + LINE_THICKNESS).min(dst_w),
178                (inner[3] + LINE_THICKNESS).min(dst_h),
179            ];
180
181            // top line
182            for y in outer[1] + 1..=inner[1] {
183                for x in outer[0] + 1..outer[2] {
184                    let index = (y * dst_rs) + (x * dst_c);
185                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
186                }
187            }
188
189            // left and right lines
190            for y in inner[1]..inner[3] {
191                for x in outer[0] + 1..=inner[0] {
192                    let index = (y * dst_rs) + (x * dst_c);
193                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
194                }
195
196                for x in inner[2]..outer[2] {
197                    let index = (y * dst_rs) + (x * dst_c);
198                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
199                }
200            }
201
202            // bottom line
203            for y in inner[3]..outer[3] {
204                for x in outer[0] + 1..outer[2] {
205                    let index = (y * dst_rs) + (x * dst_c);
206                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
207                }
208            }
209        }
210        Ok(())
211    }
212
213    /// Materialize segmentation masks from proto data into `Vec<Segmentation>`.
214    ///
215    /// This is the CPU-side decode step of the hybrid mask rendering path:
216    /// call this to get pre-decoded masks, then pass them to
217    /// [`draw_decoded_masks`](crate::ImageProcessorTrait::draw_decoded_masks) for GPU overlay.
218    /// Benchmarks show this hybrid path (CPU decode + GL overlay) is faster
219    /// than the fused GPU `draw_proto_masks` on all tested platforms.
220    ///
221    /// Optimized: fused dequantization + dot product avoids a 3.1MB f32
222    /// allocation for the full proto tensor. Uses fast sigmoid approximation.
223    pub fn materialize_segmentations(
224        &self,
225        detect: &[crate::DetectBox],
226        proto_data: &crate::ProtoData,
227        letterbox: Option<[f32; 4]>,
228    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
229        use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
230
231        let _span = tracing::trace_span!(
232            "materialize_masks",
233            mode = "proto",
234            n_detections = detect.len(),
235        )
236        .entered();
237
238        if detect.is_empty() {
239            return Ok(Vec::new());
240        }
241        let proto_shape = proto_data.protos.shape();
242        if proto_shape.len() != 3 {
243            return Err(crate::Error::InvalidShape(format!(
244                "protos tensor must be rank-3, got {proto_shape:?}"
245            )));
246        }
247        // Interpret shape based on physical layout.
248        let (proto_h, proto_w, num_protos) = match proto_data.layout {
249            edgefirst_decoder::ProtoLayout::Nhwc => {
250                (proto_shape[0], proto_shape[1], proto_shape[2])
251            }
252            edgefirst_decoder::ProtoLayout::Nchw => {
253                (proto_shape[1], proto_shape[2], proto_shape[0])
254            }
255        };
256        let coeff_shape = proto_data.mask_coefficients.shape();
257        if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
258            return Err(crate::Error::InvalidShape(format!(
259                "mask_coefficients shape {coeff_shape:?} incompatible with protos \
260                 {proto_shape:?} (expected [N, {num_protos}])"
261            )));
262        }
263        if coeff_shape[0] == 0 {
264            return Ok(Vec::new());
265        }
266        if coeff_shape[0] != detect.len() {
267            return Err(crate::Error::Internal(format!(
268                "mask_coefficients rows {} != detection count {}",
269                coeff_shape[0],
270                detect.len()
271            )));
272        }
273
274        // Precompute inverse letterbox scale for output-coord conversion.
275        let (lx0, inv_lw, ly0, inv_lh) = match letterbox {
276            Some([lx0, ly0, lx1, ly1]) => {
277                let lw = lx1 - lx0;
278                let lh = ly1 - ly0;
279                (
280                    lx0,
281                    if lw > 0.0 { 1.0 / lw } else { 1.0 },
282                    ly0,
283                    if lh > 0.0 { 1.0 / lh } else { 1.0 },
284                )
285            }
286            None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
287        };
288
289        // Fast integer path: when both coefficients and protos are I8 with
290        // per-tensor quantization, use the all-integer kernel (same dot
291        // product infrastructure as the scaled path, but at proto resolution
292        // without bilinear upsampling). Output is binary {0, 255}.
293        // Falls through to the general f32 dequant path for per-channel
294        // quantization or other unsupported modes.
295        if proto_data.mask_coefficients.dtype() == DType::I8
296            && proto_data.protos.dtype() == DType::I8
297        {
298            let coeff_t = proto_data
299                .mask_coefficients
300                .as_i8()
301                .expect("I8 coefficients");
302            let coeff_m = coeff_t.map()?;
303            let coeff_quant = coeff_t.quantization().ok_or_else(|| {
304                crate::Error::InvalidShape(
305                    "I8 mask_coefficients require quantization metadata".into(),
306                )
307            })?;
308            let proto_t = proto_data.protos.as_i8().expect("I8 protos");
309            let proto_m = proto_t.map()?;
310            let proto_quant = proto_t.quantization().ok_or_else(|| {
311                crate::Error::InvalidShape("I8 protos require quantization metadata".into())
312            })?;
313            match proto_segmentations_i8_i8(
314                detect,
315                coeff_m.as_slice(),
316                coeff_quant,
317                proto_m.as_slice(),
318                proto_quant,
319                proto_h,
320                proto_w,
321                num_protos,
322                lx0,
323                inv_lw,
324                ly0,
325                inv_lh,
326                proto_data.layout,
327            ) {
328                Ok(result) => return Ok(result),
329                Err(crate::Error::NotSupported(_)) => {
330                    // Fall through to the general f32 dequant path below for
331                    // per-channel quantization and other unsupported modes.
332                }
333                Err(e) => return Err(e),
334            }
335        }
336
337        // Coefficients may be F32 (from f32 models), F16 (from fp16 models),
338        // or I8 (from quantized models — kept raw with quantization). For the
339        // mask kernel we always need an f32 view (the multiply-accumulate is
340        // done in f32 for precision). Map once and widen once outside the loop.
341        // NCHW layout is only supported in the i8×i8 integer fast path above
342        // with per-tensor quantization. Reject here for all other combinations.
343        if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw {
344            return Err(crate::Error::NotSupported(
345                "NCHW proto layout requires I8 protos and coefficients with per-tensor quantization"
346                    .into(),
347            ));
348        }
349        let coeff_f32_storage: Vec<f32>;
350        let coeff_f32_slice: &[f32] = match proto_data.mask_coefficients.dtype() {
351            DType::F32 => {
352                let t = proto_data
353                    .mask_coefficients
354                    .as_f32()
355                    .expect("dtype matched F32");
356                let m = t.map()?;
357                coeff_f32_storage = m.as_slice().to_vec();
358                &coeff_f32_storage[..]
359            }
360            DType::F16 => {
361                let t = proto_data
362                    .mask_coefficients
363                    .as_f16()
364                    .expect("dtype matched F16");
365                let m = t.map()?;
366                coeff_f32_storage = m.as_slice().iter().map(|v| v.to_f32()).collect();
367                &coeff_f32_storage[..]
368            }
369            DType::I8 => {
370                let t = proto_data
371                    .mask_coefficients
372                    .as_i8()
373                    .expect("dtype matched I8");
374                let m = t.map()?;
375                coeff_f32_storage = if let Some(q) = t.quantization() {
376                    use edgefirst_tensor::QuantMode;
377                    let (scale, zp) = match q.mode() {
378                        QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
379                        QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
380                        other => {
381                            return Err(crate::Error::NotSupported(format!(
382                                "I8 mask_coefficients quantization mode {other:?} not supported"
383                            )));
384                        }
385                    };
386                    m.as_slice()
387                        .iter()
388                        .map(|&v| (v as f32 - zp) * scale)
389                        .collect()
390                } else {
391                    m.as_slice().iter().map(|&v| v as f32).collect()
392                };
393                &coeff_f32_storage[..]
394            }
395            other => {
396                return Err(crate::Error::InvalidShape(format!(
397                    "mask_coefficients dtype {other:?} not supported; expected F32, F16, or I8"
398                )));
399            }
400        };
401
402        // Hoist the proto tensor map() out of the per-detection loop so the
403        // map-guard is acquired once. Then dispatch per-dtype via a helper
404        // that runs the per-detection kernels in parallel across detections
405        // via rayon. This restores the parallelism that PR #54 added and
406        // PR #51 (EDGEAI-1244 f16 refactor) inadvertently removed.
407        match proto_data.protos.dtype() {
408            DType::I8 => {
409                let t = proto_data.protos.as_i8().expect("dtype matched I8");
410                let quant = t.quantization().ok_or_else(|| {
411                    crate::Error::InvalidShape("I8 protos require quantization metadata".into())
412                })?;
413                let m = t.map()?;
414                let protos_slice = m.as_slice();
415                detect
416                    .par_iter()
417                    .enumerate()
418                    .map(|(i, det)| {
419                        let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
420                        let (x0, y0, x1, y1, roi_w, roi_h) =
421                            bbox_to_proto_roi(det, proto_w, proto_h);
422                        let mask = fused_dequant_dot_sign_i8_slice(
423                            protos_slice,
424                            coeff,
425                            quant,
426                            proto_h,
427                            proto_w,
428                            y0,
429                            x0,
430                            roi_h,
431                            roi_w,
432                            num_protos,
433                        )?;
434                        Ok(seg_from_roi(
435                            mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
436                        ))
437                    })
438                    .collect()
439            }
440            DType::F32 => {
441                let t = proto_data.protos.as_f32().expect("dtype matched F32");
442                let m = t.map()?;
443                let protos_slice = m.as_slice();
444                detect
445                    .par_iter()
446                    .enumerate()
447                    .map(|(i, det)| {
448                        let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
449                        let (x0, y0, x1, y1, roi_w, roi_h) =
450                            bbox_to_proto_roi(det, proto_w, proto_h);
451                        let mask = fused_dot_sign_f32_slice(
452                            protos_slice,
453                            coeff,
454                            proto_h,
455                            proto_w,
456                            y0,
457                            x0,
458                            roi_h,
459                            roi_w,
460                            num_protos,
461                        );
462                        Ok(seg_from_roi(
463                            mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
464                        ))
465                    })
466                    .collect()
467            }
468            DType::F16 => {
469                let t = proto_data.protos.as_f16().expect("dtype matched F16");
470                let m = t.map()?;
471                let protos_slice = m.as_slice();
472                detect
473                    .par_iter()
474                    .enumerate()
475                    .map(|(i, det)| {
476                        let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
477                        let (x0, y0, x1, y1, roi_w, roi_h) =
478                            bbox_to_proto_roi(det, proto_w, proto_h);
479                        let mask = fused_dot_sign_f16_slice(
480                            protos_slice,
481                            coeff,
482                            proto_h,
483                            proto_w,
484                            y0,
485                            x0,
486                            roi_h,
487                            roi_w,
488                            num_protos,
489                        );
490                        Ok(seg_from_roi(
491                            mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
492                        ))
493                    })
494                    .collect()
495            }
496            other => Err(crate::Error::InvalidShape(format!(
497                "proto tensor dtype {other:?} not supported"
498            ))),
499        }
500    }
501
502    /// Produce per-detection masks at `(width, height)` pixel resolution by
503    /// upsampling the full proto plane once then cropping per bbox. Each
504    /// `det.bbox` is assumed to be in model-input normalized coordinates
505    /// (the convention used by the decoder output); when `letterbox` is
506    /// `Some`, `(width, height)` are original-content pixel dims and the
507    /// inverse letterbox transform is applied to both the bbox (for the
508    /// crop region and returned `Segmentation` metadata) and each output
509    /// pixel (for proto-plane sampling). Mask values are binary
510    /// `uint8 {0, 255}` after thresholding sigmoid > 0.5.
511    ///
512    /// Used by [`ImageProcessor::materialize_masks`] when the caller selects
513    /// [`MaskResolution::Scaled`](crate::MaskResolution::Scaled).
514    pub fn materialize_scaled_segmentations(
515        &self,
516        detect: &[crate::DetectBox],
517        proto_data: &crate::ProtoData,
518        letterbox: Option<[f32; 4]>,
519        width: u32,
520        height: u32,
521    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
522        use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
523
524        let _span = tracing::trace_span!(
525            "materialize_masks",
526            mode = "scaled",
527            n_detections = detect.len(),
528            width,
529            height,
530        )
531        .entered();
532
533        if detect.is_empty() {
534            return Ok(Vec::new());
535        }
536        if width == 0 || height == 0 {
537            return Err(crate::Error::InvalidShape(
538                "Scaled mask width/height must be positive".into(),
539            ));
540        }
541        let proto_shape = proto_data.protos.shape();
542        if proto_shape.len() != 3 {
543            return Err(crate::Error::InvalidShape(format!(
544                "protos tensor must be rank-3, got {proto_shape:?}"
545            )));
546        }
547        // Interpret shape based on physical layout.
548        let (proto_h, proto_w, num_protos) = match proto_data.layout {
549            edgefirst_decoder::ProtoLayout::Nhwc => {
550                (proto_shape[0], proto_shape[1], proto_shape[2])
551            }
552            edgefirst_decoder::ProtoLayout::Nchw => {
553                (proto_shape[1], proto_shape[2], proto_shape[0])
554            }
555        };
556        let coeff_shape = proto_data.mask_coefficients.shape();
557        if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
558            return Err(crate::Error::InvalidShape(format!(
559                "mask_coefficients shape {coeff_shape:?} incompatible with protos \
560                 {proto_shape:?}"
561            )));
562        }
563        if coeff_shape[0] == 0 {
564            return Ok(Vec::new());
565        }
566        if coeff_shape[0] != detect.len() {
567            return Err(crate::Error::Internal(format!(
568                "mask_coefficients rows {} != detection count {}",
569                coeff_shape[0],
570                detect.len()
571            )));
572        }
573
574        // Fast integer path: when both coefficients and protos are I8, use
575        // the all-integer kernel (i8×i8→i32 dot product, sign-shortcut
576        // bilinear). No floating-point conversion at all.
577        // Falls through to the general f32 dequant path for per-channel
578        // quantization or other unsupported modes.
579        if proto_data.mask_coefficients.dtype() == DType::I8
580            && proto_data.protos.dtype() == DType::I8
581        {
582            let coeff_t = proto_data
583                .mask_coefficients
584                .as_i8()
585                .expect("I8 coefficients");
586            let coeff_m = coeff_t.map()?;
587            let coeff_quant = coeff_t.quantization().ok_or_else(|| {
588                crate::Error::InvalidShape(
589                    "I8 mask_coefficients require quantization metadata".into(),
590                )
591            })?;
592            let proto_t = proto_data.protos.as_i8().expect("I8 protos");
593            let proto_m = proto_t.map()?;
594            let proto_quant = proto_t.quantization().ok_or_else(|| {
595                crate::Error::InvalidShape("I8 protos require quantization metadata".into())
596            })?;
597            match scaled_segmentations_i8_i8(
598                detect,
599                coeff_m.as_slice(),
600                coeff_quant,
601                proto_m.as_slice(),
602                proto_quant,
603                proto_h,
604                proto_w,
605                num_protos,
606                letterbox,
607                width,
608                height,
609                proto_data.layout,
610            ) {
611                Ok(result) => return Ok(result),
612                Err(crate::Error::NotSupported(_)) => {
613                    // Fall through to the general f32 dequant path below for
614                    // per-channel quantization and other unsupported modes.
615                }
616                Err(e) => return Err(e),
617            }
618        }
619
620        // Fallback: widen coefficients to f32 for the float-path kernels.
621        // NCHW layout is only supported in the i8×i8 integer fast path above
622        // with per-tensor quantization. Reject here for all other combinations.
623        if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw {
624            return Err(crate::Error::NotSupported(
625                "NCHW proto layout requires I8 protos and coefficients with per-tensor quantization"
626                    .into(),
627            ));
628        }
629        let coeff_f32: Vec<f32> = match proto_data.mask_coefficients.dtype() {
630            DType::F32 => {
631                let t = proto_data.mask_coefficients.as_f32().expect("F32");
632                let m = t.map()?;
633                m.as_slice().to_vec()
634            }
635            DType::F16 => {
636                let t = proto_data.mask_coefficients.as_f16().expect("F16");
637                let m = t.map()?;
638                m.as_slice().iter().map(|v| v.to_f32()).collect()
639            }
640            DType::I8 => {
641                // Dequantize I8 coefficients to f32 for the float proto path.
642                let t = proto_data.mask_coefficients.as_i8().expect("I8");
643                let m = t.map()?;
644                let q = t.quantization().ok_or_else(|| {
645                    crate::Error::InvalidShape(
646                        "I8 mask_coefficients require quantization metadata".into(),
647                    )
648                })?;
649                use edgefirst_tensor::QuantMode;
650                let (scale, zp) = match q.mode() {
651                    QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
652                    QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
653                    _ => {
654                        return Err(crate::Error::NotSupported(
655                            "per-channel mask_coefficients not supported".into(),
656                        ))
657                    }
658                };
659                m.as_slice()
660                    .iter()
661                    .map(|&v| (v as f32 - zp) * scale)
662                    .collect()
663            }
664            other => {
665                return Err(crate::Error::InvalidShape(format!(
666                    "mask_coefficients dtype {other:?} not supported"
667                )));
668            }
669        };
670
671        match proto_data.protos.dtype() {
672            DType::F32 => {
673                let t = proto_data.protos.as_f32().expect("F32");
674                let m = t.map()?;
675                scaled_segmentations_f32_slice(
676                    detect,
677                    &coeff_f32,
678                    m.as_slice(),
679                    proto_h,
680                    proto_w,
681                    num_protos,
682                    letterbox,
683                    width,
684                    height,
685                )
686            }
687            DType::F16 => {
688                let t = proto_data.protos.as_f16().expect("F16");
689                let m = t.map()?;
690                scaled_segmentations_f16_slice(
691                    detect,
692                    &coeff_f32,
693                    m.as_slice(),
694                    proto_h,
695                    proto_w,
696                    num_protos,
697                    letterbox,
698                    width,
699                    height,
700                )
701            }
702            DType::I8 => {
703                let t = proto_data.protos.as_i8().expect("I8");
704                let m = t.map()?;
705                let quant = t.quantization().ok_or_else(|| {
706                    crate::Error::InvalidShape("I8 protos require quantization metadata".into())
707                })?;
708                scaled_segmentations_i8_slice(
709                    detect,
710                    &coeff_f32,
711                    m.as_slice(),
712                    proto_h,
713                    proto_w,
714                    num_protos,
715                    quant,
716                    letterbox,
717                    width,
718                    height,
719                )
720            }
721            other => Err(crate::Error::InvalidShape(format!(
722                "proto tensor dtype {other:?} not supported"
723            ))),
724        }
725    }
726}
727
728// =============================================================================
729// Slice-native fused kernels.
730//
731// All kernels take row-major `[H, W, num_protos]` proto slices + `&[f32]`
732// coefficients (widened once from the source dtype at the materialize entry
733// point). Per-dtype variants exist for i8 (with on-the-fly dequant using a
734// tensor-level `Quantization`), f32, and f16; f16 widens to f32 per-element
735// at the FMA site via `half::f16::to_f32()`.
736//
737// On ARMv8.2-FP16 this compiles to `fcvt`; on Cortex-A53 and non-F16C x86 it
738// becomes a soft-float helper. Stage 8 adds explicit intrinsic kernels
739// gated by `#[cfg(target_feature = "fp16")]` / `+f16c`.
740// =============================================================================
741
742/// Map a detection bbox in normalised letterboxed coords to its ROI in
743/// the proto plane (floor xmin/ymin, ceil xmax/ymax, clamp to plane bounds).
744/// Returns `(x0, y0, x1, y1, roi_w, roi_h)` where roi_w/h are guaranteed ≥ 1.
745fn bbox_to_proto_roi(
746    det: &DetectBox,
747    proto_w: usize,
748    proto_h: usize,
749) -> (usize, usize, usize, usize, usize, usize) {
750    let bbox = det.bbox.to_canonical();
751    let xmin = bbox.xmin.clamp(0.0, 1.0);
752    let ymin = bbox.ymin.clamp(0.0, 1.0);
753    let xmax = bbox.xmax.clamp(0.0, 1.0);
754    let ymax = bbox.ymax.clamp(0.0, 1.0);
755    let x0 = ((xmin * proto_w as f32) as usize).min(proto_w.saturating_sub(1));
756    let y0 = ((ymin * proto_h as f32) as usize).min(proto_h.saturating_sub(1));
757    let x1 = ((xmax * proto_w as f32).ceil() as usize).min(proto_w);
758    let y1 = ((ymax * proto_h as f32).ceil() as usize).min(proto_h);
759    let roi_w = x1.saturating_sub(x0).max(1);
760    let roi_h = y1.saturating_sub(y0).max(1);
761    (x0, y0, x1, y1, roi_w, roi_h)
762}
763
764/// Build a `Segmentation` from a per-detection mask + the ROI bounds in
765/// proto coords. Applies the inverse letterbox transform to express the
766/// segmentation bbox in original-image-content normalised space.
767#[allow(clippy::too_many_arguments)]
768fn seg_from_roi(
769    mask: ndarray::Array3<u8>,
770    x0: usize,
771    y0: usize,
772    x1: usize,
773    y1: usize,
774    proto_w: usize,
775    proto_h: usize,
776    lx0: f32,
777    inv_lw: f32,
778    ly0: f32,
779    inv_lh: f32,
780) -> edgefirst_decoder::Segmentation {
781    let seg_xmin = ((x0 as f32 / proto_w as f32) - lx0) * inv_lw;
782    let seg_ymin = ((y0 as f32 / proto_h as f32) - ly0) * inv_lh;
783    let seg_xmax = ((x1 as f32 / proto_w as f32) - lx0) * inv_lw;
784    let seg_ymax = ((y1 as f32 / proto_h as f32) - ly0) * inv_lh;
785    edgefirst_decoder::Segmentation {
786        xmin: seg_xmin.clamp(0.0, 1.0),
787        ymin: seg_ymin.clamp(0.0, 1.0),
788        xmax: seg_xmax.clamp(0.0, 1.0),
789        ymax: seg_ymax.clamp(0.0, 1.0),
790        segmentation: mask,
791    }
792}
793
794// =============================================================================
795// Integer-domain proto-resolution kernel: i8 coefficients × i8 protos → i32
796// → sign threshold → binary {0, 255}.
797//
798// Reuses the same dot product infrastructure as the scaled path (NEON sdot on
799// A55+, smull+sadalp on A53, scalar fallback on x86). Since proto-resolution
800// produces masks at the native proto grid (~30×30 per ROI), there is no
801// bilinear upsampling — just a direct sign threshold per pixel.
802// =============================================================================
803
804/// Proto-resolution mask materialization using integer-domain math.
805///
806/// For each detection, computes the i8×i8 dot product at every proto-ROI pixel,
807/// applies the zero-point correction, and thresholds at sign(logit) → {0, 255}.
808/// Supports both NHWC and NCHW proto layouts.
809#[allow(clippy::too_many_arguments)]
810fn proto_segmentations_i8_i8(
811    detect: &[crate::DetectBox],
812    coeff_all: &[i8],
813    coeff_quant: &edgefirst_tensor::Quantization,
814    protos: &[i8],
815    proto_quant: &edgefirst_tensor::Quantization,
816    proto_h: usize,
817    proto_w: usize,
818    num_protos: usize,
819    lx0: f32,
820    inv_lw: f32,
821    ly0: f32,
822    inv_lh: f32,
823    layout: edgefirst_decoder::ProtoLayout,
824) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
825    use edgefirst_tensor::QuantMode;
826
827    let _span = tracing::trace_span!(
828        "mask_i8_fastpath",
829        n = detect.len(),
830        proto_h,
831        proto_w,
832        num_protos,
833        ?layout,
834    )
835    .entered();
836
837    let zp_c: i32 = match coeff_quant.mode() {
838        QuantMode::PerTensor { zero_point, .. } => zero_point,
839        QuantMode::PerTensorSymmetric { .. } => 0,
840        _ => {
841            return Err(crate::Error::NotSupported(
842                "per-channel coeff quantization not supported on proto-res i8 path".into(),
843            ))
844        }
845    };
846    let zp_p: i32 = match proto_quant.mode() {
847        QuantMode::PerTensor { zero_point, .. } => zero_point,
848        QuantMode::PerTensorSymmetric { .. } => 0,
849        _ => {
850            return Err(crate::Error::NotSupported(
851                "per-channel proto quantization not supported on proto-res i8 path".into(),
852            ))
853        }
854    };
855
856    let hw = proto_h * proto_w;
857
858    // Precompute per-pixel proto sums for zero-point correction.
859    let proto_sums: Vec<i32> = if zp_c != 0 {
860        match layout {
861            edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
862                .map(|px_idx| {
863                    let base = px_idx * num_protos;
864                    protos[base..base + num_protos]
865                        .iter()
866                        .map(|&v| v as i32)
867                        .sum()
868                })
869                .collect(),
870            edgefirst_decoder::ProtoLayout::Nchw => {
871                let mut sums = vec![0i32; hw];
872                for c in 0..num_protos {
873                    let plane = &protos[c * hw..];
874                    for (px, s) in sums.iter_mut().enumerate() {
875                        *s += plane[px] as i32;
876                    }
877                }
878                sums
879            }
880        }
881    } else {
882        Vec::new()
883    };
884
885    #[cfg(target_arch = "aarch64")]
886    let use_dotprod = std::arch::is_aarch64_feature_detected!("dotprod");
887
888    detect
889        .par_iter()
890        .enumerate()
891        .map(|(i, det)| {
892            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
893            let (x0, y0, x1, y1, roi_w, roi_h) = bbox_to_proto_roi(det, proto_w, proto_h);
894
895            // Per-detection bias: zp_p·Σc_raw - N·zp_c·zp_p
896            let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
897            let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
898
899            let mut mask_buf = vec![0u8; roi_h * roi_w];
900
901            match layout {
902                edgefirst_decoder::ProtoLayout::Nhwc => {
903                    let stride_y = proto_w * num_protos;
904                    #[cfg(target_arch = "aarch64")]
905                    {
906                        if use_dotprod {
907                            for ly in 0..roi_h {
908                                let py = y0 + ly;
909                                let row_base = py * stride_y + x0 * num_protos;
910                                for lx in 0..roi_w {
911                                    let pix_base = row_base + lx * num_protos;
912                                    let proto_px = &protos[pix_base..pix_base + num_protos];
913                                    let raw_dot = unsafe {
914                                        dot_i8_neon_dotprod(
915                                            coeff.as_ptr(),
916                                            proto_px.as_ptr(),
917                                            num_protos,
918                                        )
919                                    };
920                                    let correction = if zp_c != 0 {
921                                        zp_c * proto_sums[py * proto_w + x0 + lx]
922                                    } else {
923                                        0
924                                    };
925                                    let logit = raw_dot - correction - bias;
926                                    if logit > 0 {
927                                        mask_buf[ly * roi_w + lx] = 255;
928                                    }
929                                }
930                            }
931                        } else {
932                            for ly in 0..roi_h {
933                                let py = y0 + ly;
934                                let row_base = py * stride_y + x0 * num_protos;
935                                for lx in 0..roi_w {
936                                    let pix_base = row_base + lx * num_protos;
937                                    let proto_px = &protos[pix_base..pix_base + num_protos];
938                                    let raw_dot = unsafe {
939                                        dot_i8_neon_base(
940                                            coeff.as_ptr(),
941                                            proto_px.as_ptr(),
942                                            num_protos,
943                                        )
944                                    };
945                                    let correction = if zp_c != 0 {
946                                        zp_c * proto_sums[py * proto_w + x0 + lx]
947                                    } else {
948                                        0
949                                    };
950                                    let logit = raw_dot - correction - bias;
951                                    if logit > 0 {
952                                        mask_buf[ly * roi_w + lx] = 255;
953                                    }
954                                }
955                            }
956                        }
957                    }
958                    #[cfg(not(target_arch = "aarch64"))]
959                    {
960                        for ly in 0..roi_h {
961                            let py = y0 + ly;
962                            let row_base = py * stride_y + x0 * num_protos;
963                            for lx in 0..roi_w {
964                                let pix_base = row_base + lx * num_protos;
965                                let proto_px = &protos[pix_base..pix_base + num_protos];
966                                let raw_dot = dot_i8_scalar(coeff, proto_px, num_protos);
967                                let correction = if zp_c != 0 {
968                                    zp_c * proto_sums[py * proto_w + x0 + lx]
969                                } else {
970                                    0
971                                };
972                                let logit = raw_dot - correction - bias;
973                                if logit > 0 {
974                                    mask_buf[ly * roi_w + lx] = 255;
975                                }
976                            }
977                        }
978                    }
979                }
980                edgefirst_decoder::ProtoLayout::Nchw => {
981                    // Channel-major accumulation: for each channel, accumulate
982                    // coeff[c] * proto[c, py, px] across the ROI. Each channel
983                    // plane is contiguous, giving excellent sequential read access.
984                    let mut accum = vec![0i32; roi_h * roi_w];
985                    for c in 0..num_protos {
986                        let plane = &protos[c * hw..];
987                        let coeff_c = coeff[c] as i32;
988                        for ly in 0..roi_h {
989                            let py = y0 + ly;
990                            let row_start = py * proto_w + x0;
991                            let out_row_start = ly * roi_w;
992                            for lx in 0..roi_w {
993                                accum[out_row_start + lx] += coeff_c * plane[row_start + lx] as i32;
994                            }
995                        }
996                    }
997                    // Apply zero-point correction and threshold.
998                    for ly in 0..roi_h {
999                        let py = y0 + ly;
1000                        for lx in 0..roi_w {
1001                            let idx = ly * roi_w + lx;
1002                            let correction = if zp_c != 0 {
1003                                zp_c * proto_sums[py * proto_w + x0 + lx]
1004                            } else {
1005                                0
1006                            };
1007                            let logit = accum[idx] - correction - bias;
1008                            if logit > 0 {
1009                                mask_buf[idx] = 255;
1010                            }
1011                        }
1012                    }
1013                }
1014            }
1015
1016            let mask = ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1017                .expect("mask_buf length matches roi_h * roi_w");
1018            Ok(seg_from_roi(
1019                mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
1020            ))
1021        })
1022        .collect()
1023}
1024
1025// =============================================================================
1026// Sign-threshold proto-resolution kernels (f32/f16/i8 protos with f32 coeffs).
1027//
1028// These replace the sigmoid-computing kernels for the non-i8×i8 fallback paths.
1029// Since downstream always thresholds at > 127, computing sigmoid is wasteful;
1030// sign(dot) > 0 ⟺ sigmoid(dot) > 0.5 gives the same binary result.
1031// =============================================================================
1032
1033/// f32 protos × f32 coefficients → sign threshold → binary {0, 255}.
1034#[allow(clippy::too_many_arguments)]
1035fn fused_dot_sign_f32_slice(
1036    protos: &[f32],
1037    coeff: &[f32],
1038    _proto_h: usize,
1039    proto_w: usize,
1040    y0: usize,
1041    x0: usize,
1042    roi_h: usize,
1043    roi_w: usize,
1044    num_protos: usize,
1045) -> ndarray::Array3<u8> {
1046    let stride_y = proto_w * num_protos;
1047    let mut mask_buf = vec![0u8; roi_h * roi_w];
1048    for y in 0..roi_h {
1049        let row_base = (y0 + y) * stride_y + x0 * num_protos;
1050        let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1051        for (x, out_px) in out_row.iter_mut().enumerate() {
1052            let base = row_base + x * num_protos;
1053            let mut acc = 0.0_f32;
1054            let mut k = 0;
1055            let chunks = num_protos / 4;
1056            for _ in 0..chunks {
1057                acc += coeff[k] * protos[base + k]
1058                    + coeff[k + 1] * protos[base + k + 1]
1059                    + coeff[k + 2] * protos[base + k + 2]
1060                    + coeff[k + 3] * protos[base + k + 3];
1061                k += 4;
1062            }
1063            while k < num_protos {
1064                acc += coeff[k] * protos[base + k];
1065                k += 1;
1066            }
1067            if acc > 0.0 {
1068                *out_px = 255;
1069            }
1070        }
1071    }
1072    ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1073        .expect("mask_buf length matches roi_h * roi_w")
1074}
1075
1076/// f16 protos × f32 coefficients → sign threshold → binary {0, 255}.
1077///
1078/// Two code paths:
1079///
1080/// 1. **x86_64 + F16C + FMA** — explicit intrinsic kernel using
1081///    `_mm256_cvtph_ps` (8-lane f16→f32 widening) + `_mm256_fmadd_ps`.
1082///
1083/// 2. **Scalar fallback** — loop-unrolled by 4 with `half::f16::to_f32()`.
1084#[allow(clippy::too_many_arguments)]
1085fn fused_dot_sign_f16_slice(
1086    protos: &[half::f16],
1087    coeff: &[f32],
1088    _proto_h: usize,
1089    proto_w: usize,
1090    y0: usize,
1091    x0: usize,
1092    roi_h: usize,
1093    roi_w: usize,
1094    num_protos: usize,
1095) -> ndarray::Array3<u8> {
1096    #[cfg(all(
1097        target_arch = "x86_64",
1098        target_feature = "f16c",
1099        target_feature = "fma"
1100    ))]
1101    {
1102        // SAFETY: target-feature gates guarantee F16C + FMA support.
1103        unsafe {
1104            fused_dot_sign_f16_slice_f16c(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
1105        }
1106    }
1107    #[cfg(not(all(
1108        target_arch = "x86_64",
1109        target_feature = "f16c",
1110        target_feature = "fma"
1111    )))]
1112    {
1113        fused_dot_sign_f16_slice_scalar(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
1114    }
1115}
1116
1117/// Scalar f16 sign-threshold kernel — loop-unrolled by 4.
1118#[allow(clippy::too_many_arguments)]
1119fn fused_dot_sign_f16_slice_scalar(
1120    protos: &[half::f16],
1121    coeff: &[f32],
1122    proto_w: usize,
1123    y0: usize,
1124    x0: usize,
1125    roi_h: usize,
1126    roi_w: usize,
1127    num_protos: usize,
1128) -> ndarray::Array3<u8> {
1129    let stride_y = proto_w * num_protos;
1130    let mut mask_buf = vec![0u8; roi_h * roi_w];
1131    for y in 0..roi_h {
1132        let row_base = (y0 + y) * stride_y + x0 * num_protos;
1133        let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1134        for (x, out_px) in out_row.iter_mut().enumerate() {
1135            let base = row_base + x * num_protos;
1136            let mut acc = 0.0_f32;
1137            let mut k = 0;
1138            let chunks = num_protos / 4;
1139            for _ in 0..chunks {
1140                acc += coeff[k] * protos[base + k].to_f32()
1141                    + coeff[k + 1] * protos[base + k + 1].to_f32()
1142                    + coeff[k + 2] * protos[base + k + 2].to_f32()
1143                    + coeff[k + 3] * protos[base + k + 3].to_f32();
1144                k += 4;
1145            }
1146            while k < num_protos {
1147                acc += coeff[k] * protos[base + k].to_f32();
1148                k += 1;
1149            }
1150            if acc > 0.0 {
1151                *out_px = 255;
1152            }
1153        }
1154    }
1155    ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1156        .expect("mask_buf length matches roi_h * roi_w")
1157}
1158
1159/// x86_64 F16C + FMA intrinsic kernel for f16 sign-threshold.
1160///
1161/// Uses `_mm256_cvtph_ps` for hardware 8-lane f16→f32 widening and
1162/// `_mm256_fmadd_ps` for fused multiply-add. Only the sign of the
1163/// accumulated dot product is checked (no sigmoid needed).
1164///
1165/// # Safety
1166///
1167/// Caller must ensure the target CPU supports F16C + FMA.
1168#[cfg(all(
1169    target_arch = "x86_64",
1170    target_feature = "f16c",
1171    target_feature = "fma"
1172))]
1173#[allow(clippy::too_many_arguments)]
1174#[target_feature(enable = "f16c,fma,avx")]
1175unsafe fn fused_dot_sign_f16_slice_f16c(
1176    protos: &[half::f16],
1177    coeff: &[f32],
1178    proto_w: usize,
1179    y0: usize,
1180    x0: usize,
1181    roi_h: usize,
1182    roi_w: usize,
1183    num_protos: usize,
1184) -> ndarray::Array3<u8> {
1185    use core::arch::x86_64::{
1186        _mm256_castps256_ps128, _mm256_cvtph_ps, _mm256_extractf128_ps, _mm256_fmadd_ps,
1187        _mm256_loadu_ps, _mm256_setzero_ps, _mm_add_ps, _mm_cvtss_f32, _mm_hadd_ps,
1188        _mm_loadu_si128,
1189    };
1190
1191    let stride_y = proto_w * num_protos;
1192    let chunks8 = num_protos / 8;
1193    let mut mask_buf = vec![0u8; roi_h * roi_w];
1194
1195    for y in 0..roi_h {
1196        let row_base = (y0 + y) * stride_y + x0 * num_protos;
1197        let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1198        for (x, out_px) in out_row.iter_mut().enumerate() {
1199            let base = row_base + x * num_protos;
1200            let mut acc_v = _mm256_setzero_ps();
1201            let mut k = 0;
1202            for _ in 0..chunks8 {
1203                let p_ptr = protos
1204                    .as_ptr()
1205                    .add(base + k)
1206                    .cast::<core::arch::x86_64::__m128i>();
1207                let raw = _mm_loadu_si128(p_ptr);
1208                let widened = _mm256_cvtph_ps(raw);
1209                let coeffs_v = _mm256_loadu_ps(coeff.as_ptr().add(k));
1210                acc_v = _mm256_fmadd_ps(coeffs_v, widened, acc_v);
1211                k += 8;
1212            }
1213            // Horizontal reduce 8 → 1.
1214            let lo = _mm256_castps256_ps128(acc_v);
1215            let hi = _mm256_extractf128_ps::<1>(acc_v);
1216            let sum4 = _mm_add_ps(lo, hi);
1217            let sum2 = _mm_hadd_ps(sum4, sum4);
1218            let sum1 = _mm_hadd_ps(sum2, sum2);
1219            let mut acc = _mm_cvtss_f32(sum1);
1220
1221            // Scalar tail for num_protos % 8.
1222            while k < num_protos {
1223                acc += coeff[k] * protos[base + k].to_f32();
1224                k += 1;
1225            }
1226
1227            if acc > 0.0 {
1228                *out_px = 255;
1229            }
1230        }
1231    }
1232    ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1233        .expect("mask_buf length matches roi_h * roi_w")
1234}
1235
1236/// i8 protos (with quant) × f32 coefficients → sign threshold → binary {0, 255}.
1237/// Fallback for per-channel quant or mixed-dtype cases where the i8×i8 fast path
1238/// doesn't apply.
1239#[allow(clippy::too_many_arguments)]
1240fn fused_dequant_dot_sign_i8_slice(
1241    protos: &[i8],
1242    coeff: &[f32],
1243    quant: &edgefirst_tensor::Quantization,
1244    _proto_h: usize,
1245    proto_w: usize,
1246    y0: usize,
1247    x0: usize,
1248    roi_h: usize,
1249    roi_w: usize,
1250    num_protos: usize,
1251) -> crate::Result<ndarray::Array3<u8>> {
1252    use edgefirst_tensor::QuantMode;
1253    let stride_y = proto_w * num_protos;
1254
1255    // Precompute scaled coefficients + zp_offset (same as the old sigmoid kernel).
1256    let mut stack_scratch = [0.0_f32; 64];
1257    let mut heap_scratch: Vec<f32>;
1258    let scaled_coeff: &mut [f32] = if num_protos <= stack_scratch.len() {
1259        &mut stack_scratch[..num_protos]
1260    } else {
1261        heap_scratch = vec![0.0_f32; num_protos];
1262        heap_scratch.as_mut_slice()
1263    };
1264    let zp_offset: f32;
1265    match quant.mode() {
1266        QuantMode::PerTensorSymmetric { scale } => {
1267            for k in 0..num_protos {
1268                scaled_coeff[k] = coeff[k] * scale;
1269            }
1270            zp_offset = 0.0;
1271        }
1272        QuantMode::PerTensor { scale, zero_point } => {
1273            for k in 0..num_protos {
1274                scaled_coeff[k] = coeff[k] * scale;
1275            }
1276            zp_offset = zero_point as f32 * scaled_coeff.iter().take(num_protos).sum::<f32>();
1277        }
1278        QuantMode::PerChannelSymmetric { scales, axis } => {
1279            if axis != 2 {
1280                return Err(crate::Error::NotSupported(format!(
1281                    "per-channel quantization on axis {axis} not supported \
1282                     (only channel axis 2 is implemented on this kernel)"
1283                )));
1284            }
1285            for k in 0..num_protos {
1286                scaled_coeff[k] = coeff[k] * scales[k];
1287            }
1288            zp_offset = 0.0;
1289        }
1290        QuantMode::PerChannel {
1291            scales,
1292            zero_points,
1293            axis,
1294        } => {
1295            if axis != 2 {
1296                return Err(crate::Error::NotSupported(format!(
1297                    "per-channel quantization on axis {axis} not supported \
1298                     (only channel axis 2 is implemented on this kernel)"
1299                )));
1300            }
1301            for k in 0..num_protos {
1302                scaled_coeff[k] = coeff[k] * scales[k];
1303            }
1304            zp_offset = (0..num_protos)
1305                .map(|k| scaled_coeff[k] * zero_points[k] as f32)
1306                .sum();
1307        }
1308    }
1309
1310    let mut mask_buf = vec![0u8; roi_h * roi_w];
1311    for y in 0..roi_h {
1312        let row_base = (y0 + y) * stride_y + (x0) * num_protos;
1313        let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1314        for (x, out_px) in out_row.iter_mut().enumerate() {
1315            let base = row_base + x * num_protos;
1316            let mut acc = 0.0_f32;
1317            let mut k = 0;
1318            let chunks = num_protos / 4;
1319            for _ in 0..chunks {
1320                let p0 = protos[base + k] as f32;
1321                let p1 = protos[base + k + 1] as f32;
1322                let p2 = protos[base + k + 2] as f32;
1323                let p3 = protos[base + k + 3] as f32;
1324                acc += scaled_coeff[k] * p0
1325                    + scaled_coeff[k + 1] * p1
1326                    + scaled_coeff[k + 2] * p2
1327                    + scaled_coeff[k + 3] * p3;
1328                k += 4;
1329            }
1330            while k < num_protos {
1331                acc += scaled_coeff[k] * protos[base + k] as f32;
1332                k += 1;
1333            }
1334            if acc > zp_offset {
1335                *out_px = 255;
1336            }
1337        }
1338    }
1339    Ok(ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1340        .expect("mask_buf length matches roi_h * roi_w"))
1341}
1342
1343#[allow(clippy::too_many_arguments)]
1344fn scaled_segmentations_f32_slice(
1345    detect: &[crate::DetectBox],
1346    coeff_all: &[f32],
1347    protos: &[f32],
1348    proto_h: usize,
1349    proto_w: usize,
1350    num_protos: usize,
1351    letterbox: Option<[f32; 4]>,
1352    width: u32,
1353    height: u32,
1354) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1355    scaled_run(
1356        detect,
1357        coeff_all,
1358        protos,
1359        proto_h,
1360        proto_w,
1361        num_protos,
1362        letterbox,
1363        width,
1364        height,
1365        1.0,
1366        |p, _| *p,
1367    )
1368}
1369
1370#[allow(clippy::too_many_arguments)]
1371fn scaled_segmentations_f16_slice(
1372    detect: &[crate::DetectBox],
1373    coeff_all: &[f32],
1374    protos: &[half::f16],
1375    proto_h: usize,
1376    proto_w: usize,
1377    num_protos: usize,
1378    letterbox: Option<[f32; 4]>,
1379    width: u32,
1380    height: u32,
1381) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1382    scaled_run(
1383        detect,
1384        coeff_all,
1385        protos,
1386        proto_h,
1387        proto_w,
1388        num_protos,
1389        letterbox,
1390        width,
1391        height,
1392        1.0,
1393        |p: &half::f16, _| p.to_f32(),
1394    )
1395}
1396
1397#[allow(clippy::too_many_arguments)]
1398fn scaled_segmentations_i8_slice(
1399    detect: &[crate::DetectBox],
1400    coeff_all: &[f32],
1401    protos: &[i8],
1402    proto_h: usize,
1403    proto_w: usize,
1404    num_protos: usize,
1405    quant: &edgefirst_tensor::Quantization,
1406    letterbox: Option<[f32; 4]>,
1407    width: u32,
1408    height: u32,
1409) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1410    use edgefirst_tensor::QuantMode;
1411    // Only per-tensor quantization supported on the scaled-path CPU kernel
1412    // today. Per-channel fits naturally into a future extension (would need
1413    // per-channel scaled coefficients in scaled_run's dot-product precompute).
1414    let (scale, zp) = match quant.mode() {
1415        QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
1416        QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
1417        QuantMode::PerChannel { axis, .. } | QuantMode::PerChannelSymmetric { axis, .. } => {
1418            return Err(crate::Error::NotSupported(format!(
1419                "per-channel quantization (axis={axis}) on scaled seg path \
1420                 not yet supported"
1421            )));
1422        }
1423    };
1424    scaled_run(
1425        detect,
1426        coeff_all,
1427        protos,
1428        proto_h,
1429        proto_w,
1430        num_protos,
1431        letterbox,
1432        width,
1433        height,
1434        scale,
1435        move |p: &i8, _| *p as f32 - zp,
1436    )
1437}
1438
1439// =============================================================================
1440// Integer-domain kernel: i8 coefficients × i8 protos → i32 → sign threshold.
1441//
1442// Eliminates all f32 conversion by working directly with raw quantized values.
1443// The math:
1444//   sign(dot(dequant(coeff), dequant(proto)))
1445//   = sign(Σ (c_raw - zp_c) · (p_raw - zp_p))
1446//   = sign(Σ c_raw·p_raw - zp_c·Σp_raw - zp_p·Σc_raw + N·zp_c·zp_p)
1447//   = sign(sdot(c_raw, p_raw) - zp_c·proto_sum[pixel] - bias_per_det)
1448//
1449// where bias_per_det = zp_p·Σc_raw - N·zp_c·zp_p  (precomputed once per det).
1450// =============================================================================
1451
1452/// Compute i8×i8 dot product (32 elements) → i32.
1453/// Platform-agnostic scalar fallback.
1454#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
1455#[inline(always)]
1456fn dot_i8_scalar(coeff: &[i8], proto: &[i8], n: usize) -> i32 {
1457    let mut acc: i32 = 0;
1458    let chunks = n / 4;
1459    let mut k = 0;
1460    for _ in 0..chunks {
1461        acc += coeff[k] as i32 * proto[k] as i32
1462            + coeff[k + 1] as i32 * proto[k + 1] as i32
1463            + coeff[k + 2] as i32 * proto[k + 2] as i32
1464            + coeff[k + 3] as i32 * proto[k + 3] as i32;
1465        k += 4;
1466    }
1467    while k < n {
1468        acc += coeff[k] as i32 * proto[k] as i32;
1469        k += 1;
1470    }
1471    acc
1472}
1473
1474/// NEON i8×i8→i32 dot product using smull+sadalp (works on ALL aarch64, A53+).
1475#[cfg(target_arch = "aarch64")]
1476#[inline(always)]
1477unsafe fn dot_i8_neon_base(coeff: *const i8, proto: *const i8, n: usize) -> i32 {
1478    use std::arch::aarch64::*;
1479    let mut acc = vdupq_n_s32(0);
1480    let full_chunks = n / 16;
1481    let mut offset = 0usize;
1482    for _ in 0..full_chunks {
1483        let c = vld1q_s8(coeff.add(offset));
1484        let p = vld1q_s8(proto.add(offset));
1485        // Widening multiply + pairwise accumulate (all aarch64).
1486        let lo = vmull_s8(vget_low_s8(c), vget_low_s8(p));
1487        let hi = vmull_high_s8(c, p);
1488        acc = vpadalq_s16(acc, lo);
1489        acc = vpadalq_s16(acc, hi);
1490        offset += 16;
1491    }
1492    // Handle remaining elements (for num_protos=32, full_chunks=2, remainder=0)
1493    let remainder = n - offset;
1494    if remainder >= 8 {
1495        let c = vld1_s8(coeff.add(offset));
1496        let p = vld1_s8(proto.add(offset));
1497        let prod = vmull_s8(c, p);
1498        acc = vpadalq_s16(acc, prod);
1499        offset += 8;
1500    }
1501    let mut scalar_acc = vaddvq_s32(acc);
1502    while offset < n {
1503        scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1504        offset += 1;
1505    }
1506    scalar_acc
1507}
1508
1509/// NEON i8×i8→i32 dot product using sdot (ARMv8.2-A dotprod, A55+).
1510/// Each `sdot` processes 16 i8 lanes → 4 i32 partial sums in one instruction,
1511/// replacing the 3-instruction smull+smull2+sadalp sequence.
1512#[cfg(target_arch = "aarch64")]
1513#[inline(always)]
1514unsafe fn dot_i8_neon_dotprod(coeff: *const i8, proto: *const i8, n: usize) -> i32 {
1515    use std::arch::aarch64::*;
1516    let mut acc = vdupq_n_s32(0);
1517    let full_chunks = n / 16;
1518    let mut offset = 0usize;
1519    for _ in 0..full_chunks {
1520        let c = vld1q_s8(coeff.add(offset));
1521        let p = vld1q_s8(proto.add(offset));
1522        // Enable dotprod extension locally so the assembler accepts sdot
1523        // even when compiling for baseline aarch64 (A53). At runtime we only
1524        // reach this path when HWCAP confirms dotprod support.
1525        let result: int32x4_t;
1526        core::arch::asm!(
1527            ".arch_extension dotprod",
1528            "sdot {acc:v}.4s, {a:v}.16b, {b:v}.16b",
1529            acc = inout(vreg) acc => result,
1530            a = in(vreg) c,
1531            b = in(vreg) p,
1532            options(pure, nomem, nostack),
1533        );
1534        acc = result;
1535        offset += 16;
1536    }
1537    let mut scalar_acc = vaddvq_s32(acc);
1538    // Tail: handle remainder (unlikely for num_protos=32, but correct)
1539    while offset < n {
1540        scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1541        offset += 1;
1542    }
1543    scalar_acc
1544}
1545
1546/// Compute the logit grid using the dotprod (sdot) path.
1547/// Separated into its own function so the compiler inlines the sdot asm fully.
1548#[cfg(target_arch = "aarch64")]
1549#[inline(always)]
1550#[allow(clippy::too_many_arguments)]
1551fn compute_logits_dotprod(
1552    logits: &mut [i32],
1553    coeff: &[i8],
1554    protos: &[i8],
1555    proto_sums: &[i32],
1556    proto_w: usize,
1557    proto_x0: usize,
1558    proto_y0: usize,
1559    roi_w: usize,
1560    roi_h: usize,
1561    stride_y: usize,
1562    num_protos: usize,
1563    zp_c: i32,
1564    bias: i32,
1565) {
1566    for ly_idx in 0..roi_h {
1567        let py = proto_y0 + ly_idx;
1568        let row_base = py * stride_y + proto_x0 * num_protos;
1569        for lx_idx in 0..roi_w {
1570            let pix_base = row_base + lx_idx * num_protos;
1571            let proto_px = &protos[pix_base..pix_base + num_protos];
1572            let raw_dot =
1573                unsafe { dot_i8_neon_dotprod(coeff.as_ptr(), proto_px.as_ptr(), num_protos) };
1574            let correction = if zp_c != 0 {
1575                zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
1576            } else {
1577                0
1578            };
1579            logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
1580        }
1581    }
1582}
1583
1584/// Compute the logit grid using the base NEON path (smull+sadalp).
1585/// Separated into its own function so the compiler inlines the NEON code fully.
1586#[cfg(target_arch = "aarch64")]
1587#[inline(always)]
1588#[allow(clippy::too_many_arguments)]
1589fn compute_logits_base(
1590    logits: &mut [i32],
1591    coeff: &[i8],
1592    protos: &[i8],
1593    proto_sums: &[i32],
1594    proto_w: usize,
1595    proto_x0: usize,
1596    proto_y0: usize,
1597    roi_w: usize,
1598    roi_h: usize,
1599    stride_y: usize,
1600    num_protos: usize,
1601    zp_c: i32,
1602    bias: i32,
1603) {
1604    for ly_idx in 0..roi_h {
1605        let py = proto_y0 + ly_idx;
1606        let row_base = py * stride_y + proto_x0 * num_protos;
1607        for lx_idx in 0..roi_w {
1608            let pix_base = row_base + lx_idx * num_protos;
1609            let proto_px = &protos[pix_base..pix_base + num_protos];
1610            let raw_dot =
1611                unsafe { dot_i8_neon_base(coeff.as_ptr(), proto_px.as_ptr(), num_protos) };
1612            let correction = if zp_c != 0 {
1613                zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
1614            } else {
1615                0
1616            };
1617            logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
1618        }
1619    }
1620}
1621
1622#[allow(clippy::too_many_arguments)]
1623fn scaled_segmentations_i8_i8(
1624    detect: &[crate::DetectBox],
1625    coeff_all: &[i8],
1626    coeff_quant: &edgefirst_tensor::Quantization,
1627    protos: &[i8],
1628    proto_quant: &edgefirst_tensor::Quantization,
1629    proto_h: usize,
1630    proto_w: usize,
1631    num_protos: usize,
1632    letterbox: Option<[f32; 4]>,
1633    width: u32,
1634    height: u32,
1635    layout: edgefirst_decoder::ProtoLayout,
1636) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1637    use edgefirst_tensor::QuantMode;
1638
1639    let _span = tracing::trace_span!(
1640        "mask_i8_fastpath",
1641        n = detect.len(),
1642        proto_h,
1643        proto_w,
1644        num_protos,
1645        width,
1646        height,
1647        ?layout,
1648    )
1649    .entered();
1650
1651    let zp_c: i32 = match coeff_quant.mode() {
1652        QuantMode::PerTensor { zero_point, .. } => zero_point,
1653        QuantMode::PerTensorSymmetric { .. } => 0,
1654        _ => {
1655            return Err(crate::Error::NotSupported(
1656                "per-channel coeff quantization not supported".into(),
1657            ))
1658        }
1659    };
1660    let zp_p: i32 = match proto_quant.mode() {
1661        QuantMode::PerTensor { zero_point, .. } => zero_point,
1662        QuantMode::PerTensorSymmetric { .. } => 0,
1663        _ => {
1664            return Err(crate::Error::NotSupported(
1665                "per-channel proto quantization not supported".into(),
1666            ))
1667        }
1668    };
1669
1670    let (lx0, lw, ly0, lh) = match letterbox {
1671        Some([lx0, ly0, lx1, ly1]) => {
1672            let lw = (lx1 - lx0).max(f32::EPSILON);
1673            let lh = (ly1 - ly0).max(f32::EPSILON);
1674            (lx0, lw, ly0, lh)
1675        }
1676        None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
1677    };
1678    let out_w = width as usize;
1679    let out_h = height as usize;
1680    let hw = proto_h * proto_w;
1681
1682    // Precompute proto_sum for the entire proto tensor (zero-point correction).
1683    let proto_sums: Vec<i32> = if zp_c != 0 {
1684        match layout {
1685            edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
1686                .map(|px_idx| {
1687                    let base = px_idx * num_protos;
1688                    let mut s: i32 = 0;
1689                    for k in 0..num_protos {
1690                        s += protos[base + k] as i32;
1691                    }
1692                    s
1693                })
1694                .collect(),
1695            edgefirst_decoder::ProtoLayout::Nchw => {
1696                let mut sums = vec![0i32; hw];
1697                for c in 0..num_protos {
1698                    let plane = &protos[c * hw..];
1699                    for (px, s) in sums.iter_mut().enumerate() {
1700                        *s += plane[px] as i32;
1701                    }
1702                }
1703                sums
1704            }
1705        }
1706    } else {
1707        Vec::new()
1708    };
1709
1710    // Detect dotprod support once, outside the hot loop.
1711    #[cfg(target_arch = "aarch64")]
1712    let use_dotprod = std::arch::is_aarch64_feature_detected!("dotprod");
1713
1714    // For NHWC layout, stride for row navigation.
1715    let stride_y = proto_w * num_protos;
1716
1717    detect
1718        .par_iter()
1719        .enumerate()
1720        .map(|(i, det)| {
1721            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1722            let bbox = det.bbox.to_canonical();
1723            let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
1724            let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
1725            let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
1726            let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
1727            let px0 = (xmin * out_w as f32).round() as usize;
1728            let py0 = (ymin * out_h as f32).round() as usize;
1729            let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
1730            let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
1731            let bbox_w = px1.saturating_sub(px0).max(1);
1732            let bbox_h = py1.saturating_sub(py0).max(1);
1733
1734            // Map output bbox → proto ROI.
1735            let sample_x_at = |px: f32| -> f32 {
1736                let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
1737                model_x_norm * proto_w as f32 - 0.5
1738            };
1739            let sample_y_at = |py: f32| -> f32 {
1740                let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
1741                model_y_norm * proto_h as f32 - 0.5
1742            };
1743            let s_x_min = sample_x_at(px0 as f32);
1744            let s_x_max = sample_x_at((px1 as f32) - 1.0);
1745            let s_y_min = sample_y_at(py0 as f32);
1746            let s_y_max = sample_y_at((py1 as f32) - 1.0);
1747            let proto_x0 = (s_x_min.floor() as isize)
1748                .max(0)
1749                .min(proto_w.saturating_sub(1) as isize) as usize;
1750            let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
1751            let proto_y0 = (s_y_min.floor() as isize)
1752                .max(0)
1753                .min(proto_h.saturating_sub(1) as isize) as usize;
1754            let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
1755            let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
1756            let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
1757
1758            // Per-detection bias.
1759            let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
1760            let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
1761
1762            // Step 2: Compute i32 logits at each proto-ROI pixel.
1763            let mut logits = vec![0_i32; roi_h * roi_w];
1764            match layout {
1765                edgefirst_decoder::ProtoLayout::Nhwc => {
1766                    #[cfg(target_arch = "aarch64")]
1767                    {
1768                        if use_dotprod {
1769                            compute_logits_dotprod(
1770                                &mut logits,
1771                                coeff,
1772                                protos,
1773                                &proto_sums,
1774                                proto_w,
1775                                proto_x0,
1776                                proto_y0,
1777                                roi_w,
1778                                roi_h,
1779                                stride_y,
1780                                num_protos,
1781                                zp_c,
1782                                bias,
1783                            );
1784                        } else {
1785                            compute_logits_base(
1786                                &mut logits,
1787                                coeff,
1788                                protos,
1789                                &proto_sums,
1790                                proto_w,
1791                                proto_x0,
1792                                proto_y0,
1793                                roi_w,
1794                                roi_h,
1795                                stride_y,
1796                                num_protos,
1797                                zp_c,
1798                                bias,
1799                            );
1800                        }
1801                    }
1802                    #[cfg(not(target_arch = "aarch64"))]
1803                    {
1804                        for ly_idx in 0..roi_h {
1805                            let py = proto_y0 + ly_idx;
1806                            let row_base = py * stride_y + proto_x0 * num_protos;
1807                            for lx_idx in 0..roi_w {
1808                                let pix_base = row_base + lx_idx * num_protos;
1809                                let proto_px = &protos[pix_base..pix_base + num_protos];
1810                                let raw_dot = dot_i8_scalar(coeff, proto_px, num_protos);
1811                                let correction = if zp_c != 0 {
1812                                    zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
1813                                } else {
1814                                    0
1815                                };
1816                                logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
1817                            }
1818                        }
1819                    }
1820                }
1821                edgefirst_decoder::ProtoLayout::Nchw => {
1822                    // Channel-major accumulation: contiguous reads per channel plane.
1823                    for c in 0..num_protos {
1824                        let plane = &protos[c * hw..];
1825                        let coeff_c = coeff[c] as i32;
1826                        for ly_idx in 0..roi_h {
1827                            let py = proto_y0 + ly_idx;
1828                            let row_start = py * proto_w + proto_x0;
1829                            let out_row_start = ly_idx * roi_w;
1830                            for lx_idx in 0..roi_w {
1831                                logits[out_row_start + lx_idx] +=
1832                                    coeff_c * plane[row_start + lx_idx] as i32;
1833                            }
1834                        }
1835                    }
1836                    // Apply zero-point correction and per-detection bias.
1837                    for ly_idx in 0..roi_h {
1838                        let py = proto_y0 + ly_idx;
1839                        for lx_idx in 0..roi_w {
1840                            let idx = ly_idx * roi_w + lx_idx;
1841                            let correction = if zp_c != 0 {
1842                                zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
1843                            } else {
1844                                0
1845                            };
1846                            logits[idx] -= correction + bias;
1847                        }
1848                    }
1849                }
1850            }
1851
1852            // Step 3: Bilinear upsample i32 logits → binary mask with
1853            // sign-shortcut (skip interpolation when all 4 neighbors agree).
1854            let roi_last_x = roi_w.saturating_sub(1);
1855            let roi_last_y = roi_h.saturating_sub(1);
1856
1857            // X-coordinate LUT with fixed-point fraction (scale 1024).
1858            const FRAC_BITS: i32 = 10;
1859            const FRAC_SCALE: i32 = 1 << FRAC_BITS; // 1024
1860            let x_coords: Vec<(usize, usize, i32)> = (0..bbox_w)
1861                .map(|xi| {
1862                    let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
1863                    let x_floor = sample_x.floor();
1864                    let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as usize;
1865                    let x_hi = (x_lo + 1).min(roi_w - 1);
1866                    let x_frac = ((sample_x - x_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
1867                    (x_lo, x_hi, x_frac)
1868                })
1869                .collect();
1870
1871            let mut tile_buf = vec![0u8; bbox_h * bbox_w];
1872            for yi in 0..bbox_h {
1873                let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
1874                let y_floor = sample_y.floor();
1875                let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
1876                let y_hi = (y_lo + 1).min(roi_h - 1);
1877                let y_frac = ((sample_y - y_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
1878                let y_frac_inv = FRAC_SCALE - y_frac;
1879                let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
1880                let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
1881                let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
1882
1883                for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
1884                    let tl = row_lo[x_lo];
1885                    let tr = row_lo[x_hi];
1886                    let bl = row_hi[x_lo];
1887                    let br = row_hi[x_hi];
1888
1889                    // Sign-shortcut: if all 4 corners have the same sign,
1890                    // the bilinear interpolation (positive-weight combination)
1891                    // preserves that sign. Skip arithmetic for ~80% of pixels.
1892                    if (tl & tr & bl & br) < 0 {
1893                        // All negative → output 0 (already zero).
1894                        continue;
1895                    }
1896                    if tl > 0 && tr > 0 && bl > 0 && br > 0 {
1897                        // All strictly positive → output 255.
1898                        out_row[xi] = 255;
1899                        continue;
1900                    }
1901
1902                    // Boundary pixel: fixed-point bilinear in i64.
1903                    let x_frac_inv = FRAC_SCALE - x_frac;
1904                    let l0 = tl as i64 * x_frac_inv as i64 + tr as i64 * x_frac as i64;
1905                    let l1 = bl as i64 * x_frac_inv as i64 + br as i64 * x_frac as i64;
1906                    let logit = l0 * y_frac_inv as i64 + l1 * y_frac as i64;
1907                    out_row[xi] = if logit > 0 { 255 } else { 0 };
1908                }
1909            }
1910
1911            let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
1912                .expect("tile_buf length matches bbox_h * bbox_w");
1913            Ok(edgefirst_decoder::Segmentation {
1914                xmin,
1915                ymin,
1916                xmax,
1917                ymax,
1918                segmentation: tile,
1919            })
1920        })
1921        .collect()
1922}
1923
1924#[allow(clippy::too_many_arguments)]
1925fn scaled_run<P: Copy + Sync>(
1926    detect: &[crate::DetectBox],
1927    coeff_all: &[f32],
1928    protos: &[P],
1929    proto_h: usize,
1930    proto_w: usize,
1931    num_protos: usize,
1932    letterbox: Option<[f32; 4]>,
1933    width: u32,
1934    height: u32,
1935    acc_scale: f32,
1936    load_f32: impl Fn(&P, f32) -> f32 + Copy + Sync,
1937) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1938    let (lx0, lw, ly0, lh) = match letterbox {
1939        Some([lx0, ly0, lx1, ly1]) => {
1940            let lw = (lx1 - lx0).max(f32::EPSILON);
1941            let lh = (ly1 - ly0).max(f32::EPSILON);
1942            (lx0, lw, ly0, lh)
1943        }
1944        None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
1945    };
1946    let out_w = width as usize;
1947    let out_h = height as usize;
1948    let stride_y = proto_w * num_protos;
1949
1950    // Parallelise across detections. Each detection produces an
1951    // independent ndarray::Array3<u8> tile from a read-only proto slice +
1952    // its own coeff slice; no shared mutable state.
1953    //
1954    // Algorithm (restores the spirit of PR #54's batched-GEMM optimisation
1955    // that PR #51's f16 dispatch refactor inadvertently removed):
1956    //
1957    //   1. Map the output bbox back to a proto-plane ROI (with 1-px margin
1958    //      so the bilinear sampling at the output edges has neighbours).
1959    //   2. Precompute *f32 logits* at every proto pixel inside that ROI by
1960    //      doing a single K-wide dot product per proto pixel — once, not
1961    //      once per output pixel.
1962    //   3. For each output pixel, bilinear-interpolate the scalar f32 logit
1963    //      from the 4 surrounding proto-roi pixels, apply sigmoid, and
1964    //      threshold to {0, 255}.
1965    //
1966    // For typical YOLO-seg: proto_roi ~ 30×30 = 900 px × K=32 = 28.8K dot
1967    // ops vs the legacy "bilinear sample then dot at every output pixel"
1968    // which costs bbox_h × bbox_w × 4 × K = ~1.3M ops at 100×100 output
1969    // bbox. ~45× fewer FMAs at this size; the bilinear upsample of a
1970    // scalar plane (no inner K loop) is comparatively negligible.
1971    detect
1972        .par_iter()
1973        .enumerate()
1974        .map(|(i, det)| {
1975            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1976            let bbox = det.bbox.to_canonical();
1977            let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
1978            let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
1979            let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
1980            let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
1981            let px0 = (xmin * out_w as f32).round() as usize;
1982            let py0 = (ymin * out_h as f32).round() as usize;
1983            let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
1984            let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
1985            let bbox_w = px1.saturating_sub(px0).max(1);
1986            let bbox_h = py1.saturating_sub(py0).max(1);
1987
1988            // Step 1 — proto-plane ROI for this detection's output bbox.
1989            // Map the four output bbox corners back to proto coords and
1990            // expand by 1 pixel in each direction so the bilinear sampler
1991            // at the bbox boundary has both neighbours.
1992            let sample_x_at = |px: f32| -> f32 {
1993                let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
1994                model_x_norm * proto_w as f32 - 0.5
1995            };
1996            let sample_y_at = |py: f32| -> f32 {
1997                let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
1998                model_y_norm * proto_h as f32 - 0.5
1999            };
2000            let s_x_min = sample_x_at(px0 as f32);
2001            let s_x_max = sample_x_at((px1 as f32) - 1.0);
2002            let s_y_min = sample_y_at(py0 as f32);
2003            let s_y_max = sample_y_at((py1 as f32) - 1.0);
2004            // Floor min, ceil max+1 to include both bilinear neighbours.
2005            // Start indices are used as direct bases into `protos`, so clamp
2006            // them to the last valid index, not to the exclusive upper bound.
2007            let proto_x0 = (s_x_min.floor() as isize)
2008                .max(0)
2009                .min(proto_w.saturating_sub(1) as isize) as usize;
2010            let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2011            let proto_y0 = (s_y_min.floor() as isize)
2012                .max(0)
2013                .min(proto_h.saturating_sub(1) as isize) as usize;
2014            let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2015            let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2016            let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2017
2018            // Step 2 — precompute f32 logits at every proto-roi pixel.
2019            // logits[(py - proto_y0) * roi_w + (px - proto_x0)] = dot(coeff, proto[py, px, :])
2020            //
2021            // Since the final threshold is `logit > 0` (O1) and bilinear
2022            // interpolation is a positive-weight linear combination,
2023            // `acc_scale * interp(logits) > 0 ⟺ interp(logits) > 0` when
2024            // acc_scale > 0. We therefore skip the per-pixel `acc_scale *`
2025            // multiply entirely, storing raw dot products.
2026            if !acc_scale.is_finite() || acc_scale <= 0.0 {
2027                return Err(crate::Error::NotSupported(format!(
2028                    "acc_scale must be finite and positive for sign-threshold optimization (got {acc_scale})"
2029                )));
2030            }
2031            let _ = acc_scale; // Scale-invariant: only sign matters.
2032            let mut logits = vec![0.0_f32; roi_h * roi_w];
2033            for ly_idx in 0..roi_h {
2034                let py = proto_y0 + ly_idx;
2035                let row_base = py * stride_y + proto_x0 * num_protos;
2036                for lx_idx in 0..roi_w {
2037                    let pix_base = row_base + lx_idx * num_protos;
2038                    let mut acc = 0.0_f32;
2039                    // 4-wide unroll to help auto-vectorization.
2040                    let mut k = 0;
2041                    let chunks = num_protos / 4;
2042                    for _ in 0..chunks {
2043                        acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0)
2044                            + coeff[k + 1] * load_f32(&protos[pix_base + k + 1], 0.0)
2045                            + coeff[k + 2] * load_f32(&protos[pix_base + k + 2], 0.0)
2046                            + coeff[k + 3] * load_f32(&protos[pix_base + k + 3], 0.0);
2047                        k += 4;
2048                    }
2049                    while k < num_protos {
2050                        acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0);
2051                        k += 1;
2052                    }
2053                    logits[ly_idx * roi_w + lx_idx] = acc;
2054                }
2055            }
2056
2057            // Step 3 — bilinear upsample logits → binary mask.
2058            //
2059            // O1: sigmoid(x) > 0.5 ⟺ x > 0 (sigmoid is strictly monotonic,
2060            // and acc_scale > 0 preserves sign). The sign threshold replaces
2061            // the old fast_sigmoid approximation, saving ~15 cycles/pixel.
2062            //
2063            // O5: Pre-compute bilinear sample coordinates. sample_x_at /
2064            // sample_y_at depend only on pixel index, not on logit values.
2065            // Building lookup tables avoids redundant float ops in the inner
2066            // loop (floor, clamp, isize cast per pixel).
2067            let roi_last_x = roi_w.saturating_sub(1);
2068            let roi_last_y = roi_h.saturating_sub(1);
2069
2070            // X-coordinate LUT (shared across all rows).
2071            let x_coords: Vec<(u32, u32, f32)> = (0..bbox_w)
2072                .map(|xi| {
2073                    let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2074                    let x_floor = sample_x.floor();
2075                    let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as u32;
2076                    let x_hi = (x_lo as usize + 1).min(roi_w - 1) as u32;
2077                    let x_frac = (sample_x - x_floor).clamp(0.0, 1.0);
2078                    (x_lo, x_hi, x_frac)
2079                })
2080                .collect();
2081
2082            // Write the output tile through a contiguous slice to avoid
2083            // ndarray's per-element bounds checks + stride arithmetic.
2084            let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2085            for yi in 0..bbox_h {
2086                let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2087                let y_floor = sample_y.floor();
2088                let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2089                let y_hi = (y_lo + 1).min(roi_h - 1);
2090                let y_frac = (sample_y - y_floor).clamp(0.0, 1.0);
2091                let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2092                let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2093                let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2094                for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2095                    let (xl, xh) = (x_lo as usize, x_hi as usize);
2096                    let l0 = row_lo[xl] + (row_lo[xh] - row_lo[xl]) * x_frac;
2097                    let l1 = row_hi[xl] + (row_hi[xh] - row_hi[xl]) * x_frac;
2098                    let logit = l0 + (l1 - l0) * y_frac;
2099                    out_row[xi] = if logit > 0.0 { 255 } else { 0 };
2100                }
2101            }
2102            // Wrap into the expected Array3<u8> shape [bbox_h, bbox_w, 1].
2103            let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2104                .expect("tile_buf length matches bbox_h * bbox_w");
2105            Ok(edgefirst_decoder::Segmentation {
2106                xmin,
2107                ymin,
2108                xmax,
2109                ymax,
2110                segmentation: tile,
2111            })
2112        })
2113        .collect()
2114}