Skip to main content

edgefirst_image/cpu/
masks.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use super::CPUProcessor;
5use crate::Result;
6use edgefirst_decoder::{DetectBox, Segmentation};
7use ndarray::Axis;
8
9impl CPUProcessor {
10    #[allow(clippy::too_many_arguments)]
11    pub(super) fn render_modelpack_segmentation(
12        &mut self,
13        dst_w: usize,
14        dst_h: usize,
15        dst_rs: usize,
16        dst_c: usize,
17        dst_slice: &mut [u8],
18        segmentation: &Segmentation,
19        opacity: f32,
20    ) -> Result<()> {
21        use ndarray_stats::QuantileExt;
22
23        let seg = &segmentation.segmentation;
24        let [seg_height, seg_width, seg_classes] = *seg.shape() else {
25            unreachable!("Array3 did not have [usize; 3] as shape");
26        };
27        let start_y = (dst_h as f32 * segmentation.ymin).round();
28        let end_y = (dst_h as f32 * segmentation.ymax).round();
29        let start_x = (dst_w as f32 * segmentation.xmin).round();
30        let end_x = (dst_w as f32 * segmentation.xmax).round();
31
32        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
33        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
34
35        let start_x_u = (start_x as usize).min(dst_w);
36        let start_y_u = (start_y as usize).min(dst_h);
37        let end_x_u = (end_x as usize).min(dst_w);
38        let end_y_u = (end_y as usize).min(dst_h);
39
40        let argmax = seg.map_axis(Axis(2), |r| r.argmax().unwrap());
41        let get_value_at_nearest = |x: f32, y: f32| -> usize {
42            let x = x.round() as usize;
43            let y = y.round() as usize;
44            argmax
45                .get([y.min(seg_height - 1), x.min(seg_width - 1)])
46                .copied()
47                .unwrap_or(0)
48        };
49
50        for y in start_y_u..end_y_u {
51            for x in start_x_u..end_x_u {
52                let seg_x = (x as f32 - start_x) * scale_x;
53                let seg_y = (y as f32 - start_y) * scale_y;
54                let label = get_value_at_nearest(seg_x, seg_y);
55
56                if label == seg_classes - 1 {
57                    continue;
58                }
59
60                let color = self.colors[label % self.colors.len()];
61
62                let alpha = if opacity == 1.0 {
63                    color[3] as u16
64                } else {
65                    (color[3] as f32 * opacity).round() as u16
66                };
67
68                let dst_index = (y * dst_rs) + (x * dst_c);
69                for c in 0..3 {
70                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
71                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
72                        / 255) as u8;
73                }
74            }
75        }
76
77        Ok(())
78    }
79
80    #[allow(clippy::too_many_arguments)]
81    pub(super) fn render_yolo_segmentation(
82        &mut self,
83        dst_w: usize,
84        dst_h: usize,
85        dst_rs: usize,
86        dst_c: usize,
87        dst_slice: &mut [u8],
88        segmentation: &Segmentation,
89        class: usize,
90        opacity: f32,
91    ) -> Result<()> {
92        let seg = &segmentation.segmentation;
93        let [seg_height, seg_width, classes] = *seg.shape() else {
94            unreachable!("Array3 did not have [usize;3] as shape");
95        };
96        debug_assert_eq!(classes, 1);
97
98        let start_y = (dst_h as f32 * segmentation.ymin).round();
99        let end_y = (dst_h as f32 * segmentation.ymax).round();
100        let start_x = (dst_w as f32 * segmentation.xmin).round();
101        let end_x = (dst_w as f32 * segmentation.xmax).round();
102
103        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
104        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
105
106        let start_x_u = (start_x as usize).min(dst_w);
107        let start_y_u = (start_y as usize).min(dst_h);
108        let end_x_u = (end_x as usize).min(dst_w);
109        let end_y_u = (end_y as usize).min(dst_h);
110
111        for y in start_y_u..end_y_u {
112            for x in start_x_u..end_x_u {
113                let seg_x = ((x as f32 - start_x) * scale_x) as usize;
114                let seg_y = ((y as f32 - start_y) * scale_y) as usize;
115                let val = *seg.get([seg_y, seg_x, 0]).unwrap_or(&0);
116
117                if val < 127 {
118                    continue;
119                }
120
121                let color = self.colors[class % self.colors.len()];
122
123                let alpha = if opacity == 1.0 {
124                    color[3] as u16
125                } else {
126                    (color[3] as f32 * opacity).round() as u16
127                };
128
129                let dst_index = (y * dst_rs) + (x * dst_c);
130                for c in 0..3 {
131                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
132                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
133                        / 255) as u8;
134                }
135            }
136        }
137
138        Ok(())
139    }
140
141    pub(super) fn render_box(
142        &mut self,
143        dst_w: usize,
144        dst_h: usize,
145        dst_rs: usize,
146        dst_c: usize,
147        dst_slice: &mut [u8],
148        detect: &[DetectBox],
149    ) -> Result<()> {
150        const LINE_THICKNESS: usize = 3;
151
152        for d in detect {
153            use edgefirst_decoder::BoundingBox;
154
155            let label = d.label;
156            let [r, g, b, _] = self.colors[label % self.colors.len()];
157            let bbox = d.bbox.to_canonical();
158            let bbox = BoundingBox {
159                xmin: bbox.xmin.clamp(0.0, 1.0),
160                ymin: bbox.ymin.clamp(0.0, 1.0),
161                xmax: bbox.xmax.clamp(0.0, 1.0),
162                ymax: bbox.ymax.clamp(0.0, 1.0),
163            };
164            let inner = [
165                ((dst_w - 1) as f32 * bbox.xmin - 0.5).round() as usize,
166                ((dst_h - 1) as f32 * bbox.ymin - 0.5).round() as usize,
167                ((dst_w - 1) as f32 * bbox.xmax + 0.5).round() as usize,
168                ((dst_h - 1) as f32 * bbox.ymax + 0.5).round() as usize,
169            ];
170
171            let outer = [
172                inner[0].saturating_sub(LINE_THICKNESS),
173                inner[1].saturating_sub(LINE_THICKNESS),
174                (inner[2] + LINE_THICKNESS).min(dst_w),
175                (inner[3] + LINE_THICKNESS).min(dst_h),
176            ];
177
178            // top line
179            for y in outer[1] + 1..=inner[1] {
180                for x in outer[0] + 1..outer[2] {
181                    let index = (y * dst_rs) + (x * dst_c);
182                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
183                }
184            }
185
186            // left and right lines
187            for y in inner[1]..inner[3] {
188                for x in outer[0] + 1..=inner[0] {
189                    let index = (y * dst_rs) + (x * dst_c);
190                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
191                }
192
193                for x in inner[2]..outer[2] {
194                    let index = (y * dst_rs) + (x * dst_c);
195                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
196                }
197            }
198
199            // bottom line
200            for y in inner[3]..outer[3] {
201                for x in outer[0] + 1..outer[2] {
202                    let index = (y * dst_rs) + (x * dst_c);
203                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
204                }
205            }
206        }
207        Ok(())
208    }
209
210    /// Materialize segmentation masks from proto data into `Vec<Segmentation>`.
211    ///
212    /// This is the CPU-side decode step of the hybrid mask rendering path:
213    /// call this to get pre-decoded masks, then pass them to
214    /// [`draw_decoded_masks`](crate::ImageProcessorTrait::draw_decoded_masks) for GPU overlay.
215    /// Benchmarks show this hybrid path (CPU decode + GL overlay) is faster
216    /// than the fused GPU `draw_proto_masks` on all tested platforms.
217    ///
218    /// Optimized: fused dequantization + dot product avoids a 3.1MB f32
219    /// allocation for the full proto tensor. Uses fast sigmoid approximation.
220    pub fn materialize_segmentations(
221        &self,
222        detect: &[crate::DetectBox],
223        proto_data: &crate::ProtoData,
224    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
225        use edgefirst_decoder::ProtoTensor;
226
227        if detect.is_empty() || proto_data.mask_coefficients.is_empty() {
228            return Ok(Vec::new());
229        }
230
231        // Extract proto tensor metadata for the fused kernel.
232        let (proto_h, proto_w, num_protos) = match &proto_data.protos {
233            ProtoTensor::Quantized { protos, .. } => {
234                (protos.shape()[0], protos.shape()[1], protos.shape()[2])
235            }
236            ProtoTensor::Float(arr) => (arr.shape()[0], arr.shape()[1], arr.shape()[2]),
237        };
238
239        detect
240            .iter()
241            .zip(proto_data.mask_coefficients.iter())
242            .map(|(det, coeff)| {
243                // Clamp bbox to [0, 1]
244                let xmin = det.bbox.xmin.clamp(0.0, 1.0);
245                let ymin = det.bbox.ymin.clamp(0.0, 1.0);
246                let xmax = det.bbox.xmax.clamp(0.0, 1.0);
247                let ymax = det.bbox.ymax.clamp(0.0, 1.0);
248
249                // Map to proto-space pixel coordinates (clamp to valid range)
250                let x0 = ((xmin * proto_w as f32) as usize).min(proto_w.saturating_sub(1));
251                let y0 = ((ymin * proto_h as f32) as usize).min(proto_h.saturating_sub(1));
252                let x1 = ((xmax * proto_w as f32).ceil() as usize).min(proto_w);
253                let y1 = ((ymax * proto_h as f32).ceil() as usize).min(proto_h);
254
255                let roi_w = x1.saturating_sub(x0).max(1);
256                let roi_h = y1.saturating_sub(y0).max(1);
257
258                if coeff.len() != num_protos {
259                    return Err(crate::Error::Internal(format!(
260                        "mask coeff length {} != proto channels {num_protos}",
261                        coeff.len()
262                    )));
263                }
264
265                // Fused dequant + dot product + sigmoid, directly producing u8 mask.
266                // Avoids allocating a full f32 proto tensor and the to_shape copy.
267                let mask = match &proto_data.protos {
268                    ProtoTensor::Quantized {
269                        protos,
270                        quantization,
271                    } => {
272                        let scale = quantization.scale;
273                        let zp = quantization.zero_point as f32;
274                        fused_dequant_dot_sigmoid_i8(
275                            protos, coeff, scale, zp, y0, x0, roi_h, roi_w, num_protos,
276                        )
277                    }
278                    ProtoTensor::Float(protos) => {
279                        fused_dot_sigmoid_f32(protos, coeff, y0, x0, roi_h, roi_w, num_protos)
280                    }
281                };
282
283                Ok(edgefirst_decoder::Segmentation {
284                    xmin: x0 as f32 / proto_w as f32,
285                    ymin: y0 as f32 / proto_h as f32,
286                    xmax: x1 as f32 / proto_w as f32,
287                    ymax: y1 as f32 / proto_h as f32,
288                    segmentation: mask,
289                })
290            })
291            .collect::<crate::Result<Vec<_>>>()
292    }
293}
294
295/// Bilinear interpolation of proto values at `(px, py)` combined with dot
296/// product against `coeff`. Returns the scalar accumulator before sigmoid.
297///
298/// Samples the four nearest proto texels, weights by bilinear coefficients,
299/// and simultaneously computes the dot product with the mask coefficients.
300#[inline]
301pub(super) fn bilinear_dot(
302    protos: &ndarray::Array3<f32>,
303    coeff: &[f32],
304    num_protos: usize,
305    px: f32,
306    py: f32,
307    proto_w: usize,
308    proto_h: usize,
309) -> f32 {
310    let x0 = (px.floor() as isize).clamp(0, proto_w as isize - 1) as usize;
311    let y0 = (py.floor() as isize).clamp(0, proto_h as isize - 1) as usize;
312    let x1 = (x0 + 1).min(proto_w - 1);
313    let y1 = (y0 + 1).min(proto_h - 1);
314
315    let fx = px - px.floor();
316    let fy = py - py.floor();
317
318    let w00 = (1.0 - fx) * (1.0 - fy);
319    let w10 = fx * (1.0 - fy);
320    let w01 = (1.0 - fx) * fy;
321    let w11 = fx * fy;
322
323    let mut acc = 0.0f32;
324    for p in 0..num_protos {
325        let val = w00 * protos[[y0, x0, p]]
326            + w10 * protos[[y0, x1, p]]
327            + w01 * protos[[y1, x0, p]]
328            + w11 * protos[[y1, x1, p]];
329        acc += coeff[p] * val;
330    }
331    acc
332}
333
334/// Fast sigmoid approximation: `1 / (1 + exp(-x))`.
335///
336/// Uses bit-manipulation for exp() (same approach as `fast_math::exp_raw`).
337/// Max relative error < 1.1% for normal results, which is well within the
338/// precision needed for mask thresholding and u8 quantization.
339#[inline(always)]
340fn fast_sigmoid(x: f32) -> f32 {
341    if x >= 16.0 {
342        return 1.0;
343    }
344    if x <= -16.0 {
345        return 0.0;
346    }
347    // Fast exp(-x) via bit manipulation (Schraudolph's algorithm).
348    // f32 bits: 2^23 * log2(e) * x + (127 << 23) approximates exp(x).
349    const A: f32 = (1u32 << 23) as f32; // 8388608.0
350    const B: f32 = A * std::f32::consts::LOG2_E; // A / ln(2)
351    const C: u32 = 127 << 23; // exponent bias
352    let neg_x = -x;
353    let bits = (B * neg_x) as i32 + C as i32;
354    let exp_neg_x = f32::from_bits(bits as u32);
355    1.0 / (1.0 + exp_neg_x)
356}
357
358/// Fused dequantization + dot product + sigmoid for quantized (i8) protos.
359///
360/// For each pixel in the ROI, computes:
361///   acc = sum_k(coeff[k] * (proto[y, x, k] as f32 - zp) * scale)
362///   mask[y, x] = fast_sigmoid(acc) * 255
363///
364/// This avoids allocating a full f32 proto tensor (3.1MB for 160x160x32)
365/// and the hidden `to_shape` copy on non-contiguous ROI slices.
366#[allow(clippy::too_many_arguments)]
367fn fused_dequant_dot_sigmoid_i8(
368    protos: &ndarray::Array3<i8>,
369    coeff: &[f32],
370    scale: f32,
371    zp: f32,
372    y0: usize,
373    x0: usize,
374    roi_h: usize,
375    roi_w: usize,
376    num_protos: usize,
377) -> ndarray::Array3<u8> {
378    debug_assert!(
379        protos.strides().iter().all(|&s| s >= 0),
380        "negative strides unsupported"
381    );
382    // Pre-scale coefficients: coeff[k] * scale, so the inner loop is
383    // just fma: acc += scaled_coeff[k] * (proto_i8 - zp)
384    let scaled_coeff: Vec<f32> = coeff.iter().map(|&c| c * scale).collect();
385    // Pre-compute coeff_sum * (-zp * scale) offset:
386    // sum_k(coeff[k] * (proto - zp) * scale) = sum_k(scaled_coeff[k] * proto) - zp * sum_k(scaled_coeff[k])
387    // But since zp is a constant per-pixel term, factor it out:
388    // acc = sum_k(scaled_coeff[k] * proto_i8_as_f32) - zp * sum_k(scaled_coeff[k])
389    let zp_offset: f32 = zp * scaled_coeff.iter().sum::<f32>();
390
391    let proto_stride_y = protos.strides()[0] as usize;
392    let proto_stride_x = protos.strides()[1] as usize;
393    let proto_stride_k = protos.strides()[2] as usize;
394    let proto_ptr = protos.as_ptr();
395
396    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
397
398    for y in 0..roi_h {
399        for x in 0..roi_w {
400            // Base pointer for protos[y0+y, x0+x, 0]
401            let base = (y0 + y) * proto_stride_y + (x0 + x) * proto_stride_x;
402
403            let mut acc = 0.0f32;
404            let mut k = 0;
405
406            // Process 4 protos at a time for better ILP
407            let chunks = num_protos / 4;
408            for _ in 0..chunks {
409                // SAFETY: bounds are guaranteed by ROI clamping in the caller:
410                // y0+y < proto_h, x0+x < proto_w, k+3 < num_protos <= protos.shape()[2].
411                unsafe {
412                    let p0 = *proto_ptr.add(base + k * proto_stride_k) as f32;
413                    let p1 = *proto_ptr.add(base + (k + 1) * proto_stride_k) as f32;
414                    let p2 = *proto_ptr.add(base + (k + 2) * proto_stride_k) as f32;
415                    let p3 = *proto_ptr.add(base + (k + 3) * proto_stride_k) as f32;
416                    acc += scaled_coeff[k] * p0
417                        + scaled_coeff[k + 1] * p1
418                        + scaled_coeff[k + 2] * p2
419                        + scaled_coeff[k + 3] * p3;
420                }
421                k += 4;
422            }
423            // Remainder
424            while k < num_protos {
425                // SAFETY: bounds are guaranteed by ROI clamping in the caller:
426                // y0+y < proto_h, x0+x < proto_w, k < num_protos <= protos.shape()[2].
427                unsafe {
428                    let p = *proto_ptr.add(base + k * proto_stride_k) as f32;
429                    acc += scaled_coeff[k] * p;
430                }
431                k += 1;
432            }
433
434            acc -= zp_offset;
435            let sigmoid = fast_sigmoid(acc);
436            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
437        }
438    }
439    mask
440}
441
442/// Fused dot product + sigmoid for f32 protos (no dequantization needed).
443fn fused_dot_sigmoid_f32(
444    protos: &ndarray::Array3<f32>,
445    coeff: &[f32],
446    y0: usize,
447    x0: usize,
448    roi_h: usize,
449    roi_w: usize,
450    num_protos: usize,
451) -> ndarray::Array3<u8> {
452    debug_assert!(
453        protos.strides().iter().all(|&s| s >= 0),
454        "negative strides unsupported"
455    );
456    let proto_stride_y = protos.strides()[0] as usize;
457    let proto_stride_x = protos.strides()[1] as usize;
458    let proto_stride_k = protos.strides()[2] as usize;
459    let proto_ptr = protos.as_ptr();
460
461    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
462
463    for y in 0..roi_h {
464        for x in 0..roi_w {
465            let base = (y0 + y) * proto_stride_y + (x0 + x) * proto_stride_x;
466
467            let mut acc = 0.0f32;
468            let mut k = 0;
469            let chunks = num_protos / 4;
470            for _ in 0..chunks {
471                // SAFETY: bounds are guaranteed by ROI clamping in the caller:
472                // y0+y < proto_h, x0+x < proto_w, k+3 < num_protos <= protos.shape()[2].
473                unsafe {
474                    let p0 = *proto_ptr.add(base + k * proto_stride_k);
475                    let p1 = *proto_ptr.add(base + (k + 1) * proto_stride_k);
476                    let p2 = *proto_ptr.add(base + (k + 2) * proto_stride_k);
477                    let p3 = *proto_ptr.add(base + (k + 3) * proto_stride_k);
478                    acc +=
479                        coeff[k] * p0 + coeff[k + 1] * p1 + coeff[k + 2] * p2 + coeff[k + 3] * p3;
480                }
481                k += 4;
482            }
483            while k < num_protos {
484                // SAFETY: bounds are guaranteed by ROI clamping in the caller:
485                // y0+y < proto_h, x0+x < proto_w, k < num_protos <= protos.shape()[2].
486                unsafe {
487                    let p = *proto_ptr.add(base + k * proto_stride_k);
488                    acc += coeff[k] * p;
489                }
490                k += 1;
491            }
492
493            let sigmoid = fast_sigmoid(acc);
494            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
495        }
496    }
497    mask
498}