Skip to main content

edgefirst_image/cpu/
masks.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use super::CPUProcessor;
5use crate::Result;
6use edgefirst_decoder::{DetectBox, Segmentation};
7use ndarray::Axis;
8
9impl CPUProcessor {
10    #[allow(clippy::too_many_arguments)]
11    pub(super) fn render_modelpack_segmentation(
12        &mut self,
13        dst_w: usize,
14        dst_h: usize,
15        dst_rs: usize,
16        dst_c: usize,
17        dst_slice: &mut [u8],
18        segmentation: &Segmentation,
19        opacity: f32,
20    ) -> Result<()> {
21        use ndarray_stats::QuantileExt;
22
23        let seg = &segmentation.segmentation;
24        let [seg_height, seg_width, seg_classes] = *seg.shape() else {
25            unreachable!("Array3 did not have [usize; 3] as shape");
26        };
27        let start_y = (dst_h as f32 * segmentation.ymin).round();
28        let end_y = (dst_h as f32 * segmentation.ymax).round();
29        let start_x = (dst_w as f32 * segmentation.xmin).round();
30        let end_x = (dst_w as f32 * segmentation.xmax).round();
31
32        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
33        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
34
35        let start_x_u = (start_x as usize).min(dst_w);
36        let start_y_u = (start_y as usize).min(dst_h);
37        let end_x_u = (end_x as usize).min(dst_w);
38        let end_y_u = (end_y as usize).min(dst_h);
39
40        let argmax = seg.map_axis(Axis(2), |r| r.argmax().unwrap());
41        let get_value_at_nearest = |x: f32, y: f32| -> usize {
42            let x = x.round() as usize;
43            let y = y.round() as usize;
44            argmax
45                .get([y.min(seg_height - 1), x.min(seg_width - 1)])
46                .copied()
47                .unwrap_or(0)
48        };
49
50        for y in start_y_u..end_y_u {
51            for x in start_x_u..end_x_u {
52                let seg_x = (x as f32 - start_x) * scale_x;
53                let seg_y = (y as f32 - start_y) * scale_y;
54                let label = get_value_at_nearest(seg_x, seg_y);
55
56                if label == seg_classes - 1 {
57                    continue;
58                }
59
60                let color = self.colors[label % self.colors.len()];
61
62                let alpha = if opacity == 1.0 {
63                    color[3] as u16
64                } else {
65                    (color[3] as f32 * opacity).round() as u16
66                };
67
68                let dst_index = (y * dst_rs) + (x * dst_c);
69                for c in 0..3 {
70                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
71                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
72                        / 255) as u8;
73                }
74            }
75        }
76
77        Ok(())
78    }
79
80    #[allow(clippy::too_many_arguments)]
81    pub(super) fn render_yolo_segmentation(
82        &mut self,
83        dst_w: usize,
84        dst_h: usize,
85        dst_rs: usize,
86        dst_c: usize,
87        dst_slice: &mut [u8],
88        segmentation: &Segmentation,
89        class: usize,
90        opacity: f32,
91    ) -> Result<()> {
92        let seg = &segmentation.segmentation;
93        let [seg_height, seg_width, classes] = *seg.shape() else {
94            unreachable!("Array3 did not have [usize;3] as shape");
95        };
96        debug_assert_eq!(classes, 1);
97
98        let start_y = (dst_h as f32 * segmentation.ymin).round();
99        let end_y = (dst_h as f32 * segmentation.ymax).round();
100        let start_x = (dst_w as f32 * segmentation.xmin).round();
101        let end_x = (dst_w as f32 * segmentation.xmax).round();
102
103        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
104        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
105
106        let start_x_u = (start_x as usize).min(dst_w);
107        let start_y_u = (start_y as usize).min(dst_h);
108        let end_x_u = (end_x as usize).min(dst_w);
109        let end_y_u = (end_y as usize).min(dst_h);
110
111        for y in start_y_u..end_y_u {
112            for x in start_x_u..end_x_u {
113                let seg_x = ((x as f32 - start_x) * scale_x) as usize;
114                let seg_y = ((y as f32 - start_y) * scale_y) as usize;
115                let val = *seg.get([seg_y, seg_x, 0]).unwrap_or(&0);
116
117                if val < 127 {
118                    continue;
119                }
120
121                let color = self.colors[class % self.colors.len()];
122
123                let alpha = if opacity == 1.0 {
124                    color[3] as u16
125                } else {
126                    (color[3] as f32 * opacity).round() as u16
127                };
128
129                let dst_index = (y * dst_rs) + (x * dst_c);
130                for c in 0..3 {
131                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
132                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
133                        / 255) as u8;
134                }
135            }
136        }
137
138        Ok(())
139    }
140
141    #[allow(clippy::too_many_arguments)]
142    pub(super) fn render_box(
143        &mut self,
144        dst_w: usize,
145        dst_h: usize,
146        dst_rs: usize,
147        dst_c: usize,
148        dst_slice: &mut [u8],
149        detect: &[DetectBox],
150        color_mode: crate::ColorMode,
151    ) -> Result<()> {
152        const LINE_THICKNESS: usize = 3;
153
154        for (idx, d) in detect.iter().enumerate() {
155            use edgefirst_decoder::BoundingBox;
156
157            let color_index = color_mode.index(idx, d.label);
158            let [r, g, b, _] = self.colors[color_index % self.colors.len()];
159            let bbox = d.bbox.to_canonical();
160            let bbox = BoundingBox {
161                xmin: bbox.xmin.clamp(0.0, 1.0),
162                ymin: bbox.ymin.clamp(0.0, 1.0),
163                xmax: bbox.xmax.clamp(0.0, 1.0),
164                ymax: bbox.ymax.clamp(0.0, 1.0),
165            };
166            let inner = [
167                ((dst_w - 1) as f32 * bbox.xmin - 0.5).round() as usize,
168                ((dst_h - 1) as f32 * bbox.ymin - 0.5).round() as usize,
169                ((dst_w - 1) as f32 * bbox.xmax + 0.5).round() as usize,
170                ((dst_h - 1) as f32 * bbox.ymax + 0.5).round() as usize,
171            ];
172
173            let outer = [
174                inner[0].saturating_sub(LINE_THICKNESS),
175                inner[1].saturating_sub(LINE_THICKNESS),
176                (inner[2] + LINE_THICKNESS).min(dst_w),
177                (inner[3] + LINE_THICKNESS).min(dst_h),
178            ];
179
180            // top line
181            for y in outer[1] + 1..=inner[1] {
182                for x in outer[0] + 1..outer[2] {
183                    let index = (y * dst_rs) + (x * dst_c);
184                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
185                }
186            }
187
188            // left and right lines
189            for y in inner[1]..inner[3] {
190                for x in outer[0] + 1..=inner[0] {
191                    let index = (y * dst_rs) + (x * dst_c);
192                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
193                }
194
195                for x in inner[2]..outer[2] {
196                    let index = (y * dst_rs) + (x * dst_c);
197                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
198                }
199            }
200
201            // bottom line
202            for y in inner[3]..outer[3] {
203                for x in outer[0] + 1..outer[2] {
204                    let index = (y * dst_rs) + (x * dst_c);
205                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
206                }
207            }
208        }
209        Ok(())
210    }
211
212    /// Materialize segmentation masks from proto data into `Vec<Segmentation>`.
213    ///
214    /// This is the CPU-side decode step of the hybrid mask rendering path:
215    /// call this to get pre-decoded masks, then pass them to
216    /// [`draw_decoded_masks`](crate::ImageProcessorTrait::draw_decoded_masks) for GPU overlay.
217    /// Benchmarks show this hybrid path (CPU decode + GL overlay) is faster
218    /// than the fused GPU `draw_proto_masks` on all tested platforms.
219    ///
220    /// Optimized: fused dequantization + dot product avoids a 3.1MB f32
221    /// allocation for the full proto tensor. Uses fast sigmoid approximation.
222    pub fn materialize_segmentations(
223        &self,
224        detect: &[crate::DetectBox],
225        proto_data: &crate::ProtoData,
226        letterbox: Option<[f32; 4]>,
227    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
228        use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
229
230        if detect.is_empty() {
231            return Ok(Vec::new());
232        }
233        let proto_shape = proto_data.protos.shape();
234        if proto_shape.len() != 3 {
235            return Err(crate::Error::InvalidShape(format!(
236                "protos tensor must be rank-3, got {proto_shape:?}"
237            )));
238        }
239        let (proto_h, proto_w, num_protos) = (proto_shape[0], proto_shape[1], proto_shape[2]);
240        let coeff_shape = proto_data.mask_coefficients.shape();
241        if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
242            return Err(crate::Error::InvalidShape(format!(
243                "mask_coefficients shape {coeff_shape:?} incompatible with protos \
244                 {proto_shape:?} (expected [N, {num_protos}])"
245            )));
246        }
247        if coeff_shape[0] == 0 {
248            return Ok(Vec::new());
249        }
250        if coeff_shape[0] != detect.len() {
251            return Err(crate::Error::Internal(format!(
252                "mask_coefficients rows {} != detection count {}",
253                coeff_shape[0],
254                detect.len()
255            )));
256        }
257
258        // Precompute inverse letterbox scale for output-coord conversion.
259        let (lx0, inv_lw, ly0, inv_lh) = match letterbox {
260            Some([lx0, ly0, lx1, ly1]) => {
261                let lw = lx1 - lx0;
262                let lh = ly1 - ly0;
263                (
264                    lx0,
265                    if lw > 0.0 { 1.0 / lw } else { 1.0 },
266                    ly0,
267                    if lh > 0.0 { 1.0 / lh } else { 1.0 },
268                )
269            }
270            None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
271        };
272
273        // Coefficients may be F32 (from quantized or f32 models) or F16
274        // (from fp16 models). For the mask kernel we always need an f32
275        // view (the multiply-accumulate is done in f32 for precision). Map
276        // once and widen once if f16, outside the per-detection loop.
277        let coeff_f32_storage: Vec<f32>;
278        let coeff_f32_slice: &[f32] = match proto_data.mask_coefficients.dtype() {
279            DType::F32 => {
280                let t = proto_data
281                    .mask_coefficients
282                    .as_f32()
283                    .expect("dtype matched F32");
284                let m = t.map()?;
285                coeff_f32_storage = m.as_slice().to_vec();
286                &coeff_f32_storage[..]
287            }
288            DType::F16 => {
289                let t = proto_data
290                    .mask_coefficients
291                    .as_f16()
292                    .expect("dtype matched F16");
293                let m = t.map()?;
294                coeff_f32_storage = m.as_slice().iter().map(|v| v.to_f32()).collect();
295                &coeff_f32_storage[..]
296            }
297            other => {
298                return Err(crate::Error::InvalidShape(format!(
299                    "mask_coefficients dtype {other:?} not supported; expected F32 or F16"
300                )));
301            }
302        };
303
304        let per_detection = |i: usize,
305                             det: &crate::DetectBox|
306         -> crate::Result<edgefirst_decoder::Segmentation> {
307            let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
308
309            let bbox = det.bbox.to_canonical();
310            let xmin = bbox.xmin.clamp(0.0, 1.0);
311            let ymin = bbox.ymin.clamp(0.0, 1.0);
312            let xmax = bbox.xmax.clamp(0.0, 1.0);
313            let ymax = bbox.ymax.clamp(0.0, 1.0);
314            let x0 = ((xmin * proto_w as f32) as usize).min(proto_w.saturating_sub(1));
315            let y0 = ((ymin * proto_h as f32) as usize).min(proto_h.saturating_sub(1));
316            let x1 = ((xmax * proto_w as f32).ceil() as usize).min(proto_w);
317            let y1 = ((ymax * proto_h as f32).ceil() as usize).min(proto_h);
318            let roi_w = x1.saturating_sub(x0).max(1);
319            let roi_h = y1.saturating_sub(y0).max(1);
320
321            let mask = match proto_data.protos.dtype() {
322                DType::I8 => {
323                    let t = proto_data.protos.as_i8().expect("dtype matched I8");
324                    let quant = t.quantization().ok_or_else(|| {
325                        crate::Error::InvalidShape("I8 protos require quantization metadata".into())
326                    })?;
327                    let m = t.map()?;
328                    fused_dequant_dot_sigmoid_i8_slice(
329                        m.as_slice(),
330                        coeff,
331                        quant,
332                        proto_h,
333                        proto_w,
334                        y0,
335                        x0,
336                        roi_h,
337                        roi_w,
338                        num_protos,
339                    )?
340                }
341                DType::F32 => {
342                    let t = proto_data.protos.as_f32().expect("dtype matched F32");
343                    let m = t.map()?;
344                    fused_dot_sigmoid_f32_slice(
345                        m.as_slice(),
346                        coeff,
347                        proto_h,
348                        proto_w,
349                        y0,
350                        x0,
351                        roi_h,
352                        roi_w,
353                        num_protos,
354                    )
355                }
356                DType::F16 => {
357                    let t = proto_data.protos.as_f16().expect("dtype matched F16");
358                    let m = t.map()?;
359                    fused_dot_sigmoid_f16_slice(
360                        m.as_slice(),
361                        coeff,
362                        proto_h,
363                        proto_w,
364                        y0,
365                        x0,
366                        roi_h,
367                        roi_w,
368                        num_protos,
369                    )
370                }
371                other => {
372                    return Err(crate::Error::InvalidShape(format!(
373                        "proto tensor dtype {other:?} not supported"
374                    )));
375                }
376            };
377
378            let seg_xmin = ((x0 as f32 / proto_w as f32) - lx0) * inv_lw;
379            let seg_ymin = ((y0 as f32 / proto_h as f32) - ly0) * inv_lh;
380            let seg_xmax = ((x1 as f32 / proto_w as f32) - lx0) * inv_lw;
381            let seg_ymax = ((y1 as f32 / proto_h as f32) - ly0) * inv_lh;
382            Ok(edgefirst_decoder::Segmentation {
383                xmin: seg_xmin.clamp(0.0, 1.0),
384                ymin: seg_ymin.clamp(0.0, 1.0),
385                xmax: seg_xmax.clamp(0.0, 1.0),
386                ymax: seg_ymax.clamp(0.0, 1.0),
387                segmentation: mask,
388            })
389        };
390
391        detect
392            .iter()
393            .enumerate()
394            .map(|(i, det)| per_detection(i, det))
395            .collect()
396    }
397
398    /// Produce per-detection masks at `(width, height)` pixel resolution by
399    /// upsampling the full proto plane once then cropping per bbox. Each
400    /// `det.bbox` is assumed to be in model-input normalized coordinates
401    /// (the convention used by the decoder output); when `letterbox` is
402    /// `Some`, `(width, height)` are original-content pixel dims and the
403    /// inverse letterbox transform is applied to both the bbox (for the
404    /// crop region and returned `Segmentation` metadata) and each output
405    /// pixel (for proto-plane sampling). Mask values are binary
406    /// `uint8 {0, 255}` after thresholding sigmoid > 0.5.
407    ///
408    /// Used by [`ImageProcessor::materialize_masks`] when the caller selects
409    /// [`MaskResolution::Scaled`](crate::MaskResolution::Scaled).
410    pub fn materialize_scaled_segmentations(
411        &self,
412        detect: &[crate::DetectBox],
413        proto_data: &crate::ProtoData,
414        letterbox: Option<[f32; 4]>,
415        width: u32,
416        height: u32,
417    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
418        use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
419
420        if detect.is_empty() {
421            return Ok(Vec::new());
422        }
423        if width == 0 || height == 0 {
424            return Err(crate::Error::InvalidShape(
425                "Scaled mask width/height must be positive".into(),
426            ));
427        }
428        let proto_shape = proto_data.protos.shape();
429        if proto_shape.len() != 3 {
430            return Err(crate::Error::InvalidShape(format!(
431                "protos tensor must be rank-3, got {proto_shape:?}"
432            )));
433        }
434        let (proto_h, proto_w, num_protos) = (proto_shape[0], proto_shape[1], proto_shape[2]);
435        let coeff_shape = proto_data.mask_coefficients.shape();
436        if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
437            return Err(crate::Error::InvalidShape(format!(
438                "mask_coefficients shape {coeff_shape:?} incompatible with protos \
439                 {proto_shape:?}"
440            )));
441        }
442        if coeff_shape[0] == 0 {
443            return Ok(Vec::new());
444        }
445
446        // Widen coefficients to f32 once for the scaled-sample inner loop.
447        let coeff_f32: Vec<f32> = match proto_data.mask_coefficients.dtype() {
448            DType::F32 => {
449                let t = proto_data.mask_coefficients.as_f32().expect("F32");
450                let m = t.map()?;
451                m.as_slice().to_vec()
452            }
453            DType::F16 => {
454                let t = proto_data.mask_coefficients.as_f16().expect("F16");
455                let m = t.map()?;
456                m.as_slice().iter().map(|v| v.to_f32()).collect()
457            }
458            other => {
459                return Err(crate::Error::InvalidShape(format!(
460                    "mask_coefficients dtype {other:?} not supported"
461                )));
462            }
463        };
464
465        match proto_data.protos.dtype() {
466            DType::F32 => {
467                let t = proto_data.protos.as_f32().expect("F32");
468                let m = t.map()?;
469                scaled_segmentations_f32_slice(
470                    detect,
471                    &coeff_f32,
472                    m.as_slice(),
473                    proto_h,
474                    proto_w,
475                    num_protos,
476                    letterbox,
477                    width,
478                    height,
479                )
480            }
481            DType::F16 => {
482                let t = proto_data.protos.as_f16().expect("F16");
483                let m = t.map()?;
484                scaled_segmentations_f16_slice(
485                    detect,
486                    &coeff_f32,
487                    m.as_slice(),
488                    proto_h,
489                    proto_w,
490                    num_protos,
491                    letterbox,
492                    width,
493                    height,
494                )
495            }
496            DType::I8 => {
497                let t = proto_data.protos.as_i8().expect("I8");
498                let m = t.map()?;
499                let quant = t.quantization().ok_or_else(|| {
500                    crate::Error::InvalidShape("I8 protos require quantization metadata".into())
501                })?;
502                scaled_segmentations_i8_slice(
503                    detect,
504                    &coeff_f32,
505                    m.as_slice(),
506                    proto_h,
507                    proto_w,
508                    num_protos,
509                    quant,
510                    letterbox,
511                    width,
512                    height,
513                )
514            }
515            other => Err(crate::Error::InvalidShape(format!(
516                "proto tensor dtype {other:?} not supported"
517            ))),
518        }
519    }
520}
521
522// =============================================================================
523// Slice-native fused kernels.
524//
525// All kernels take row-major `[H, W, num_protos]` proto slices + `&[f32]`
526// coefficients (widened once from the source dtype at the materialize entry
527// point). Per-dtype variants exist for i8 (with on-the-fly dequant using a
528// tensor-level `Quantization`), f32, and f16; f16 widens to f32 per-element
529// at the FMA site via `half::f16::to_f32()`.
530//
531// On ARMv8.2-FP16 this compiles to `fcvt`; on Cortex-A53 and non-F16C x86 it
532// becomes a soft-float helper. Stage 8 adds explicit intrinsic kernels
533// gated by `#[cfg(target_feature = "fp16")]` / `+f16c`.
534// =============================================================================
535
536#[allow(clippy::too_many_arguments)]
537fn fused_dequant_dot_sigmoid_i8_slice(
538    protos: &[i8],
539    coeff: &[f32],
540    quant: &edgefirst_tensor::Quantization,
541    _proto_h: usize,
542    proto_w: usize,
543    y0: usize,
544    x0: usize,
545    roi_h: usize,
546    roi_w: usize,
547    num_protos: usize,
548) -> crate::Result<ndarray::Array3<u8>> {
549    use edgefirst_tensor::QuantMode;
550    let stride_y = proto_w * num_protos;
551    // Precompute scaled coefficients + zp_offset. Stack scratch covers
552    // `num_protos ≤ 64` (every production model today); larger proto counts
553    // fall back to a single heap allocation per kernel call so the kernel
554    // does not silently reject valid-but-larger models.
555    let mut stack_scratch = [0.0_f32; 64];
556    let mut heap_scratch: Vec<f32>;
557    let scaled_coeff: &mut [f32] = if num_protos <= stack_scratch.len() {
558        &mut stack_scratch[..num_protos]
559    } else {
560        heap_scratch = vec![0.0_f32; num_protos];
561        heap_scratch.as_mut_slice()
562    };
563    let zp_offset: f32;
564    match quant.mode() {
565        QuantMode::PerTensorSymmetric { scale } => {
566            for k in 0..num_protos {
567                scaled_coeff[k] = coeff[k] * scale;
568            }
569            zp_offset = 0.0;
570        }
571        QuantMode::PerTensor { scale, zero_point } => {
572            for k in 0..num_protos {
573                scaled_coeff[k] = coeff[k] * scale;
574            }
575            zp_offset = zero_point as f32 * scaled_coeff.iter().take(num_protos).sum::<f32>();
576        }
577        QuantMode::PerChannelSymmetric { scales, axis } => {
578            if axis != 2 {
579                return Err(crate::Error::NotSupported(format!(
580                    "per-channel quantization on axis {axis} not supported \
581                     (only channel axis 2 is implemented on this kernel)"
582                )));
583            }
584            for k in 0..num_protos {
585                scaled_coeff[k] = coeff[k] * scales[k];
586            }
587            zp_offset = 0.0;
588        }
589        QuantMode::PerChannel {
590            scales,
591            zero_points,
592            axis,
593        } => {
594            if axis != 2 {
595                return Err(crate::Error::NotSupported(format!(
596                    "per-channel quantization on axis {axis} not supported \
597                     (only channel axis 2 is implemented on this kernel)"
598                )));
599            }
600            for k in 0..num_protos {
601                scaled_coeff[k] = coeff[k] * scales[k];
602            }
603            zp_offset = (0..num_protos)
604                .map(|k| scaled_coeff[k] * zero_points[k] as f32)
605                .sum();
606        }
607    }
608
609    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
610    for y in 0..roi_h {
611        for x in 0..roi_w {
612            let base = (y0 + y) * stride_y + (x0 + x) * num_protos;
613            let mut acc = 0.0_f32;
614            let mut k = 0;
615            let chunks = num_protos / 4;
616            for _ in 0..chunks {
617                let p0 = protos[base + k] as f32;
618                let p1 = protos[base + k + 1] as f32;
619                let p2 = protos[base + k + 2] as f32;
620                let p3 = protos[base + k + 3] as f32;
621                acc += scaled_coeff[k] * p0
622                    + scaled_coeff[k + 1] * p1
623                    + scaled_coeff[k + 2] * p2
624                    + scaled_coeff[k + 3] * p3;
625                k += 4;
626            }
627            while k < num_protos {
628                acc += scaled_coeff[k] * protos[base + k] as f32;
629                k += 1;
630            }
631            acc -= zp_offset;
632            let sigmoid = fast_sigmoid(acc);
633            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
634        }
635    }
636    Ok(mask)
637}
638
639#[allow(clippy::too_many_arguments)]
640fn fused_dot_sigmoid_f32_slice(
641    protos: &[f32],
642    coeff: &[f32],
643    _proto_h: usize,
644    proto_w: usize,
645    y0: usize,
646    x0: usize,
647    roi_h: usize,
648    roi_w: usize,
649    num_protos: usize,
650) -> ndarray::Array3<u8> {
651    let stride_y = proto_w * num_protos;
652    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
653    for y in 0..roi_h {
654        for x in 0..roi_w {
655            let base = (y0 + y) * stride_y + (x0 + x) * num_protos;
656            let mut acc = 0.0_f32;
657            let mut k = 0;
658            let chunks = num_protos / 4;
659            for _ in 0..chunks {
660                acc += coeff[k] * protos[base + k]
661                    + coeff[k + 1] * protos[base + k + 1]
662                    + coeff[k + 2] * protos[base + k + 2]
663                    + coeff[k + 3] * protos[base + k + 3];
664                k += 4;
665            }
666            while k < num_protos {
667                acc += coeff[k] * protos[base + k];
668                k += 1;
669            }
670            let sigmoid = fast_sigmoid(acc);
671            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
672        }
673    }
674    mask
675}
676
677/// Native-f16 fused kernel.
678///
679/// Three code paths, selected at compile time:
680///
681/// 1. **x86_64 + F16C + FMA** — explicit intrinsic kernel (`_mm256_cvtph_ps`
682///    8-lane f16→f32 widening, `_mm256_fmadd_ps` FMA). Guaranteed to use
683///    hardware f16 conversion, not LLVM's autovectorizer (which is unreliable
684///    for this pattern per rust-lang/stdarch #1349).
685///
686/// 2. **aarch64 + FP16** — scalar `half::f16::to_f32()` at the FMA site.
687///    LLVM lowers each `.to_f32()` to a single `fcvt` instruction when
688///    `target-feature=+fp16` is active (e.g. `target-cpu=cortex-a78ae`).
689///    The stable f16-typed NEON intrinsics (`vcvt_f32_f16`, `vld1q_f16`)
690///    require nightly as of this commit; the scalar path is equally
691///    efficient at this granularity.
692///
693/// 3. **Fallback (Cortex-A53, targets without FP16)** — same scalar code.
694///    `half::f16::to_f32()` lowers to `__extendhfsf2` soft-float helper,
695///    one call per proto load. Correctness-preserving; ~15 cycles/load
696///    vs. ~3 cycles for the hardware path. Documented in
697///    `docs/orin-build.md`.
698#[allow(clippy::too_many_arguments)]
699fn fused_dot_sigmoid_f16_slice(
700    protos: &[half::f16],
701    coeff: &[f32],
702    proto_h: usize,
703    proto_w: usize,
704    y0: usize,
705    x0: usize,
706    roi_h: usize,
707    roi_w: usize,
708    num_protos: usize,
709) -> ndarray::Array3<u8> {
710    #[cfg(all(
711        target_arch = "x86_64",
712        target_feature = "f16c",
713        target_feature = "fma"
714    ))]
715    {
716        // SAFETY: target-feature gates both `vcvtph2ps` and `vfmadd*ps`;
717        // the caller's slice-bounds contract is identical to the scalar arm.
718        unsafe {
719            fused_dot_sigmoid_f16_slice_f16c(
720                protos, coeff, proto_h, proto_w, y0, x0, roi_h, roi_w, num_protos,
721            )
722        }
723    }
724    #[cfg(not(all(
725        target_arch = "x86_64",
726        target_feature = "f16c",
727        target_feature = "fma"
728    )))]
729    {
730        let _ = proto_h;
731        fused_dot_sigmoid_f16_slice_scalar(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
732    }
733}
734
735/// Scalar native-f16 kernel. `half::f16::to_f32()` at the FMA site is
736/// lowered to a single `fcvt` (aarch64+fp16) or a single `vcvtps_ps`
737/// (x86_64+f16c) by LLVM, or to the soft-float helper `__extendhfsf2` on
738/// targets without FP16 hardware. Loop unrolled by 4 to give the scheduler
739/// room to overlap loads with FMAs.
740#[allow(clippy::too_many_arguments, dead_code)]
741fn fused_dot_sigmoid_f16_slice_scalar(
742    protos: &[half::f16],
743    coeff: &[f32],
744    proto_w: usize,
745    y0: usize,
746    x0: usize,
747    roi_h: usize,
748    roi_w: usize,
749    num_protos: usize,
750) -> ndarray::Array3<u8> {
751    let stride_y = proto_w * num_protos;
752    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
753    for y in 0..roi_h {
754        for x in 0..roi_w {
755            let base = (y0 + y) * stride_y + (x0 + x) * num_protos;
756            let mut acc = 0.0_f32;
757            let mut k = 0;
758            let chunks = num_protos / 4;
759            for _ in 0..chunks {
760                let p0 = protos[base + k].to_f32();
761                let p1 = protos[base + k + 1].to_f32();
762                let p2 = protos[base + k + 2].to_f32();
763                let p3 = protos[base + k + 3].to_f32();
764                acc += coeff[k] * p0 + coeff[k + 1] * p1 + coeff[k + 2] * p2 + coeff[k + 3] * p3;
765                k += 4;
766            }
767            while k < num_protos {
768                acc += coeff[k] * protos[base + k].to_f32();
769                k += 1;
770            }
771            let sigmoid = fast_sigmoid(acc);
772            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
773        }
774    }
775    mask
776}
777
778/// x86_64 F16C + FMA explicit intrinsic kernel. Processes 8 f16 lanes per
779/// inner iteration via `_mm256_cvtph_ps` (8-lane f16→f32 widen) followed by
780/// `_mm256_fmadd_ps` (8-lane fused multiply-add). Horizontal reduce at the
781/// end of each pixel.
782///
783/// # Safety
784///
785/// Caller must ensure the target CPU supports F16C + FMA. The workspace's
786/// `.cargo/config.toml` sets these target-features on `x86_64-unknown-linux-gnu`
787/// and `x86_64-apple-darwin`, making this function statically callable on
788/// those targets.
789#[cfg(all(
790    target_arch = "x86_64",
791    target_feature = "f16c",
792    target_feature = "fma"
793))]
794#[allow(clippy::too_many_arguments)]
795#[target_feature(enable = "f16c,fma,avx")]
796unsafe fn fused_dot_sigmoid_f16_slice_f16c(
797    protos: &[half::f16],
798    coeff: &[f32],
799    _proto_h: usize,
800    proto_w: usize,
801    y0: usize,
802    x0: usize,
803    roi_h: usize,
804    roi_w: usize,
805    num_protos: usize,
806) -> ndarray::Array3<u8> {
807    use core::arch::x86_64::{
808        _mm256_castps256_ps128, _mm256_cvtph_ps, _mm256_extractf128_ps, _mm256_fmadd_ps,
809        _mm256_loadu_ps, _mm256_setzero_ps, _mm_add_ps, _mm_cvtss_f32, _mm_hadd_ps,
810        _mm_loadu_si128,
811    };
812
813    let stride_y = proto_w * num_protos;
814    let chunks8 = num_protos / 8;
815    let tail = num_protos % 8;
816    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
817
818    for y in 0..roi_h {
819        for x in 0..roi_w {
820            let base = (y0 + y) * stride_y + (x0 + x) * num_protos;
821            let mut acc_v = _mm256_setzero_ps();
822            let mut k = 0;
823            for _ in 0..chunks8 {
824                // Load 8 f16 (128 bits / 16 bytes) via a byte-level cast.
825                let p_ptr = protos
826                    .as_ptr()
827                    .add(base + k)
828                    .cast::<core::arch::x86_64::__m128i>();
829                let raw = _mm_loadu_si128(p_ptr);
830                let widened = _mm256_cvtph_ps(raw);
831                let coeffs_v = _mm256_loadu_ps(coeff.as_ptr().add(k));
832                acc_v = _mm256_fmadd_ps(coeffs_v, widened, acc_v);
833                k += 8;
834            }
835            // Horizontal reduce 8 → 1.
836            let lo = _mm256_castps256_ps128(acc_v);
837            let hi = _mm256_extractf128_ps::<1>(acc_v);
838            let sum4 = _mm_add_ps(lo, hi);
839            let sum2 = _mm_hadd_ps(sum4, sum4);
840            let sum1 = _mm_hadd_ps(sum2, sum2);
841            let mut acc = _mm_cvtss_f32(sum1);
842
843            // Scalar tail for num_protos % 8 (≤ 7 items).
844            while k < num_protos && k - chunks8 * 8 < tail {
845                acc += coeff[k] * protos[base + k].to_f32();
846                k += 1;
847            }
848
849            let sigmoid = fast_sigmoid(acc);
850            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
851        }
852    }
853    mask
854}
855
856// Scaled mask kernels — bilinear sample + sigmoid, binarized output.
857
858#[allow(clippy::too_many_arguments)]
859#[inline(always)]
860fn bilinear_dot_slice<P: Copy>(
861    protos: &[P],
862    stride_y: usize,
863    num_protos: usize,
864    coeff: &[f32],
865    px: f32,
866    py: f32,
867    proto_w: usize,
868    proto_h: usize,
869    load_f32: impl Fn(&P) -> f32,
870) -> f32 {
871    let x0 = (px.floor() as isize).clamp(0, proto_w as isize - 1) as usize;
872    let y0 = (py.floor() as isize).clamp(0, proto_h as isize - 1) as usize;
873    let x1 = (x0 + 1).min(proto_w - 1);
874    let y1 = (y0 + 1).min(proto_h - 1);
875    let fx = px - px.floor();
876    let fy = py - py.floor();
877    let w00 = (1.0 - fx) * (1.0 - fy);
878    let w10 = fx * (1.0 - fy);
879    let w01 = (1.0 - fx) * fy;
880    let w11 = fx * fy;
881    let b00 = y0 * stride_y + x0 * num_protos;
882    let b10 = y0 * stride_y + x1 * num_protos;
883    let b01 = y1 * stride_y + x0 * num_protos;
884    let b11 = y1 * stride_y + x1 * num_protos;
885    let mut acc = 0.0_f32;
886    for p in 0..num_protos {
887        let v00 = load_f32(&protos[b00 + p]);
888        let v10 = load_f32(&protos[b10 + p]);
889        let v01 = load_f32(&protos[b01 + p]);
890        let v11 = load_f32(&protos[b11 + p]);
891        let val = w00 * v00 + w10 * v10 + w01 * v01 + w11 * v11;
892        acc += coeff[p] * val;
893    }
894    acc
895}
896
897#[allow(clippy::too_many_arguments)]
898fn scaled_segmentations_f32_slice(
899    detect: &[crate::DetectBox],
900    coeff_all: &[f32],
901    protos: &[f32],
902    proto_h: usize,
903    proto_w: usize,
904    num_protos: usize,
905    letterbox: Option<[f32; 4]>,
906    width: u32,
907    height: u32,
908) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
909    scaled_run(
910        detect,
911        coeff_all,
912        protos,
913        proto_h,
914        proto_w,
915        num_protos,
916        letterbox,
917        width,
918        height,
919        1.0,
920        |p, _| *p,
921    )
922}
923
924#[allow(clippy::too_many_arguments)]
925fn scaled_segmentations_f16_slice(
926    detect: &[crate::DetectBox],
927    coeff_all: &[f32],
928    protos: &[half::f16],
929    proto_h: usize,
930    proto_w: usize,
931    num_protos: usize,
932    letterbox: Option<[f32; 4]>,
933    width: u32,
934    height: u32,
935) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
936    scaled_run(
937        detect,
938        coeff_all,
939        protos,
940        proto_h,
941        proto_w,
942        num_protos,
943        letterbox,
944        width,
945        height,
946        1.0,
947        |p: &half::f16, _| p.to_f32(),
948    )
949}
950
951#[allow(clippy::too_many_arguments)]
952fn scaled_segmentations_i8_slice(
953    detect: &[crate::DetectBox],
954    coeff_all: &[f32],
955    protos: &[i8],
956    proto_h: usize,
957    proto_w: usize,
958    num_protos: usize,
959    quant: &edgefirst_tensor::Quantization,
960    letterbox: Option<[f32; 4]>,
961    width: u32,
962    height: u32,
963) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
964    use edgefirst_tensor::QuantMode;
965    // Only per-tensor quantization supported on the scaled-path CPU kernel
966    // today. Per-channel fits naturally into a future extension (would need
967    // to pass per-channel scaled coefficients into bilinear_dot_slice).
968    let (scale, zp) = match quant.mode() {
969        QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
970        QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
971        QuantMode::PerChannel { axis, .. } | QuantMode::PerChannelSymmetric { axis, .. } => {
972            return Err(crate::Error::NotSupported(format!(
973                "per-channel quantization (axis={axis}) on scaled seg path \
974                 not yet supported"
975            )));
976        }
977    };
978    scaled_run(
979        detect,
980        coeff_all,
981        protos,
982        proto_h,
983        proto_w,
984        num_protos,
985        letterbox,
986        width,
987        height,
988        scale,
989        move |p: &i8, _| *p as f32 - zp,
990    )
991}
992
993#[allow(clippy::too_many_arguments)]
994fn scaled_run<P: Copy>(
995    detect: &[crate::DetectBox],
996    coeff_all: &[f32],
997    protos: &[P],
998    proto_h: usize,
999    proto_w: usize,
1000    num_protos: usize,
1001    letterbox: Option<[f32; 4]>,
1002    width: u32,
1003    height: u32,
1004    acc_scale: f32,
1005    load_f32: impl Fn(&P, f32) -> f32 + Copy,
1006) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1007    let (lx0, lw, ly0, lh) = match letterbox {
1008        Some([lx0, ly0, lx1, ly1]) => {
1009            let lw = (lx1 - lx0).max(f32::EPSILON);
1010            let lh = (ly1 - ly0).max(f32::EPSILON);
1011            (lx0, lw, ly0, lh)
1012        }
1013        None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
1014    };
1015    let out_w = width as usize;
1016    let out_h = height as usize;
1017    let stride_y = proto_w * num_protos;
1018
1019    detect
1020        .iter()
1021        .enumerate()
1022        .map(|(i, det)| {
1023            let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1024            let bbox = det.bbox.to_canonical();
1025            let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
1026            let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
1027            let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
1028            let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
1029            let px0 = (xmin * out_w as f32).round() as usize;
1030            let py0 = (ymin * out_h as f32).round() as usize;
1031            let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
1032            let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
1033            let bbox_w = px1.saturating_sub(px0).max(1);
1034            let bbox_h = py1.saturating_sub(py0).max(1);
1035            let mut tile = ndarray::Array3::<u8>::zeros((bbox_h, bbox_w, 1));
1036            for yi in 0..bbox_h {
1037                let py = (py0 + yi) as f32;
1038                let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
1039                let sample_y = model_y_norm * proto_h as f32 - 0.5;
1040                for xi in 0..bbox_w {
1041                    let px = (px0 + xi) as f32;
1042                    let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
1043                    let sample_x = model_x_norm * proto_w as f32 - 0.5;
1044                    let acc = bilinear_dot_slice(
1045                        protos,
1046                        stride_y,
1047                        num_protos,
1048                        coeff,
1049                        sample_x,
1050                        sample_y,
1051                        proto_w,
1052                        proto_h,
1053                        |p: &P| load_f32(p, 0.0),
1054                    );
1055                    let sigmoid = fast_sigmoid(acc_scale * acc);
1056                    tile[[yi, xi, 0]] = if sigmoid > 0.5 { 255 } else { 0 };
1057                }
1058            }
1059            Ok(edgefirst_decoder::Segmentation {
1060                xmin,
1061                ymin,
1062                xmax,
1063                ymax,
1064                segmentation: tile,
1065            })
1066        })
1067        .collect()
1068}
1069
1070fn fast_sigmoid(x: f32) -> f32 {
1071    if x >= 16.0 {
1072        return 1.0;
1073    }
1074    if x <= -16.0 {
1075        return 0.0;
1076    }
1077    // Fast exp(-x) via bit manipulation (Schraudolph's algorithm).
1078    // f32 bits: 2^23 * log2(e) * x + (127 << 23) approximates exp(x).
1079    const A: f32 = (1u32 << 23) as f32; // 8388608.0
1080    const B: f32 = A * std::f32::consts::LOG2_E; // A / ln(2)
1081    const C: u32 = 127 << 23; // exponent bias
1082    let neg_x = -x;
1083    let bits = (B * neg_x) as i32 + C as i32;
1084    let exp_neg_x = f32::from_bits(bits as u32);
1085    1.0 / (1.0 + exp_neg_x)
1086}