Skip to main content

edgefirst_image/cpu/
masks.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use super::CPUProcessor;
5use crate::Result;
6use edgefirst_decoder::{DetectBox, Segmentation};
7use ndarray::Axis;
8
9impl CPUProcessor {
10    pub(super) fn render_modelpack_segmentation(
11        &mut self,
12        dst_w: usize,
13        dst_h: usize,
14        dst_rs: usize,
15        dst_c: usize,
16        dst_slice: &mut [u8],
17        segmentation: &Segmentation,
18    ) -> Result<()> {
19        use ndarray_stats::QuantileExt;
20
21        let seg = &segmentation.segmentation;
22        let [seg_height, seg_width, seg_classes] = *seg.shape() else {
23            unreachable!("Array3 did not have [usize; 3] as shape");
24        };
25        let start_y = (dst_h as f32 * segmentation.ymin).round();
26        let end_y = (dst_h as f32 * segmentation.ymax).round();
27        let start_x = (dst_w as f32 * segmentation.xmin).round();
28        let end_x = (dst_w as f32 * segmentation.xmax).round();
29
30        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
31        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
32
33        let start_x_u = (start_x as usize).min(dst_w);
34        let start_y_u = (start_y as usize).min(dst_h);
35        let end_x_u = (end_x as usize).min(dst_w);
36        let end_y_u = (end_y as usize).min(dst_h);
37
38        let argmax = seg.map_axis(Axis(2), |r| r.argmax().unwrap());
39        let get_value_at_nearest = |x: f32, y: f32| -> usize {
40            let x = x.round() as usize;
41            let y = y.round() as usize;
42            argmax
43                .get([y.min(seg_height - 1), x.min(seg_width - 1)])
44                .copied()
45                .unwrap_or(0)
46        };
47
48        for y in start_y_u..end_y_u {
49            for x in start_x_u..end_x_u {
50                let seg_x = (x as f32 - start_x) * scale_x;
51                let seg_y = (y as f32 - start_y) * scale_y;
52                let label = get_value_at_nearest(seg_x, seg_y);
53
54                if label == seg_classes - 1 {
55                    continue;
56                }
57
58                let color = self.colors[label % self.colors.len()];
59
60                let alpha = color[3] as u16;
61
62                let dst_index = (y * dst_rs) + (x * dst_c);
63                for c in 0..3 {
64                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
65                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
66                        / 255) as u8;
67                }
68            }
69        }
70
71        Ok(())
72    }
73
74    #[allow(clippy::too_many_arguments)]
75    pub(super) fn render_yolo_segmentation(
76        &mut self,
77        dst_w: usize,
78        dst_h: usize,
79        dst_rs: usize,
80        dst_c: usize,
81        dst_slice: &mut [u8],
82        segmentation: &Segmentation,
83        class: usize,
84    ) -> Result<()> {
85        let seg = &segmentation.segmentation;
86        let [seg_height, seg_width, classes] = *seg.shape() else {
87            unreachable!("Array3 did not have [usize;3] as shape");
88        };
89        debug_assert_eq!(classes, 1);
90
91        let start_y = (dst_h as f32 * segmentation.ymin).round();
92        let end_y = (dst_h as f32 * segmentation.ymax).round();
93        let start_x = (dst_w as f32 * segmentation.xmin).round();
94        let end_x = (dst_w as f32 * segmentation.xmax).round();
95
96        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
97        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
98
99        let start_x_u = (start_x as usize).min(dst_w);
100        let start_y_u = (start_y as usize).min(dst_h);
101        let end_x_u = (end_x as usize).min(dst_w);
102        let end_y_u = (end_y as usize).min(dst_h);
103
104        for y in start_y_u..end_y_u {
105            for x in start_x_u..end_x_u {
106                let seg_x = ((x as f32 - start_x) * scale_x) as usize;
107                let seg_y = ((y as f32 - start_y) * scale_y) as usize;
108                let val = *seg.get([seg_y, seg_x, 0]).unwrap_or(&0);
109
110                if val < 127 {
111                    continue;
112                }
113
114                let color = self.colors[class % self.colors.len()];
115
116                let alpha = color[3] as u16;
117
118                let dst_index = (y * dst_rs) + (x * dst_c);
119                for c in 0..3 {
120                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
121                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
122                        / 255) as u8;
123                }
124            }
125        }
126
127        Ok(())
128    }
129
130    pub(super) fn render_box(
131        &mut self,
132        dst_w: usize,
133        dst_h: usize,
134        dst_rs: usize,
135        dst_c: usize,
136        dst_slice: &mut [u8],
137        detect: &[DetectBox],
138    ) -> Result<()> {
139        const LINE_THICKNESS: usize = 3;
140
141        for d in detect {
142            use edgefirst_decoder::BoundingBox;
143
144            let label = d.label;
145            let [r, g, b, _] = self.colors[label % self.colors.len()];
146            let bbox = d.bbox.to_canonical();
147            let bbox = BoundingBox {
148                xmin: bbox.xmin.clamp(0.0, 1.0),
149                ymin: bbox.ymin.clamp(0.0, 1.0),
150                xmax: bbox.xmax.clamp(0.0, 1.0),
151                ymax: bbox.ymax.clamp(0.0, 1.0),
152            };
153            let inner = [
154                ((dst_w - 1) as f32 * bbox.xmin - 0.5).round() as usize,
155                ((dst_h - 1) as f32 * bbox.ymin - 0.5).round() as usize,
156                ((dst_w - 1) as f32 * bbox.xmax + 0.5).round() as usize,
157                ((dst_h - 1) as f32 * bbox.ymax + 0.5).round() as usize,
158            ];
159
160            let outer = [
161                inner[0].saturating_sub(LINE_THICKNESS),
162                inner[1].saturating_sub(LINE_THICKNESS),
163                (inner[2] + LINE_THICKNESS).min(dst_w),
164                (inner[3] + LINE_THICKNESS).min(dst_h),
165            ];
166
167            // top line
168            for y in outer[1] + 1..=inner[1] {
169                for x in outer[0] + 1..outer[2] {
170                    let index = (y * dst_rs) + (x * dst_c);
171                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
172                }
173            }
174
175            // left and right lines
176            for y in inner[1]..inner[3] {
177                for x in outer[0] + 1..=inner[0] {
178                    let index = (y * dst_rs) + (x * dst_c);
179                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
180                }
181
182                for x in inner[2]..outer[2] {
183                    let index = (y * dst_rs) + (x * dst_c);
184                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
185                }
186            }
187
188            // bottom line
189            for y in inner[3]..outer[3] {
190                for x in outer[0] + 1..outer[2] {
191                    let index = (y * dst_rs) + (x * dst_c);
192                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
193                }
194            }
195        }
196        Ok(())
197    }
198
199    /// Materialize segmentation masks from proto data into `Vec<Segmentation>`.
200    ///
201    /// This is the CPU-side decode step of the hybrid mask rendering path:
202    /// call this to get pre-decoded masks, then pass them to
203    /// [`draw_masks`](crate::ImageProcessorTrait::draw_masks) for GPU overlay.
204    /// Benchmarks show this hybrid path (CPU decode + GL overlay) is faster
205    /// than the fused GPU `draw_masks_proto` on all tested platforms.
206    pub fn materialize_segmentations(
207        &self,
208        detect: &[crate::DetectBox],
209        proto_data: &crate::ProtoData,
210    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
211        if detect.is_empty() || proto_data.mask_coefficients.is_empty() {
212            return Ok(Vec::new());
213        }
214
215        let protos_cow = proto_data.protos.as_f32();
216        let protos = protos_cow.as_ref();
217        let proto_h = protos.shape()[0];
218        let proto_w = protos.shape()[1];
219        let num_protos = protos.shape()[2];
220
221        detect
222            .iter()
223            .zip(proto_data.mask_coefficients.iter())
224            .map(|(det, coeff)| {
225                // Clamp bbox to [0, 1]
226                let xmin = det.bbox.xmin.clamp(0.0, 1.0);
227                let ymin = det.bbox.ymin.clamp(0.0, 1.0);
228                let xmax = det.bbox.xmax.clamp(0.0, 1.0);
229                let ymax = det.bbox.ymax.clamp(0.0, 1.0);
230
231                // Map to proto-space pixel coordinates (clamp to valid range)
232                let x0 = ((xmin * proto_w as f32) as usize).min(proto_w.saturating_sub(1));
233                let y0 = ((ymin * proto_h as f32) as usize).min(proto_h.saturating_sub(1));
234                let x1 = ((xmax * proto_w as f32).ceil() as usize).min(proto_w);
235                let y1 = ((ymax * proto_h as f32).ceil() as usize).min(proto_h);
236
237                let roi_w = x1.saturating_sub(x0).max(1);
238                let roi_h = y1.saturating_sub(y0).max(1);
239
240                // Extract proto ROI and compute mask_coeff @ protos
241                let roi = protos.slice(ndarray::s![y0..y0 + roi_h, x0..x0 + roi_w, ..]);
242                let coeff_arr = ndarray::Array2::from_shape_vec((1, num_protos), coeff.clone())
243                    .map_err(|e| crate::Error::Internal(format!("mask coeff shape: {e}")))?;
244                let protos_2d = roi
245                    .to_shape((roi_h * roi_w, num_protos))
246                    .map_err(|e| crate::Error::Internal(format!("proto reshape: {e}")))?
247                    .reversed_axes();
248                let mask = coeff_arr.dot(&protos_2d);
249                let mask = mask
250                    .into_shape_with_order((roi_h, roi_w, 1))
251                    .map_err(|e| crate::Error::Internal(format!("mask reshape: {e}")))?
252                    .mapv(|x: f32| {
253                        let sigmoid = 1.0 / (1.0 + (-x).exp());
254                        (sigmoid * 255.0).round() as u8
255                    });
256
257                Ok(edgefirst_decoder::Segmentation {
258                    xmin: x0 as f32 / proto_w as f32,
259                    ymin: y0 as f32 / proto_h as f32,
260                    xmax: x1 as f32 / proto_w as f32,
261                    ymax: y1 as f32 / proto_h as f32,
262                    segmentation: mask,
263                })
264            })
265            .collect::<crate::Result<Vec<_>>>()
266    }
267
268    /// Renders per-instance grayscale masks from raw prototype data at full
269    /// output resolution. Used internally by [`decode_masks_atlas`] to generate
270    /// per-detection mask crops that are then packed into the atlas.
271    pub(super) fn render_masks_from_protos(
272        &mut self,
273        detect: &[crate::DetectBox],
274        proto_data: crate::ProtoData,
275        output_width: usize,
276        output_height: usize,
277    ) -> Result<Vec<crate::MaskResult>> {
278        use crate::FunctionTimer;
279
280        let _timer = FunctionTimer::new("CPUProcessor::render_masks_from_protos");
281
282        if detect.is_empty() || proto_data.mask_coefficients.is_empty() {
283            return Ok(Vec::new());
284        }
285
286        let protos_cow = proto_data.protos.as_f32();
287        let protos = protos_cow.as_ref();
288        let proto_h = protos.shape()[0];
289        let proto_w = protos.shape()[1];
290        let num_protos = protos.shape()[2];
291
292        let mut results = Vec::with_capacity(detect.len());
293
294        for (det, coeff) in detect.iter().zip(proto_data.mask_coefficients.iter()) {
295            let start_x = (output_width as f32 * det.bbox.xmin).round() as usize;
296            let start_y = (output_height as f32 * det.bbox.ymin).round() as usize;
297            // Use span-based rounding to match the numpy reference convention.
298            let bbox_w = ((det.bbox.xmax - det.bbox.xmin) * output_width as f32)
299                .round()
300                .max(1.0) as usize;
301            let bbox_h = ((det.bbox.ymax - det.bbox.ymin) * output_height as f32)
302                .round()
303                .max(1.0) as usize;
304            let bbox_w = bbox_w.min(output_width.saturating_sub(start_x));
305            let bbox_h = bbox_h.min(output_height.saturating_sub(start_y));
306
307            let mut pixels = vec![0u8; bbox_w * bbox_h];
308
309            for row in 0..bbox_h {
310                let y = start_y + row;
311                for col in 0..bbox_w {
312                    let x = start_x + col;
313                    let px = (x as f32 / output_width as f32) * proto_w as f32 - 0.5;
314                    let py = (y as f32 / output_height as f32) * proto_h as f32 - 0.5;
315                    let acc = bilinear_dot(protos, coeff, num_protos, px, py, proto_w, proto_h);
316                    let mask = 1.0 / (1.0 + (-acc).exp());
317                    pixels[row * bbox_w + col] = if mask > 0.5 { 255 } else { 0 };
318                }
319            }
320
321            results.push(crate::MaskResult {
322                x: start_x,
323                y: start_y,
324                w: bbox_w,
325                h: bbox_h,
326                pixels,
327            });
328        }
329
330        Ok(results)
331    }
332}
333
334/// Bilinear interpolation of proto values at `(px, py)` combined with dot
335/// product against `coeff`. Returns the scalar accumulator before sigmoid.
336///
337/// Samples the four nearest proto texels, weights by bilinear coefficients,
338/// and simultaneously computes the dot product with the mask coefficients.
339#[inline]
340pub(super) fn bilinear_dot(
341    protos: &ndarray::Array3<f32>,
342    coeff: &[f32],
343    num_protos: usize,
344    px: f32,
345    py: f32,
346    proto_w: usize,
347    proto_h: usize,
348) -> f32 {
349    let x0 = (px.floor() as isize).clamp(0, proto_w as isize - 1) as usize;
350    let y0 = (py.floor() as isize).clamp(0, proto_h as isize - 1) as usize;
351    let x1 = (x0 + 1).min(proto_w - 1);
352    let y1 = (y0 + 1).min(proto_h - 1);
353
354    let fx = px - px.floor();
355    let fy = py - py.floor();
356
357    let w00 = (1.0 - fx) * (1.0 - fy);
358    let w10 = fx * (1.0 - fy);
359    let w01 = (1.0 - fx) * fy;
360    let w11 = fx * fy;
361
362    let mut acc = 0.0f32;
363    for p in 0..num_protos {
364        let val = w00 * protos[[y0, x0, p]]
365            + w10 * protos[[y0, x1, p]]
366            + w01 * protos[[y1, x0, p]]
367            + w11 * protos[[y1, x1, p]];
368        acc += coeff[p] * val;
369    }
370    acc
371}