1use super::CPUProcessor;
5use crate::Result;
6use edgefirst_decoder::{DetectBox, Segmentation};
7use ndarray::Axis;
8use rayon::prelude::*;
9
10impl CPUProcessor {
11 #[allow(clippy::too_many_arguments)]
12 pub(super) fn render_modelpack_segmentation(
13 &mut self,
14 dst_w: usize,
15 dst_h: usize,
16 dst_rs: usize,
17 dst_c: usize,
18 dst_slice: &mut [u8],
19 segmentation: &Segmentation,
20 opacity: f32,
21 ) -> Result<()> {
22 use ndarray_stats::QuantileExt;
23
24 let seg = &segmentation.segmentation;
25 let [seg_height, seg_width, seg_classes] = *seg.shape() else {
26 unreachable!("Array3 did not have [usize; 3] as shape");
27 };
28 let start_y = (dst_h as f32 * segmentation.ymin).round();
29 let end_y = (dst_h as f32 * segmentation.ymax).round();
30 let start_x = (dst_w as f32 * segmentation.xmin).round();
31 let end_x = (dst_w as f32 * segmentation.xmax).round();
32
33 let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
34 let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
35
36 let start_x_u = (start_x as usize).min(dst_w);
37 let start_y_u = (start_y as usize).min(dst_h);
38 let end_x_u = (end_x as usize).min(dst_w);
39 let end_y_u = (end_y as usize).min(dst_h);
40
41 let argmax = seg.map_axis(Axis(2), |r| r.argmax().unwrap());
42 let get_value_at_nearest = |x: f32, y: f32| -> usize {
43 let x = x.round() as usize;
44 let y = y.round() as usize;
45 argmax
46 .get([y.min(seg_height - 1), x.min(seg_width - 1)])
47 .copied()
48 .unwrap_or(0)
49 };
50
51 for y in start_y_u..end_y_u {
52 for x in start_x_u..end_x_u {
53 let seg_x = (x as f32 - start_x) * scale_x;
54 let seg_y = (y as f32 - start_y) * scale_y;
55 let label = get_value_at_nearest(seg_x, seg_y);
56
57 if label == seg_classes - 1 {
58 continue;
59 }
60
61 let color = self.colors[label % self.colors.len()];
62
63 let alpha = if opacity == 1.0 {
64 color[3] as u16
65 } else {
66 (color[3] as f32 * opacity).round() as u16
67 };
68
69 let dst_index = (y * dst_rs) + (x * dst_c);
70 for c in 0..3 {
71 dst_slice[dst_index + c] = ((color[c] as u16 * alpha
72 + dst_slice[dst_index + c] as u16 * (255 - alpha))
73 / 255) as u8;
74 }
75 }
76 }
77
78 Ok(())
79 }
80
81 #[allow(clippy::too_many_arguments)]
82 pub(super) fn render_yolo_segmentation(
83 &mut self,
84 dst_w: usize,
85 dst_h: usize,
86 dst_rs: usize,
87 dst_c: usize,
88 dst_slice: &mut [u8],
89 segmentation: &Segmentation,
90 class: usize,
91 opacity: f32,
92 ) -> Result<()> {
93 let seg = &segmentation.segmentation;
94 let [seg_height, seg_width, classes] = *seg.shape() else {
95 unreachable!("Array3 did not have [usize;3] as shape");
96 };
97 debug_assert_eq!(classes, 1);
98
99 let start_y = (dst_h as f32 * segmentation.ymin).round();
100 let end_y = (dst_h as f32 * segmentation.ymax).round();
101 let start_x = (dst_w as f32 * segmentation.xmin).round();
102 let end_x = (dst_w as f32 * segmentation.xmax).round();
103
104 let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
105 let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
106
107 let start_x_u = (start_x as usize).min(dst_w);
108 let start_y_u = (start_y as usize).min(dst_h);
109 let end_x_u = (end_x as usize).min(dst_w);
110 let end_y_u = (end_y as usize).min(dst_h);
111
112 for y in start_y_u..end_y_u {
113 for x in start_x_u..end_x_u {
114 let seg_x = ((x as f32 - start_x) * scale_x) as usize;
115 let seg_y = ((y as f32 - start_y) * scale_y) as usize;
116 let val = *seg.get([seg_y, seg_x, 0]).unwrap_or(&0);
117
118 if val < 127 {
119 continue;
120 }
121
122 let color = self.colors[class % self.colors.len()];
123
124 let alpha = if opacity == 1.0 {
125 color[3] as u16
126 } else {
127 (color[3] as f32 * opacity).round() as u16
128 };
129
130 let dst_index = (y * dst_rs) + (x * dst_c);
131 for c in 0..3 {
132 dst_slice[dst_index + c] = ((color[c] as u16 * alpha
133 + dst_slice[dst_index + c] as u16 * (255 - alpha))
134 / 255) as u8;
135 }
136 }
137 }
138
139 Ok(())
140 }
141
142 #[allow(clippy::too_many_arguments)]
143 pub(super) fn render_box(
144 &mut self,
145 dst_w: usize,
146 dst_h: usize,
147 dst_rs: usize,
148 dst_c: usize,
149 dst_slice: &mut [u8],
150 detect: &[DetectBox],
151 color_mode: crate::ColorMode,
152 ) -> Result<()> {
153 const LINE_THICKNESS: usize = 3;
154
155 for (idx, d) in detect.iter().enumerate() {
156 use edgefirst_decoder::BoundingBox;
157
158 let color_index = color_mode.index(idx, d.label);
159 let [r, g, b, _] = self.colors[color_index % self.colors.len()];
160 let bbox = d.bbox.to_canonical();
161 let bbox = BoundingBox {
162 xmin: bbox.xmin.clamp(0.0, 1.0),
163 ymin: bbox.ymin.clamp(0.0, 1.0),
164 xmax: bbox.xmax.clamp(0.0, 1.0),
165 ymax: bbox.ymax.clamp(0.0, 1.0),
166 };
167 let inner = [
168 ((dst_w - 1) as f32 * bbox.xmin - 0.5).round() as usize,
169 ((dst_h - 1) as f32 * bbox.ymin - 0.5).round() as usize,
170 ((dst_w - 1) as f32 * bbox.xmax + 0.5).round() as usize,
171 ((dst_h - 1) as f32 * bbox.ymax + 0.5).round() as usize,
172 ];
173
174 let outer = [
175 inner[0].saturating_sub(LINE_THICKNESS),
176 inner[1].saturating_sub(LINE_THICKNESS),
177 (inner[2] + LINE_THICKNESS).min(dst_w),
178 (inner[3] + LINE_THICKNESS).min(dst_h),
179 ];
180
181 for y in outer[1] + 1..=inner[1] {
183 for x in outer[0] + 1..outer[2] {
184 let index = (y * dst_rs) + (x * dst_c);
185 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
186 }
187 }
188
189 for y in inner[1]..inner[3] {
191 for x in outer[0] + 1..=inner[0] {
192 let index = (y * dst_rs) + (x * dst_c);
193 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
194 }
195
196 for x in inner[2]..outer[2] {
197 let index = (y * dst_rs) + (x * dst_c);
198 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
199 }
200 }
201
202 for y in inner[3]..outer[3] {
204 for x in outer[0] + 1..outer[2] {
205 let index = (y * dst_rs) + (x * dst_c);
206 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
207 }
208 }
209 }
210 Ok(())
211 }
212
213 pub fn materialize_segmentations(
224 &self,
225 detect: &[crate::DetectBox],
226 proto_data: &crate::ProtoData,
227 letterbox: Option<[f32; 4]>,
228 ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
229 use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
230
231 let _span = tracing::trace_span!(
232 "materialize_masks",
233 mode = "proto",
234 n_detections = detect.len(),
235 )
236 .entered();
237
238 if detect.is_empty() {
239 return Ok(Vec::new());
240 }
241 let proto_shape = proto_data.protos.shape();
242 if proto_shape.len() != 3 {
243 return Err(crate::Error::InvalidShape(format!(
244 "protos tensor must be rank-3, got {proto_shape:?}"
245 )));
246 }
247 let (proto_h, proto_w, num_protos) = match proto_data.layout {
249 edgefirst_decoder::ProtoLayout::Nhwc => {
250 (proto_shape[0], proto_shape[1], proto_shape[2])
251 }
252 edgefirst_decoder::ProtoLayout::Nchw => {
253 (proto_shape[1], proto_shape[2], proto_shape[0])
254 }
255 };
256 let coeff_shape = proto_data.mask_coefficients.shape();
257 if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
258 return Err(crate::Error::InvalidShape(format!(
259 "mask_coefficients shape {coeff_shape:?} incompatible with protos \
260 {proto_shape:?} (expected [N, {num_protos}])"
261 )));
262 }
263 if coeff_shape[0] == 0 {
264 return Ok(Vec::new());
265 }
266 if coeff_shape[0] != detect.len() {
267 return Err(crate::Error::Internal(format!(
268 "mask_coefficients rows {} != detection count {}",
269 coeff_shape[0],
270 detect.len()
271 )));
272 }
273
274 let (lx0, inv_lw, ly0, inv_lh) = match letterbox {
276 Some([lx0, ly0, lx1, ly1]) => {
277 let lw = lx1 - lx0;
278 let lh = ly1 - ly0;
279 (
280 lx0,
281 if lw > 0.0 { 1.0 / lw } else { 1.0 },
282 ly0,
283 if lh > 0.0 { 1.0 / lh } else { 1.0 },
284 )
285 }
286 None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
287 };
288
289 if proto_data.mask_coefficients.dtype() == DType::I8
296 && proto_data.protos.dtype() == DType::I8
297 {
298 let coeff_t = proto_data
299 .mask_coefficients
300 .as_i8()
301 .expect("I8 coefficients");
302 let coeff_m = coeff_t.map()?;
303 let coeff_quant = coeff_t.quantization().ok_or_else(|| {
304 crate::Error::InvalidShape(
305 "I8 mask_coefficients require quantization metadata".into(),
306 )
307 })?;
308 let proto_t = proto_data.protos.as_i8().expect("I8 protos");
309 let proto_m = proto_t.map()?;
310 let proto_quant = proto_t.quantization().ok_or_else(|| {
311 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
312 })?;
313 match proto_segmentations_i8_i8(
314 detect,
315 coeff_m.as_slice(),
316 coeff_quant,
317 proto_m.as_slice(),
318 proto_quant,
319 proto_h,
320 proto_w,
321 num_protos,
322 lx0,
323 inv_lw,
324 ly0,
325 inv_lh,
326 proto_data.layout,
327 ) {
328 Ok(result) => return Ok(result),
329 Err(crate::Error::NotSupported(_)) => {
330 }
333 Err(e) => return Err(e),
334 }
335 }
336
337 if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw {
344 return Err(crate::Error::NotSupported(
345 "NCHW proto layout requires I8 protos and coefficients with per-tensor quantization"
346 .into(),
347 ));
348 }
349 let coeff_f32_storage: Vec<f32>;
350 let coeff_f32_slice: &[f32] = match proto_data.mask_coefficients.dtype() {
351 DType::F32 => {
352 let t = proto_data
353 .mask_coefficients
354 .as_f32()
355 .expect("dtype matched F32");
356 let m = t.map()?;
357 coeff_f32_storage = m.as_slice().to_vec();
358 &coeff_f32_storage[..]
359 }
360 DType::F16 => {
361 let t = proto_data
362 .mask_coefficients
363 .as_f16()
364 .expect("dtype matched F16");
365 let m = t.map()?;
366 coeff_f32_storage = m.as_slice().iter().map(|v| v.to_f32()).collect();
367 &coeff_f32_storage[..]
368 }
369 DType::I8 => {
370 let t = proto_data
371 .mask_coefficients
372 .as_i8()
373 .expect("dtype matched I8");
374 let m = t.map()?;
375 coeff_f32_storage = if let Some(q) = t.quantization() {
376 use edgefirst_tensor::QuantMode;
377 let (scale, zp) = match q.mode() {
378 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
379 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
380 other => {
381 return Err(crate::Error::NotSupported(format!(
382 "I8 mask_coefficients quantization mode {other:?} not supported"
383 )));
384 }
385 };
386 m.as_slice()
387 .iter()
388 .map(|&v| (v as f32 - zp) * scale)
389 .collect()
390 } else {
391 m.as_slice().iter().map(|&v| v as f32).collect()
392 };
393 &coeff_f32_storage[..]
394 }
395 other => {
396 return Err(crate::Error::InvalidShape(format!(
397 "mask_coefficients dtype {other:?} not supported; expected F32, F16, or I8"
398 )));
399 }
400 };
401
402 match proto_data.protos.dtype() {
408 DType::I8 => {
409 let t = proto_data.protos.as_i8().expect("dtype matched I8");
410 let quant = t.quantization().ok_or_else(|| {
411 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
412 })?;
413 let m = t.map()?;
414 let protos_slice = m.as_slice();
415 detect
416 .par_iter()
417 .enumerate()
418 .map(|(i, det)| {
419 let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
420 let (x0, y0, x1, y1, roi_w, roi_h) =
421 bbox_to_proto_roi(det, proto_w, proto_h);
422 let mask = fused_dequant_dot_sign_i8_slice(
423 protos_slice,
424 coeff,
425 quant,
426 proto_h,
427 proto_w,
428 y0,
429 x0,
430 roi_h,
431 roi_w,
432 num_protos,
433 )?;
434 Ok(seg_from_roi(
435 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
436 ))
437 })
438 .collect()
439 }
440 DType::F32 => {
441 let t = proto_data.protos.as_f32().expect("dtype matched F32");
442 let m = t.map()?;
443 let protos_slice = m.as_slice();
444 detect
445 .par_iter()
446 .enumerate()
447 .map(|(i, det)| {
448 let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
449 let (x0, y0, x1, y1, roi_w, roi_h) =
450 bbox_to_proto_roi(det, proto_w, proto_h);
451 let mask = fused_dot_sign_f32_slice(
452 protos_slice,
453 coeff,
454 proto_h,
455 proto_w,
456 y0,
457 x0,
458 roi_h,
459 roi_w,
460 num_protos,
461 );
462 Ok(seg_from_roi(
463 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
464 ))
465 })
466 .collect()
467 }
468 DType::F16 => {
469 let t = proto_data.protos.as_f16().expect("dtype matched F16");
470 let m = t.map()?;
471 let protos_slice = m.as_slice();
472 detect
473 .par_iter()
474 .enumerate()
475 .map(|(i, det)| {
476 let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
477 let (x0, y0, x1, y1, roi_w, roi_h) =
478 bbox_to_proto_roi(det, proto_w, proto_h);
479 let mask = fused_dot_sign_f16_slice(
480 protos_slice,
481 coeff,
482 proto_h,
483 proto_w,
484 y0,
485 x0,
486 roi_h,
487 roi_w,
488 num_protos,
489 );
490 Ok(seg_from_roi(
491 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
492 ))
493 })
494 .collect()
495 }
496 other => Err(crate::Error::InvalidShape(format!(
497 "proto tensor dtype {other:?} not supported"
498 ))),
499 }
500 }
501
502 pub fn materialize_scaled_segmentations(
515 &self,
516 detect: &[crate::DetectBox],
517 proto_data: &crate::ProtoData,
518 letterbox: Option<[f32; 4]>,
519 width: u32,
520 height: u32,
521 ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
522 use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
523
524 let _span = tracing::trace_span!(
525 "materialize_masks",
526 mode = "scaled",
527 n_detections = detect.len(),
528 width,
529 height,
530 )
531 .entered();
532
533 if detect.is_empty() {
534 return Ok(Vec::new());
535 }
536 if width == 0 || height == 0 {
537 return Err(crate::Error::InvalidShape(
538 "Scaled mask width/height must be positive".into(),
539 ));
540 }
541 let proto_shape = proto_data.protos.shape();
542 if proto_shape.len() != 3 {
543 return Err(crate::Error::InvalidShape(format!(
544 "protos tensor must be rank-3, got {proto_shape:?}"
545 )));
546 }
547 let (proto_h, proto_w, num_protos) = match proto_data.layout {
549 edgefirst_decoder::ProtoLayout::Nhwc => {
550 (proto_shape[0], proto_shape[1], proto_shape[2])
551 }
552 edgefirst_decoder::ProtoLayout::Nchw => {
553 (proto_shape[1], proto_shape[2], proto_shape[0])
554 }
555 };
556 let coeff_shape = proto_data.mask_coefficients.shape();
557 if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
558 return Err(crate::Error::InvalidShape(format!(
559 "mask_coefficients shape {coeff_shape:?} incompatible with protos \
560 {proto_shape:?}"
561 )));
562 }
563 if coeff_shape[0] == 0 {
564 return Ok(Vec::new());
565 }
566 if coeff_shape[0] != detect.len() {
567 return Err(crate::Error::Internal(format!(
568 "mask_coefficients rows {} != detection count {}",
569 coeff_shape[0],
570 detect.len()
571 )));
572 }
573
574 if proto_data.mask_coefficients.dtype() == DType::I8
580 && proto_data.protos.dtype() == DType::I8
581 {
582 let coeff_t = proto_data
583 .mask_coefficients
584 .as_i8()
585 .expect("I8 coefficients");
586 let coeff_m = coeff_t.map()?;
587 let coeff_quant = coeff_t.quantization().ok_or_else(|| {
588 crate::Error::InvalidShape(
589 "I8 mask_coefficients require quantization metadata".into(),
590 )
591 })?;
592 let proto_t = proto_data.protos.as_i8().expect("I8 protos");
593 let proto_m = proto_t.map()?;
594 let proto_quant = proto_t.quantization().ok_or_else(|| {
595 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
596 })?;
597 match scaled_segmentations_i8_i8(
598 detect,
599 coeff_m.as_slice(),
600 coeff_quant,
601 proto_m.as_slice(),
602 proto_quant,
603 proto_h,
604 proto_w,
605 num_protos,
606 letterbox,
607 width,
608 height,
609 proto_data.layout,
610 ) {
611 Ok(result) => return Ok(result),
612 Err(crate::Error::NotSupported(_)) => {
613 }
616 Err(e) => return Err(e),
617 }
618 }
619
620 if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw {
624 return Err(crate::Error::NotSupported(
625 "NCHW proto layout requires I8 protos and coefficients with per-tensor quantization"
626 .into(),
627 ));
628 }
629 let coeff_f32: Vec<f32> = match proto_data.mask_coefficients.dtype() {
630 DType::F32 => {
631 let t = proto_data.mask_coefficients.as_f32().expect("F32");
632 let m = t.map()?;
633 m.as_slice().to_vec()
634 }
635 DType::F16 => {
636 let t = proto_data.mask_coefficients.as_f16().expect("F16");
637 let m = t.map()?;
638 m.as_slice().iter().map(|v| v.to_f32()).collect()
639 }
640 DType::I8 => {
641 let t = proto_data.mask_coefficients.as_i8().expect("I8");
643 let m = t.map()?;
644 let q = t.quantization().ok_or_else(|| {
645 crate::Error::InvalidShape(
646 "I8 mask_coefficients require quantization metadata".into(),
647 )
648 })?;
649 use edgefirst_tensor::QuantMode;
650 let (scale, zp) = match q.mode() {
651 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
652 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
653 _ => {
654 return Err(crate::Error::NotSupported(
655 "per-channel mask_coefficients not supported".into(),
656 ))
657 }
658 };
659 m.as_slice()
660 .iter()
661 .map(|&v| (v as f32 - zp) * scale)
662 .collect()
663 }
664 other => {
665 return Err(crate::Error::InvalidShape(format!(
666 "mask_coefficients dtype {other:?} not supported"
667 )));
668 }
669 };
670
671 match proto_data.protos.dtype() {
672 DType::F32 => {
673 let t = proto_data.protos.as_f32().expect("F32");
674 let m = t.map()?;
675 scaled_segmentations_f32_slice(
676 detect,
677 &coeff_f32,
678 m.as_slice(),
679 proto_h,
680 proto_w,
681 num_protos,
682 letterbox,
683 width,
684 height,
685 )
686 }
687 DType::F16 => {
688 let t = proto_data.protos.as_f16().expect("F16");
689 let m = t.map()?;
690 scaled_segmentations_f16_slice(
691 detect,
692 &coeff_f32,
693 m.as_slice(),
694 proto_h,
695 proto_w,
696 num_protos,
697 letterbox,
698 width,
699 height,
700 )
701 }
702 DType::I8 => {
703 let t = proto_data.protos.as_i8().expect("I8");
704 let m = t.map()?;
705 let quant = t.quantization().ok_or_else(|| {
706 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
707 })?;
708 scaled_segmentations_i8_slice(
709 detect,
710 &coeff_f32,
711 m.as_slice(),
712 proto_h,
713 proto_w,
714 num_protos,
715 quant,
716 letterbox,
717 width,
718 height,
719 )
720 }
721 other => Err(crate::Error::InvalidShape(format!(
722 "proto tensor dtype {other:?} not supported"
723 ))),
724 }
725 }
726}
727
728fn bbox_to_proto_roi(
746 det: &DetectBox,
747 proto_w: usize,
748 proto_h: usize,
749) -> (usize, usize, usize, usize, usize, usize) {
750 let bbox = det.bbox.to_canonical();
751 let xmin = bbox.xmin.clamp(0.0, 1.0);
752 let ymin = bbox.ymin.clamp(0.0, 1.0);
753 let xmax = bbox.xmax.clamp(0.0, 1.0);
754 let ymax = bbox.ymax.clamp(0.0, 1.0);
755 let x0 = ((xmin * proto_w as f32) as usize).min(proto_w.saturating_sub(1));
756 let y0 = ((ymin * proto_h as f32) as usize).min(proto_h.saturating_sub(1));
757 let x1 = ((xmax * proto_w as f32).ceil() as usize).min(proto_w);
758 let y1 = ((ymax * proto_h as f32).ceil() as usize).min(proto_h);
759 let roi_w = x1.saturating_sub(x0).max(1);
760 let roi_h = y1.saturating_sub(y0).max(1);
761 (x0, y0, x1, y1, roi_w, roi_h)
762}
763
764#[allow(clippy::too_many_arguments)]
768fn seg_from_roi(
769 mask: ndarray::Array3<u8>,
770 x0: usize,
771 y0: usize,
772 x1: usize,
773 y1: usize,
774 proto_w: usize,
775 proto_h: usize,
776 lx0: f32,
777 inv_lw: f32,
778 ly0: f32,
779 inv_lh: f32,
780) -> edgefirst_decoder::Segmentation {
781 let seg_xmin = ((x0 as f32 / proto_w as f32) - lx0) * inv_lw;
782 let seg_ymin = ((y0 as f32 / proto_h as f32) - ly0) * inv_lh;
783 let seg_xmax = ((x1 as f32 / proto_w as f32) - lx0) * inv_lw;
784 let seg_ymax = ((y1 as f32 / proto_h as f32) - ly0) * inv_lh;
785 edgefirst_decoder::Segmentation {
786 xmin: seg_xmin.clamp(0.0, 1.0),
787 ymin: seg_ymin.clamp(0.0, 1.0),
788 xmax: seg_xmax.clamp(0.0, 1.0),
789 ymax: seg_ymax.clamp(0.0, 1.0),
790 segmentation: mask,
791 }
792}
793
794#[allow(clippy::too_many_arguments)]
810fn proto_segmentations_i8_i8(
811 detect: &[crate::DetectBox],
812 coeff_all: &[i8],
813 coeff_quant: &edgefirst_tensor::Quantization,
814 protos: &[i8],
815 proto_quant: &edgefirst_tensor::Quantization,
816 proto_h: usize,
817 proto_w: usize,
818 num_protos: usize,
819 lx0: f32,
820 inv_lw: f32,
821 ly0: f32,
822 inv_lh: f32,
823 layout: edgefirst_decoder::ProtoLayout,
824) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
825 use edgefirst_tensor::QuantMode;
826
827 let _span = tracing::trace_span!(
828 "mask_i8_fastpath",
829 n = detect.len(),
830 proto_h,
831 proto_w,
832 num_protos,
833 ?layout,
834 )
835 .entered();
836
837 let zp_c: i32 = match coeff_quant.mode() {
838 QuantMode::PerTensor { zero_point, .. } => zero_point,
839 QuantMode::PerTensorSymmetric { .. } => 0,
840 _ => {
841 return Err(crate::Error::NotSupported(
842 "per-channel coeff quantization not supported on proto-res i8 path".into(),
843 ))
844 }
845 };
846 let zp_p: i32 = match proto_quant.mode() {
847 QuantMode::PerTensor { zero_point, .. } => zero_point,
848 QuantMode::PerTensorSymmetric { .. } => 0,
849 _ => {
850 return Err(crate::Error::NotSupported(
851 "per-channel proto quantization not supported on proto-res i8 path".into(),
852 ))
853 }
854 };
855
856 let hw = proto_h * proto_w;
857
858 let proto_sums: Vec<i32> = if zp_c != 0 {
860 match layout {
861 edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
862 .map(|px_idx| {
863 let base = px_idx * num_protos;
864 protos[base..base + num_protos]
865 .iter()
866 .map(|&v| v as i32)
867 .sum()
868 })
869 .collect(),
870 edgefirst_decoder::ProtoLayout::Nchw => {
871 let mut sums = vec![0i32; hw];
872 for c in 0..num_protos {
873 let plane = &protos[c * hw..];
874 for (px, s) in sums.iter_mut().enumerate() {
875 *s += plane[px] as i32;
876 }
877 }
878 sums
879 }
880 }
881 } else {
882 Vec::new()
883 };
884
885 #[cfg(target_arch = "aarch64")]
886 let use_dotprod = std::arch::is_aarch64_feature_detected!("dotprod");
887
888 detect
889 .par_iter()
890 .enumerate()
891 .map(|(i, det)| {
892 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
893 let (x0, y0, x1, y1, roi_w, roi_h) = bbox_to_proto_roi(det, proto_w, proto_h);
894
895 let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
897 let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
898
899 let mut mask_buf = vec![0u8; roi_h * roi_w];
900
901 match layout {
902 edgefirst_decoder::ProtoLayout::Nhwc => {
903 let stride_y = proto_w * num_protos;
904 #[cfg(target_arch = "aarch64")]
905 {
906 if use_dotprod {
907 for ly in 0..roi_h {
908 let py = y0 + ly;
909 let row_base = py * stride_y + x0 * num_protos;
910 for lx in 0..roi_w {
911 let pix_base = row_base + lx * num_protos;
912 let proto_px = &protos[pix_base..pix_base + num_protos];
913 let raw_dot = unsafe {
914 dot_i8_neon_dotprod(
915 coeff.as_ptr(),
916 proto_px.as_ptr(),
917 num_protos,
918 )
919 };
920 let correction = if zp_c != 0 {
921 zp_c * proto_sums[py * proto_w + x0 + lx]
922 } else {
923 0
924 };
925 let logit = raw_dot - correction - bias;
926 if logit > 0 {
927 mask_buf[ly * roi_w + lx] = 255;
928 }
929 }
930 }
931 } else {
932 for ly in 0..roi_h {
933 let py = y0 + ly;
934 let row_base = py * stride_y + x0 * num_protos;
935 for lx in 0..roi_w {
936 let pix_base = row_base + lx * num_protos;
937 let proto_px = &protos[pix_base..pix_base + num_protos];
938 let raw_dot = unsafe {
939 dot_i8_neon_base(
940 coeff.as_ptr(),
941 proto_px.as_ptr(),
942 num_protos,
943 )
944 };
945 let correction = if zp_c != 0 {
946 zp_c * proto_sums[py * proto_w + x0 + lx]
947 } else {
948 0
949 };
950 let logit = raw_dot - correction - bias;
951 if logit > 0 {
952 mask_buf[ly * roi_w + lx] = 255;
953 }
954 }
955 }
956 }
957 }
958 #[cfg(not(target_arch = "aarch64"))]
959 {
960 for ly in 0..roi_h {
961 let py = y0 + ly;
962 let row_base = py * stride_y + x0 * num_protos;
963 for lx in 0..roi_w {
964 let pix_base = row_base + lx * num_protos;
965 let proto_px = &protos[pix_base..pix_base + num_protos];
966 let raw_dot = dot_i8_scalar(coeff, proto_px, num_protos);
967 let correction = if zp_c != 0 {
968 zp_c * proto_sums[py * proto_w + x0 + lx]
969 } else {
970 0
971 };
972 let logit = raw_dot - correction - bias;
973 if logit > 0 {
974 mask_buf[ly * roi_w + lx] = 255;
975 }
976 }
977 }
978 }
979 }
980 edgefirst_decoder::ProtoLayout::Nchw => {
981 let mut accum = vec![0i32; roi_h * roi_w];
985 for c in 0..num_protos {
986 let plane = &protos[c * hw..];
987 let coeff_c = coeff[c] as i32;
988 for ly in 0..roi_h {
989 let py = y0 + ly;
990 let row_start = py * proto_w + x0;
991 let out_row_start = ly * roi_w;
992 for lx in 0..roi_w {
993 accum[out_row_start + lx] += coeff_c * plane[row_start + lx] as i32;
994 }
995 }
996 }
997 for ly in 0..roi_h {
999 let py = y0 + ly;
1000 for lx in 0..roi_w {
1001 let idx = ly * roi_w + lx;
1002 let correction = if zp_c != 0 {
1003 zp_c * proto_sums[py * proto_w + x0 + lx]
1004 } else {
1005 0
1006 };
1007 let logit = accum[idx] - correction - bias;
1008 if logit > 0 {
1009 mask_buf[idx] = 255;
1010 }
1011 }
1012 }
1013 }
1014 }
1015
1016 let mask = ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1017 .expect("mask_buf length matches roi_h * roi_w");
1018 Ok(seg_from_roi(
1019 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
1020 ))
1021 })
1022 .collect()
1023}
1024
1025#[allow(clippy::too_many_arguments)]
1035fn fused_dot_sign_f32_slice(
1036 protos: &[f32],
1037 coeff: &[f32],
1038 _proto_h: usize,
1039 proto_w: usize,
1040 y0: usize,
1041 x0: usize,
1042 roi_h: usize,
1043 roi_w: usize,
1044 num_protos: usize,
1045) -> ndarray::Array3<u8> {
1046 let stride_y = proto_w * num_protos;
1047 let mut mask_buf = vec![0u8; roi_h * roi_w];
1048 for y in 0..roi_h {
1049 let row_base = (y0 + y) * stride_y + x0 * num_protos;
1050 let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1051 for (x, out_px) in out_row.iter_mut().enumerate() {
1052 let base = row_base + x * num_protos;
1053 let mut acc = 0.0_f32;
1054 let mut k = 0;
1055 let chunks = num_protos / 4;
1056 for _ in 0..chunks {
1057 acc += coeff[k] * protos[base + k]
1058 + coeff[k + 1] * protos[base + k + 1]
1059 + coeff[k + 2] * protos[base + k + 2]
1060 + coeff[k + 3] * protos[base + k + 3];
1061 k += 4;
1062 }
1063 while k < num_protos {
1064 acc += coeff[k] * protos[base + k];
1065 k += 1;
1066 }
1067 if acc > 0.0 {
1068 *out_px = 255;
1069 }
1070 }
1071 }
1072 ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1073 .expect("mask_buf length matches roi_h * roi_w")
1074}
1075
1076#[allow(clippy::too_many_arguments)]
1085fn fused_dot_sign_f16_slice(
1086 protos: &[half::f16],
1087 coeff: &[f32],
1088 _proto_h: usize,
1089 proto_w: usize,
1090 y0: usize,
1091 x0: usize,
1092 roi_h: usize,
1093 roi_w: usize,
1094 num_protos: usize,
1095) -> ndarray::Array3<u8> {
1096 #[cfg(all(
1097 target_arch = "x86_64",
1098 target_feature = "f16c",
1099 target_feature = "fma"
1100 ))]
1101 {
1102 unsafe {
1104 fused_dot_sign_f16_slice_f16c(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
1105 }
1106 }
1107 #[cfg(not(all(
1108 target_arch = "x86_64",
1109 target_feature = "f16c",
1110 target_feature = "fma"
1111 )))]
1112 {
1113 fused_dot_sign_f16_slice_scalar(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
1114 }
1115}
1116
1117#[allow(clippy::too_many_arguments)]
1119fn fused_dot_sign_f16_slice_scalar(
1120 protos: &[half::f16],
1121 coeff: &[f32],
1122 proto_w: usize,
1123 y0: usize,
1124 x0: usize,
1125 roi_h: usize,
1126 roi_w: usize,
1127 num_protos: usize,
1128) -> ndarray::Array3<u8> {
1129 let stride_y = proto_w * num_protos;
1130 let mut mask_buf = vec![0u8; roi_h * roi_w];
1131 for y in 0..roi_h {
1132 let row_base = (y0 + y) * stride_y + x0 * num_protos;
1133 let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1134 for (x, out_px) in out_row.iter_mut().enumerate() {
1135 let base = row_base + x * num_protos;
1136 let mut acc = 0.0_f32;
1137 let mut k = 0;
1138 let chunks = num_protos / 4;
1139 for _ in 0..chunks {
1140 acc += coeff[k] * protos[base + k].to_f32()
1141 + coeff[k + 1] * protos[base + k + 1].to_f32()
1142 + coeff[k + 2] * protos[base + k + 2].to_f32()
1143 + coeff[k + 3] * protos[base + k + 3].to_f32();
1144 k += 4;
1145 }
1146 while k < num_protos {
1147 acc += coeff[k] * protos[base + k].to_f32();
1148 k += 1;
1149 }
1150 if acc > 0.0 {
1151 *out_px = 255;
1152 }
1153 }
1154 }
1155 ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1156 .expect("mask_buf length matches roi_h * roi_w")
1157}
1158
1159#[cfg(all(
1169 target_arch = "x86_64",
1170 target_feature = "f16c",
1171 target_feature = "fma"
1172))]
1173#[allow(clippy::too_many_arguments)]
1174#[target_feature(enable = "f16c,fma,avx")]
1175unsafe fn fused_dot_sign_f16_slice_f16c(
1176 protos: &[half::f16],
1177 coeff: &[f32],
1178 proto_w: usize,
1179 y0: usize,
1180 x0: usize,
1181 roi_h: usize,
1182 roi_w: usize,
1183 num_protos: usize,
1184) -> ndarray::Array3<u8> {
1185 use core::arch::x86_64::{
1186 _mm256_castps256_ps128, _mm256_cvtph_ps, _mm256_extractf128_ps, _mm256_fmadd_ps,
1187 _mm256_loadu_ps, _mm256_setzero_ps, _mm_add_ps, _mm_cvtss_f32, _mm_hadd_ps,
1188 _mm_loadu_si128,
1189 };
1190
1191 let stride_y = proto_w * num_protos;
1192 let chunks8 = num_protos / 8;
1193 let mut mask_buf = vec![0u8; roi_h * roi_w];
1194
1195 for y in 0..roi_h {
1196 let row_base = (y0 + y) * stride_y + x0 * num_protos;
1197 let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1198 for (x, out_px) in out_row.iter_mut().enumerate() {
1199 let base = row_base + x * num_protos;
1200 let mut acc_v = _mm256_setzero_ps();
1201 let mut k = 0;
1202 for _ in 0..chunks8 {
1203 let p_ptr = protos
1204 .as_ptr()
1205 .add(base + k)
1206 .cast::<core::arch::x86_64::__m128i>();
1207 let raw = _mm_loadu_si128(p_ptr);
1208 let widened = _mm256_cvtph_ps(raw);
1209 let coeffs_v = _mm256_loadu_ps(coeff.as_ptr().add(k));
1210 acc_v = _mm256_fmadd_ps(coeffs_v, widened, acc_v);
1211 k += 8;
1212 }
1213 let lo = _mm256_castps256_ps128(acc_v);
1215 let hi = _mm256_extractf128_ps::<1>(acc_v);
1216 let sum4 = _mm_add_ps(lo, hi);
1217 let sum2 = _mm_hadd_ps(sum4, sum4);
1218 let sum1 = _mm_hadd_ps(sum2, sum2);
1219 let mut acc = _mm_cvtss_f32(sum1);
1220
1221 while k < num_protos {
1223 acc += coeff[k] * protos[base + k].to_f32();
1224 k += 1;
1225 }
1226
1227 if acc > 0.0 {
1228 *out_px = 255;
1229 }
1230 }
1231 }
1232 ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1233 .expect("mask_buf length matches roi_h * roi_w")
1234}
1235
1236#[allow(clippy::too_many_arguments)]
1240fn fused_dequant_dot_sign_i8_slice(
1241 protos: &[i8],
1242 coeff: &[f32],
1243 quant: &edgefirst_tensor::Quantization,
1244 _proto_h: usize,
1245 proto_w: usize,
1246 y0: usize,
1247 x0: usize,
1248 roi_h: usize,
1249 roi_w: usize,
1250 num_protos: usize,
1251) -> crate::Result<ndarray::Array3<u8>> {
1252 use edgefirst_tensor::QuantMode;
1253 let stride_y = proto_w * num_protos;
1254
1255 let mut stack_scratch = [0.0_f32; 64];
1257 let mut heap_scratch: Vec<f32>;
1258 let scaled_coeff: &mut [f32] = if num_protos <= stack_scratch.len() {
1259 &mut stack_scratch[..num_protos]
1260 } else {
1261 heap_scratch = vec![0.0_f32; num_protos];
1262 heap_scratch.as_mut_slice()
1263 };
1264 let zp_offset: f32;
1265 match quant.mode() {
1266 QuantMode::PerTensorSymmetric { scale } => {
1267 for k in 0..num_protos {
1268 scaled_coeff[k] = coeff[k] * scale;
1269 }
1270 zp_offset = 0.0;
1271 }
1272 QuantMode::PerTensor { scale, zero_point } => {
1273 for k in 0..num_protos {
1274 scaled_coeff[k] = coeff[k] * scale;
1275 }
1276 zp_offset = zero_point as f32 * scaled_coeff.iter().take(num_protos).sum::<f32>();
1277 }
1278 QuantMode::PerChannelSymmetric { scales, axis } => {
1279 if axis != 2 {
1280 return Err(crate::Error::NotSupported(format!(
1281 "per-channel quantization on axis {axis} not supported \
1282 (only channel axis 2 is implemented on this kernel)"
1283 )));
1284 }
1285 for k in 0..num_protos {
1286 scaled_coeff[k] = coeff[k] * scales[k];
1287 }
1288 zp_offset = 0.0;
1289 }
1290 QuantMode::PerChannel {
1291 scales,
1292 zero_points,
1293 axis,
1294 } => {
1295 if axis != 2 {
1296 return Err(crate::Error::NotSupported(format!(
1297 "per-channel quantization on axis {axis} not supported \
1298 (only channel axis 2 is implemented on this kernel)"
1299 )));
1300 }
1301 for k in 0..num_protos {
1302 scaled_coeff[k] = coeff[k] * scales[k];
1303 }
1304 zp_offset = (0..num_protos)
1305 .map(|k| scaled_coeff[k] * zero_points[k] as f32)
1306 .sum();
1307 }
1308 }
1309
1310 let mut mask_buf = vec![0u8; roi_h * roi_w];
1311 for y in 0..roi_h {
1312 let row_base = (y0 + y) * stride_y + (x0) * num_protos;
1313 let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1314 for (x, out_px) in out_row.iter_mut().enumerate() {
1315 let base = row_base + x * num_protos;
1316 let mut acc = 0.0_f32;
1317 let mut k = 0;
1318 let chunks = num_protos / 4;
1319 for _ in 0..chunks {
1320 let p0 = protos[base + k] as f32;
1321 let p1 = protos[base + k + 1] as f32;
1322 let p2 = protos[base + k + 2] as f32;
1323 let p3 = protos[base + k + 3] as f32;
1324 acc += scaled_coeff[k] * p0
1325 + scaled_coeff[k + 1] * p1
1326 + scaled_coeff[k + 2] * p2
1327 + scaled_coeff[k + 3] * p3;
1328 k += 4;
1329 }
1330 while k < num_protos {
1331 acc += scaled_coeff[k] * protos[base + k] as f32;
1332 k += 1;
1333 }
1334 if acc > zp_offset {
1335 *out_px = 255;
1336 }
1337 }
1338 }
1339 Ok(ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1340 .expect("mask_buf length matches roi_h * roi_w"))
1341}
1342
1343#[allow(clippy::too_many_arguments)]
1344fn scaled_segmentations_f32_slice(
1345 detect: &[crate::DetectBox],
1346 coeff_all: &[f32],
1347 protos: &[f32],
1348 proto_h: usize,
1349 proto_w: usize,
1350 num_protos: usize,
1351 letterbox: Option<[f32; 4]>,
1352 width: u32,
1353 height: u32,
1354) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1355 scaled_run(
1356 detect,
1357 coeff_all,
1358 protos,
1359 proto_h,
1360 proto_w,
1361 num_protos,
1362 letterbox,
1363 width,
1364 height,
1365 1.0,
1366 |p, _| *p,
1367 )
1368}
1369
1370#[allow(clippy::too_many_arguments)]
1371fn scaled_segmentations_f16_slice(
1372 detect: &[crate::DetectBox],
1373 coeff_all: &[f32],
1374 protos: &[half::f16],
1375 proto_h: usize,
1376 proto_w: usize,
1377 num_protos: usize,
1378 letterbox: Option<[f32; 4]>,
1379 width: u32,
1380 height: u32,
1381) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1382 scaled_run(
1383 detect,
1384 coeff_all,
1385 protos,
1386 proto_h,
1387 proto_w,
1388 num_protos,
1389 letterbox,
1390 width,
1391 height,
1392 1.0,
1393 |p: &half::f16, _| p.to_f32(),
1394 )
1395}
1396
1397#[allow(clippy::too_many_arguments)]
1398fn scaled_segmentations_i8_slice(
1399 detect: &[crate::DetectBox],
1400 coeff_all: &[f32],
1401 protos: &[i8],
1402 proto_h: usize,
1403 proto_w: usize,
1404 num_protos: usize,
1405 quant: &edgefirst_tensor::Quantization,
1406 letterbox: Option<[f32; 4]>,
1407 width: u32,
1408 height: u32,
1409) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1410 use edgefirst_tensor::QuantMode;
1411 let (scale, zp) = match quant.mode() {
1415 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
1416 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
1417 QuantMode::PerChannel { axis, .. } | QuantMode::PerChannelSymmetric { axis, .. } => {
1418 return Err(crate::Error::NotSupported(format!(
1419 "per-channel quantization (axis={axis}) on scaled seg path \
1420 not yet supported"
1421 )));
1422 }
1423 };
1424 scaled_run(
1425 detect,
1426 coeff_all,
1427 protos,
1428 proto_h,
1429 proto_w,
1430 num_protos,
1431 letterbox,
1432 width,
1433 height,
1434 scale,
1435 move |p: &i8, _| *p as f32 - zp,
1436 )
1437}
1438
1439#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
1455#[inline(always)]
1456fn dot_i8_scalar(coeff: &[i8], proto: &[i8], n: usize) -> i32 {
1457 let mut acc: i32 = 0;
1458 let chunks = n / 4;
1459 let mut k = 0;
1460 for _ in 0..chunks {
1461 acc += coeff[k] as i32 * proto[k] as i32
1462 + coeff[k + 1] as i32 * proto[k + 1] as i32
1463 + coeff[k + 2] as i32 * proto[k + 2] as i32
1464 + coeff[k + 3] as i32 * proto[k + 3] as i32;
1465 k += 4;
1466 }
1467 while k < n {
1468 acc += coeff[k] as i32 * proto[k] as i32;
1469 k += 1;
1470 }
1471 acc
1472}
1473
1474#[cfg(target_arch = "aarch64")]
1476#[inline(always)]
1477unsafe fn dot_i8_neon_base(coeff: *const i8, proto: *const i8, n: usize) -> i32 {
1478 use std::arch::aarch64::*;
1479 let mut acc = vdupq_n_s32(0);
1480 let full_chunks = n / 16;
1481 let mut offset = 0usize;
1482 for _ in 0..full_chunks {
1483 let c = vld1q_s8(coeff.add(offset));
1484 let p = vld1q_s8(proto.add(offset));
1485 let lo = vmull_s8(vget_low_s8(c), vget_low_s8(p));
1487 let hi = vmull_high_s8(c, p);
1488 acc = vpadalq_s16(acc, lo);
1489 acc = vpadalq_s16(acc, hi);
1490 offset += 16;
1491 }
1492 let remainder = n - offset;
1494 if remainder >= 8 {
1495 let c = vld1_s8(coeff.add(offset));
1496 let p = vld1_s8(proto.add(offset));
1497 let prod = vmull_s8(c, p);
1498 acc = vpadalq_s16(acc, prod);
1499 offset += 8;
1500 }
1501 let mut scalar_acc = vaddvq_s32(acc);
1502 while offset < n {
1503 scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1504 offset += 1;
1505 }
1506 scalar_acc
1507}
1508
1509#[cfg(target_arch = "aarch64")]
1513#[inline(always)]
1514unsafe fn dot_i8_neon_dotprod(coeff: *const i8, proto: *const i8, n: usize) -> i32 {
1515 use std::arch::aarch64::*;
1516 let mut acc = vdupq_n_s32(0);
1517 let full_chunks = n / 16;
1518 let mut offset = 0usize;
1519 for _ in 0..full_chunks {
1520 let c = vld1q_s8(coeff.add(offset));
1521 let p = vld1q_s8(proto.add(offset));
1522 let result: int32x4_t;
1526 core::arch::asm!(
1527 ".arch_extension dotprod",
1528 "sdot {acc:v}.4s, {a:v}.16b, {b:v}.16b",
1529 acc = inout(vreg) acc => result,
1530 a = in(vreg) c,
1531 b = in(vreg) p,
1532 options(pure, nomem, nostack),
1533 );
1534 acc = result;
1535 offset += 16;
1536 }
1537 let mut scalar_acc = vaddvq_s32(acc);
1538 while offset < n {
1540 scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1541 offset += 1;
1542 }
1543 scalar_acc
1544}
1545
1546#[cfg(target_arch = "aarch64")]
1549#[inline(always)]
1550#[allow(clippy::too_many_arguments)]
1551fn compute_logits_dotprod(
1552 logits: &mut [i32],
1553 coeff: &[i8],
1554 protos: &[i8],
1555 proto_sums: &[i32],
1556 proto_w: usize,
1557 proto_x0: usize,
1558 proto_y0: usize,
1559 roi_w: usize,
1560 roi_h: usize,
1561 stride_y: usize,
1562 num_protos: usize,
1563 zp_c: i32,
1564 bias: i32,
1565) {
1566 for ly_idx in 0..roi_h {
1567 let py = proto_y0 + ly_idx;
1568 let row_base = py * stride_y + proto_x0 * num_protos;
1569 for lx_idx in 0..roi_w {
1570 let pix_base = row_base + lx_idx * num_protos;
1571 let proto_px = &protos[pix_base..pix_base + num_protos];
1572 let raw_dot =
1573 unsafe { dot_i8_neon_dotprod(coeff.as_ptr(), proto_px.as_ptr(), num_protos) };
1574 let correction = if zp_c != 0 {
1575 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
1576 } else {
1577 0
1578 };
1579 logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
1580 }
1581 }
1582}
1583
1584#[cfg(target_arch = "aarch64")]
1587#[inline(always)]
1588#[allow(clippy::too_many_arguments)]
1589fn compute_logits_base(
1590 logits: &mut [i32],
1591 coeff: &[i8],
1592 protos: &[i8],
1593 proto_sums: &[i32],
1594 proto_w: usize,
1595 proto_x0: usize,
1596 proto_y0: usize,
1597 roi_w: usize,
1598 roi_h: usize,
1599 stride_y: usize,
1600 num_protos: usize,
1601 zp_c: i32,
1602 bias: i32,
1603) {
1604 for ly_idx in 0..roi_h {
1605 let py = proto_y0 + ly_idx;
1606 let row_base = py * stride_y + proto_x0 * num_protos;
1607 for lx_idx in 0..roi_w {
1608 let pix_base = row_base + lx_idx * num_protos;
1609 let proto_px = &protos[pix_base..pix_base + num_protos];
1610 let raw_dot =
1611 unsafe { dot_i8_neon_base(coeff.as_ptr(), proto_px.as_ptr(), num_protos) };
1612 let correction = if zp_c != 0 {
1613 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
1614 } else {
1615 0
1616 };
1617 logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
1618 }
1619 }
1620}
1621
1622#[allow(clippy::too_many_arguments)]
1623fn scaled_segmentations_i8_i8(
1624 detect: &[crate::DetectBox],
1625 coeff_all: &[i8],
1626 coeff_quant: &edgefirst_tensor::Quantization,
1627 protos: &[i8],
1628 proto_quant: &edgefirst_tensor::Quantization,
1629 proto_h: usize,
1630 proto_w: usize,
1631 num_protos: usize,
1632 letterbox: Option<[f32; 4]>,
1633 width: u32,
1634 height: u32,
1635 layout: edgefirst_decoder::ProtoLayout,
1636) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1637 use edgefirst_tensor::QuantMode;
1638
1639 let _span = tracing::trace_span!(
1640 "mask_i8_fastpath",
1641 n = detect.len(),
1642 proto_h,
1643 proto_w,
1644 num_protos,
1645 width,
1646 height,
1647 ?layout,
1648 )
1649 .entered();
1650
1651 let zp_c: i32 = match coeff_quant.mode() {
1652 QuantMode::PerTensor { zero_point, .. } => zero_point,
1653 QuantMode::PerTensorSymmetric { .. } => 0,
1654 _ => {
1655 return Err(crate::Error::NotSupported(
1656 "per-channel coeff quantization not supported".into(),
1657 ))
1658 }
1659 };
1660 let zp_p: i32 = match proto_quant.mode() {
1661 QuantMode::PerTensor { zero_point, .. } => zero_point,
1662 QuantMode::PerTensorSymmetric { .. } => 0,
1663 _ => {
1664 return Err(crate::Error::NotSupported(
1665 "per-channel proto quantization not supported".into(),
1666 ))
1667 }
1668 };
1669
1670 let (lx0, lw, ly0, lh) = match letterbox {
1671 Some([lx0, ly0, lx1, ly1]) => {
1672 let lw = (lx1 - lx0).max(f32::EPSILON);
1673 let lh = (ly1 - ly0).max(f32::EPSILON);
1674 (lx0, lw, ly0, lh)
1675 }
1676 None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
1677 };
1678 let out_w = width as usize;
1679 let out_h = height as usize;
1680 let hw = proto_h * proto_w;
1681
1682 let proto_sums: Vec<i32> = if zp_c != 0 {
1684 match layout {
1685 edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
1686 .map(|px_idx| {
1687 let base = px_idx * num_protos;
1688 let mut s: i32 = 0;
1689 for k in 0..num_protos {
1690 s += protos[base + k] as i32;
1691 }
1692 s
1693 })
1694 .collect(),
1695 edgefirst_decoder::ProtoLayout::Nchw => {
1696 let mut sums = vec![0i32; hw];
1697 for c in 0..num_protos {
1698 let plane = &protos[c * hw..];
1699 for (px, s) in sums.iter_mut().enumerate() {
1700 *s += plane[px] as i32;
1701 }
1702 }
1703 sums
1704 }
1705 }
1706 } else {
1707 Vec::new()
1708 };
1709
1710 #[cfg(target_arch = "aarch64")]
1712 let use_dotprod = std::arch::is_aarch64_feature_detected!("dotprod");
1713
1714 let stride_y = proto_w * num_protos;
1716
1717 detect
1718 .par_iter()
1719 .enumerate()
1720 .map(|(i, det)| {
1721 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1722 let bbox = det.bbox.to_canonical();
1723 let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
1724 let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
1725 let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
1726 let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
1727 let px0 = (xmin * out_w as f32).round() as usize;
1728 let py0 = (ymin * out_h as f32).round() as usize;
1729 let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
1730 let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
1731 let bbox_w = px1.saturating_sub(px0).max(1);
1732 let bbox_h = py1.saturating_sub(py0).max(1);
1733
1734 let sample_x_at = |px: f32| -> f32 {
1736 let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
1737 model_x_norm * proto_w as f32 - 0.5
1738 };
1739 let sample_y_at = |py: f32| -> f32 {
1740 let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
1741 model_y_norm * proto_h as f32 - 0.5
1742 };
1743 let s_x_min = sample_x_at(px0 as f32);
1744 let s_x_max = sample_x_at((px1 as f32) - 1.0);
1745 let s_y_min = sample_y_at(py0 as f32);
1746 let s_y_max = sample_y_at((py1 as f32) - 1.0);
1747 let proto_x0 = (s_x_min.floor() as isize)
1748 .max(0)
1749 .min(proto_w.saturating_sub(1) as isize) as usize;
1750 let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
1751 let proto_y0 = (s_y_min.floor() as isize)
1752 .max(0)
1753 .min(proto_h.saturating_sub(1) as isize) as usize;
1754 let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
1755 let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
1756 let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
1757
1758 let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
1760 let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
1761
1762 let mut logits = vec![0_i32; roi_h * roi_w];
1764 match layout {
1765 edgefirst_decoder::ProtoLayout::Nhwc => {
1766 #[cfg(target_arch = "aarch64")]
1767 {
1768 if use_dotprod {
1769 compute_logits_dotprod(
1770 &mut logits,
1771 coeff,
1772 protos,
1773 &proto_sums,
1774 proto_w,
1775 proto_x0,
1776 proto_y0,
1777 roi_w,
1778 roi_h,
1779 stride_y,
1780 num_protos,
1781 zp_c,
1782 bias,
1783 );
1784 } else {
1785 compute_logits_base(
1786 &mut logits,
1787 coeff,
1788 protos,
1789 &proto_sums,
1790 proto_w,
1791 proto_x0,
1792 proto_y0,
1793 roi_w,
1794 roi_h,
1795 stride_y,
1796 num_protos,
1797 zp_c,
1798 bias,
1799 );
1800 }
1801 }
1802 #[cfg(not(target_arch = "aarch64"))]
1803 {
1804 for ly_idx in 0..roi_h {
1805 let py = proto_y0 + ly_idx;
1806 let row_base = py * stride_y + proto_x0 * num_protos;
1807 for lx_idx in 0..roi_w {
1808 let pix_base = row_base + lx_idx * num_protos;
1809 let proto_px = &protos[pix_base..pix_base + num_protos];
1810 let raw_dot = dot_i8_scalar(coeff, proto_px, num_protos);
1811 let correction = if zp_c != 0 {
1812 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
1813 } else {
1814 0
1815 };
1816 logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
1817 }
1818 }
1819 }
1820 }
1821 edgefirst_decoder::ProtoLayout::Nchw => {
1822 for c in 0..num_protos {
1824 let plane = &protos[c * hw..];
1825 let coeff_c = coeff[c] as i32;
1826 for ly_idx in 0..roi_h {
1827 let py = proto_y0 + ly_idx;
1828 let row_start = py * proto_w + proto_x0;
1829 let out_row_start = ly_idx * roi_w;
1830 for lx_idx in 0..roi_w {
1831 logits[out_row_start + lx_idx] +=
1832 coeff_c * plane[row_start + lx_idx] as i32;
1833 }
1834 }
1835 }
1836 for ly_idx in 0..roi_h {
1838 let py = proto_y0 + ly_idx;
1839 for lx_idx in 0..roi_w {
1840 let idx = ly_idx * roi_w + lx_idx;
1841 let correction = if zp_c != 0 {
1842 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
1843 } else {
1844 0
1845 };
1846 logits[idx] -= correction + bias;
1847 }
1848 }
1849 }
1850 }
1851
1852 let roi_last_x = roi_w.saturating_sub(1);
1855 let roi_last_y = roi_h.saturating_sub(1);
1856
1857 const FRAC_BITS: i32 = 10;
1859 const FRAC_SCALE: i32 = 1 << FRAC_BITS; let x_coords: Vec<(usize, usize, i32)> = (0..bbox_w)
1861 .map(|xi| {
1862 let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
1863 let x_floor = sample_x.floor();
1864 let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as usize;
1865 let x_hi = (x_lo + 1).min(roi_w - 1);
1866 let x_frac = ((sample_x - x_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
1867 (x_lo, x_hi, x_frac)
1868 })
1869 .collect();
1870
1871 let mut tile_buf = vec![0u8; bbox_h * bbox_w];
1872 for yi in 0..bbox_h {
1873 let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
1874 let y_floor = sample_y.floor();
1875 let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
1876 let y_hi = (y_lo + 1).min(roi_h - 1);
1877 let y_frac = ((sample_y - y_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
1878 let y_frac_inv = FRAC_SCALE - y_frac;
1879 let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
1880 let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
1881 let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
1882
1883 for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
1884 let tl = row_lo[x_lo];
1885 let tr = row_lo[x_hi];
1886 let bl = row_hi[x_lo];
1887 let br = row_hi[x_hi];
1888
1889 if (tl & tr & bl & br) < 0 {
1893 continue;
1895 }
1896 if tl > 0 && tr > 0 && bl > 0 && br > 0 {
1897 out_row[xi] = 255;
1899 continue;
1900 }
1901
1902 let x_frac_inv = FRAC_SCALE - x_frac;
1904 let l0 = tl as i64 * x_frac_inv as i64 + tr as i64 * x_frac as i64;
1905 let l1 = bl as i64 * x_frac_inv as i64 + br as i64 * x_frac as i64;
1906 let logit = l0 * y_frac_inv as i64 + l1 * y_frac as i64;
1907 out_row[xi] = if logit > 0 { 255 } else { 0 };
1908 }
1909 }
1910
1911 let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
1912 .expect("tile_buf length matches bbox_h * bbox_w");
1913 Ok(edgefirst_decoder::Segmentation {
1914 xmin,
1915 ymin,
1916 xmax,
1917 ymax,
1918 segmentation: tile,
1919 })
1920 })
1921 .collect()
1922}
1923
1924#[allow(clippy::too_many_arguments)]
1925fn scaled_run<P: Copy + Sync>(
1926 detect: &[crate::DetectBox],
1927 coeff_all: &[f32],
1928 protos: &[P],
1929 proto_h: usize,
1930 proto_w: usize,
1931 num_protos: usize,
1932 letterbox: Option<[f32; 4]>,
1933 width: u32,
1934 height: u32,
1935 acc_scale: f32,
1936 load_f32: impl Fn(&P, f32) -> f32 + Copy + Sync,
1937) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1938 let (lx0, lw, ly0, lh) = match letterbox {
1939 Some([lx0, ly0, lx1, ly1]) => {
1940 let lw = (lx1 - lx0).max(f32::EPSILON);
1941 let lh = (ly1 - ly0).max(f32::EPSILON);
1942 (lx0, lw, ly0, lh)
1943 }
1944 None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
1945 };
1946 let out_w = width as usize;
1947 let out_h = height as usize;
1948 let stride_y = proto_w * num_protos;
1949
1950 detect
1972 .par_iter()
1973 .enumerate()
1974 .map(|(i, det)| {
1975 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1976 let bbox = det.bbox.to_canonical();
1977 let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
1978 let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
1979 let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
1980 let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
1981 let px0 = (xmin * out_w as f32).round() as usize;
1982 let py0 = (ymin * out_h as f32).round() as usize;
1983 let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
1984 let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
1985 let bbox_w = px1.saturating_sub(px0).max(1);
1986 let bbox_h = py1.saturating_sub(py0).max(1);
1987
1988 let sample_x_at = |px: f32| -> f32 {
1993 let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
1994 model_x_norm * proto_w as f32 - 0.5
1995 };
1996 let sample_y_at = |py: f32| -> f32 {
1997 let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
1998 model_y_norm * proto_h as f32 - 0.5
1999 };
2000 let s_x_min = sample_x_at(px0 as f32);
2001 let s_x_max = sample_x_at((px1 as f32) - 1.0);
2002 let s_y_min = sample_y_at(py0 as f32);
2003 let s_y_max = sample_y_at((py1 as f32) - 1.0);
2004 let proto_x0 = (s_x_min.floor() as isize)
2008 .max(0)
2009 .min(proto_w.saturating_sub(1) as isize) as usize;
2010 let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2011 let proto_y0 = (s_y_min.floor() as isize)
2012 .max(0)
2013 .min(proto_h.saturating_sub(1) as isize) as usize;
2014 let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2015 let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2016 let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2017
2018 if !acc_scale.is_finite() || acc_scale <= 0.0 {
2027 return Err(crate::Error::NotSupported(format!(
2028 "acc_scale must be finite and positive for sign-threshold optimization (got {acc_scale})"
2029 )));
2030 }
2031 let _ = acc_scale; let mut logits = vec![0.0_f32; roi_h * roi_w];
2033 for ly_idx in 0..roi_h {
2034 let py = proto_y0 + ly_idx;
2035 let row_base = py * stride_y + proto_x0 * num_protos;
2036 for lx_idx in 0..roi_w {
2037 let pix_base = row_base + lx_idx * num_protos;
2038 let mut acc = 0.0_f32;
2039 let mut k = 0;
2041 let chunks = num_protos / 4;
2042 for _ in 0..chunks {
2043 acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0)
2044 + coeff[k + 1] * load_f32(&protos[pix_base + k + 1], 0.0)
2045 + coeff[k + 2] * load_f32(&protos[pix_base + k + 2], 0.0)
2046 + coeff[k + 3] * load_f32(&protos[pix_base + k + 3], 0.0);
2047 k += 4;
2048 }
2049 while k < num_protos {
2050 acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0);
2051 k += 1;
2052 }
2053 logits[ly_idx * roi_w + lx_idx] = acc;
2054 }
2055 }
2056
2057 let roi_last_x = roi_w.saturating_sub(1);
2068 let roi_last_y = roi_h.saturating_sub(1);
2069
2070 let x_coords: Vec<(u32, u32, f32)> = (0..bbox_w)
2072 .map(|xi| {
2073 let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2074 let x_floor = sample_x.floor();
2075 let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as u32;
2076 let x_hi = (x_lo as usize + 1).min(roi_w - 1) as u32;
2077 let x_frac = (sample_x - x_floor).clamp(0.0, 1.0);
2078 (x_lo, x_hi, x_frac)
2079 })
2080 .collect();
2081
2082 let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2085 for yi in 0..bbox_h {
2086 let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2087 let y_floor = sample_y.floor();
2088 let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2089 let y_hi = (y_lo + 1).min(roi_h - 1);
2090 let y_frac = (sample_y - y_floor).clamp(0.0, 1.0);
2091 let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2092 let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2093 let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2094 for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2095 let (xl, xh) = (x_lo as usize, x_hi as usize);
2096 let l0 = row_lo[xl] + (row_lo[xh] - row_lo[xl]) * x_frac;
2097 let l1 = row_hi[xl] + (row_hi[xh] - row_hi[xl]) * x_frac;
2098 let logit = l0 + (l1 - l0) * y_frac;
2099 out_row[xi] = if logit > 0.0 { 255 } else { 0 };
2100 }
2101 }
2102 let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2104 .expect("tile_buf length matches bbox_h * bbox_w");
2105 Ok(edgefirst_decoder::Segmentation {
2106 xmin,
2107 ymin,
2108 xmax,
2109 ymax,
2110 segmentation: tile,
2111 })
2112 })
2113 .collect()
2114}