Skip to main content

edgefirst_image/cpu/
masks.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use super::CPUProcessor;
5use crate::Result;
6use edgefirst_decoder::{DetectBox, Segmentation};
7use ndarray::Axis;
8
9impl CPUProcessor {
10    #[allow(clippy::too_many_arguments)]
11    pub(super) fn render_modelpack_segmentation(
12        &mut self,
13        dst_w: usize,
14        dst_h: usize,
15        dst_rs: usize,
16        dst_c: usize,
17        dst_slice: &mut [u8],
18        segmentation: &Segmentation,
19        opacity: f32,
20    ) -> Result<()> {
21        use ndarray_stats::QuantileExt;
22
23        let seg = &segmentation.segmentation;
24        let [seg_height, seg_width, seg_classes] = *seg.shape() else {
25            unreachable!("Array3 did not have [usize; 3] as shape");
26        };
27        let start_y = (dst_h as f32 * segmentation.ymin).round();
28        let end_y = (dst_h as f32 * segmentation.ymax).round();
29        let start_x = (dst_w as f32 * segmentation.xmin).round();
30        let end_x = (dst_w as f32 * segmentation.xmax).round();
31
32        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
33        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
34
35        let start_x_u = (start_x as usize).min(dst_w);
36        let start_y_u = (start_y as usize).min(dst_h);
37        let end_x_u = (end_x as usize).min(dst_w);
38        let end_y_u = (end_y as usize).min(dst_h);
39
40        let argmax = seg.map_axis(Axis(2), |r| r.argmax().unwrap());
41        let get_value_at_nearest = |x: f32, y: f32| -> usize {
42            let x = x.round() as usize;
43            let y = y.round() as usize;
44            argmax
45                .get([y.min(seg_height - 1), x.min(seg_width - 1)])
46                .copied()
47                .unwrap_or(0)
48        };
49
50        for y in start_y_u..end_y_u {
51            for x in start_x_u..end_x_u {
52                let seg_x = (x as f32 - start_x) * scale_x;
53                let seg_y = (y as f32 - start_y) * scale_y;
54                let label = get_value_at_nearest(seg_x, seg_y);
55
56                if label == seg_classes - 1 {
57                    continue;
58                }
59
60                let color = self.colors[label % self.colors.len()];
61
62                let alpha = if opacity == 1.0 {
63                    color[3] as u16
64                } else {
65                    (color[3] as f32 * opacity).round() as u16
66                };
67
68                let dst_index = (y * dst_rs) + (x * dst_c);
69                for c in 0..3 {
70                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
71                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
72                        / 255) as u8;
73                }
74            }
75        }
76
77        Ok(())
78    }
79
80    #[allow(clippy::too_many_arguments)]
81    pub(super) fn render_yolo_segmentation(
82        &mut self,
83        dst_w: usize,
84        dst_h: usize,
85        dst_rs: usize,
86        dst_c: usize,
87        dst_slice: &mut [u8],
88        segmentation: &Segmentation,
89        class: usize,
90        opacity: f32,
91    ) -> Result<()> {
92        let seg = &segmentation.segmentation;
93        let [seg_height, seg_width, classes] = *seg.shape() else {
94            unreachable!("Array3 did not have [usize;3] as shape");
95        };
96        debug_assert_eq!(classes, 1);
97
98        let start_y = (dst_h as f32 * segmentation.ymin).round();
99        let end_y = (dst_h as f32 * segmentation.ymax).round();
100        let start_x = (dst_w as f32 * segmentation.xmin).round();
101        let end_x = (dst_w as f32 * segmentation.xmax).round();
102
103        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
104        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
105
106        let start_x_u = (start_x as usize).min(dst_w);
107        let start_y_u = (start_y as usize).min(dst_h);
108        let end_x_u = (end_x as usize).min(dst_w);
109        let end_y_u = (end_y as usize).min(dst_h);
110
111        for y in start_y_u..end_y_u {
112            for x in start_x_u..end_x_u {
113                let seg_x = ((x as f32 - start_x) * scale_x) as usize;
114                let seg_y = ((y as f32 - start_y) * scale_y) as usize;
115                let val = *seg.get([seg_y, seg_x, 0]).unwrap_or(&0);
116
117                if val < 127 {
118                    continue;
119                }
120
121                let color = self.colors[class % self.colors.len()];
122
123                let alpha = if opacity == 1.0 {
124                    color[3] as u16
125                } else {
126                    (color[3] as f32 * opacity).round() as u16
127                };
128
129                let dst_index = (y * dst_rs) + (x * dst_c);
130                for c in 0..3 {
131                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
132                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
133                        / 255) as u8;
134                }
135            }
136        }
137
138        Ok(())
139    }
140
141    #[allow(clippy::too_many_arguments)]
142    pub(super) fn render_box(
143        &mut self,
144        dst_w: usize,
145        dst_h: usize,
146        dst_rs: usize,
147        dst_c: usize,
148        dst_slice: &mut [u8],
149        detect: &[DetectBox],
150        color_mode: crate::ColorMode,
151    ) -> Result<()> {
152        const LINE_THICKNESS: usize = 3;
153
154        for (idx, d) in detect.iter().enumerate() {
155            use edgefirst_decoder::BoundingBox;
156
157            let color_index = color_mode.index(idx, d.label);
158            let [r, g, b, _] = self.colors[color_index % self.colors.len()];
159            let bbox = d.bbox.to_canonical();
160            let bbox = BoundingBox {
161                xmin: bbox.xmin.clamp(0.0, 1.0),
162                ymin: bbox.ymin.clamp(0.0, 1.0),
163                xmax: bbox.xmax.clamp(0.0, 1.0),
164                ymax: bbox.ymax.clamp(0.0, 1.0),
165            };
166            let inner = [
167                ((dst_w - 1) as f32 * bbox.xmin - 0.5).round() as usize,
168                ((dst_h - 1) as f32 * bbox.ymin - 0.5).round() as usize,
169                ((dst_w - 1) as f32 * bbox.xmax + 0.5).round() as usize,
170                ((dst_h - 1) as f32 * bbox.ymax + 0.5).round() as usize,
171            ];
172
173            let outer = [
174                inner[0].saturating_sub(LINE_THICKNESS),
175                inner[1].saturating_sub(LINE_THICKNESS),
176                (inner[2] + LINE_THICKNESS).min(dst_w),
177                (inner[3] + LINE_THICKNESS).min(dst_h),
178            ];
179
180            // top line
181            for y in outer[1] + 1..=inner[1] {
182                for x in outer[0] + 1..outer[2] {
183                    let index = (y * dst_rs) + (x * dst_c);
184                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
185                }
186            }
187
188            // left and right lines
189            for y in inner[1]..inner[3] {
190                for x in outer[0] + 1..=inner[0] {
191                    let index = (y * dst_rs) + (x * dst_c);
192                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
193                }
194
195                for x in inner[2]..outer[2] {
196                    let index = (y * dst_rs) + (x * dst_c);
197                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
198                }
199            }
200
201            // bottom line
202            for y in inner[3]..outer[3] {
203                for x in outer[0] + 1..outer[2] {
204                    let index = (y * dst_rs) + (x * dst_c);
205                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
206                }
207            }
208        }
209        Ok(())
210    }
211
212    /// Materialize segmentation masks from proto data into `Vec<Segmentation>`.
213    ///
214    /// This is the CPU-side decode step of the hybrid mask rendering path:
215    /// call this to get pre-decoded masks, then pass them to
216    /// [`draw_decoded_masks`](crate::ImageProcessorTrait::draw_decoded_masks) for GPU overlay.
217    /// Benchmarks show this hybrid path (CPU decode + GL overlay) is faster
218    /// than the fused GPU `draw_proto_masks` on all tested platforms.
219    ///
220    /// Optimized: fused dequantization + dot product avoids a 3.1MB f32
221    /// allocation for the full proto tensor. Uses fast sigmoid approximation.
222    pub fn materialize_segmentations(
223        &self,
224        detect: &[crate::DetectBox],
225        proto_data: &crate::ProtoData,
226        letterbox: Option<[f32; 4]>,
227    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
228        use edgefirst_decoder::ProtoTensor;
229
230        if detect.is_empty() || proto_data.mask_coefficients.is_empty() {
231            return Ok(Vec::new());
232        }
233
234        // Extract proto tensor metadata for the fused kernel.
235        let (proto_h, proto_w, num_protos) = match &proto_data.protos {
236            ProtoTensor::Quantized { protos, .. } => {
237                (protos.shape()[0], protos.shape()[1], protos.shape()[2])
238            }
239            ProtoTensor::Float(arr) => (arr.shape()[0], arr.shape()[1], arr.shape()[2]),
240        };
241
242        // Precompute inverse letterbox scale for output-coord conversion.
243        let (lx0, inv_lw, ly0, inv_lh) = match letterbox {
244            Some([lx0, ly0, lx1, ly1]) => {
245                let lw = lx1 - lx0;
246                let lh = ly1 - ly0;
247                (
248                    lx0,
249                    if lw > 0.0 { 1.0 / lw } else { 1.0 },
250                    ly0,
251                    if lh > 0.0 { 1.0 / lh } else { 1.0 },
252                )
253            }
254            None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
255        };
256
257        detect
258            .iter()
259            .zip(proto_data.mask_coefficients.iter())
260            .map(|(det, coeff)| {
261                // Canonicalise bbox (swap min/max if inverted) then clamp to [0,1].
262                // Without canonicalisation a degenerate box (ymax < ymin) causes a
263                // near-zero ROI height after saturating_sub, making the mask invisible.
264                let bbox = det.bbox.to_canonical();
265                let xmin = bbox.xmin.clamp(0.0, 1.0);
266                let ymin = bbox.ymin.clamp(0.0, 1.0);
267                let xmax = bbox.xmax.clamp(0.0, 1.0);
268                let ymax = bbox.ymax.clamp(0.0, 1.0);
269
270                // Map to proto-space pixel coordinates (clamp to valid range)
271                let x0 = ((xmin * proto_w as f32) as usize).min(proto_w.saturating_sub(1));
272                let y0 = ((ymin * proto_h as f32) as usize).min(proto_h.saturating_sub(1));
273                let x1 = ((xmax * proto_w as f32).ceil() as usize).min(proto_w);
274                let y1 = ((ymax * proto_h as f32).ceil() as usize).min(proto_h);
275
276                let roi_w = x1.saturating_sub(x0).max(1);
277                let roi_h = y1.saturating_sub(y0).max(1);
278
279                if coeff.len() != num_protos {
280                    return Err(crate::Error::Internal(format!(
281                        "mask coeff length {} != proto channels {num_protos}",
282                        coeff.len()
283                    )));
284                }
285
286                // Fused dequant + dot product + sigmoid, directly producing u8 mask.
287                // Avoids allocating a full f32 proto tensor and the to_shape copy.
288                let mask = match &proto_data.protos {
289                    ProtoTensor::Quantized {
290                        protos,
291                        quantization,
292                    } => {
293                        let scale = quantization.scale;
294                        let zp = quantization.zero_point as f32;
295                        fused_dequant_dot_sigmoid_i8(
296                            protos, coeff, scale, zp, y0, x0, roi_h, roi_w, num_protos,
297                        )
298                    }
299                    ProtoTensor::Float(protos) => {
300                        fused_dot_sigmoid_f32(protos, coeff, y0, x0, roi_h, roi_w, num_protos)
301                    }
302                };
303
304                // Convert proto-space normalised coords to output-image-normalised
305                // coords by applying the inverse letterbox transform.  When no
306                // letterbox was used (lx0=0, inv_lw=1, ly0=0, inv_lh=1) this is a
307                // no-op and the coords are identical to the proto-normalised values.
308                let seg_xmin = ((x0 as f32 / proto_w as f32) - lx0) * inv_lw;
309                let seg_ymin = ((y0 as f32 / proto_h as f32) - ly0) * inv_lh;
310                let seg_xmax = ((x1 as f32 / proto_w as f32) - lx0) * inv_lw;
311                let seg_ymax = ((y1 as f32 / proto_h as f32) - ly0) * inv_lh;
312
313                Ok(edgefirst_decoder::Segmentation {
314                    xmin: seg_xmin.clamp(0.0, 1.0),
315                    ymin: seg_ymin.clamp(0.0, 1.0),
316                    xmax: seg_xmax.clamp(0.0, 1.0),
317                    ymax: seg_ymax.clamp(0.0, 1.0),
318                    segmentation: mask,
319                })
320            })
321            .collect::<crate::Result<Vec<_>>>()
322    }
323}
324
325/// Bilinear interpolation of proto values at `(px, py)` combined with dot
326/// product against `coeff`. Returns the scalar accumulator before sigmoid.
327///
328/// Samples the four nearest proto texels, weights by bilinear coefficients,
329/// and simultaneously computes the dot product with the mask coefficients.
330#[inline]
331pub(super) fn bilinear_dot(
332    protos: &ndarray::Array3<f32>,
333    coeff: &[f32],
334    num_protos: usize,
335    px: f32,
336    py: f32,
337    proto_w: usize,
338    proto_h: usize,
339) -> f32 {
340    let x0 = (px.floor() as isize).clamp(0, proto_w as isize - 1) as usize;
341    let y0 = (py.floor() as isize).clamp(0, proto_h as isize - 1) as usize;
342    let x1 = (x0 + 1).min(proto_w - 1);
343    let y1 = (y0 + 1).min(proto_h - 1);
344
345    let fx = px - px.floor();
346    let fy = py - py.floor();
347
348    let w00 = (1.0 - fx) * (1.0 - fy);
349    let w10 = fx * (1.0 - fy);
350    let w01 = (1.0 - fx) * fy;
351    let w11 = fx * fy;
352
353    let mut acc = 0.0f32;
354    for p in 0..num_protos {
355        let val = w00 * protos[[y0, x0, p]]
356            + w10 * protos[[y0, x1, p]]
357            + w01 * protos[[y1, x0, p]]
358            + w11 * protos[[y1, x1, p]];
359        acc += coeff[p] * val;
360    }
361    acc
362}
363
364/// Fast sigmoid approximation: `1 / (1 + exp(-x))`.
365///
366/// Uses bit-manipulation for exp() (same approach as `fast_math::exp_raw`).
367/// Max relative error < 1.1% for normal results, which is well within the
368/// precision needed for mask thresholding and u8 quantization.
369#[inline(always)]
370fn fast_sigmoid(x: f32) -> f32 {
371    if x >= 16.0 {
372        return 1.0;
373    }
374    if x <= -16.0 {
375        return 0.0;
376    }
377    // Fast exp(-x) via bit manipulation (Schraudolph's algorithm).
378    // f32 bits: 2^23 * log2(e) * x + (127 << 23) approximates exp(x).
379    const A: f32 = (1u32 << 23) as f32; // 8388608.0
380    const B: f32 = A * std::f32::consts::LOG2_E; // A / ln(2)
381    const C: u32 = 127 << 23; // exponent bias
382    let neg_x = -x;
383    let bits = (B * neg_x) as i32 + C as i32;
384    let exp_neg_x = f32::from_bits(bits as u32);
385    1.0 / (1.0 + exp_neg_x)
386}
387
388/// Fused dequantization + dot product + sigmoid for quantized (i8) protos.
389///
390/// For each pixel in the ROI, computes:
391///   acc = sum_k(coeff[k] * (proto[y, x, k] as f32 - zp) * scale)
392///   mask[y, x] = fast_sigmoid(acc) * 255
393///
394/// This avoids allocating a full f32 proto tensor (3.1MB for 160x160x32)
395/// and the hidden `to_shape` copy on non-contiguous ROI slices.
396#[allow(clippy::too_many_arguments)]
397fn fused_dequant_dot_sigmoid_i8(
398    protos: &ndarray::Array3<i8>,
399    coeff: &[f32],
400    scale: f32,
401    zp: f32,
402    y0: usize,
403    x0: usize,
404    roi_h: usize,
405    roi_w: usize,
406    num_protos: usize,
407) -> ndarray::Array3<u8> {
408    debug_assert!(
409        protos.strides().iter().all(|&s| s >= 0),
410        "negative strides unsupported"
411    );
412    // Pre-scale coefficients: coeff[k] * scale, so the inner loop is
413    // just fma: acc += scaled_coeff[k] * (proto_i8 - zp)
414    let scaled_coeff: Vec<f32> = coeff.iter().map(|&c| c * scale).collect();
415    // Pre-compute coeff_sum * (-zp * scale) offset:
416    // sum_k(coeff[k] * (proto - zp) * scale) = sum_k(scaled_coeff[k] * proto) - zp * sum_k(scaled_coeff[k])
417    // But since zp is a constant per-pixel term, factor it out:
418    // acc = sum_k(scaled_coeff[k] * proto_i8_as_f32) - zp * sum_k(scaled_coeff[k])
419    let zp_offset: f32 = zp * scaled_coeff.iter().sum::<f32>();
420
421    let proto_stride_y = protos.strides()[0] as usize;
422    let proto_stride_x = protos.strides()[1] as usize;
423    let proto_stride_k = protos.strides()[2] as usize;
424    let proto_ptr = protos.as_ptr();
425
426    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
427
428    for y in 0..roi_h {
429        for x in 0..roi_w {
430            // Base pointer for protos[y0+y, x0+x, 0]
431            let base = (y0 + y) * proto_stride_y + (x0 + x) * proto_stride_x;
432
433            let mut acc = 0.0f32;
434            let mut k = 0;
435
436            // Process 4 protos at a time for better ILP
437            let chunks = num_protos / 4;
438            for _ in 0..chunks {
439                // SAFETY: bounds are guaranteed by ROI clamping in the caller:
440                // y0+y < proto_h, x0+x < proto_w, k+3 < num_protos <= protos.shape()[2].
441                unsafe {
442                    let p0 = *proto_ptr.add(base + k * proto_stride_k) as f32;
443                    let p1 = *proto_ptr.add(base + (k + 1) * proto_stride_k) as f32;
444                    let p2 = *proto_ptr.add(base + (k + 2) * proto_stride_k) as f32;
445                    let p3 = *proto_ptr.add(base + (k + 3) * proto_stride_k) as f32;
446                    acc += scaled_coeff[k] * p0
447                        + scaled_coeff[k + 1] * p1
448                        + scaled_coeff[k + 2] * p2
449                        + scaled_coeff[k + 3] * p3;
450                }
451                k += 4;
452            }
453            // Remainder
454            while k < num_protos {
455                // SAFETY: bounds are guaranteed by ROI clamping in the caller:
456                // y0+y < proto_h, x0+x < proto_w, k < num_protos <= protos.shape()[2].
457                unsafe {
458                    let p = *proto_ptr.add(base + k * proto_stride_k) as f32;
459                    acc += scaled_coeff[k] * p;
460                }
461                k += 1;
462            }
463
464            acc -= zp_offset;
465            let sigmoid = fast_sigmoid(acc);
466            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
467        }
468    }
469    mask
470}
471
472/// Fused dot product + sigmoid for f32 protos (no dequantization needed).
473fn fused_dot_sigmoid_f32(
474    protos: &ndarray::Array3<f32>,
475    coeff: &[f32],
476    y0: usize,
477    x0: usize,
478    roi_h: usize,
479    roi_w: usize,
480    num_protos: usize,
481) -> ndarray::Array3<u8> {
482    debug_assert!(
483        protos.strides().iter().all(|&s| s >= 0),
484        "negative strides unsupported"
485    );
486    let proto_stride_y = protos.strides()[0] as usize;
487    let proto_stride_x = protos.strides()[1] as usize;
488    let proto_stride_k = protos.strides()[2] as usize;
489    let proto_ptr = protos.as_ptr();
490
491    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
492
493    for y in 0..roi_h {
494        for x in 0..roi_w {
495            let base = (y0 + y) * proto_stride_y + (x0 + x) * proto_stride_x;
496
497            let mut acc = 0.0f32;
498            let mut k = 0;
499            let chunks = num_protos / 4;
500            for _ in 0..chunks {
501                // SAFETY: bounds are guaranteed by ROI clamping in the caller:
502                // y0+y < proto_h, x0+x < proto_w, k+3 < num_protos <= protos.shape()[2].
503                unsafe {
504                    let p0 = *proto_ptr.add(base + k * proto_stride_k);
505                    let p1 = *proto_ptr.add(base + (k + 1) * proto_stride_k);
506                    let p2 = *proto_ptr.add(base + (k + 2) * proto_stride_k);
507                    let p3 = *proto_ptr.add(base + (k + 3) * proto_stride_k);
508                    acc +=
509                        coeff[k] * p0 + coeff[k + 1] * p1 + coeff[k + 2] * p2 + coeff[k + 3] * p3;
510                }
511                k += 4;
512            }
513            while k < num_protos {
514                // SAFETY: bounds are guaranteed by ROI clamping in the caller:
515                // y0+y < proto_h, x0+x < proto_w, k < num_protos <= protos.shape()[2].
516                unsafe {
517                    let p = *proto_ptr.add(base + k * proto_stride_k);
518                    acc += coeff[k] * p;
519                }
520                k += 1;
521            }
522
523            let sigmoid = fast_sigmoid(acc);
524            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
525        }
526    }
527    mask
528}