Skip to main content

edgefirst_image/cpu/
masks.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use super::CPUProcessor;
5use crate::Result;
6use edgefirst_decoder::{DetectBox, Segmentation};
7use ndarray::Axis;
8use rayon::prelude::*;
9
10impl CPUProcessor {
11    #[allow(clippy::too_many_arguments)]
12    pub(super) fn render_modelpack_segmentation(
13        &mut self,
14        dst_w: usize,
15        dst_h: usize,
16        dst_rs: usize,
17        dst_c: usize,
18        dst_slice: &mut [u8],
19        segmentation: &Segmentation,
20        opacity: f32,
21    ) -> Result<()> {
22        use ndarray_stats::QuantileExt;
23
24        let seg = &segmentation.segmentation;
25        let [seg_height, seg_width, seg_classes] = *seg.shape() else {
26            unreachable!("Array3 did not have [usize; 3] as shape");
27        };
28        let start_y = (dst_h as f32 * segmentation.ymin).round();
29        let end_y = (dst_h as f32 * segmentation.ymax).round();
30        let start_x = (dst_w as f32 * segmentation.xmin).round();
31        let end_x = (dst_w as f32 * segmentation.xmax).round();
32
33        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
34        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
35
36        let start_x_u = (start_x as usize).min(dst_w);
37        let start_y_u = (start_y as usize).min(dst_h);
38        let end_x_u = (end_x as usize).min(dst_w);
39        let end_y_u = (end_y as usize).min(dst_h);
40
41        let argmax = seg.map_axis(Axis(2), |r| r.argmax().unwrap());
42        let get_value_at_nearest = |x: f32, y: f32| -> usize {
43            let x = x.round() as usize;
44            let y = y.round() as usize;
45            argmax
46                .get([y.min(seg_height - 1), x.min(seg_width - 1)])
47                .copied()
48                .unwrap_or(0)
49        };
50
51        for y in start_y_u..end_y_u {
52            for x in start_x_u..end_x_u {
53                let seg_x = (x as f32 - start_x) * scale_x;
54                let seg_y = (y as f32 - start_y) * scale_y;
55                let label = get_value_at_nearest(seg_x, seg_y);
56
57                if label == seg_classes - 1 {
58                    continue;
59                }
60
61                let color = self.colors[label % self.colors.len()];
62
63                let alpha = if opacity == 1.0 {
64                    color[3] as u16
65                } else {
66                    (color[3] as f32 * opacity).round() as u16
67                };
68
69                let dst_index = (y * dst_rs) + (x * dst_c);
70                for c in 0..3 {
71                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
72                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
73                        / 255) as u8;
74                }
75            }
76        }
77
78        Ok(())
79    }
80
81    #[allow(clippy::too_many_arguments)]
82    pub(super) fn render_yolo_segmentation(
83        &mut self,
84        dst_w: usize,
85        dst_h: usize,
86        dst_rs: usize,
87        dst_c: usize,
88        dst_slice: &mut [u8],
89        segmentation: &Segmentation,
90        class: usize,
91        opacity: f32,
92    ) -> Result<()> {
93        let seg = &segmentation.segmentation;
94        let [seg_height, seg_width, classes] = *seg.shape() else {
95            unreachable!("Array3 did not have [usize;3] as shape");
96        };
97        debug_assert_eq!(classes, 1);
98
99        let start_y = (dst_h as f32 * segmentation.ymin).round();
100        let end_y = (dst_h as f32 * segmentation.ymax).round();
101        let start_x = (dst_w as f32 * segmentation.xmin).round();
102        let end_x = (dst_w as f32 * segmentation.xmax).round();
103
104        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
105        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
106
107        let start_x_u = (start_x as usize).min(dst_w);
108        let start_y_u = (start_y as usize).min(dst_h);
109        let end_x_u = (end_x as usize).min(dst_w);
110        let end_y_u = (end_y as usize).min(dst_h);
111
112        for y in start_y_u..end_y_u {
113            for x in start_x_u..end_x_u {
114                let seg_x = ((x as f32 - start_x) * scale_x) as usize;
115                let seg_y = ((y as f32 - start_y) * scale_y) as usize;
116                let val = *seg.get([seg_y, seg_x, 0]).unwrap_or(&0);
117
118                if val < 127 {
119                    continue;
120                }
121
122                let color = self.colors[class % self.colors.len()];
123
124                let alpha = if opacity == 1.0 {
125                    color[3] as u16
126                } else {
127                    (color[3] as f32 * opacity).round() as u16
128                };
129
130                let dst_index = (y * dst_rs) + (x * dst_c);
131                for c in 0..3 {
132                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
133                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
134                        / 255) as u8;
135                }
136            }
137        }
138
139        Ok(())
140    }
141
142    #[allow(clippy::too_many_arguments)]
143    pub(super) fn render_box(
144        &mut self,
145        dst_w: usize,
146        dst_h: usize,
147        dst_rs: usize,
148        dst_c: usize,
149        dst_slice: &mut [u8],
150        detect: &[DetectBox],
151        color_mode: crate::ColorMode,
152    ) -> Result<()> {
153        const LINE_THICKNESS: usize = 3;
154
155        for (idx, d) in detect.iter().enumerate() {
156            use edgefirst_decoder::BoundingBox;
157
158            let color_index = color_mode.index(idx, d.label);
159            let [r, g, b, _] = self.colors[color_index % self.colors.len()];
160            let bbox = d.bbox.to_canonical();
161            let bbox = BoundingBox {
162                xmin: bbox.xmin.clamp(0.0, 1.0),
163                ymin: bbox.ymin.clamp(0.0, 1.0),
164                xmax: bbox.xmax.clamp(0.0, 1.0),
165                ymax: bbox.ymax.clamp(0.0, 1.0),
166            };
167            let inner = [
168                ((dst_w - 1) as f32 * bbox.xmin - 0.5).round() as usize,
169                ((dst_h - 1) as f32 * bbox.ymin - 0.5).round() as usize,
170                ((dst_w - 1) as f32 * bbox.xmax + 0.5).round() as usize,
171                ((dst_h - 1) as f32 * bbox.ymax + 0.5).round() as usize,
172            ];
173
174            let outer = [
175                inner[0].saturating_sub(LINE_THICKNESS),
176                inner[1].saturating_sub(LINE_THICKNESS),
177                (inner[2] + LINE_THICKNESS).min(dst_w),
178                (inner[3] + LINE_THICKNESS).min(dst_h),
179            ];
180
181            // top line
182            for y in outer[1] + 1..=inner[1] {
183                for x in outer[0] + 1..outer[2] {
184                    let index = (y * dst_rs) + (x * dst_c);
185                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
186                }
187            }
188
189            // left and right lines
190            for y in inner[1]..inner[3] {
191                for x in outer[0] + 1..=inner[0] {
192                    let index = (y * dst_rs) + (x * dst_c);
193                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
194                }
195
196                for x in inner[2]..outer[2] {
197                    let index = (y * dst_rs) + (x * dst_c);
198                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
199                }
200            }
201
202            // bottom line
203            for y in inner[3]..outer[3] {
204                for x in outer[0] + 1..outer[2] {
205                    let index = (y * dst_rs) + (x * dst_c);
206                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
207                }
208            }
209        }
210        Ok(())
211    }
212
213    /// Materialize segmentation masks from proto data into `Vec<Segmentation>`.
214    ///
215    /// This is the CPU-side decode step of the hybrid mask rendering path:
216    /// call this to get pre-decoded masks, then pass them to
217    /// [`draw_decoded_masks`](crate::ImageProcessorTrait::draw_decoded_masks) for GPU overlay.
218    /// Benchmarks show this hybrid path (CPU decode + GL overlay) is faster
219    /// than the fused GPU `draw_proto_masks` on all tested platforms.
220    ///
221    /// Optimized: fused dequantization + dot product avoids a 3.1MB f32
222    /// allocation for the full proto tensor. Uses fast sigmoid approximation.
223    pub fn materialize_segmentations(
224        &self,
225        detect: &[crate::DetectBox],
226        proto_data: &crate::ProtoData,
227        letterbox: Option<[f32; 4]>,
228    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
229        use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
230
231        if detect.is_empty() {
232            return Ok(Vec::new());
233        }
234        let proto_shape = proto_data.protos.shape();
235        if proto_shape.len() != 3 {
236            return Err(crate::Error::InvalidShape(format!(
237                "protos tensor must be rank-3, got {proto_shape:?}"
238            )));
239        }
240        let (proto_h, proto_w, num_protos) = (proto_shape[0], proto_shape[1], proto_shape[2]);
241        let coeff_shape = proto_data.mask_coefficients.shape();
242        if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
243            return Err(crate::Error::InvalidShape(format!(
244                "mask_coefficients shape {coeff_shape:?} incompatible with protos \
245                 {proto_shape:?} (expected [N, {num_protos}])"
246            )));
247        }
248        if coeff_shape[0] == 0 {
249            return Ok(Vec::new());
250        }
251        if coeff_shape[0] != detect.len() {
252            return Err(crate::Error::Internal(format!(
253                "mask_coefficients rows {} != detection count {}",
254                coeff_shape[0],
255                detect.len()
256            )));
257        }
258
259        // Precompute inverse letterbox scale for output-coord conversion.
260        let (lx0, inv_lw, ly0, inv_lh) = match letterbox {
261            Some([lx0, ly0, lx1, ly1]) => {
262                let lw = lx1 - lx0;
263                let lh = ly1 - ly0;
264                (
265                    lx0,
266                    if lw > 0.0 { 1.0 / lw } else { 1.0 },
267                    ly0,
268                    if lh > 0.0 { 1.0 / lh } else { 1.0 },
269                )
270            }
271            None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
272        };
273
274        // Coefficients may be F32 (from quantized or f32 models) or F16
275        // (from fp16 models). For the mask kernel we always need an f32
276        // view (the multiply-accumulate is done in f32 for precision). Map
277        // once and widen once if f16, outside the per-detection loop.
278        let coeff_f32_storage: Vec<f32>;
279        let coeff_f32_slice: &[f32] = match proto_data.mask_coefficients.dtype() {
280            DType::F32 => {
281                let t = proto_data
282                    .mask_coefficients
283                    .as_f32()
284                    .expect("dtype matched F32");
285                let m = t.map()?;
286                coeff_f32_storage = m.as_slice().to_vec();
287                &coeff_f32_storage[..]
288            }
289            DType::F16 => {
290                let t = proto_data
291                    .mask_coefficients
292                    .as_f16()
293                    .expect("dtype matched F16");
294                let m = t.map()?;
295                coeff_f32_storage = m.as_slice().iter().map(|v| v.to_f32()).collect();
296                &coeff_f32_storage[..]
297            }
298            other => {
299                return Err(crate::Error::InvalidShape(format!(
300                    "mask_coefficients dtype {other:?} not supported; expected F32 or F16"
301                )));
302            }
303        };
304
305        // Hoist the proto tensor map() out of the per-detection loop so the
306        // map-guard is acquired once. Then dispatch per-dtype via a helper
307        // that runs the per-detection kernels in parallel across detections
308        // via rayon. This restores the parallelism that PR #54 added and
309        // PR #51 (EDGEAI-1244 f16 refactor) inadvertently removed.
310        match proto_data.protos.dtype() {
311            DType::I8 => {
312                let t = proto_data.protos.as_i8().expect("dtype matched I8");
313                let quant = t.quantization().ok_or_else(|| {
314                    crate::Error::InvalidShape("I8 protos require quantization metadata".into())
315                })?;
316                let m = t.map()?;
317                let protos_slice = m.as_slice();
318                detect
319                    .par_iter()
320                    .enumerate()
321                    .map(|(i, det)| {
322                        let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
323                        let (x0, y0, x1, y1, roi_w, roi_h) =
324                            bbox_to_proto_roi(det, proto_w, proto_h);
325                        let mask = fused_dequant_dot_sigmoid_i8_slice(
326                            protos_slice,
327                            coeff,
328                            quant,
329                            proto_h,
330                            proto_w,
331                            y0,
332                            x0,
333                            roi_h,
334                            roi_w,
335                            num_protos,
336                        )?;
337                        Ok(seg_from_roi(
338                            mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
339                        ))
340                    })
341                    .collect()
342            }
343            DType::F32 => {
344                let t = proto_data.protos.as_f32().expect("dtype matched F32");
345                let m = t.map()?;
346                let protos_slice = m.as_slice();
347                detect
348                    .par_iter()
349                    .enumerate()
350                    .map(|(i, det)| {
351                        let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
352                        let (x0, y0, x1, y1, roi_w, roi_h) =
353                            bbox_to_proto_roi(det, proto_w, proto_h);
354                        let mask = fused_dot_sigmoid_f32_slice(
355                            protos_slice,
356                            coeff,
357                            proto_h,
358                            proto_w,
359                            y0,
360                            x0,
361                            roi_h,
362                            roi_w,
363                            num_protos,
364                        );
365                        Ok(seg_from_roi(
366                            mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
367                        ))
368                    })
369                    .collect()
370            }
371            DType::F16 => {
372                let t = proto_data.protos.as_f16().expect("dtype matched F16");
373                let m = t.map()?;
374                let protos_slice = m.as_slice();
375                detect
376                    .par_iter()
377                    .enumerate()
378                    .map(|(i, det)| {
379                        let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
380                        let (x0, y0, x1, y1, roi_w, roi_h) =
381                            bbox_to_proto_roi(det, proto_w, proto_h);
382                        let mask = fused_dot_sigmoid_f16_slice(
383                            protos_slice,
384                            coeff,
385                            proto_h,
386                            proto_w,
387                            y0,
388                            x0,
389                            roi_h,
390                            roi_w,
391                            num_protos,
392                        );
393                        Ok(seg_from_roi(
394                            mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
395                        ))
396                    })
397                    .collect()
398            }
399            other => Err(crate::Error::InvalidShape(format!(
400                "proto tensor dtype {other:?} not supported"
401            ))),
402        }
403    }
404
405    /// Produce per-detection masks at `(width, height)` pixel resolution by
406    /// upsampling the full proto plane once then cropping per bbox. Each
407    /// `det.bbox` is assumed to be in model-input normalized coordinates
408    /// (the convention used by the decoder output); when `letterbox` is
409    /// `Some`, `(width, height)` are original-content pixel dims and the
410    /// inverse letterbox transform is applied to both the bbox (for the
411    /// crop region and returned `Segmentation` metadata) and each output
412    /// pixel (for proto-plane sampling). Mask values are binary
413    /// `uint8 {0, 255}` after thresholding sigmoid > 0.5.
414    ///
415    /// Used by [`ImageProcessor::materialize_masks`] when the caller selects
416    /// [`MaskResolution::Scaled`](crate::MaskResolution::Scaled).
417    pub fn materialize_scaled_segmentations(
418        &self,
419        detect: &[crate::DetectBox],
420        proto_data: &crate::ProtoData,
421        letterbox: Option<[f32; 4]>,
422        width: u32,
423        height: u32,
424    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
425        use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
426
427        if detect.is_empty() {
428            return Ok(Vec::new());
429        }
430        if width == 0 || height == 0 {
431            return Err(crate::Error::InvalidShape(
432                "Scaled mask width/height must be positive".into(),
433            ));
434        }
435        let proto_shape = proto_data.protos.shape();
436        if proto_shape.len() != 3 {
437            return Err(crate::Error::InvalidShape(format!(
438                "protos tensor must be rank-3, got {proto_shape:?}"
439            )));
440        }
441        let (proto_h, proto_w, num_protos) = (proto_shape[0], proto_shape[1], proto_shape[2]);
442        let coeff_shape = proto_data.mask_coefficients.shape();
443        if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
444            return Err(crate::Error::InvalidShape(format!(
445                "mask_coefficients shape {coeff_shape:?} incompatible with protos \
446                 {proto_shape:?}"
447            )));
448        }
449        if coeff_shape[0] == 0 {
450            return Ok(Vec::new());
451        }
452        if coeff_shape[0] != detect.len() {
453            return Err(crate::Error::Internal(format!(
454                "mask_coefficients rows {} != detection count {}",
455                coeff_shape[0],
456                detect.len()
457            )));
458        }
459
460        // Widen coefficients to f32 once for the scaled-sample inner loop.
461        let coeff_f32: Vec<f32> = match proto_data.mask_coefficients.dtype() {
462            DType::F32 => {
463                let t = proto_data.mask_coefficients.as_f32().expect("F32");
464                let m = t.map()?;
465                m.as_slice().to_vec()
466            }
467            DType::F16 => {
468                let t = proto_data.mask_coefficients.as_f16().expect("F16");
469                let m = t.map()?;
470                m.as_slice().iter().map(|v| v.to_f32()).collect()
471            }
472            other => {
473                return Err(crate::Error::InvalidShape(format!(
474                    "mask_coefficients dtype {other:?} not supported"
475                )));
476            }
477        };
478
479        match proto_data.protos.dtype() {
480            DType::F32 => {
481                let t = proto_data.protos.as_f32().expect("F32");
482                let m = t.map()?;
483                scaled_segmentations_f32_slice(
484                    detect,
485                    &coeff_f32,
486                    m.as_slice(),
487                    proto_h,
488                    proto_w,
489                    num_protos,
490                    letterbox,
491                    width,
492                    height,
493                )
494            }
495            DType::F16 => {
496                let t = proto_data.protos.as_f16().expect("F16");
497                let m = t.map()?;
498                scaled_segmentations_f16_slice(
499                    detect,
500                    &coeff_f32,
501                    m.as_slice(),
502                    proto_h,
503                    proto_w,
504                    num_protos,
505                    letterbox,
506                    width,
507                    height,
508                )
509            }
510            DType::I8 => {
511                let t = proto_data.protos.as_i8().expect("I8");
512                let m = t.map()?;
513                let quant = t.quantization().ok_or_else(|| {
514                    crate::Error::InvalidShape("I8 protos require quantization metadata".into())
515                })?;
516                scaled_segmentations_i8_slice(
517                    detect,
518                    &coeff_f32,
519                    m.as_slice(),
520                    proto_h,
521                    proto_w,
522                    num_protos,
523                    quant,
524                    letterbox,
525                    width,
526                    height,
527                )
528            }
529            other => Err(crate::Error::InvalidShape(format!(
530                "proto tensor dtype {other:?} not supported"
531            ))),
532        }
533    }
534}
535
536// =============================================================================
537// Slice-native fused kernels.
538//
539// All kernels take row-major `[H, W, num_protos]` proto slices + `&[f32]`
540// coefficients (widened once from the source dtype at the materialize entry
541// point). Per-dtype variants exist for i8 (with on-the-fly dequant using a
542// tensor-level `Quantization`), f32, and f16; f16 widens to f32 per-element
543// at the FMA site via `half::f16::to_f32()`.
544//
545// On ARMv8.2-FP16 this compiles to `fcvt`; on Cortex-A53 and non-F16C x86 it
546// becomes a soft-float helper. Stage 8 adds explicit intrinsic kernels
547// gated by `#[cfg(target_feature = "fp16")]` / `+f16c`.
548// =============================================================================
549
550/// Map a detection bbox in normalised letterboxed coords to its ROI in
551/// the proto plane (floor xmin/ymin, ceil xmax/ymax, clamp to plane bounds).
552/// Returns `(x0, y0, x1, y1, roi_w, roi_h)` where roi_w/h are guaranteed ≥ 1.
553fn bbox_to_proto_roi(
554    det: &DetectBox,
555    proto_w: usize,
556    proto_h: usize,
557) -> (usize, usize, usize, usize, usize, usize) {
558    let bbox = det.bbox.to_canonical();
559    let xmin = bbox.xmin.clamp(0.0, 1.0);
560    let ymin = bbox.ymin.clamp(0.0, 1.0);
561    let xmax = bbox.xmax.clamp(0.0, 1.0);
562    let ymax = bbox.ymax.clamp(0.0, 1.0);
563    let x0 = ((xmin * proto_w as f32) as usize).min(proto_w.saturating_sub(1));
564    let y0 = ((ymin * proto_h as f32) as usize).min(proto_h.saturating_sub(1));
565    let x1 = ((xmax * proto_w as f32).ceil() as usize).min(proto_w);
566    let y1 = ((ymax * proto_h as f32).ceil() as usize).min(proto_h);
567    let roi_w = x1.saturating_sub(x0).max(1);
568    let roi_h = y1.saturating_sub(y0).max(1);
569    (x0, y0, x1, y1, roi_w, roi_h)
570}
571
572/// Build a `Segmentation` from a per-detection mask + the ROI bounds in
573/// proto coords. Applies the inverse letterbox transform to express the
574/// segmentation bbox in original-image-content normalised space.
575#[allow(clippy::too_many_arguments)]
576fn seg_from_roi(
577    mask: ndarray::Array3<u8>,
578    x0: usize,
579    y0: usize,
580    x1: usize,
581    y1: usize,
582    proto_w: usize,
583    proto_h: usize,
584    lx0: f32,
585    inv_lw: f32,
586    ly0: f32,
587    inv_lh: f32,
588) -> edgefirst_decoder::Segmentation {
589    let seg_xmin = ((x0 as f32 / proto_w as f32) - lx0) * inv_lw;
590    let seg_ymin = ((y0 as f32 / proto_h as f32) - ly0) * inv_lh;
591    let seg_xmax = ((x1 as f32 / proto_w as f32) - lx0) * inv_lw;
592    let seg_ymax = ((y1 as f32 / proto_h as f32) - ly0) * inv_lh;
593    edgefirst_decoder::Segmentation {
594        xmin: seg_xmin.clamp(0.0, 1.0),
595        ymin: seg_ymin.clamp(0.0, 1.0),
596        xmax: seg_xmax.clamp(0.0, 1.0),
597        ymax: seg_ymax.clamp(0.0, 1.0),
598        segmentation: mask,
599    }
600}
601
602#[allow(clippy::too_many_arguments)]
603fn fused_dequant_dot_sigmoid_i8_slice(
604    protos: &[i8],
605    coeff: &[f32],
606    quant: &edgefirst_tensor::Quantization,
607    _proto_h: usize,
608    proto_w: usize,
609    y0: usize,
610    x0: usize,
611    roi_h: usize,
612    roi_w: usize,
613    num_protos: usize,
614) -> crate::Result<ndarray::Array3<u8>> {
615    use edgefirst_tensor::QuantMode;
616    let stride_y = proto_w * num_protos;
617    // Precompute scaled coefficients + zp_offset. Stack scratch covers
618    // `num_protos ≤ 64` (every production model today); larger proto counts
619    // fall back to a single heap allocation per kernel call so the kernel
620    // does not silently reject valid-but-larger models.
621    let mut stack_scratch = [0.0_f32; 64];
622    let mut heap_scratch: Vec<f32>;
623    let scaled_coeff: &mut [f32] = if num_protos <= stack_scratch.len() {
624        &mut stack_scratch[..num_protos]
625    } else {
626        heap_scratch = vec![0.0_f32; num_protos];
627        heap_scratch.as_mut_slice()
628    };
629    let zp_offset: f32;
630    match quant.mode() {
631        QuantMode::PerTensorSymmetric { scale } => {
632            for k in 0..num_protos {
633                scaled_coeff[k] = coeff[k] * scale;
634            }
635            zp_offset = 0.0;
636        }
637        QuantMode::PerTensor { scale, zero_point } => {
638            for k in 0..num_protos {
639                scaled_coeff[k] = coeff[k] * scale;
640            }
641            zp_offset = zero_point as f32 * scaled_coeff.iter().take(num_protos).sum::<f32>();
642        }
643        QuantMode::PerChannelSymmetric { scales, axis } => {
644            if axis != 2 {
645                return Err(crate::Error::NotSupported(format!(
646                    "per-channel quantization on axis {axis} not supported \
647                     (only channel axis 2 is implemented on this kernel)"
648                )));
649            }
650            for k in 0..num_protos {
651                scaled_coeff[k] = coeff[k] * scales[k];
652            }
653            zp_offset = 0.0;
654        }
655        QuantMode::PerChannel {
656            scales,
657            zero_points,
658            axis,
659        } => {
660            if axis != 2 {
661                return Err(crate::Error::NotSupported(format!(
662                    "per-channel quantization on axis {axis} not supported \
663                     (only channel axis 2 is implemented on this kernel)"
664                )));
665            }
666            for k in 0..num_protos {
667                scaled_coeff[k] = coeff[k] * scales[k];
668            }
669            zp_offset = (0..num_protos)
670                .map(|k| scaled_coeff[k] * zero_points[k] as f32)
671                .sum();
672        }
673    }
674
675    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
676    for y in 0..roi_h {
677        for x in 0..roi_w {
678            let base = (y0 + y) * stride_y + (x0 + x) * num_protos;
679            let mut acc = 0.0_f32;
680            let mut k = 0;
681            let chunks = num_protos / 4;
682            for _ in 0..chunks {
683                let p0 = protos[base + k] as f32;
684                let p1 = protos[base + k + 1] as f32;
685                let p2 = protos[base + k + 2] as f32;
686                let p3 = protos[base + k + 3] as f32;
687                acc += scaled_coeff[k] * p0
688                    + scaled_coeff[k + 1] * p1
689                    + scaled_coeff[k + 2] * p2
690                    + scaled_coeff[k + 3] * p3;
691                k += 4;
692            }
693            while k < num_protos {
694                acc += scaled_coeff[k] * protos[base + k] as f32;
695                k += 1;
696            }
697            acc -= zp_offset;
698            let sigmoid = fast_sigmoid(acc);
699            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
700        }
701    }
702    Ok(mask)
703}
704
705#[allow(clippy::too_many_arguments)]
706fn fused_dot_sigmoid_f32_slice(
707    protos: &[f32],
708    coeff: &[f32],
709    _proto_h: usize,
710    proto_w: usize,
711    y0: usize,
712    x0: usize,
713    roi_h: usize,
714    roi_w: usize,
715    num_protos: usize,
716) -> ndarray::Array3<u8> {
717    let stride_y = proto_w * num_protos;
718    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
719    for y in 0..roi_h {
720        for x in 0..roi_w {
721            let base = (y0 + y) * stride_y + (x0 + x) * num_protos;
722            let mut acc = 0.0_f32;
723            let mut k = 0;
724            let chunks = num_protos / 4;
725            for _ in 0..chunks {
726                acc += coeff[k] * protos[base + k]
727                    + coeff[k + 1] * protos[base + k + 1]
728                    + coeff[k + 2] * protos[base + k + 2]
729                    + coeff[k + 3] * protos[base + k + 3];
730                k += 4;
731            }
732            while k < num_protos {
733                acc += coeff[k] * protos[base + k];
734                k += 1;
735            }
736            let sigmoid = fast_sigmoid(acc);
737            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
738        }
739    }
740    mask
741}
742
743/// Native-f16 fused kernel.
744///
745/// Three code paths, selected at compile time:
746///
747/// 1. **x86_64 + F16C + FMA** — explicit intrinsic kernel (`_mm256_cvtph_ps`
748///    8-lane f16→f32 widening, `_mm256_fmadd_ps` FMA). Guaranteed to use
749///    hardware f16 conversion, not LLVM's autovectorizer (which is unreliable
750///    for this pattern per rust-lang/stdarch #1349).
751///
752/// 2. **aarch64 + FP16** — scalar `half::f16::to_f32()` at the FMA site.
753///    LLVM lowers each `.to_f32()` to a single `fcvt` instruction when
754///    `target-feature=+fp16` is active (e.g. `target-cpu=cortex-a78ae`).
755///    The stable f16-typed NEON intrinsics (`vcvt_f32_f16`, `vld1q_f16`)
756///    require nightly as of this commit; the scalar path is equally
757///    efficient at this granularity.
758///
759/// 3. **Fallback (Cortex-A53, targets without FP16)** — same scalar code.
760///    `half::f16::to_f32()` lowers to `__extendhfsf2` soft-float helper,
761///    one call per proto load. Correctness-preserving; ~15 cycles/load
762///    vs. ~3 cycles for the hardware path. Documented in
763///    `docs/orin-build.md`.
764#[allow(clippy::too_many_arguments)]
765fn fused_dot_sigmoid_f16_slice(
766    protos: &[half::f16],
767    coeff: &[f32],
768    proto_h: usize,
769    proto_w: usize,
770    y0: usize,
771    x0: usize,
772    roi_h: usize,
773    roi_w: usize,
774    num_protos: usize,
775) -> ndarray::Array3<u8> {
776    #[cfg(all(
777        target_arch = "x86_64",
778        target_feature = "f16c",
779        target_feature = "fma"
780    ))]
781    {
782        // SAFETY: target-feature gates both `vcvtph2ps` and `vfmadd*ps`;
783        // the caller's slice-bounds contract is identical to the scalar arm.
784        unsafe {
785            fused_dot_sigmoid_f16_slice_f16c(
786                protos, coeff, proto_h, proto_w, y0, x0, roi_h, roi_w, num_protos,
787            )
788        }
789    }
790    #[cfg(not(all(
791        target_arch = "x86_64",
792        target_feature = "f16c",
793        target_feature = "fma"
794    )))]
795    {
796        let _ = proto_h;
797        fused_dot_sigmoid_f16_slice_scalar(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
798    }
799}
800
801/// Scalar native-f16 kernel. `half::f16::to_f32()` at the FMA site is
802/// lowered to a single `fcvt` (aarch64+fp16) or a single `vcvtps_ps`
803/// (x86_64+f16c) by LLVM, or to the soft-float helper `__extendhfsf2` on
804/// targets without FP16 hardware. Loop unrolled by 4 to give the scheduler
805/// room to overlap loads with FMAs.
806#[allow(clippy::too_many_arguments, dead_code)]
807fn fused_dot_sigmoid_f16_slice_scalar(
808    protos: &[half::f16],
809    coeff: &[f32],
810    proto_w: usize,
811    y0: usize,
812    x0: usize,
813    roi_h: usize,
814    roi_w: usize,
815    num_protos: usize,
816) -> ndarray::Array3<u8> {
817    let stride_y = proto_w * num_protos;
818    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
819    for y in 0..roi_h {
820        for x in 0..roi_w {
821            let base = (y0 + y) * stride_y + (x0 + x) * num_protos;
822            let mut acc = 0.0_f32;
823            let mut k = 0;
824            let chunks = num_protos / 4;
825            for _ in 0..chunks {
826                let p0 = protos[base + k].to_f32();
827                let p1 = protos[base + k + 1].to_f32();
828                let p2 = protos[base + k + 2].to_f32();
829                let p3 = protos[base + k + 3].to_f32();
830                acc += coeff[k] * p0 + coeff[k + 1] * p1 + coeff[k + 2] * p2 + coeff[k + 3] * p3;
831                k += 4;
832            }
833            while k < num_protos {
834                acc += coeff[k] * protos[base + k].to_f32();
835                k += 1;
836            }
837            let sigmoid = fast_sigmoid(acc);
838            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
839        }
840    }
841    mask
842}
843
844/// x86_64 F16C + FMA explicit intrinsic kernel. Processes 8 f16 lanes per
845/// inner iteration via `_mm256_cvtph_ps` (8-lane f16→f32 widen) followed by
846/// `_mm256_fmadd_ps` (8-lane fused multiply-add). Horizontal reduce at the
847/// end of each pixel.
848///
849/// # Safety
850///
851/// Caller must ensure the target CPU supports F16C + FMA. The workspace's
852/// `.cargo/config.toml` sets these target-features on `x86_64-unknown-linux-gnu`
853/// and `x86_64-apple-darwin`, making this function statically callable on
854/// those targets.
855#[cfg(all(
856    target_arch = "x86_64",
857    target_feature = "f16c",
858    target_feature = "fma"
859))]
860#[allow(clippy::too_many_arguments)]
861#[target_feature(enable = "f16c,fma,avx")]
862unsafe fn fused_dot_sigmoid_f16_slice_f16c(
863    protos: &[half::f16],
864    coeff: &[f32],
865    _proto_h: usize,
866    proto_w: usize,
867    y0: usize,
868    x0: usize,
869    roi_h: usize,
870    roi_w: usize,
871    num_protos: usize,
872) -> ndarray::Array3<u8> {
873    use core::arch::x86_64::{
874        _mm256_castps256_ps128, _mm256_cvtph_ps, _mm256_extractf128_ps, _mm256_fmadd_ps,
875        _mm256_loadu_ps, _mm256_setzero_ps, _mm_add_ps, _mm_cvtss_f32, _mm_hadd_ps,
876        _mm_loadu_si128,
877    };
878
879    let stride_y = proto_w * num_protos;
880    let chunks8 = num_protos / 8;
881    let tail = num_protos % 8;
882    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
883
884    for y in 0..roi_h {
885        for x in 0..roi_w {
886            let base = (y0 + y) * stride_y + (x0 + x) * num_protos;
887            let mut acc_v = _mm256_setzero_ps();
888            let mut k = 0;
889            for _ in 0..chunks8 {
890                // Load 8 f16 (128 bits / 16 bytes) via a byte-level cast.
891                let p_ptr = protos
892                    .as_ptr()
893                    .add(base + k)
894                    .cast::<core::arch::x86_64::__m128i>();
895                let raw = _mm_loadu_si128(p_ptr);
896                let widened = _mm256_cvtph_ps(raw);
897                let coeffs_v = _mm256_loadu_ps(coeff.as_ptr().add(k));
898                acc_v = _mm256_fmadd_ps(coeffs_v, widened, acc_v);
899                k += 8;
900            }
901            // Horizontal reduce 8 → 1.
902            let lo = _mm256_castps256_ps128(acc_v);
903            let hi = _mm256_extractf128_ps::<1>(acc_v);
904            let sum4 = _mm_add_ps(lo, hi);
905            let sum2 = _mm_hadd_ps(sum4, sum4);
906            let sum1 = _mm_hadd_ps(sum2, sum2);
907            let mut acc = _mm_cvtss_f32(sum1);
908
909            // Scalar tail for num_protos % 8 (≤ 7 items).
910            while k < num_protos && k - chunks8 * 8 < tail {
911                acc += coeff[k] * protos[base + k].to_f32();
912                k += 1;
913            }
914
915            let sigmoid = fast_sigmoid(acc);
916            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
917        }
918    }
919    mask
920}
921
922#[allow(clippy::too_many_arguments)]
923fn scaled_segmentations_f32_slice(
924    detect: &[crate::DetectBox],
925    coeff_all: &[f32],
926    protos: &[f32],
927    proto_h: usize,
928    proto_w: usize,
929    num_protos: usize,
930    letterbox: Option<[f32; 4]>,
931    width: u32,
932    height: u32,
933) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
934    scaled_run(
935        detect,
936        coeff_all,
937        protos,
938        proto_h,
939        proto_w,
940        num_protos,
941        letterbox,
942        width,
943        height,
944        1.0,
945        |p, _| *p,
946    )
947}
948
949#[allow(clippy::too_many_arguments)]
950fn scaled_segmentations_f16_slice(
951    detect: &[crate::DetectBox],
952    coeff_all: &[f32],
953    protos: &[half::f16],
954    proto_h: usize,
955    proto_w: usize,
956    num_protos: usize,
957    letterbox: Option<[f32; 4]>,
958    width: u32,
959    height: u32,
960) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
961    scaled_run(
962        detect,
963        coeff_all,
964        protos,
965        proto_h,
966        proto_w,
967        num_protos,
968        letterbox,
969        width,
970        height,
971        1.0,
972        |p: &half::f16, _| p.to_f32(),
973    )
974}
975
976#[allow(clippy::too_many_arguments)]
977fn scaled_segmentations_i8_slice(
978    detect: &[crate::DetectBox],
979    coeff_all: &[f32],
980    protos: &[i8],
981    proto_h: usize,
982    proto_w: usize,
983    num_protos: usize,
984    quant: &edgefirst_tensor::Quantization,
985    letterbox: Option<[f32; 4]>,
986    width: u32,
987    height: u32,
988) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
989    use edgefirst_tensor::QuantMode;
990    // Only per-tensor quantization supported on the scaled-path CPU kernel
991    // today. Per-channel fits naturally into a future extension (would need
992    // per-channel scaled coefficients in scaled_run's dot-product precompute).
993    let (scale, zp) = match quant.mode() {
994        QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
995        QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
996        QuantMode::PerChannel { axis, .. } | QuantMode::PerChannelSymmetric { axis, .. } => {
997            return Err(crate::Error::NotSupported(format!(
998                "per-channel quantization (axis={axis}) on scaled seg path \
999                 not yet supported"
1000            )));
1001        }
1002    };
1003    scaled_run(
1004        detect,
1005        coeff_all,
1006        protos,
1007        proto_h,
1008        proto_w,
1009        num_protos,
1010        letterbox,
1011        width,
1012        height,
1013        scale,
1014        move |p: &i8, _| *p as f32 - zp,
1015    )
1016}
1017
1018#[allow(clippy::too_many_arguments)]
1019fn scaled_run<P: Copy + Sync>(
1020    detect: &[crate::DetectBox],
1021    coeff_all: &[f32],
1022    protos: &[P],
1023    proto_h: usize,
1024    proto_w: usize,
1025    num_protos: usize,
1026    letterbox: Option<[f32; 4]>,
1027    width: u32,
1028    height: u32,
1029    acc_scale: f32,
1030    load_f32: impl Fn(&P, f32) -> f32 + Copy + Sync,
1031) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1032    let (lx0, lw, ly0, lh) = match letterbox {
1033        Some([lx0, ly0, lx1, ly1]) => {
1034            let lw = (lx1 - lx0).max(f32::EPSILON);
1035            let lh = (ly1 - ly0).max(f32::EPSILON);
1036            (lx0, lw, ly0, lh)
1037        }
1038        None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
1039    };
1040    let out_w = width as usize;
1041    let out_h = height as usize;
1042    let stride_y = proto_w * num_protos;
1043
1044    // Parallelise across detections. Each detection produces an
1045    // independent ndarray::Array3<u8> tile from a read-only proto slice +
1046    // its own coeff slice; no shared mutable state.
1047    //
1048    // Algorithm (restores the spirit of PR #54's batched-GEMM optimisation
1049    // that PR #51's f16 dispatch refactor inadvertently removed):
1050    //
1051    //   1. Map the output bbox back to a proto-plane ROI (with 1-px margin
1052    //      so the bilinear sampling at the output edges has neighbours).
1053    //   2. Precompute *f32 logits* at every proto pixel inside that ROI by
1054    //      doing a single K-wide dot product per proto pixel — once, not
1055    //      once per output pixel.
1056    //   3. For each output pixel, bilinear-interpolate the scalar f32 logit
1057    //      from the 4 surrounding proto-roi pixels, apply sigmoid, and
1058    //      threshold to {0, 255}.
1059    //
1060    // For typical YOLO-seg: proto_roi ~ 30×30 = 900 px × K=32 = 28.8K dot
1061    // ops vs the legacy "bilinear sample then dot at every output pixel"
1062    // which costs bbox_h × bbox_w × 4 × K = ~1.3M ops at 100×100 output
1063    // bbox. ~45× fewer FMAs at this size; the bilinear upsample of a
1064    // scalar plane (no inner K loop) is comparatively negligible.
1065    detect
1066        .par_iter()
1067        .enumerate()
1068        .map(|(i, det)| {
1069            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1070            let bbox = det.bbox.to_canonical();
1071            let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
1072            let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
1073            let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
1074            let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
1075            let px0 = (xmin * out_w as f32).round() as usize;
1076            let py0 = (ymin * out_h as f32).round() as usize;
1077            let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
1078            let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
1079            let bbox_w = px1.saturating_sub(px0).max(1);
1080            let bbox_h = py1.saturating_sub(py0).max(1);
1081
1082            // Step 1 — proto-plane ROI for this detection's output bbox.
1083            // Map the four output bbox corners back to proto coords and
1084            // expand by 1 pixel in each direction so the bilinear sampler
1085            // at the bbox boundary has both neighbours.
1086            let sample_x_at = |px: f32| -> f32 {
1087                let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
1088                model_x_norm * proto_w as f32 - 0.5
1089            };
1090            let sample_y_at = |py: f32| -> f32 {
1091                let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
1092                model_y_norm * proto_h as f32 - 0.5
1093            };
1094            let s_x_min = sample_x_at(px0 as f32);
1095            let s_x_max = sample_x_at((px1 as f32) - 1.0);
1096            let s_y_min = sample_y_at(py0 as f32);
1097            let s_y_max = sample_y_at((py1 as f32) - 1.0);
1098            // Floor min, ceil max+1 to include both bilinear neighbours.
1099            // Start indices are used as direct bases into `protos`, so clamp
1100            // them to the last valid index, not to the exclusive upper bound.
1101            let proto_x0 = (s_x_min.floor() as isize)
1102                .max(0)
1103                .min(proto_w.saturating_sub(1) as isize) as usize;
1104            let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
1105            let proto_y0 = (s_y_min.floor() as isize)
1106                .max(0)
1107                .min(proto_h.saturating_sub(1) as isize) as usize;
1108            let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
1109            let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
1110            let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
1111
1112            // Step 2 — precompute f32 logits at every proto-roi pixel.
1113            // logits[(py - proto_y0) * roi_w + (px - proto_x0)] = dot(coeff, proto[py, px, :])
1114            let mut logits = vec![0.0_f32; roi_h * roi_w];
1115            for ly_idx in 0..roi_h {
1116                let py = proto_y0 + ly_idx;
1117                let row_base = py * stride_y + proto_x0 * num_protos;
1118                for lx_idx in 0..roi_w {
1119                    let pix_base = row_base + lx_idx * num_protos;
1120                    let mut acc = 0.0_f32;
1121                    for k in 0..num_protos {
1122                        acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0);
1123                    }
1124                    logits[ly_idx * roi_w + lx_idx] = acc_scale * acc;
1125                }
1126            }
1127
1128            // Step 3 — bilinear upsample logits → output bbox, sigmoid + threshold.
1129            let mut tile = ndarray::Array3::<u8>::zeros((bbox_h, bbox_w, 1));
1130            for yi in 0..bbox_h {
1131                let py_o = (py0 + yi) as f32;
1132                let sample_y = sample_y_at(py_o) - proto_y0 as f32;
1133                let y_floor = sample_y.floor();
1134                let y_lo = (y_floor as isize)
1135                    .max(0)
1136                    .min(roi_h.saturating_sub(1) as isize) as usize;
1137                let y_hi = (y_lo + 1).min(roi_h - 1);
1138                let y_frac = (sample_y - y_floor).clamp(0.0, 1.0);
1139                let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
1140                let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
1141                for xi in 0..bbox_w {
1142                    let px_o = (px0 + xi) as f32;
1143                    let sample_x = sample_x_at(px_o) - proto_x0 as f32;
1144                    let x_floor = sample_x.floor();
1145                    let x_lo = (x_floor as isize)
1146                        .max(0)
1147                        .min(roi_w.saturating_sub(1) as isize)
1148                        as usize;
1149                    let x_hi = (x_lo + 1).min(roi_w - 1);
1150                    let x_frac = (sample_x - x_floor).clamp(0.0, 1.0);
1151                    // Bilinear interp on the scalar logit plane.
1152                    let l00 = row_lo[x_lo];
1153                    let l01 = row_lo[x_hi];
1154                    let l10 = row_hi[x_lo];
1155                    let l11 = row_hi[x_hi];
1156                    let l0 = l00 + (l01 - l00) * x_frac;
1157                    let l1 = l10 + (l11 - l10) * x_frac;
1158                    let logit = l0 + (l1 - l0) * y_frac;
1159                    let sigmoid = fast_sigmoid(logit);
1160                    tile[[yi, xi, 0]] = if sigmoid > 0.5 { 255 } else { 0 };
1161                }
1162            }
1163            Ok(edgefirst_decoder::Segmentation {
1164                xmin,
1165                ymin,
1166                xmax,
1167                ymax,
1168                segmentation: tile,
1169            })
1170        })
1171        .collect()
1172}
1173
1174fn fast_sigmoid(x: f32) -> f32 {
1175    if x >= 16.0 {
1176        return 1.0;
1177    }
1178    if x <= -16.0 {
1179        return 0.0;
1180    }
1181    // Fast exp(-x) via bit manipulation (Schraudolph's algorithm).
1182    // f32 bits: 2^23 * log2(e) * x + (127 << 23) approximates exp(x).
1183    const A: f32 = (1u32 << 23) as f32; // 8388608.0
1184    const B: f32 = A * std::f32::consts::LOG2_E; // A / ln(2)
1185    const C: u32 = 127 << 23; // exponent bias
1186    let neg_x = -x;
1187    let bits = (B * neg_x) as i32 + C as i32;
1188    let exp_neg_x = f32::from_bits(bits as u32);
1189    1.0 / (1.0 + exp_neg_x)
1190}