1use super::CPUProcessor;
5use crate::Result;
6use edgefirst_decoder::{DetectBox, Segmentation};
7use ndarray::Axis;
8use rayon::prelude::*;
9
10impl CPUProcessor {
11 #[allow(clippy::too_many_arguments)]
12 pub(super) fn render_modelpack_segmentation(
13 &mut self,
14 dst_w: usize,
15 dst_h: usize,
16 dst_rs: usize,
17 dst_c: usize,
18 dst_slice: &mut [u8],
19 segmentation: &Segmentation,
20 opacity: f32,
21 ) -> Result<()> {
22 use ndarray_stats::QuantileExt;
23
24 let seg = &segmentation.segmentation;
25 let [seg_height, seg_width, seg_classes] = *seg.shape() else {
26 unreachable!("Array3 did not have [usize; 3] as shape");
27 };
28 let start_y = (dst_h as f32 * segmentation.ymin).round();
29 let end_y = (dst_h as f32 * segmentation.ymax).round();
30 let start_x = (dst_w as f32 * segmentation.xmin).round();
31 let end_x = (dst_w as f32 * segmentation.xmax).round();
32
33 let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
34 let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
35
36 let start_x_u = (start_x as usize).min(dst_w);
37 let start_y_u = (start_y as usize).min(dst_h);
38 let end_x_u = (end_x as usize).min(dst_w);
39 let end_y_u = (end_y as usize).min(dst_h);
40
41 let argmax = seg.map_axis(Axis(2), |r| r.argmax().unwrap());
42 let get_value_at_nearest = |x: f32, y: f32| -> usize {
43 let x = x.round() as usize;
44 let y = y.round() as usize;
45 argmax
46 .get([y.min(seg_height - 1), x.min(seg_width - 1)])
47 .copied()
48 .unwrap_or(0)
49 };
50
51 for y in start_y_u..end_y_u {
52 for x in start_x_u..end_x_u {
53 let seg_x = (x as f32 - start_x) * scale_x;
54 let seg_y = (y as f32 - start_y) * scale_y;
55 let label = get_value_at_nearest(seg_x, seg_y);
56
57 if label == seg_classes - 1 {
58 continue;
59 }
60
61 let color = self.colors[label % self.colors.len()];
62
63 let alpha = if opacity == 1.0 {
64 color[3] as u16
65 } else {
66 (color[3] as f32 * opacity).round() as u16
67 };
68
69 let dst_index = (y * dst_rs) + (x * dst_c);
70 for c in 0..3 {
71 dst_slice[dst_index + c] = ((color[c] as u16 * alpha
72 + dst_slice[dst_index + c] as u16 * (255 - alpha))
73 / 255) as u8;
74 }
75 }
76 }
77
78 Ok(())
79 }
80
81 #[allow(clippy::too_many_arguments)]
82 pub(super) fn render_yolo_segmentation(
83 &mut self,
84 dst_w: usize,
85 dst_h: usize,
86 dst_rs: usize,
87 dst_c: usize,
88 dst_slice: &mut [u8],
89 segmentation: &Segmentation,
90 class: usize,
91 opacity: f32,
92 ) -> Result<()> {
93 let seg = &segmentation.segmentation;
94 let [seg_height, seg_width, classes] = *seg.shape() else {
95 unreachable!("Array3 did not have [usize;3] as shape");
96 };
97 debug_assert_eq!(classes, 1);
98
99 let start_y = (dst_h as f32 * segmentation.ymin).round();
100 let end_y = (dst_h as f32 * segmentation.ymax).round();
101 let start_x = (dst_w as f32 * segmentation.xmin).round();
102 let end_x = (dst_w as f32 * segmentation.xmax).round();
103
104 let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
105 let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
106
107 let start_x_u = (start_x as usize).min(dst_w);
108 let start_y_u = (start_y as usize).min(dst_h);
109 let end_x_u = (end_x as usize).min(dst_w);
110 let end_y_u = (end_y as usize).min(dst_h);
111
112 for y in start_y_u..end_y_u {
113 for x in start_x_u..end_x_u {
114 let seg_x = ((x as f32 - start_x) * scale_x) as usize;
115 let seg_y = ((y as f32 - start_y) * scale_y) as usize;
116 let val = *seg.get([seg_y, seg_x, 0]).unwrap_or(&0);
117
118 if val < 127 {
119 continue;
120 }
121
122 let color = self.colors[class % self.colors.len()];
123
124 let alpha = if opacity == 1.0 {
125 color[3] as u16
126 } else {
127 (color[3] as f32 * opacity).round() as u16
128 };
129
130 let dst_index = (y * dst_rs) + (x * dst_c);
131 for c in 0..3 {
132 dst_slice[dst_index + c] = ((color[c] as u16 * alpha
133 + dst_slice[dst_index + c] as u16 * (255 - alpha))
134 / 255) as u8;
135 }
136 }
137 }
138
139 Ok(())
140 }
141
142 #[allow(clippy::too_many_arguments)]
143 pub(super) fn render_box(
144 &mut self,
145 dst_w: usize,
146 dst_h: usize,
147 dst_rs: usize,
148 dst_c: usize,
149 dst_slice: &mut [u8],
150 detect: &[DetectBox],
151 color_mode: crate::ColorMode,
152 ) -> Result<()> {
153 const LINE_THICKNESS: usize = 3;
154
155 for (idx, d) in detect.iter().enumerate() {
156 use edgefirst_decoder::BoundingBox;
157
158 let color_index = color_mode.index(idx, d.label);
159 let [r, g, b, _] = self.colors[color_index % self.colors.len()];
160 let bbox = d.bbox.to_canonical();
161 let bbox = BoundingBox {
162 xmin: bbox.xmin.clamp(0.0, 1.0),
163 ymin: bbox.ymin.clamp(0.0, 1.0),
164 xmax: bbox.xmax.clamp(0.0, 1.0),
165 ymax: bbox.ymax.clamp(0.0, 1.0),
166 };
167 let inner = [
168 ((dst_w - 1) as f32 * bbox.xmin - 0.5).round() as usize,
169 ((dst_h - 1) as f32 * bbox.ymin - 0.5).round() as usize,
170 ((dst_w - 1) as f32 * bbox.xmax + 0.5).round() as usize,
171 ((dst_h - 1) as f32 * bbox.ymax + 0.5).round() as usize,
172 ];
173
174 let outer = [
175 inner[0].saturating_sub(LINE_THICKNESS),
176 inner[1].saturating_sub(LINE_THICKNESS),
177 (inner[2] + LINE_THICKNESS).min(dst_w),
178 (inner[3] + LINE_THICKNESS).min(dst_h),
179 ];
180
181 for y in outer[1] + 1..=inner[1] {
183 for x in outer[0] + 1..outer[2] {
184 let index = (y * dst_rs) + (x * dst_c);
185 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
186 }
187 }
188
189 for y in inner[1]..inner[3] {
191 for x in outer[0] + 1..=inner[0] {
192 let index = (y * dst_rs) + (x * dst_c);
193 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
194 }
195
196 for x in inner[2]..outer[2] {
197 let index = (y * dst_rs) + (x * dst_c);
198 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
199 }
200 }
201
202 for y in inner[3]..outer[3] {
204 for x in outer[0] + 1..outer[2] {
205 let index = (y * dst_rs) + (x * dst_c);
206 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
207 }
208 }
209 }
210 Ok(())
211 }
212
213 pub fn materialize_segmentations(
224 &self,
225 detect: &[crate::DetectBox],
226 proto_data: &crate::ProtoData,
227 letterbox: Option<[f32; 4]>,
228 ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
229 use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
230
231 if detect.is_empty() {
232 return Ok(Vec::new());
233 }
234 let proto_shape = proto_data.protos.shape();
235 if proto_shape.len() != 3 {
236 return Err(crate::Error::InvalidShape(format!(
237 "protos tensor must be rank-3, got {proto_shape:?}"
238 )));
239 }
240 let (proto_h, proto_w, num_protos) = (proto_shape[0], proto_shape[1], proto_shape[2]);
241 let coeff_shape = proto_data.mask_coefficients.shape();
242 if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
243 return Err(crate::Error::InvalidShape(format!(
244 "mask_coefficients shape {coeff_shape:?} incompatible with protos \
245 {proto_shape:?} (expected [N, {num_protos}])"
246 )));
247 }
248 if coeff_shape[0] == 0 {
249 return Ok(Vec::new());
250 }
251 if coeff_shape[0] != detect.len() {
252 return Err(crate::Error::Internal(format!(
253 "mask_coefficients rows {} != detection count {}",
254 coeff_shape[0],
255 detect.len()
256 )));
257 }
258
259 let (lx0, inv_lw, ly0, inv_lh) = match letterbox {
261 Some([lx0, ly0, lx1, ly1]) => {
262 let lw = lx1 - lx0;
263 let lh = ly1 - ly0;
264 (
265 lx0,
266 if lw > 0.0 { 1.0 / lw } else { 1.0 },
267 ly0,
268 if lh > 0.0 { 1.0 / lh } else { 1.0 },
269 )
270 }
271 None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
272 };
273
274 let coeff_f32_storage: Vec<f32>;
279 let coeff_f32_slice: &[f32] = match proto_data.mask_coefficients.dtype() {
280 DType::F32 => {
281 let t = proto_data
282 .mask_coefficients
283 .as_f32()
284 .expect("dtype matched F32");
285 let m = t.map()?;
286 coeff_f32_storage = m.as_slice().to_vec();
287 &coeff_f32_storage[..]
288 }
289 DType::F16 => {
290 let t = proto_data
291 .mask_coefficients
292 .as_f16()
293 .expect("dtype matched F16");
294 let m = t.map()?;
295 coeff_f32_storage = m.as_slice().iter().map(|v| v.to_f32()).collect();
296 &coeff_f32_storage[..]
297 }
298 other => {
299 return Err(crate::Error::InvalidShape(format!(
300 "mask_coefficients dtype {other:?} not supported; expected F32 or F16"
301 )));
302 }
303 };
304
305 match proto_data.protos.dtype() {
311 DType::I8 => {
312 let t = proto_data.protos.as_i8().expect("dtype matched I8");
313 let quant = t.quantization().ok_or_else(|| {
314 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
315 })?;
316 let m = t.map()?;
317 let protos_slice = m.as_slice();
318 detect
319 .par_iter()
320 .enumerate()
321 .map(|(i, det)| {
322 let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
323 let (x0, y0, x1, y1, roi_w, roi_h) =
324 bbox_to_proto_roi(det, proto_w, proto_h);
325 let mask = fused_dequant_dot_sigmoid_i8_slice(
326 protos_slice,
327 coeff,
328 quant,
329 proto_h,
330 proto_w,
331 y0,
332 x0,
333 roi_h,
334 roi_w,
335 num_protos,
336 )?;
337 Ok(seg_from_roi(
338 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
339 ))
340 })
341 .collect()
342 }
343 DType::F32 => {
344 let t = proto_data.protos.as_f32().expect("dtype matched F32");
345 let m = t.map()?;
346 let protos_slice = m.as_slice();
347 detect
348 .par_iter()
349 .enumerate()
350 .map(|(i, det)| {
351 let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
352 let (x0, y0, x1, y1, roi_w, roi_h) =
353 bbox_to_proto_roi(det, proto_w, proto_h);
354 let mask = fused_dot_sigmoid_f32_slice(
355 protos_slice,
356 coeff,
357 proto_h,
358 proto_w,
359 y0,
360 x0,
361 roi_h,
362 roi_w,
363 num_protos,
364 );
365 Ok(seg_from_roi(
366 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
367 ))
368 })
369 .collect()
370 }
371 DType::F16 => {
372 let t = proto_data.protos.as_f16().expect("dtype matched F16");
373 let m = t.map()?;
374 let protos_slice = m.as_slice();
375 detect
376 .par_iter()
377 .enumerate()
378 .map(|(i, det)| {
379 let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
380 let (x0, y0, x1, y1, roi_w, roi_h) =
381 bbox_to_proto_roi(det, proto_w, proto_h);
382 let mask = fused_dot_sigmoid_f16_slice(
383 protos_slice,
384 coeff,
385 proto_h,
386 proto_w,
387 y0,
388 x0,
389 roi_h,
390 roi_w,
391 num_protos,
392 );
393 Ok(seg_from_roi(
394 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
395 ))
396 })
397 .collect()
398 }
399 other => Err(crate::Error::InvalidShape(format!(
400 "proto tensor dtype {other:?} not supported"
401 ))),
402 }
403 }
404
405 pub fn materialize_scaled_segmentations(
418 &self,
419 detect: &[crate::DetectBox],
420 proto_data: &crate::ProtoData,
421 letterbox: Option<[f32; 4]>,
422 width: u32,
423 height: u32,
424 ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
425 use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
426
427 if detect.is_empty() {
428 return Ok(Vec::new());
429 }
430 if width == 0 || height == 0 {
431 return Err(crate::Error::InvalidShape(
432 "Scaled mask width/height must be positive".into(),
433 ));
434 }
435 let proto_shape = proto_data.protos.shape();
436 if proto_shape.len() != 3 {
437 return Err(crate::Error::InvalidShape(format!(
438 "protos tensor must be rank-3, got {proto_shape:?}"
439 )));
440 }
441 let (proto_h, proto_w, num_protos) = (proto_shape[0], proto_shape[1], proto_shape[2]);
442 let coeff_shape = proto_data.mask_coefficients.shape();
443 if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
444 return Err(crate::Error::InvalidShape(format!(
445 "mask_coefficients shape {coeff_shape:?} incompatible with protos \
446 {proto_shape:?}"
447 )));
448 }
449 if coeff_shape[0] == 0 {
450 return Ok(Vec::new());
451 }
452 if coeff_shape[0] != detect.len() {
453 return Err(crate::Error::Internal(format!(
454 "mask_coefficients rows {} != detection count {}",
455 coeff_shape[0],
456 detect.len()
457 )));
458 }
459
460 let coeff_f32: Vec<f32> = match proto_data.mask_coefficients.dtype() {
462 DType::F32 => {
463 let t = proto_data.mask_coefficients.as_f32().expect("F32");
464 let m = t.map()?;
465 m.as_slice().to_vec()
466 }
467 DType::F16 => {
468 let t = proto_data.mask_coefficients.as_f16().expect("F16");
469 let m = t.map()?;
470 m.as_slice().iter().map(|v| v.to_f32()).collect()
471 }
472 other => {
473 return Err(crate::Error::InvalidShape(format!(
474 "mask_coefficients dtype {other:?} not supported"
475 )));
476 }
477 };
478
479 match proto_data.protos.dtype() {
480 DType::F32 => {
481 let t = proto_data.protos.as_f32().expect("F32");
482 let m = t.map()?;
483 scaled_segmentations_f32_slice(
484 detect,
485 &coeff_f32,
486 m.as_slice(),
487 proto_h,
488 proto_w,
489 num_protos,
490 letterbox,
491 width,
492 height,
493 )
494 }
495 DType::F16 => {
496 let t = proto_data.protos.as_f16().expect("F16");
497 let m = t.map()?;
498 scaled_segmentations_f16_slice(
499 detect,
500 &coeff_f32,
501 m.as_slice(),
502 proto_h,
503 proto_w,
504 num_protos,
505 letterbox,
506 width,
507 height,
508 )
509 }
510 DType::I8 => {
511 let t = proto_data.protos.as_i8().expect("I8");
512 let m = t.map()?;
513 let quant = t.quantization().ok_or_else(|| {
514 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
515 })?;
516 scaled_segmentations_i8_slice(
517 detect,
518 &coeff_f32,
519 m.as_slice(),
520 proto_h,
521 proto_w,
522 num_protos,
523 quant,
524 letterbox,
525 width,
526 height,
527 )
528 }
529 other => Err(crate::Error::InvalidShape(format!(
530 "proto tensor dtype {other:?} not supported"
531 ))),
532 }
533 }
534}
535
536fn bbox_to_proto_roi(
554 det: &DetectBox,
555 proto_w: usize,
556 proto_h: usize,
557) -> (usize, usize, usize, usize, usize, usize) {
558 let bbox = det.bbox.to_canonical();
559 let xmin = bbox.xmin.clamp(0.0, 1.0);
560 let ymin = bbox.ymin.clamp(0.0, 1.0);
561 let xmax = bbox.xmax.clamp(0.0, 1.0);
562 let ymax = bbox.ymax.clamp(0.0, 1.0);
563 let x0 = ((xmin * proto_w as f32) as usize).min(proto_w.saturating_sub(1));
564 let y0 = ((ymin * proto_h as f32) as usize).min(proto_h.saturating_sub(1));
565 let x1 = ((xmax * proto_w as f32).ceil() as usize).min(proto_w);
566 let y1 = ((ymax * proto_h as f32).ceil() as usize).min(proto_h);
567 let roi_w = x1.saturating_sub(x0).max(1);
568 let roi_h = y1.saturating_sub(y0).max(1);
569 (x0, y0, x1, y1, roi_w, roi_h)
570}
571
572#[allow(clippy::too_many_arguments)]
576fn seg_from_roi(
577 mask: ndarray::Array3<u8>,
578 x0: usize,
579 y0: usize,
580 x1: usize,
581 y1: usize,
582 proto_w: usize,
583 proto_h: usize,
584 lx0: f32,
585 inv_lw: f32,
586 ly0: f32,
587 inv_lh: f32,
588) -> edgefirst_decoder::Segmentation {
589 let seg_xmin = ((x0 as f32 / proto_w as f32) - lx0) * inv_lw;
590 let seg_ymin = ((y0 as f32 / proto_h as f32) - ly0) * inv_lh;
591 let seg_xmax = ((x1 as f32 / proto_w as f32) - lx0) * inv_lw;
592 let seg_ymax = ((y1 as f32 / proto_h as f32) - ly0) * inv_lh;
593 edgefirst_decoder::Segmentation {
594 xmin: seg_xmin.clamp(0.0, 1.0),
595 ymin: seg_ymin.clamp(0.0, 1.0),
596 xmax: seg_xmax.clamp(0.0, 1.0),
597 ymax: seg_ymax.clamp(0.0, 1.0),
598 segmentation: mask,
599 }
600}
601
602#[allow(clippy::too_many_arguments)]
603fn fused_dequant_dot_sigmoid_i8_slice(
604 protos: &[i8],
605 coeff: &[f32],
606 quant: &edgefirst_tensor::Quantization,
607 _proto_h: usize,
608 proto_w: usize,
609 y0: usize,
610 x0: usize,
611 roi_h: usize,
612 roi_w: usize,
613 num_protos: usize,
614) -> crate::Result<ndarray::Array3<u8>> {
615 use edgefirst_tensor::QuantMode;
616 let stride_y = proto_w * num_protos;
617 let mut stack_scratch = [0.0_f32; 64];
622 let mut heap_scratch: Vec<f32>;
623 let scaled_coeff: &mut [f32] = if num_protos <= stack_scratch.len() {
624 &mut stack_scratch[..num_protos]
625 } else {
626 heap_scratch = vec![0.0_f32; num_protos];
627 heap_scratch.as_mut_slice()
628 };
629 let zp_offset: f32;
630 match quant.mode() {
631 QuantMode::PerTensorSymmetric { scale } => {
632 for k in 0..num_protos {
633 scaled_coeff[k] = coeff[k] * scale;
634 }
635 zp_offset = 0.0;
636 }
637 QuantMode::PerTensor { scale, zero_point } => {
638 for k in 0..num_protos {
639 scaled_coeff[k] = coeff[k] * scale;
640 }
641 zp_offset = zero_point as f32 * scaled_coeff.iter().take(num_protos).sum::<f32>();
642 }
643 QuantMode::PerChannelSymmetric { scales, axis } => {
644 if axis != 2 {
645 return Err(crate::Error::NotSupported(format!(
646 "per-channel quantization on axis {axis} not supported \
647 (only channel axis 2 is implemented on this kernel)"
648 )));
649 }
650 for k in 0..num_protos {
651 scaled_coeff[k] = coeff[k] * scales[k];
652 }
653 zp_offset = 0.0;
654 }
655 QuantMode::PerChannel {
656 scales,
657 zero_points,
658 axis,
659 } => {
660 if axis != 2 {
661 return Err(crate::Error::NotSupported(format!(
662 "per-channel quantization on axis {axis} not supported \
663 (only channel axis 2 is implemented on this kernel)"
664 )));
665 }
666 for k in 0..num_protos {
667 scaled_coeff[k] = coeff[k] * scales[k];
668 }
669 zp_offset = (0..num_protos)
670 .map(|k| scaled_coeff[k] * zero_points[k] as f32)
671 .sum();
672 }
673 }
674
675 let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
676 for y in 0..roi_h {
677 for x in 0..roi_w {
678 let base = (y0 + y) * stride_y + (x0 + x) * num_protos;
679 let mut acc = 0.0_f32;
680 let mut k = 0;
681 let chunks = num_protos / 4;
682 for _ in 0..chunks {
683 let p0 = protos[base + k] as f32;
684 let p1 = protos[base + k + 1] as f32;
685 let p2 = protos[base + k + 2] as f32;
686 let p3 = protos[base + k + 3] as f32;
687 acc += scaled_coeff[k] * p0
688 + scaled_coeff[k + 1] * p1
689 + scaled_coeff[k + 2] * p2
690 + scaled_coeff[k + 3] * p3;
691 k += 4;
692 }
693 while k < num_protos {
694 acc += scaled_coeff[k] * protos[base + k] as f32;
695 k += 1;
696 }
697 acc -= zp_offset;
698 let sigmoid = fast_sigmoid(acc);
699 mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
700 }
701 }
702 Ok(mask)
703}
704
705#[allow(clippy::too_many_arguments)]
706fn fused_dot_sigmoid_f32_slice(
707 protos: &[f32],
708 coeff: &[f32],
709 _proto_h: usize,
710 proto_w: usize,
711 y0: usize,
712 x0: usize,
713 roi_h: usize,
714 roi_w: usize,
715 num_protos: usize,
716) -> ndarray::Array3<u8> {
717 let stride_y = proto_w * num_protos;
718 let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
719 for y in 0..roi_h {
720 for x in 0..roi_w {
721 let base = (y0 + y) * stride_y + (x0 + x) * num_protos;
722 let mut acc = 0.0_f32;
723 let mut k = 0;
724 let chunks = num_protos / 4;
725 for _ in 0..chunks {
726 acc += coeff[k] * protos[base + k]
727 + coeff[k + 1] * protos[base + k + 1]
728 + coeff[k + 2] * protos[base + k + 2]
729 + coeff[k + 3] * protos[base + k + 3];
730 k += 4;
731 }
732 while k < num_protos {
733 acc += coeff[k] * protos[base + k];
734 k += 1;
735 }
736 let sigmoid = fast_sigmoid(acc);
737 mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
738 }
739 }
740 mask
741}
742
743#[allow(clippy::too_many_arguments)]
765fn fused_dot_sigmoid_f16_slice(
766 protos: &[half::f16],
767 coeff: &[f32],
768 proto_h: usize,
769 proto_w: usize,
770 y0: usize,
771 x0: usize,
772 roi_h: usize,
773 roi_w: usize,
774 num_protos: usize,
775) -> ndarray::Array3<u8> {
776 #[cfg(all(
777 target_arch = "x86_64",
778 target_feature = "f16c",
779 target_feature = "fma"
780 ))]
781 {
782 unsafe {
785 fused_dot_sigmoid_f16_slice_f16c(
786 protos, coeff, proto_h, proto_w, y0, x0, roi_h, roi_w, num_protos,
787 )
788 }
789 }
790 #[cfg(not(all(
791 target_arch = "x86_64",
792 target_feature = "f16c",
793 target_feature = "fma"
794 )))]
795 {
796 let _ = proto_h;
797 fused_dot_sigmoid_f16_slice_scalar(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
798 }
799}
800
801#[allow(clippy::too_many_arguments, dead_code)]
807fn fused_dot_sigmoid_f16_slice_scalar(
808 protos: &[half::f16],
809 coeff: &[f32],
810 proto_w: usize,
811 y0: usize,
812 x0: usize,
813 roi_h: usize,
814 roi_w: usize,
815 num_protos: usize,
816) -> ndarray::Array3<u8> {
817 let stride_y = proto_w * num_protos;
818 let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
819 for y in 0..roi_h {
820 for x in 0..roi_w {
821 let base = (y0 + y) * stride_y + (x0 + x) * num_protos;
822 let mut acc = 0.0_f32;
823 let mut k = 0;
824 let chunks = num_protos / 4;
825 for _ in 0..chunks {
826 let p0 = protos[base + k].to_f32();
827 let p1 = protos[base + k + 1].to_f32();
828 let p2 = protos[base + k + 2].to_f32();
829 let p3 = protos[base + k + 3].to_f32();
830 acc += coeff[k] * p0 + coeff[k + 1] * p1 + coeff[k + 2] * p2 + coeff[k + 3] * p3;
831 k += 4;
832 }
833 while k < num_protos {
834 acc += coeff[k] * protos[base + k].to_f32();
835 k += 1;
836 }
837 let sigmoid = fast_sigmoid(acc);
838 mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
839 }
840 }
841 mask
842}
843
844#[cfg(all(
856 target_arch = "x86_64",
857 target_feature = "f16c",
858 target_feature = "fma"
859))]
860#[allow(clippy::too_many_arguments)]
861#[target_feature(enable = "f16c,fma,avx")]
862unsafe fn fused_dot_sigmoid_f16_slice_f16c(
863 protos: &[half::f16],
864 coeff: &[f32],
865 _proto_h: usize,
866 proto_w: usize,
867 y0: usize,
868 x0: usize,
869 roi_h: usize,
870 roi_w: usize,
871 num_protos: usize,
872) -> ndarray::Array3<u8> {
873 use core::arch::x86_64::{
874 _mm256_castps256_ps128, _mm256_cvtph_ps, _mm256_extractf128_ps, _mm256_fmadd_ps,
875 _mm256_loadu_ps, _mm256_setzero_ps, _mm_add_ps, _mm_cvtss_f32, _mm_hadd_ps,
876 _mm_loadu_si128,
877 };
878
879 let stride_y = proto_w * num_protos;
880 let chunks8 = num_protos / 8;
881 let tail = num_protos % 8;
882 let mut mask = ndarray::Array3::<u8>::zeros((roi_h, roi_w, 1));
883
884 for y in 0..roi_h {
885 for x in 0..roi_w {
886 let base = (y0 + y) * stride_y + (x0 + x) * num_protos;
887 let mut acc_v = _mm256_setzero_ps();
888 let mut k = 0;
889 for _ in 0..chunks8 {
890 let p_ptr = protos
892 .as_ptr()
893 .add(base + k)
894 .cast::<core::arch::x86_64::__m128i>();
895 let raw = _mm_loadu_si128(p_ptr);
896 let widened = _mm256_cvtph_ps(raw);
897 let coeffs_v = _mm256_loadu_ps(coeff.as_ptr().add(k));
898 acc_v = _mm256_fmadd_ps(coeffs_v, widened, acc_v);
899 k += 8;
900 }
901 let lo = _mm256_castps256_ps128(acc_v);
903 let hi = _mm256_extractf128_ps::<1>(acc_v);
904 let sum4 = _mm_add_ps(lo, hi);
905 let sum2 = _mm_hadd_ps(sum4, sum4);
906 let sum1 = _mm_hadd_ps(sum2, sum2);
907 let mut acc = _mm_cvtss_f32(sum1);
908
909 while k < num_protos && k - chunks8 * 8 < tail {
911 acc += coeff[k] * protos[base + k].to_f32();
912 k += 1;
913 }
914
915 let sigmoid = fast_sigmoid(acc);
916 mask[[y, x, 0]] = (sigmoid * 255.0 + 0.5) as u8;
917 }
918 }
919 mask
920}
921
922#[allow(clippy::too_many_arguments)]
923fn scaled_segmentations_f32_slice(
924 detect: &[crate::DetectBox],
925 coeff_all: &[f32],
926 protos: &[f32],
927 proto_h: usize,
928 proto_w: usize,
929 num_protos: usize,
930 letterbox: Option<[f32; 4]>,
931 width: u32,
932 height: u32,
933) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
934 scaled_run(
935 detect,
936 coeff_all,
937 protos,
938 proto_h,
939 proto_w,
940 num_protos,
941 letterbox,
942 width,
943 height,
944 1.0,
945 |p, _| *p,
946 )
947}
948
949#[allow(clippy::too_many_arguments)]
950fn scaled_segmentations_f16_slice(
951 detect: &[crate::DetectBox],
952 coeff_all: &[f32],
953 protos: &[half::f16],
954 proto_h: usize,
955 proto_w: usize,
956 num_protos: usize,
957 letterbox: Option<[f32; 4]>,
958 width: u32,
959 height: u32,
960) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
961 scaled_run(
962 detect,
963 coeff_all,
964 protos,
965 proto_h,
966 proto_w,
967 num_protos,
968 letterbox,
969 width,
970 height,
971 1.0,
972 |p: &half::f16, _| p.to_f32(),
973 )
974}
975
976#[allow(clippy::too_many_arguments)]
977fn scaled_segmentations_i8_slice(
978 detect: &[crate::DetectBox],
979 coeff_all: &[f32],
980 protos: &[i8],
981 proto_h: usize,
982 proto_w: usize,
983 num_protos: usize,
984 quant: &edgefirst_tensor::Quantization,
985 letterbox: Option<[f32; 4]>,
986 width: u32,
987 height: u32,
988) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
989 use edgefirst_tensor::QuantMode;
990 let (scale, zp) = match quant.mode() {
994 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
995 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
996 QuantMode::PerChannel { axis, .. } | QuantMode::PerChannelSymmetric { axis, .. } => {
997 return Err(crate::Error::NotSupported(format!(
998 "per-channel quantization (axis={axis}) on scaled seg path \
999 not yet supported"
1000 )));
1001 }
1002 };
1003 scaled_run(
1004 detect,
1005 coeff_all,
1006 protos,
1007 proto_h,
1008 proto_w,
1009 num_protos,
1010 letterbox,
1011 width,
1012 height,
1013 scale,
1014 move |p: &i8, _| *p as f32 - zp,
1015 )
1016}
1017
1018#[allow(clippy::too_many_arguments)]
1019fn scaled_run<P: Copy + Sync>(
1020 detect: &[crate::DetectBox],
1021 coeff_all: &[f32],
1022 protos: &[P],
1023 proto_h: usize,
1024 proto_w: usize,
1025 num_protos: usize,
1026 letterbox: Option<[f32; 4]>,
1027 width: u32,
1028 height: u32,
1029 acc_scale: f32,
1030 load_f32: impl Fn(&P, f32) -> f32 + Copy + Sync,
1031) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1032 let (lx0, lw, ly0, lh) = match letterbox {
1033 Some([lx0, ly0, lx1, ly1]) => {
1034 let lw = (lx1 - lx0).max(f32::EPSILON);
1035 let lh = (ly1 - ly0).max(f32::EPSILON);
1036 (lx0, lw, ly0, lh)
1037 }
1038 None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
1039 };
1040 let out_w = width as usize;
1041 let out_h = height as usize;
1042 let stride_y = proto_w * num_protos;
1043
1044 detect
1066 .par_iter()
1067 .enumerate()
1068 .map(|(i, det)| {
1069 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1070 let bbox = det.bbox.to_canonical();
1071 let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
1072 let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
1073 let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
1074 let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
1075 let px0 = (xmin * out_w as f32).round() as usize;
1076 let py0 = (ymin * out_h as f32).round() as usize;
1077 let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
1078 let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
1079 let bbox_w = px1.saturating_sub(px0).max(1);
1080 let bbox_h = py1.saturating_sub(py0).max(1);
1081
1082 let sample_x_at = |px: f32| -> f32 {
1087 let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
1088 model_x_norm * proto_w as f32 - 0.5
1089 };
1090 let sample_y_at = |py: f32| -> f32 {
1091 let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
1092 model_y_norm * proto_h as f32 - 0.5
1093 };
1094 let s_x_min = sample_x_at(px0 as f32);
1095 let s_x_max = sample_x_at((px1 as f32) - 1.0);
1096 let s_y_min = sample_y_at(py0 as f32);
1097 let s_y_max = sample_y_at((py1 as f32) - 1.0);
1098 let proto_x0 = (s_x_min.floor() as isize)
1102 .max(0)
1103 .min(proto_w.saturating_sub(1) as isize) as usize;
1104 let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
1105 let proto_y0 = (s_y_min.floor() as isize)
1106 .max(0)
1107 .min(proto_h.saturating_sub(1) as isize) as usize;
1108 let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
1109 let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
1110 let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
1111
1112 let mut logits = vec![0.0_f32; roi_h * roi_w];
1115 for ly_idx in 0..roi_h {
1116 let py = proto_y0 + ly_idx;
1117 let row_base = py * stride_y + proto_x0 * num_protos;
1118 for lx_idx in 0..roi_w {
1119 let pix_base = row_base + lx_idx * num_protos;
1120 let mut acc = 0.0_f32;
1121 for k in 0..num_protos {
1122 acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0);
1123 }
1124 logits[ly_idx * roi_w + lx_idx] = acc_scale * acc;
1125 }
1126 }
1127
1128 let mut tile = ndarray::Array3::<u8>::zeros((bbox_h, bbox_w, 1));
1130 for yi in 0..bbox_h {
1131 let py_o = (py0 + yi) as f32;
1132 let sample_y = sample_y_at(py_o) - proto_y0 as f32;
1133 let y_floor = sample_y.floor();
1134 let y_lo = (y_floor as isize)
1135 .max(0)
1136 .min(roi_h.saturating_sub(1) as isize) as usize;
1137 let y_hi = (y_lo + 1).min(roi_h - 1);
1138 let y_frac = (sample_y - y_floor).clamp(0.0, 1.0);
1139 let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
1140 let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
1141 for xi in 0..bbox_w {
1142 let px_o = (px0 + xi) as f32;
1143 let sample_x = sample_x_at(px_o) - proto_x0 as f32;
1144 let x_floor = sample_x.floor();
1145 let x_lo = (x_floor as isize)
1146 .max(0)
1147 .min(roi_w.saturating_sub(1) as isize)
1148 as usize;
1149 let x_hi = (x_lo + 1).min(roi_w - 1);
1150 let x_frac = (sample_x - x_floor).clamp(0.0, 1.0);
1151 let l00 = row_lo[x_lo];
1153 let l01 = row_lo[x_hi];
1154 let l10 = row_hi[x_lo];
1155 let l11 = row_hi[x_hi];
1156 let l0 = l00 + (l01 - l00) * x_frac;
1157 let l1 = l10 + (l11 - l10) * x_frac;
1158 let logit = l0 + (l1 - l0) * y_frac;
1159 let sigmoid = fast_sigmoid(logit);
1160 tile[[yi, xi, 0]] = if sigmoid > 0.5 { 255 } else { 0 };
1161 }
1162 }
1163 Ok(edgefirst_decoder::Segmentation {
1164 xmin,
1165 ymin,
1166 xmax,
1167 ymax,
1168 segmentation: tile,
1169 })
1170 })
1171 .collect()
1172}
1173
1174fn fast_sigmoid(x: f32) -> f32 {
1175 if x >= 16.0 {
1176 return 1.0;
1177 }
1178 if x <= -16.0 {
1179 return 0.0;
1180 }
1181 const A: f32 = (1u32 << 23) as f32; const B: f32 = A * std::f32::consts::LOG2_E; const C: u32 = 127 << 23; let neg_x = -x;
1187 let bits = (B * neg_x) as i32 + C as i32;
1188 let exp_neg_x = f32::from_bits(bits as u32);
1189 1.0 / (1.0 + exp_neg_x)
1190}