Skip to main content

edgefirst_image/cpu/
masks.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use super::CPUProcessor;
5use crate::Result;
6use edgefirst_decoder::{DetectBox, Segmentation};
7use ndarray::Axis;
8
9impl CPUProcessor {
10    #[allow(clippy::too_many_arguments)]
11    pub(super) fn render_modelpack_segmentation(
12        &mut self,
13        dst_w: usize,
14        dst_h: usize,
15        dst_rs: usize,
16        dst_c: usize,
17        dst_slice: &mut [u8],
18        segmentation: &Segmentation,
19        opacity: f32,
20    ) -> Result<()> {
21        use ndarray_stats::QuantileExt;
22
23        let seg = &segmentation.segmentation;
24        let [seg_height, seg_width, seg_classes] = *seg.shape() else {
25            unreachable!("Array3 did not have [usize; 3] as shape");
26        };
27        let start_y = (dst_h as f32 * segmentation.ymin).round();
28        let end_y = (dst_h as f32 * segmentation.ymax).round();
29        let start_x = (dst_w as f32 * segmentation.xmin).round();
30        let end_x = (dst_w as f32 * segmentation.xmax).round();
31
32        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
33        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
34
35        let start_x_u = (start_x as usize).min(dst_w);
36        let start_y_u = (start_y as usize).min(dst_h);
37        let end_x_u = (end_x as usize).min(dst_w);
38        let end_y_u = (end_y as usize).min(dst_h);
39
40        let argmax = seg.map_axis(Axis(2), |r| r.argmax().unwrap());
41        let get_value_at_nearest = |x: f32, y: f32| -> usize {
42            let x = x.round() as usize;
43            let y = y.round() as usize;
44            argmax
45                .get([y.min(seg_height - 1), x.min(seg_width - 1)])
46                .copied()
47                .unwrap_or(0)
48        };
49
50        for y in start_y_u..end_y_u {
51            for x in start_x_u..end_x_u {
52                let seg_x = (x as f32 - start_x) * scale_x;
53                let seg_y = (y as f32 - start_y) * scale_y;
54                let label = get_value_at_nearest(seg_x, seg_y);
55
56                if label == seg_classes - 1 {
57                    continue;
58                }
59
60                let color = self.colors[label % self.colors.len()];
61
62                let alpha = if opacity == 1.0 {
63                    color[3] as u16
64                } else {
65                    (color[3] as f32 * opacity).round() as u16
66                };
67
68                let dst_index = (y * dst_rs) + (x * dst_c);
69                for c in 0..3 {
70                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
71                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
72                        / 255) as u8;
73                }
74            }
75        }
76
77        Ok(())
78    }
79
80    #[allow(clippy::too_many_arguments)]
81    pub(super) fn render_yolo_segmentation(
82        &mut self,
83        dst_w: usize,
84        dst_h: usize,
85        dst_rs: usize,
86        dst_c: usize,
87        dst_slice: &mut [u8],
88        segmentation: &Segmentation,
89        class: usize,
90        opacity: f32,
91    ) -> Result<()> {
92        let seg = &segmentation.segmentation;
93        let [seg_height, seg_width, classes] = *seg.shape() else {
94            unreachable!("Array3 did not have [usize;3] as shape");
95        };
96        debug_assert_eq!(classes, 1);
97
98        let start_y = (dst_h as f32 * segmentation.ymin).round();
99        let end_y = (dst_h as f32 * segmentation.ymax).round();
100        let start_x = (dst_w as f32 * segmentation.xmin).round();
101        let end_x = (dst_w as f32 * segmentation.xmax).round();
102
103        let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
104        let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
105
106        let start_x_u = (start_x as usize).min(dst_w);
107        let start_y_u = (start_y as usize).min(dst_h);
108        let end_x_u = (end_x as usize).min(dst_w);
109        let end_y_u = (end_y as usize).min(dst_h);
110
111        for y in start_y_u..end_y_u {
112            for x in start_x_u..end_x_u {
113                let seg_x = ((x as f32 - start_x) * scale_x) as usize;
114                let seg_y = ((y as f32 - start_y) * scale_y) as usize;
115                let val = *seg.get([seg_y, seg_x, 0]).unwrap_or(&0);
116
117                if val < 127 {
118                    continue;
119                }
120
121                let color = self.colors[class % self.colors.len()];
122
123                let alpha = if opacity == 1.0 {
124                    color[3] as u16
125                } else {
126                    (color[3] as f32 * opacity).round() as u16
127                };
128
129                let dst_index = (y * dst_rs) + (x * dst_c);
130                for c in 0..3 {
131                    dst_slice[dst_index + c] = ((color[c] as u16 * alpha
132                        + dst_slice[dst_index + c] as u16 * (255 - alpha))
133                        / 255) as u8;
134                }
135            }
136        }
137
138        Ok(())
139    }
140
141    #[allow(clippy::too_many_arguments)]
142    pub(super) fn render_box(
143        &mut self,
144        dst_w: usize,
145        dst_h: usize,
146        dst_rs: usize,
147        dst_c: usize,
148        dst_slice: &mut [u8],
149        detect: &[DetectBox],
150        color_mode: crate::ColorMode,
151    ) -> Result<()> {
152        const LINE_THICKNESS: usize = 3;
153
154        for (idx, d) in detect.iter().enumerate() {
155            use edgefirst_decoder::BoundingBox;
156
157            let color_index = color_mode.index(idx, d.label);
158            let [r, g, b, _] = self.colors[color_index % self.colors.len()];
159            let bbox = d.bbox.to_canonical();
160            let bbox = BoundingBox {
161                xmin: bbox.xmin.clamp(0.0, 1.0),
162                ymin: bbox.ymin.clamp(0.0, 1.0),
163                xmax: bbox.xmax.clamp(0.0, 1.0),
164                ymax: bbox.ymax.clamp(0.0, 1.0),
165            };
166            let inner = [
167                ((dst_w - 1) as f32 * bbox.xmin - 0.5).round() as usize,
168                ((dst_h - 1) as f32 * bbox.ymin - 0.5).round() as usize,
169                ((dst_w - 1) as f32 * bbox.xmax + 0.5).round() as usize,
170                ((dst_h - 1) as f32 * bbox.ymax + 0.5).round() as usize,
171            ];
172
173            let outer = [
174                inner[0].saturating_sub(LINE_THICKNESS),
175                inner[1].saturating_sub(LINE_THICKNESS),
176                (inner[2] + LINE_THICKNESS).min(dst_w),
177                (inner[3] + LINE_THICKNESS).min(dst_h),
178            ];
179
180            // top line
181            for y in outer[1] + 1..=inner[1] {
182                for x in outer[0] + 1..outer[2] {
183                    let index = (y * dst_rs) + (x * dst_c);
184                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
185                }
186            }
187
188            // left and right lines
189            for y in inner[1]..inner[3] {
190                for x in outer[0] + 1..=inner[0] {
191                    let index = (y * dst_rs) + (x * dst_c);
192                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
193                }
194
195                for x in inner[2]..outer[2] {
196                    let index = (y * dst_rs) + (x * dst_c);
197                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
198                }
199            }
200
201            // bottom line
202            for y in inner[3]..outer[3] {
203                for x in outer[0] + 1..outer[2] {
204                    let index = (y * dst_rs) + (x * dst_c);
205                    dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
206                }
207            }
208        }
209        Ok(())
210    }
211
212    /// Materialize segmentation masks from proto data into `Vec<Segmentation>`.
213    ///
214    /// This is the CPU-side decode step of the hybrid mask rendering path:
215    /// call this to get pre-decoded masks, then pass them to
216    /// [`draw_decoded_masks`](crate::ImageProcessorTrait::draw_decoded_masks) for GPU overlay.
217    /// Benchmarks show this hybrid path (CPU decode + GL overlay) is faster
218    /// than the fused GPU `draw_proto_masks` on all tested platforms.
219    ///
220    /// Optimized: fused dequantization + dot product avoids a 3.1MB f32
221    /// allocation for the full proto tensor. Uses fast sigmoid approximation.
222    pub fn materialize_segmentations(
223        &self,
224        detect: &[crate::DetectBox],
225        proto_data: &crate::ProtoData,
226        letterbox: Option<[f32; 4]>,
227    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
228        use edgefirst_decoder::ProtoTensor;
229
230        if detect.is_empty() || proto_data.mask_coefficients.is_empty() {
231            return Ok(Vec::new());
232        }
233
234        // Extract proto tensor metadata for the fused kernel.
235        let (proto_h, proto_w, num_protos) = match &proto_data.protos {
236            ProtoTensor::Quantized { protos, .. } => {
237                (protos.shape()[0], protos.shape()[1], protos.shape()[2])
238            }
239            ProtoTensor::Float(arr) => (arr.shape()[0], arr.shape()[1], arr.shape()[2]),
240        };
241
242        // Precompute inverse letterbox scale for output-coord conversion.
243        let (lx0, inv_lw, ly0, inv_lh) = match letterbox {
244            Some([lx0, ly0, lx1, ly1]) => {
245                let lw = lx1 - lx0;
246                let lh = ly1 - ly0;
247                (
248                    lx0,
249                    if lw > 0.0 { 1.0 / lw } else { 1.0 },
250                    ly0,
251                    if lh > 0.0 { 1.0 / lh } else { 1.0 },
252                )
253            }
254            None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
255        };
256
257        detect
258            .iter()
259            .zip(proto_data.mask_coefficients.iter())
260            .map(|(det, coeff)| {
261                // Canonicalise bbox (swap min/max if inverted) then clamp to [0,1].
262                // Without canonicalisation a degenerate box (ymax < ymin) causes a
263                // near-zero ROI height after saturating_sub, making the mask invisible.
264                let bbox = det.bbox.to_canonical();
265                let xmin = bbox.xmin.clamp(0.0, 1.0);
266                let ymin = bbox.ymin.clamp(0.0, 1.0);
267                let xmax = bbox.xmax.clamp(0.0, 1.0);
268                let ymax = bbox.ymax.clamp(0.0, 1.0);
269
270                // Map to proto-space pixel coordinates (clamp to valid range)
271                let x0 = ((xmin * proto_w as f32) as usize).min(proto_w.saturating_sub(1));
272                let y0 = ((ymin * proto_h as f32) as usize).min(proto_h.saturating_sub(1));
273                let x1 = ((xmax * proto_w as f32).ceil() as usize).min(proto_w);
274                let y1 = ((ymax * proto_h as f32).ceil() as usize).min(proto_h);
275
276                let roi_w = x1.saturating_sub(x0).max(1);
277                let roi_h = y1.saturating_sub(y0).max(1);
278
279                if coeff.len() != num_protos {
280                    return Err(crate::Error::Internal(format!(
281                        "mask coeff length {} != proto channels {num_protos}",
282                        coeff.len()
283                    )));
284                }
285
286                // Fused dequant + dot product + sigmoid, directly producing u8 mask.
287                // Avoids allocating a full f32 proto tensor and the to_shape copy.
288                let mask = match &proto_data.protos {
289                    ProtoTensor::Quantized {
290                        protos,
291                        quantization,
292                    } => {
293                        let scale = quantization.scale;
294                        let zp = quantization.zero_point as f32;
295                        fused_dequant_dot_sigmoid_i8(
296                            protos, coeff, scale, zp, y0, x0, roi_h, roi_w, num_protos,
297                        )
298                    }
299                    ProtoTensor::Float(protos) => {
300                        fused_dot_sigmoid_f32(protos, coeff, y0, x0, roi_h, roi_w, num_protos)
301                    }
302                };
303
304                // Convert proto-space normalised coords to output-image-normalised
305                // coords by applying the inverse letterbox transform.  When no
306                // letterbox was used (lx0=0, inv_lw=1, ly0=0, inv_lh=1) this is a
307                // no-op and the coords are identical to the proto-normalised values.
308                let seg_xmin = ((x0 as f32 / proto_w as f32) - lx0) * inv_lw;
309                let seg_ymin = ((y0 as f32 / proto_h as f32) - ly0) * inv_lh;
310                let seg_xmax = ((x1 as f32 / proto_w as f32) - lx0) * inv_lw;
311                let seg_ymax = ((y1 as f32 / proto_h as f32) - ly0) * inv_lh;
312
313                Ok(edgefirst_decoder::Segmentation {
314                    xmin: seg_xmin.clamp(0.0, 1.0),
315                    ymin: seg_ymin.clamp(0.0, 1.0),
316                    xmax: seg_xmax.clamp(0.0, 1.0),
317                    ymax: seg_ymax.clamp(0.0, 1.0),
318                    segmentation: mask,
319                })
320            })
321            .collect::<crate::Result<Vec<_>>>()
322    }
323
324    /// Produce per-detection masks at `(width, height)` pixel resolution by
325    /// upsampling the full proto plane once then cropping per bbox. Each
326    /// `det.bbox` is assumed to be in model-input normalized coordinates
327    /// (the convention used by the decoder output); when `letterbox` is
328    /// `Some`, `(width, height)` are original-content pixel dims and the
329    /// inverse letterbox transform is applied to both the bbox (for the
330    /// crop region and returned `Segmentation` metadata) and each output
331    /// pixel (for proto-plane sampling). Mask values are binary
332    /// `uint8 {0, 255}` after thresholding sigmoid > 0.5.
333    ///
334    /// Used by [`ImageProcessor::materialize_masks`] when the caller selects
335    /// [`MaskResolution::Scaled`](crate::MaskResolution::Scaled).
336    pub fn materialize_scaled_segmentations(
337        &self,
338        detect: &[crate::DetectBox],
339        proto_data: &crate::ProtoData,
340        letterbox: Option<[f32; 4]>,
341        width: u32,
342        height: u32,
343    ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
344        use edgefirst_decoder::ProtoTensor;
345
346        if detect.is_empty() || proto_data.mask_coefficients.is_empty() {
347            return Ok(Vec::new());
348        }
349        if width == 0 || height == 0 {
350            return Err(crate::Error::InvalidShape(
351                "Scaled mask width/height must be positive".into(),
352            ));
353        }
354
355        match &proto_data.protos {
356            ProtoTensor::Float(protos) => scaled_segmentations_float(
357                detect,
358                &proto_data.mask_coefficients,
359                protos,
360                letterbox,
361                width,
362                height,
363            ),
364            ProtoTensor::Quantized {
365                protos,
366                quantization,
367            } => scaled_segmentations_quant_i8(
368                detect,
369                &proto_data.mask_coefficients,
370                protos,
371                *quantization,
372                letterbox,
373                width,
374                height,
375            ),
376        }
377    }
378}
379
380fn scaled_segmentations_float(
381    detect: &[crate::DetectBox],
382    mask_coefficients: &[Vec<f32>],
383    protos: &ndarray::Array3<f32>,
384    letterbox: Option<[f32; 4]>,
385    width: u32,
386    height: u32,
387) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
388    let (proto_h, proto_w, num_protos) = (protos.shape()[0], protos.shape()[1], protos.shape()[2]);
389
390    // letterbox = [lx0, ly0, lx1, ly1] in model-input normalized coords.
391    // Two uses for the same (lx0, lw) pair:
392    //   * Forward (sampling): each output pixel maps to model-input normalized
393    //     via model_norm = lx0 + out_norm * lw, then to proto pixels.
394    //   * Inverse (bbox frame): det.bbox arrives in model-input normalized
395    //     (decoder output); the crop region and returned Segmentation are in
396    //     output-content normalized, obtained via (bbox - lx0) / lw.
397    // When letterbox is None, (lx0=0, lw=1) makes both directions the identity.
398    let (lx0, lw, ly0, lh) = match letterbox {
399        Some([lx0, ly0, lx1, ly1]) => {
400            let lw = (lx1 - lx0).max(f32::EPSILON);
401            let lh = (ly1 - ly0).max(f32::EPSILON);
402            (lx0, lw, ly0, lh)
403        }
404        None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
405    };
406
407    let out_w = width as usize;
408    let out_h = height as usize;
409
410    detect
411        .iter()
412        .zip(mask_coefficients.iter())
413        .map(|(det, coeff)| {
414            if coeff.len() != num_protos {
415                return Err(crate::Error::Internal(format!(
416                    "mask coeff length {} != proto channels {num_protos}",
417                    coeff.len()
418                )));
419            }
420
421            // Canonicalise, then inverse-letterbox into output-content
422            // normalized space for bbox crop + Segmentation metadata.
423            // Matches the end-of-pipeline transform in the Proto path
424            // (see `materialize_segmentations`).
425            let bbox = det.bbox.to_canonical();
426            let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
427            let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
428            let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
429            let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
430
431            let px0 = (xmin * out_w as f32).round() as usize;
432            let py0 = (ymin * out_h as f32).round() as usize;
433            let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
434            let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
435            let bbox_w = px1.saturating_sub(px0).max(1);
436            let bbox_h = py1.saturating_sub(py0).max(1);
437
438            let mut tile = ndarray::Array3::<u8>::zeros((bbox_h, bbox_w, 1));
439
440            // Map each output pixel within bbox back to proto-plane
441            // coordinates. Center-of-pixel offset (-0.5) matches torch
442            // align_corners=False, the Ultralytics retina convention.
443            for yi in 0..bbox_h {
444                let py = (py0 + yi) as f32;
445                let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
446                let sample_y = model_y_norm * proto_h as f32 - 0.5;
447                for xi in 0..bbox_w {
448                    let px = (px0 + xi) as f32;
449                    let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
450                    let sample_x = model_x_norm * proto_w as f32 - 0.5;
451                    let acc = bilinear_dot(
452                        protos, coeff, num_protos, sample_x, sample_y, proto_w, proto_h,
453                    );
454                    let sigmoid = fast_sigmoid(acc);
455                    tile[[yi, xi, 0]] = if sigmoid > 0.5 { 255 } else { 0 };
456                }
457            }
458
459            Ok(edgefirst_decoder::Segmentation {
460                xmin,
461                ymin,
462                xmax,
463                ymax,
464                segmentation: tile,
465            })
466        })
467        .collect::<crate::Result<Vec<_>>>()
468}
469
470/// Quantized i8 proto variant of [`scaled_segmentations_float`]. Dequantizes
471/// inline during bilinear sample + dot product:
472///   dequant = (i8 - zero_point) * scale
473/// Factoring scale out of the dot product:
474///   sigmoid(scale * Σ coef_k * (i8 - zp))
475/// so we can run a single scale multiply on the accumulator.
476#[allow(clippy::too_many_arguments)]
477fn scaled_segmentations_quant_i8(
478    detect: &[crate::DetectBox],
479    mask_coefficients: &[Vec<f32>],
480    protos: &ndarray::Array3<i8>,
481    quant: edgefirst_decoder::Quantization,
482    letterbox: Option<[f32; 4]>,
483    width: u32,
484    height: u32,
485) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
486    let (proto_h, proto_w, num_protos) = (protos.shape()[0], protos.shape()[1], protos.shape()[2]);
487
488    // See scaled_segmentations_float for the forward / inverse letterbox
489    // convention — det.bbox arrives in model-input normalized; crop region
490    // and Segmentation metadata are output-content normalized via
491    // (bbox - lx0) / lw; the sampling transform goes the other direction.
492    let (lx0, lw, ly0, lh) = match letterbox {
493        Some([lx0, ly0, lx1, ly1]) => {
494            let lw = (lx1 - lx0).max(f32::EPSILON);
495            let lh = (ly1 - ly0).max(f32::EPSILON);
496            (lx0, lw, ly0, lh)
497        }
498        None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
499    };
500
501    let out_w = width as usize;
502    let out_h = height as usize;
503    let scale = quant.scale;
504    let zp = quant.zero_point as f32;
505
506    detect
507        .iter()
508        .zip(mask_coefficients.iter())
509        .map(|(det, coeff)| {
510            if coeff.len() != num_protos {
511                return Err(crate::Error::Internal(format!(
512                    "mask coeff length {} != proto channels {num_protos}",
513                    coeff.len()
514                )));
515            }
516
517            let bbox = det.bbox.to_canonical();
518            let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
519            let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
520            let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
521            let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
522
523            let px0 = (xmin * out_w as f32).round() as usize;
524            let py0 = (ymin * out_h as f32).round() as usize;
525            let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
526            let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
527            let bbox_w = px1.saturating_sub(px0).max(1);
528            let bbox_h = py1.saturating_sub(py0).max(1);
529
530            let mut tile = ndarray::Array3::<u8>::zeros((bbox_h, bbox_w, 1));
531
532            for yi in 0..bbox_h {
533                let py = (py0 + yi) as f32;
534                let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
535                let sample_y = model_y_norm * proto_h as f32 - 0.5;
536                for xi in 0..bbox_w {
537                    let px = (px0 + xi) as f32;
538                    let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
539                    let sample_x = model_x_norm * proto_w as f32 - 0.5;
540                    let acc = bilinear_dot_quant_i8(
541                        protos, coeff, num_protos, sample_x, sample_y, proto_w, proto_h, zp,
542                    );
543                    let sigmoid = fast_sigmoid(scale * acc);
544                    tile[[yi, xi, 0]] = if sigmoid > 0.5 { 255 } else { 0 };
545                }
546            }
547
548            Ok(edgefirst_decoder::Segmentation {
549                xmin,
550                ymin,
551                xmax,
552                ymax,
553                segmentation: tile,
554            })
555        })
556        .collect::<crate::Result<Vec<_>>>()
557}
558
559/// Bilinear sample + zero-point-subtracting dot product over an i8 proto
560/// tensor. Returns the scaled-but-not-yet-sigmoid accumulator; caller applies
561/// the quantization scale and sigmoid.
562#[inline]
563#[allow(clippy::too_many_arguments)]
564fn bilinear_dot_quant_i8(
565    protos: &ndarray::Array3<i8>,
566    coeff: &[f32],
567    num_protos: usize,
568    px: f32,
569    py: f32,
570    proto_w: usize,
571    proto_h: usize,
572    zp: f32,
573) -> f32 {
574    let x0 = (px.floor() as isize).clamp(0, proto_w as isize - 1) as usize;
575    let y0 = (py.floor() as isize).clamp(0, proto_h as isize - 1) as usize;
576    let x1 = (x0 + 1).min(proto_w - 1);
577    let y1 = (y0 + 1).min(proto_h - 1);
578
579    let fx = px - px.floor();
580    let fy = py - py.floor();
581    let w00 = (1.0 - fx) * (1.0 - fy);
582    let w10 = fx * (1.0 - fy);
583    let w01 = (1.0 - fx) * fy;
584    let w11 = fx * fy;
585
586    let mut acc = 0.0f32;
587    for p in 0..num_protos {
588        let v00 = protos[[y0, x0, p]] as f32 - zp;
589        let v10 = protos[[y0, x1, p]] as f32 - zp;
590        let v01 = protos[[y1, x0, p]] as f32 - zp;
591        let v11 = protos[[y1, x1, p]] as f32 - zp;
592        let val = w00 * v00 + w10 * v10 + w01 * v01 + w11 * v11;
593        acc += coeff[p] * val;
594    }
595    acc
596}
597
598/// Bilinear interpolation of proto values at `(px, py)` combined with dot
599/// product against `coeff`. Returns the scalar accumulator before sigmoid.
600///
601/// Samples the four nearest proto texels, weights by bilinear coefficients,
602/// and simultaneously computes the dot product with the mask coefficients.
603#[inline]
604pub(super) fn bilinear_dot(
605    protos: &ndarray::Array3<f32>,
606    coeff: &[f32],
607    num_protos: usize,
608    px: f32,
609    py: f32,
610    proto_w: usize,
611    proto_h: usize,
612) -> f32 {
613    let x0 = (px.floor() as isize).clamp(0, proto_w as isize - 1) as usize;
614    let y0 = (py.floor() as isize).clamp(0, proto_h as isize - 1) as usize;
615    let x1 = (x0 + 1).min(proto_w - 1);
616    let y1 = (y0 + 1).min(proto_h - 1);
617
618    let fx = px - px.floor();
619    let fy = py - py.floor();
620
621    let w00 = (1.0 - fx) * (1.0 - fy);
622    let w10 = fx * (1.0 - fy);
623    let w01 = (1.0 - fx) * fy;
624    let w11 = fx * fy;
625
626    let mut acc = 0.0f32;
627    for p in 0..num_protos {
628        let val = w00 * protos[[y0, x0, p]]
629            + w10 * protos[[y0, x1, p]]
630            + w01 * protos[[y1, x0, p]]
631            + w11 * protos[[y1, x1, p]];
632        acc += coeff[p] * val;
633    }
634    acc
635}
636
637/// Fast sigmoid approximation: `1 / (1 + exp(-x))`.
638///
639/// Uses bit-manipulation for exp() (same approach as `fast_math::exp_raw`).
640/// Max relative error < 1.1% for normal results, which is well within the
641/// precision needed for mask thresholding and u8 quantization.
642#[inline(always)]
643fn fast_sigmoid(x: f32) -> f32 {
644    if x >= 16.0 {
645        return 1.0;
646    }
647    if x <= -16.0 {
648        return 0.0;
649    }
650    // Fast exp(-x) via bit manipulation (Schraudolph's algorithm).
651    // f32 bits: 2^23 * log2(e) * x + (127 << 23) approximates exp(x).
652    const A: f32 = (1u32 << 23) as f32; // 8388608.0
653    const B: f32 = A * std::f32::consts::LOG2_E; // A / ln(2)
654    const C: u32 = 127 << 23; // exponent bias
655    let neg_x = -x;
656    let bits = (B * neg_x) as i32 + C as i32;
657    let exp_neg_x = f32::from_bits(bits as u32);
658    1.0 / (1.0 + exp_neg_x)
659}
660
661/// Fused dequantization + dot product + sigmoid for quantized (i8) protos.
662///
663/// For each pixel in the ROI, computes:
664///   acc = sum_k(coeff[k] * (proto[y, x, k] as f32 - zp) * scale)
665///   mask[y, x] = fast_sigmoid(acc) * 255
666///
667/// This avoids allocating a full f32 proto tensor (3.1MB for 160x160x32)
668/// and the hidden `to_shape` copy on non-contiguous ROI slices.
669#[allow(clippy::too_many_arguments)]
670fn fused_dequant_dot_sigmoid_i8(
671    protos: &ndarray::Array3<i8>,
672    coeff: &[f32],
673    scale: f32,
674    zp: f32,
675    y0: usize,
676    x0: usize,
677    roi_h: usize,
678    roi_w: usize,
679    num_protos: usize,
680) -> ndarray::Array3<u8> {
681    debug_assert!(
682        protos.strides().iter().all(|&s| s >= 0),
683        "negative strides unsupported"
684    );
685    // Pre-scale coefficients: coeff[k] * scale, so the inner loop is
686    // just fma: acc += scaled_coeff[k] * (proto_i8 - zp)
687    let scaled_coeff: Vec<f32> = coeff.iter().map(|&c| c * scale).collect();
688    // Pre-compute coeff_sum * (-zp * scale) offset:
689    // sum_k(coeff[k] * (proto - zp) * scale) = sum_k(scaled_coeff[k] * proto) - zp * sum_k(scaled_coeff[k])
690    // But since zp is a constant per-pixel term, factor it out:
691    // acc = sum_k(scaled_coeff[k] * proto_i8_as_f32) - zp * sum_k(scaled_coeff[k])
692    let zp_offset: f32 = zp * scaled_coeff.iter().sum::<f32>();
693
694    let proto_stride_y = protos.strides()[0] as usize;
695    let proto_stride_x = protos.strides()[1] as usize;
696    let proto_stride_k = protos.strides()[2] as usize;
697    let proto_ptr = protos.as_ptr();
698
699    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
700
701    for y in 0..roi_h {
702        for x in 0..roi_w {
703            // Base pointer for protos[y0+y, x0+x, 0]
704            let base = (y0 + y) * proto_stride_y + (x0 + x) * proto_stride_x;
705
706            let mut acc = 0.0f32;
707            let mut k = 0;
708
709            // Process 4 protos at a time for better ILP
710            let chunks = num_protos / 4;
711            for _ in 0..chunks {
712                // SAFETY: bounds are guaranteed by ROI clamping in the caller:
713                // y0+y < proto_h, x0+x < proto_w, k+3 < num_protos <= protos.shape()[2].
714                unsafe {
715                    let p0 = *proto_ptr.add(base + k * proto_stride_k) as f32;
716                    let p1 = *proto_ptr.add(base + (k + 1) * proto_stride_k) as f32;
717                    let p2 = *proto_ptr.add(base + (k + 2) * proto_stride_k) as f32;
718                    let p3 = *proto_ptr.add(base + (k + 3) * proto_stride_k) as f32;
719                    acc += scaled_coeff[k] * p0
720                        + scaled_coeff[k + 1] * p1
721                        + scaled_coeff[k + 2] * p2
722                        + scaled_coeff[k + 3] * p3;
723                }
724                k += 4;
725            }
726            // Remainder
727            while k < num_protos {
728                // SAFETY: bounds are guaranteed by ROI clamping in the caller:
729                // y0+y < proto_h, x0+x < proto_w, k < num_protos <= protos.shape()[2].
730                unsafe {
731                    let p = *proto_ptr.add(base + k * proto_stride_k) as f32;
732                    acc += scaled_coeff[k] * p;
733                }
734                k += 1;
735            }
736
737            acc -= zp_offset;
738            let sigmoid = fast_sigmoid(acc);
739            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
740        }
741    }
742    mask
743}
744
745/// Fused dot product + sigmoid for f32 protos (no dequantization needed).
746fn fused_dot_sigmoid_f32(
747    protos: &ndarray::Array3<f32>,
748    coeff: &[f32],
749    y0: usize,
750    x0: usize,
751    roi_h: usize,
752    roi_w: usize,
753    num_protos: usize,
754) -> ndarray::Array3<u8> {
755    debug_assert!(
756        protos.strides().iter().all(|&s| s >= 0),
757        "negative strides unsupported"
758    );
759    let proto_stride_y = protos.strides()[0] as usize;
760    let proto_stride_x = protos.strides()[1] as usize;
761    let proto_stride_k = protos.strides()[2] as usize;
762    let proto_ptr = protos.as_ptr();
763
764    let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
765
766    for y in 0..roi_h {
767        for x in 0..roi_w {
768            let base = (y0 + y) * proto_stride_y + (x0 + x) * proto_stride_x;
769
770            let mut acc = 0.0f32;
771            let mut k = 0;
772            let chunks = num_protos / 4;
773            for _ in 0..chunks {
774                // SAFETY: bounds are guaranteed by ROI clamping in the caller:
775                // y0+y < proto_h, x0+x < proto_w, k+3 < num_protos <= protos.shape()[2].
776                unsafe {
777                    let p0 = *proto_ptr.add(base + k * proto_stride_k);
778                    let p1 = *proto_ptr.add(base + (k + 1) * proto_stride_k);
779                    let p2 = *proto_ptr.add(base + (k + 2) * proto_stride_k);
780                    let p3 = *proto_ptr.add(base + (k + 3) * proto_stride_k);
781                    acc +=
782                        coeff[k] * p0 + coeff[k + 1] * p1 + coeff[k + 2] * p2 + coeff[k + 3] * p3;
783                }
784                k += 4;
785            }
786            while k < num_protos {
787                // SAFETY: bounds are guaranteed by ROI clamping in the caller:
788                // y0+y < proto_h, x0+x < proto_w, k < num_protos <= protos.shape()[2].
789                unsafe {
790                    let p = *proto_ptr.add(base + k * proto_stride_k);
791                    acc += coeff[k] * p;
792                }
793                k += 1;
794            }
795
796            let sigmoid = fast_sigmoid(acc);
797            mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
798        }
799    }
800    mask
801}
802
803#[cfg(test)]
804mod scaled_tests {
805    use super::*;
806    use edgefirst_decoder::{BoundingBox, DetectBox, ProtoData, ProtoTensor};
807    use ndarray::Array3;
808
809    fn make_cpu() -> CPUProcessor {
810        CPUProcessor::new()
811    }
812
813    /// A synthetic proto plane where channel 0 is a centred gaussian peaking
814    /// at 10.0 and tailing off to ~0.0 at the edges; other channels zero.
815    /// Mask coefficients (1.0, 0, 0, ...) select only channel 0.
816    fn synthetic_proto_data(proto_h: usize, proto_w: usize, num_protos: usize) -> ProtoData {
817        let mut protos = Array3::<f32>::zeros((proto_h, proto_w, num_protos));
818        let cy = (proto_h as f32 - 1.0) / 2.0;
819        let cx = (proto_w as f32 - 1.0) / 2.0;
820        for y in 0..proto_h {
821            for x in 0..proto_w {
822                let dy = (y as f32 - cy) / cy;
823                let dx = (x as f32 - cx) / cx;
824                protos[[y, x, 0]] = 10.0 * (-(dx * dx + dy * dy)).exp();
825            }
826        }
827        let mut coeffs = vec![0.0_f32; num_protos];
828        coeffs[0] = 1.0;
829        ProtoData {
830            mask_coefficients: vec![coeffs],
831            protos: ProtoTensor::Float(protos),
832        }
833    }
834
835    #[test]
836    fn scaled_central_bbox_produces_foreground_blob() {
837        let cpu = make_cpu();
838        let proto_data = synthetic_proto_data(16, 16, 4);
839
840        // bbox covers the centre 50% of the plane.
841        let detect = vec![DetectBox {
842            bbox: BoundingBox::new(0.25, 0.25, 0.75, 0.75),
843            score: 0.9,
844            label: 0,
845        }];
846
847        let out = cpu
848            .materialize_scaled_segmentations(&detect, &proto_data, None, 64, 64)
849            .expect("scaled mask rendering must succeed");
850
851        assert_eq!(out.len(), 1);
852        let seg = &out[0];
853        assert_eq!(seg.segmentation.shape(), &[32, 32, 1]);
854        let values: Vec<u8> = seg.segmentation.iter().copied().collect();
855        assert!(
856            values.iter().all(|&v| v == 0 || v == 255),
857            "scaled mask must be binary {{0, 255}}, found other values"
858        );
859        let fg_count = values.iter().filter(|&&v| v == 255).count();
860        let fg_frac = fg_count as f32 / values.len() as f32;
861        assert!(
862            fg_frac > 0.80,
863            "expected >80% foreground for centred gaussian, got {fg_frac:.2}"
864        );
865    }
866
867    /// Realistic 160×160 → 640×640 parity vs hand-rolled bilinear + sigmoid
868    /// reference. Tolerate <0.5% pixel mismatch due to fast_sigmoid's ~1.1%
869    /// approximation error near the decision boundary.
870    #[test]
871    fn scaled_realistic_160x160_to_640x640_matches_reference() {
872        let mut rng_seed: u32 = 0xD00D_F00D;
873        let next = |s: &mut u32| -> f32 {
874            *s = s.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
875            (*s as f32) / (u32::MAX as f32) * 2.0 - 1.0
876        };
877
878        let proto_h = 160;
879        let proto_w = 160;
880        let num_protos = 32;
881        let mut protos = Array3::<f32>::zeros((proto_h, proto_w, num_protos));
882        for y in 0..proto_h {
883            for x in 0..proto_w {
884                for k in 0..num_protos {
885                    protos[[y, x, k]] = next(&mut rng_seed) * 3.0;
886                }
887            }
888        }
889        let mut coeffs = vec![0.0_f32; num_protos];
890        for c in &mut coeffs {
891            *c = next(&mut rng_seed);
892        }
893        let proto_data = ProtoData {
894            mask_coefficients: vec![coeffs.clone()],
895            protos: ProtoTensor::Float(protos.clone()),
896        };
897
898        let detect = vec![DetectBox {
899            bbox: BoundingBox::new(0.1, 0.2, 0.6, 0.9),
900            score: 0.9,
901            label: 0,
902        }];
903
904        let cpu = make_cpu();
905        let out = cpu
906            .materialize_scaled_segmentations(&detect, &proto_data, None, 640, 640)
907            .unwrap();
908        assert_eq!(out.len(), 1);
909        let hal_tile = &out[0].segmentation;
910
911        let px0 = (0.1_f32 * 640.0_f32).round() as usize;
912        let py0 = (0.2_f32 * 640.0_f32).round() as usize;
913        let px1 = (0.6_f32 * 640.0_f32).round() as usize;
914        let py1 = (0.9_f32 * 640.0_f32).round() as usize;
915        let bbox_h = py1 - py0;
916        let bbox_w = px1 - px0;
917        assert_eq!(hal_tile.shape(), &[bbox_h, bbox_w, 1]);
918
919        let sx = proto_w as f32 / 640.0_f32;
920        let sy = proto_h as f32 / 640.0_f32;
921        let mut mismatches = 0_usize;
922        for yi in 0..bbox_h {
923            let py = (py0 + yi) as f32;
924            let sy_coord = (py + 0.5) * sy - 0.5;
925            for xi in 0..bbox_w {
926                let px = (px0 + xi) as f32;
927                let sx_coord = (px + 0.5) * sx - 0.5;
928                let x0 = sx_coord.floor().clamp(0.0, proto_w as f32 - 1.0) as usize;
929                let y0 = sy_coord.floor().clamp(0.0, proto_h as f32 - 1.0) as usize;
930                let x1 = (x0 + 1).min(proto_w - 1);
931                let y1 = (y0 + 1).min(proto_h - 1);
932                let fx = sx_coord - sx_coord.floor();
933                let fy = sy_coord - sy_coord.floor();
934                let w00 = (1.0 - fx) * (1.0 - fy);
935                let w10 = fx * (1.0 - fy);
936                let w01 = (1.0 - fx) * fy;
937                let w11 = fx * fy;
938                let mut acc = 0.0_f32;
939                for p in 0..num_protos {
940                    let val = w00 * protos[[y0, x0, p]]
941                        + w10 * protos[[y0, x1, p]]
942                        + w01 * protos[[y1, x0, p]]
943                        + w11 * protos[[y1, x1, p]];
944                    acc += coeffs[p] * val;
945                }
946                let sigmoid = 1.0_f32 / (1.0 + (-acc).exp());
947                let expected: u8 = if sigmoid > 0.5 { 255 } else { 0 };
948                if hal_tile[[yi, xi, 0]] != expected {
949                    mismatches += 1;
950                }
951            }
952        }
953        let total = bbox_h * bbox_w;
954        let mismatch_rate = mismatches as f32 / total as f32;
955        assert!(
956            mismatch_rate < 0.005,
957            "mismatch rate {mismatch_rate:.4} > 0.5% tolerance \
958             ({mismatches}/{total} pixels)"
959        );
960    }
961
962    /// Letterbox path: the inverse letterbox transform must shift the
963    /// bbox sample into the proto plane correctly when (width, height)
964    /// are original-content pixel dims. A centred gaussian on the proto
965    /// plane, aligned with the content region via letterbox, should
966    /// still produce a centred foreground blob in the original-content
967    /// coordinate frame.
968    #[test]
969    fn scaled_letterbox_original_content_coords() {
970        let cpu = make_cpu();
971        let proto_data = synthetic_proto_data(16, 16, 4);
972
973        // Bbox in *model-input* normalized coords — the frame used by the
974        // decoder output and consumed by `materialize_scaled_segmentations`.
975        // Under the letterbox below, this maps to output-content normalized
976        // (0.25, 0.3, 0.75, 0.7) and centres on the proto-plane centre.
977        let detect = vec![DetectBox {
978            bbox: BoundingBox::new(0.25, 0.3875, 0.75, 0.6125),
979            score: 0.9,
980            label: 0,
981        }];
982
983        // Letterbox: content fills full width, centred 56.25% vertically.
984        // Caller requests Scaled(640, 360) in original-content space.
985        let letterbox = Some([0.0_f32, 0.21875, 1.0, 0.78125]);
986        let out = cpu
987            .materialize_scaled_segmentations(&detect, &proto_data, letterbox, 640, 360)
988            .expect("letterbox scaled rendering must succeed");
989
990        assert_eq!(out.len(), 1);
991        let seg = &out[0];
992        let bbox_w = (0.5_f32 * 640.0_f32).round() as usize; // 320
993        let bbox_h = (0.4_f32 * 360.0_f32).round() as usize; // 144
994        assert_eq!(seg.segmentation.shape(), &[bbox_h, bbox_w, 1]);
995        let uniq: std::collections::BTreeSet<u8> = seg.segmentation.iter().copied().collect();
996        assert!(
997            uniq.iter().all(|&v| v == 0 || v == 255),
998            "letterbox scaled mask must be binary {{0, 255}}, got {uniq:?}"
999        );
1000
1001        // Middle-25%-square of the bbox should be solidly foreground
1002        // (the proto gaussian is centred; the bbox is centred on the
1003        // original-content frame; letterbox preserves centre alignment).
1004        let cy = bbox_h / 2;
1005        let cx = bbox_w / 2;
1006        let patch_h = bbox_h / 4;
1007        let patch_w = bbox_w / 4;
1008        let mut fg = 0_usize;
1009        let mut total = 0_usize;
1010        for y in cy.saturating_sub(patch_h / 2)..=(cy + patch_h / 2).min(bbox_h - 1) {
1011            for x in cx.saturating_sub(patch_w / 2)..=(cx + patch_w / 2).min(bbox_w - 1) {
1012                if seg.segmentation[[y, x, 0]] == 255 {
1013                    fg += 1;
1014                }
1015                total += 1;
1016            }
1017        }
1018        let fg_frac = fg as f32 / total as f32;
1019        assert!(
1020            fg_frac > 0.95,
1021            "centre of letterboxed bbox should be >95% foreground, got {fg_frac:.2}"
1022        );
1023    }
1024
1025    /// Quantized i8 protos must produce the same binary mask as a float
1026    /// equivalent, modulo quantization rounding error.
1027    #[test]
1028    fn scaled_quantized_proto_produces_same_result_as_float() {
1029        use edgefirst_decoder::Quantization;
1030
1031        let proto_h = 32;
1032        let proto_w = 32;
1033        let num_protos = 8;
1034        let mut protos_f32 = Array3::<f32>::zeros((proto_h, proto_w, num_protos));
1035        for y in 0..proto_h {
1036            for x in 0..proto_w {
1037                for k in 0..num_protos {
1038                    let v = ((y + x + k * 3) as f32 * 0.05).sin() * 3.0;
1039                    protos_f32[[y, x, k]] = v;
1040                }
1041            }
1042        }
1043        let scale = 0.1_f32;
1044        let zp = 0_i32;
1045        let protos_i8 = protos_f32.mapv(|v| (v / scale).round().clamp(-127.0, 127.0) as i8);
1046        let coeffs: Vec<f32> = (0..num_protos).map(|k| (k as f32 - 3.5) * 0.3).collect();
1047
1048        let pd_float = ProtoData {
1049            mask_coefficients: vec![coeffs.clone()],
1050            protos: ProtoTensor::Float(protos_f32.clone()),
1051        };
1052        let pd_quant = ProtoData {
1053            mask_coefficients: vec![coeffs.clone()],
1054            protos: ProtoTensor::Quantized {
1055                protos: protos_i8,
1056                quantization: Quantization {
1057                    scale,
1058                    zero_point: zp,
1059                },
1060            },
1061        };
1062
1063        let detect = vec![DetectBox {
1064            bbox: BoundingBox::new(0.1, 0.1, 0.9, 0.9),
1065            score: 0.9,
1066            label: 0,
1067        }];
1068
1069        let cpu = make_cpu();
1070        let out_f = cpu
1071            .materialize_scaled_segmentations(&detect, &pd_float, None, 320, 320)
1072            .unwrap();
1073        let out_q = cpu
1074            .materialize_scaled_segmentations(&detect, &pd_quant, None, 320, 320)
1075            .unwrap();
1076        let mismatches = out_f[0]
1077            .segmentation
1078            .iter()
1079            .zip(out_q[0].segmentation.iter())
1080            .filter(|(a, b)| a != b)
1081            .count();
1082        let total = out_f[0].segmentation.len();
1083        assert!(
1084            mismatches < total / 200,
1085            "quantized and float diverged at {mismatches}/{total} pixels (>0.5%)"
1086        );
1087    }
1088}