1use super::CPUProcessor;
5use crate::Result;
6use edgefirst_decoder::{DetectBox, Segmentation};
7use ndarray::Axis;
8use rayon::prelude::*;
9
10impl CPUProcessor {
11 #[allow(clippy::too_many_arguments)]
12 pub(super) fn render_modelpack_segmentation(
13 &mut self,
14 dst_w: usize,
15 dst_h: usize,
16 dst_rs: usize,
17 dst_c: usize,
18 dst_slice: &mut [u8],
19 segmentation: &Segmentation,
20 opacity: f32,
21 ) -> Result<()> {
22 use ndarray_stats::QuantileExt;
23
24 let seg = &segmentation.segmentation;
25 let [seg_height, seg_width, seg_classes] = *seg.shape() else {
26 unreachable!("Array3 did not have [usize; 3] as shape");
27 };
28 let start_y = (dst_h as f32 * segmentation.ymin).round();
29 let end_y = (dst_h as f32 * segmentation.ymax).round();
30 let start_x = (dst_w as f32 * segmentation.xmin).round();
31 let end_x = (dst_w as f32 * segmentation.xmax).round();
32
33 let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
34 let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
35
36 let start_x_u = (start_x as usize).min(dst_w);
37 let start_y_u = (start_y as usize).min(dst_h);
38 let end_x_u = (end_x as usize).min(dst_w);
39 let end_y_u = (end_y as usize).min(dst_h);
40
41 let argmax = seg.map_axis(Axis(2), |r| r.argmax().unwrap());
42 let get_value_at_nearest = |x: f32, y: f32| -> usize {
43 let x = x.round() as usize;
44 let y = y.round() as usize;
45 argmax
46 .get([y.min(seg_height - 1), x.min(seg_width - 1)])
47 .copied()
48 .unwrap_or(0)
49 };
50
51 for y in start_y_u..end_y_u {
52 for x in start_x_u..end_x_u {
53 let seg_x = (x as f32 - start_x) * scale_x;
54 let seg_y = (y as f32 - start_y) * scale_y;
55 let label = get_value_at_nearest(seg_x, seg_y);
56
57 if label == seg_classes - 1 {
58 continue;
59 }
60
61 let color = self.colors[label % self.colors.len()];
62
63 let alpha = if opacity == 1.0 {
64 color[3] as u16
65 } else {
66 (color[3] as f32 * opacity).round() as u16
67 };
68
69 let dst_index = (y * dst_rs) + (x * dst_c);
70 for c in 0..3 {
71 dst_slice[dst_index + c] = ((color[c] as u16 * alpha
72 + dst_slice[dst_index + c] as u16 * (255 - alpha))
73 / 255) as u8;
74 }
75 }
76 }
77
78 Ok(())
79 }
80
81 #[allow(clippy::too_many_arguments)]
82 pub(super) fn render_yolo_segmentation(
83 &mut self,
84 dst_w: usize,
85 dst_h: usize,
86 dst_rs: usize,
87 dst_c: usize,
88 dst_slice: &mut [u8],
89 segmentation: &Segmentation,
90 class: usize,
91 opacity: f32,
92 ) -> Result<()> {
93 let seg = &segmentation.segmentation;
94 let [seg_height, seg_width, classes] = *seg.shape() else {
95 unreachable!("Array3 did not have [usize;3] as shape");
96 };
97 debug_assert_eq!(classes, 1);
98
99 let start_y = (dst_h as f32 * segmentation.ymin).round();
100 let end_y = (dst_h as f32 * segmentation.ymax).round();
101 let start_x = (dst_w as f32 * segmentation.xmin).round();
102 let end_x = (dst_w as f32 * segmentation.xmax).round();
103
104 let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
105 let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
106
107 let start_x_u = (start_x as usize).min(dst_w);
108 let start_y_u = (start_y as usize).min(dst_h);
109 let end_x_u = (end_x as usize).min(dst_w);
110 let end_y_u = (end_y as usize).min(dst_h);
111
112 for y in start_y_u..end_y_u {
113 for x in start_x_u..end_x_u {
114 let seg_x = ((x as f32 - start_x) * scale_x) as usize;
115 let seg_y = ((y as f32 - start_y) * scale_y) as usize;
116 let val = *seg.get([seg_y, seg_x, 0]).unwrap_or(&0);
117
118 if val < 127 {
119 continue;
120 }
121
122 let color = self.colors[class % self.colors.len()];
123
124 let alpha = if opacity == 1.0 {
125 color[3] as u16
126 } else {
127 (color[3] as f32 * opacity).round() as u16
128 };
129
130 let dst_index = (y * dst_rs) + (x * dst_c);
131 for c in 0..3 {
132 dst_slice[dst_index + c] = ((color[c] as u16 * alpha
133 + dst_slice[dst_index + c] as u16 * (255 - alpha))
134 / 255) as u8;
135 }
136 }
137 }
138
139 Ok(())
140 }
141
142 #[allow(clippy::too_many_arguments)]
143 pub(super) fn render_box(
144 &mut self,
145 dst_w: usize,
146 dst_h: usize,
147 dst_rs: usize,
148 dst_c: usize,
149 dst_slice: &mut [u8],
150 detect: &[DetectBox],
151 color_mode: crate::ColorMode,
152 ) -> Result<()> {
153 const LINE_THICKNESS: usize = 3;
154
155 for (idx, d) in detect.iter().enumerate() {
156 use edgefirst_decoder::BoundingBox;
157
158 let color_index = color_mode.index(idx, d.label);
159 let [r, g, b, _] = self.colors[color_index % self.colors.len()];
160 let bbox = d.bbox.to_canonical();
161 let bbox = BoundingBox {
162 xmin: bbox.xmin.clamp(0.0, 1.0),
163 ymin: bbox.ymin.clamp(0.0, 1.0),
164 xmax: bbox.xmax.clamp(0.0, 1.0),
165 ymax: bbox.ymax.clamp(0.0, 1.0),
166 };
167 let inner = [
168 ((dst_w - 1) as f32 * bbox.xmin - 0.5).round() as usize,
169 ((dst_h - 1) as f32 * bbox.ymin - 0.5).round() as usize,
170 ((dst_w - 1) as f32 * bbox.xmax + 0.5).round() as usize,
171 ((dst_h - 1) as f32 * bbox.ymax + 0.5).round() as usize,
172 ];
173
174 let outer = [
175 inner[0].saturating_sub(LINE_THICKNESS),
176 inner[1].saturating_sub(LINE_THICKNESS),
177 (inner[2] + LINE_THICKNESS).min(dst_w),
178 (inner[3] + LINE_THICKNESS).min(dst_h),
179 ];
180
181 for y in outer[1] + 1..=inner[1] {
183 for x in outer[0] + 1..outer[2] {
184 let index = (y * dst_rs) + (x * dst_c);
185 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
186 }
187 }
188
189 for y in inner[1]..inner[3] {
191 for x in outer[0] + 1..=inner[0] {
192 let index = (y * dst_rs) + (x * dst_c);
193 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
194 }
195
196 for x in inner[2]..outer[2] {
197 let index = (y * dst_rs) + (x * dst_c);
198 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
199 }
200 }
201
202 for y in inner[3]..outer[3] {
204 for x in outer[0] + 1..outer[2] {
205 let index = (y * dst_rs) + (x * dst_c);
206 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
207 }
208 }
209 }
210 Ok(())
211 }
212
213 pub fn materialize_segmentations(
224 &self,
225 detect: &[crate::DetectBox],
226 proto_data: &crate::ProtoData,
227 letterbox: Option<[f32; 4]>,
228 ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
229 use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
230
231 let _span = tracing::trace_span!(
232 "materialize_masks",
233 mode = "proto",
234 n_detections = detect.len(),
235 )
236 .entered();
237
238 if detect.is_empty() {
239 return Ok(Vec::new());
240 }
241 let proto_shape = proto_data.protos.shape();
242 if proto_shape.len() != 3 {
243 return Err(crate::Error::InvalidShape(format!(
244 "protos tensor must be rank-3, got {proto_shape:?}"
245 )));
246 }
247 let (proto_h, proto_w, num_protos) = match proto_data.layout {
249 edgefirst_decoder::ProtoLayout::Nhwc => {
250 (proto_shape[0], proto_shape[1], proto_shape[2])
251 }
252 edgefirst_decoder::ProtoLayout::Nchw => {
253 (proto_shape[1], proto_shape[2], proto_shape[0])
254 }
255 };
256 let coeff_shape = proto_data.mask_coefficients.shape();
257 if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
258 return Err(crate::Error::InvalidShape(format!(
259 "mask_coefficients shape {coeff_shape:?} incompatible with protos \
260 {proto_shape:?} (expected [N, {num_protos}])"
261 )));
262 }
263 if coeff_shape[0] == 0 {
264 return Ok(Vec::new());
265 }
266 if coeff_shape[0] != detect.len() {
267 return Err(crate::Error::Internal(format!(
268 "mask_coefficients rows {} != detection count {}",
269 coeff_shape[0],
270 detect.len()
271 )));
272 }
273
274 let (lx0, inv_lw, ly0, inv_lh) = match letterbox {
276 Some([lx0, ly0, lx1, ly1]) => {
277 let lw = lx1 - lx0;
278 let lh = ly1 - ly0;
279 (
280 lx0,
281 if lw > 0.0 { 1.0 / lw } else { 1.0 },
282 ly0,
283 if lh > 0.0 { 1.0 / lh } else { 1.0 },
284 )
285 }
286 None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
287 };
288
289 if proto_data.mask_coefficients.dtype() == DType::I8
296 && proto_data.protos.dtype() == DType::I8
297 {
298 let coeff_t = proto_data
299 .mask_coefficients
300 .as_i8()
301 .expect("I8 coefficients");
302 let coeff_m = coeff_t.map()?;
303 let coeff_quant = coeff_t.quantization().ok_or_else(|| {
304 crate::Error::InvalidShape(
305 "I8 mask_coefficients require quantization metadata".into(),
306 )
307 })?;
308 let proto_t = proto_data.protos.as_i8().expect("I8 protos");
309 let proto_m = proto_t.map()?;
310 let proto_quant = proto_t.quantization().ok_or_else(|| {
311 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
312 })?;
313 match proto_segmentations_i8_i8(
314 detect,
315 coeff_m.as_slice(),
316 coeff_quant,
317 proto_m.as_slice(),
318 proto_quant,
319 proto_h,
320 proto_w,
321 num_protos,
322 lx0,
323 inv_lw,
324 ly0,
325 inv_lh,
326 proto_data.layout,
327 ) {
328 Ok(result) => return Ok(result),
329 Err(crate::Error::NotSupported(_)) => {
330 }
333 Err(e) => return Err(e),
334 }
335 }
336
337 if proto_data.mask_coefficients.dtype() == DType::I16
339 && proto_data.protos.dtype() == DType::I8
340 {
341 let coeff_t = proto_data
342 .mask_coefficients
343 .as_i16()
344 .expect("I16 coefficients");
345 let coeff_m = coeff_t.map()?;
346 if let Some(coeff_quant) = coeff_t.quantization() {
349 let proto_t = proto_data.protos.as_i8().expect("I8 protos");
350 let proto_m = proto_t.map()?;
351 let proto_quant = proto_t.quantization().ok_or_else(|| {
352 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
353 })?;
354 match proto_segmentations_i16_i8(
355 detect,
356 coeff_m.as_slice(),
357 coeff_quant,
358 proto_m.as_slice(),
359 proto_quant,
360 proto_h,
361 proto_w,
362 num_protos,
363 lx0,
364 inv_lw,
365 ly0,
366 inv_lh,
367 proto_data.layout,
368 ) {
369 Ok(result) => return Ok(result),
370 Err(crate::Error::NotSupported(_)) => {
371 }
373 Err(e) => return Err(e),
374 }
375 }
376 }
377
378 if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw
383 && proto_data.protos.dtype() != DType::I8
384 {
385 return Err(crate::Error::NotSupported(
386 "NCHW proto layout with non-I8 protos is not supported in the f32 fallback path"
387 .into(),
388 ));
389 }
390 let coeff_f32_storage: Vec<f32>;
391 let coeff_f32_slice: &[f32] = match proto_data.mask_coefficients.dtype() {
392 DType::F32 => {
393 let t = proto_data
394 .mask_coefficients
395 .as_f32()
396 .expect("dtype matched F32");
397 let m = t.map()?;
398 coeff_f32_storage = m.as_slice().to_vec();
399 &coeff_f32_storage[..]
400 }
401 DType::F16 => {
402 let t = proto_data
403 .mask_coefficients
404 .as_f16()
405 .expect("dtype matched F16");
406 let m = t.map()?;
407 coeff_f32_storage = m.as_slice().iter().map(|v| v.to_f32()).collect();
408 &coeff_f32_storage[..]
409 }
410 DType::I8 => {
411 let t = proto_data
412 .mask_coefficients
413 .as_i8()
414 .expect("dtype matched I8");
415 let m = t.map()?;
416 coeff_f32_storage = if let Some(q) = t.quantization() {
417 use edgefirst_tensor::QuantMode;
418 let (scale, zp) = match q.mode() {
419 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
420 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
421 other => {
422 return Err(crate::Error::NotSupported(format!(
423 "I8 mask_coefficients quantization mode {other:?} not supported"
424 )));
425 }
426 };
427 m.as_slice()
428 .iter()
429 .map(|&v| (v as f32 - zp) * scale)
430 .collect()
431 } else {
432 m.as_slice().iter().map(|&v| v as f32).collect()
433 };
434 &coeff_f32_storage[..]
435 }
436 DType::I16 => {
437 let t = proto_data
438 .mask_coefficients
439 .as_i16()
440 .expect("dtype matched I16");
441 let m = t.map()?;
442 coeff_f32_storage = if let Some(q) = t.quantization() {
443 use edgefirst_tensor::QuantMode;
444 let (scale, zp) = match q.mode() {
445 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
446 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
447 other => {
448 return Err(crate::Error::NotSupported(format!(
449 "I16 mask_coefficients quantization mode {other:?} not supported"
450 )));
451 }
452 };
453 m.as_slice()
454 .iter()
455 .map(|&v| (v as f32 - zp) * scale)
456 .collect()
457 } else {
458 m.as_slice().iter().map(|&v| v as f32).collect()
459 };
460 &coeff_f32_storage[..]
461 }
462 other => {
463 return Err(crate::Error::InvalidShape(format!(
464 "mask_coefficients dtype {other:?} not supported; expected F32, F16, I8, or I16"
465 )));
466 }
467 };
468
469 match proto_data.protos.dtype() {
475 DType::I8 => {
476 let t = proto_data.protos.as_i8().expect("dtype matched I8");
477 let quant = t.quantization().ok_or_else(|| {
478 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
479 })?;
480 let m = t.map()?;
481 let src_slice = m.as_slice();
482 let transposed_storage =
483 if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw {
484 let hw = proto_h * proto_w;
485 let mut nhwc = vec![0i8; hw * num_protos];
486 for c in 0..num_protos {
487 let plane = &src_slice[c * hw..(c + 1) * hw];
488 for px in 0..hw {
489 nhwc[px * num_protos + c] = plane[px];
490 }
491 }
492 Some(nhwc)
493 } else {
494 None
495 };
496 let protos_slice = transposed_storage.as_deref().unwrap_or(src_slice);
497 detect
498 .par_iter()
499 .enumerate()
500 .map(|(i, det)| {
501 let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
502 let (x0, y0, x1, y1, roi_w, roi_h) =
503 bbox_to_proto_roi(det, proto_w, proto_h);
504 let mask = fused_dequant_dot_sign_i8_slice(
505 protos_slice,
506 coeff,
507 quant,
508 proto_h,
509 proto_w,
510 y0,
511 x0,
512 roi_h,
513 roi_w,
514 num_protos,
515 )?;
516 Ok(seg_from_roi(
517 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
518 ))
519 })
520 .collect()
521 }
522 DType::F32 => {
523 let t = proto_data.protos.as_f32().expect("dtype matched F32");
524 let m = t.map()?;
525 let protos_slice = m.as_slice();
526 detect
527 .par_iter()
528 .enumerate()
529 .map(|(i, det)| {
530 let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
531 let (x0, y0, x1, y1, roi_w, roi_h) =
532 bbox_to_proto_roi(det, proto_w, proto_h);
533 let mask = fused_dot_sign_f32_slice(
534 protos_slice,
535 coeff,
536 proto_h,
537 proto_w,
538 y0,
539 x0,
540 roi_h,
541 roi_w,
542 num_protos,
543 );
544 Ok(seg_from_roi(
545 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
546 ))
547 })
548 .collect()
549 }
550 DType::F16 => {
551 let t = proto_data.protos.as_f16().expect("dtype matched F16");
552 let m = t.map()?;
553 let protos_slice = m.as_slice();
554 detect
555 .par_iter()
556 .enumerate()
557 .map(|(i, det)| {
558 let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
559 let (x0, y0, x1, y1, roi_w, roi_h) =
560 bbox_to_proto_roi(det, proto_w, proto_h);
561 let mask = fused_dot_sign_f16_slice(
562 protos_slice,
563 coeff,
564 proto_h,
565 proto_w,
566 y0,
567 x0,
568 roi_h,
569 roi_w,
570 num_protos,
571 );
572 Ok(seg_from_roi(
573 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
574 ))
575 })
576 .collect()
577 }
578 other => Err(crate::Error::InvalidShape(format!(
579 "proto tensor dtype {other:?} not supported"
580 ))),
581 }
582 }
583
584 pub fn materialize_scaled_segmentations(
597 &self,
598 detect: &[crate::DetectBox],
599 proto_data: &crate::ProtoData,
600 letterbox: Option<[f32; 4]>,
601 width: u32,
602 height: u32,
603 ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
604 use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
605
606 let _span = tracing::trace_span!(
607 "materialize_masks",
608 mode = "scaled",
609 n_detections = detect.len(),
610 width,
611 height,
612 )
613 .entered();
614
615 if detect.is_empty() {
616 return Ok(Vec::new());
617 }
618 if width == 0 || height == 0 {
619 return Err(crate::Error::InvalidShape(
620 "Scaled mask width/height must be positive".into(),
621 ));
622 }
623 let proto_shape = proto_data.protos.shape();
624 if proto_shape.len() != 3 {
625 return Err(crate::Error::InvalidShape(format!(
626 "protos tensor must be rank-3, got {proto_shape:?}"
627 )));
628 }
629 let (proto_h, proto_w, num_protos) = match proto_data.layout {
631 edgefirst_decoder::ProtoLayout::Nhwc => {
632 (proto_shape[0], proto_shape[1], proto_shape[2])
633 }
634 edgefirst_decoder::ProtoLayout::Nchw => {
635 (proto_shape[1], proto_shape[2], proto_shape[0])
636 }
637 };
638 let coeff_shape = proto_data.mask_coefficients.shape();
639 if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
640 return Err(crate::Error::InvalidShape(format!(
641 "mask_coefficients shape {coeff_shape:?} incompatible with protos \
642 {proto_shape:?}"
643 )));
644 }
645 if coeff_shape[0] == 0 {
646 return Ok(Vec::new());
647 }
648 if coeff_shape[0] != detect.len() {
649 return Err(crate::Error::Internal(format!(
650 "mask_coefficients rows {} != detection count {}",
651 coeff_shape[0],
652 detect.len()
653 )));
654 }
655
656 if proto_data.mask_coefficients.dtype() == DType::I8
662 && proto_data.protos.dtype() == DType::I8
663 {
664 let coeff_t = proto_data
665 .mask_coefficients
666 .as_i8()
667 .expect("I8 coefficients");
668 let coeff_m = coeff_t.map()?;
669 let coeff_quant = coeff_t.quantization().ok_or_else(|| {
670 crate::Error::InvalidShape(
671 "I8 mask_coefficients require quantization metadata".into(),
672 )
673 })?;
674 let proto_t = proto_data.protos.as_i8().expect("I8 protos");
675 let proto_m = proto_t.map()?;
676 let proto_quant = proto_t.quantization().ok_or_else(|| {
677 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
678 })?;
679 match scaled_segmentations_i8_i8(
680 detect,
681 coeff_m.as_slice(),
682 coeff_quant,
683 proto_m.as_slice(),
684 proto_quant,
685 proto_h,
686 proto_w,
687 num_protos,
688 letterbox,
689 width,
690 height,
691 proto_data.layout,
692 ) {
693 Ok(result) => return Ok(result),
694 Err(crate::Error::NotSupported(_)) => {
695 }
698 Err(e) => return Err(e),
699 }
700 }
701
702 if proto_data.mask_coefficients.dtype() == DType::I16
704 && proto_data.protos.dtype() == DType::I8
705 {
706 let coeff_t = proto_data
707 .mask_coefficients
708 .as_i16()
709 .expect("I16 coefficients");
710 let coeff_m = coeff_t.map()?;
711 if let Some(coeff_quant) = coeff_t.quantization() {
714 let proto_t = proto_data.protos.as_i8().expect("I8 protos");
715 let proto_m = proto_t.map()?;
716 let proto_quant = proto_t.quantization().ok_or_else(|| {
717 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
718 })?;
719 match scaled_segmentations_i16_i8(
720 detect,
721 coeff_m.as_slice(),
722 coeff_quant,
723 proto_m.as_slice(),
724 proto_quant,
725 proto_h,
726 proto_w,
727 num_protos,
728 letterbox,
729 width,
730 height,
731 proto_data.layout,
732 ) {
733 Ok(result) => return Ok(result),
734 Err(crate::Error::NotSupported(_)) => {}
735 Err(e) => return Err(e),
736 }
737 }
738 }
739
740 if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw
742 && proto_data.protos.dtype() != DType::I8
743 {
744 return Err(crate::Error::NotSupported(
745 "NCHW proto layout with non-I8 protos is not supported in the f32 fallback path"
746 .into(),
747 ));
748 }
749 let coeff_f32: Vec<f32> = match proto_data.mask_coefficients.dtype() {
750 DType::F32 => {
751 let t = proto_data.mask_coefficients.as_f32().expect("F32");
752 let m = t.map()?;
753 m.as_slice().to_vec()
754 }
755 DType::F16 => {
756 let t = proto_data.mask_coefficients.as_f16().expect("F16");
757 let m = t.map()?;
758 m.as_slice().iter().map(|v| v.to_f32()).collect()
759 }
760 DType::I8 => {
761 let t = proto_data.mask_coefficients.as_i8().expect("I8");
763 let m = t.map()?;
764 let q = t.quantization().ok_or_else(|| {
765 crate::Error::InvalidShape(
766 "I8 mask_coefficients require quantization metadata".into(),
767 )
768 })?;
769 use edgefirst_tensor::QuantMode;
770 let (scale, zp) = match q.mode() {
771 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
772 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
773 _ => {
774 return Err(crate::Error::NotSupported(
775 "per-channel mask_coefficients not supported".into(),
776 ))
777 }
778 };
779 m.as_slice()
780 .iter()
781 .map(|&v| (v as f32 - zp) * scale)
782 .collect()
783 }
784 DType::I16 => {
785 let t = proto_data.mask_coefficients.as_i16().expect("I16");
786 let m = t.map()?;
787 if let Some(q) = t.quantization() {
788 use edgefirst_tensor::QuantMode;
789 let (scale, zp) = match q.mode() {
790 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
791 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
792 other => {
793 return Err(crate::Error::NotSupported(format!(
794 "I16 mask_coefficients quantization mode {other:?} not supported"
795 )))
796 }
797 };
798 m.as_slice()
799 .iter()
800 .map(|&v| (v as f32 - zp) * scale)
801 .collect()
802 } else {
803 m.as_slice().iter().map(|&v| v as f32).collect()
804 }
805 }
806 other => {
807 return Err(crate::Error::InvalidShape(format!(
808 "mask_coefficients dtype {other:?} not supported"
809 )));
810 }
811 };
812
813 match proto_data.protos.dtype() {
814 DType::F32 => {
815 let t = proto_data.protos.as_f32().expect("F32");
816 let m = t.map()?;
817 scaled_segmentations_f32_slice(
818 detect,
819 &coeff_f32,
820 m.as_slice(),
821 proto_h,
822 proto_w,
823 num_protos,
824 letterbox,
825 width,
826 height,
827 )
828 }
829 DType::F16 => {
830 let t = proto_data.protos.as_f16().expect("F16");
831 let m = t.map()?;
832 scaled_segmentations_f16_slice(
833 detect,
834 &coeff_f32,
835 m.as_slice(),
836 proto_h,
837 proto_w,
838 num_protos,
839 letterbox,
840 width,
841 height,
842 )
843 }
844 DType::I8 => {
845 let t = proto_data.protos.as_i8().expect("I8");
846 let m = t.map()?;
847 let quant = t.quantization().ok_or_else(|| {
848 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
849 })?;
850 let src_slice = m.as_slice();
851 let transposed_storage =
852 if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw {
853 let hw = proto_h * proto_w;
854 let mut nhwc = vec![0i8; hw * num_protos];
855 for c in 0..num_protos {
856 let plane = &src_slice[c * hw..(c + 1) * hw];
857 for px in 0..hw {
858 nhwc[px * num_protos + c] = plane[px];
859 }
860 }
861 Some(nhwc)
862 } else {
863 None
864 };
865 let protos_slice = transposed_storage.as_deref().unwrap_or(src_slice);
866 scaled_segmentations_i8_slice(
867 detect,
868 &coeff_f32,
869 protos_slice,
870 proto_h,
871 proto_w,
872 num_protos,
873 quant,
874 letterbox,
875 width,
876 height,
877 )
878 }
879 other => Err(crate::Error::InvalidShape(format!(
880 "proto tensor dtype {other:?} not supported"
881 ))),
882 }
883 }
884}
885
886fn bbox_to_proto_roi(
904 det: &DetectBox,
905 proto_w: usize,
906 proto_h: usize,
907) -> (usize, usize, usize, usize, usize, usize) {
908 let bbox = det.bbox.to_canonical();
909 let xmin = bbox.xmin.clamp(0.0, 1.0);
910 let ymin = bbox.ymin.clamp(0.0, 1.0);
911 let xmax = bbox.xmax.clamp(0.0, 1.0);
912 let ymax = bbox.ymax.clamp(0.0, 1.0);
913 let x0 = ((xmin * proto_w as f32) as usize).min(proto_w.saturating_sub(1));
914 let y0 = ((ymin * proto_h as f32) as usize).min(proto_h.saturating_sub(1));
915 let x1 = ((xmax * proto_w as f32).ceil() as usize).min(proto_w);
916 let y1 = ((ymax * proto_h as f32).ceil() as usize).min(proto_h);
917 let roi_w = x1.saturating_sub(x0).max(1);
918 let roi_h = y1.saturating_sub(y0).max(1);
919 (x0, y0, x1, y1, roi_w, roi_h)
920}
921
922#[allow(clippy::too_many_arguments)]
926fn seg_from_roi(
927 mask: ndarray::Array3<u8>,
928 x0: usize,
929 y0: usize,
930 x1: usize,
931 y1: usize,
932 proto_w: usize,
933 proto_h: usize,
934 lx0: f32,
935 inv_lw: f32,
936 ly0: f32,
937 inv_lh: f32,
938) -> edgefirst_decoder::Segmentation {
939 let seg_xmin = ((x0 as f32 / proto_w as f32) - lx0) * inv_lw;
940 let seg_ymin = ((y0 as f32 / proto_h as f32) - ly0) * inv_lh;
941 let seg_xmax = ((x1 as f32 / proto_w as f32) - lx0) * inv_lw;
942 let seg_ymax = ((y1 as f32 / proto_h as f32) - ly0) * inv_lh;
943 edgefirst_decoder::Segmentation {
944 xmin: seg_xmin.clamp(0.0, 1.0),
945 ymin: seg_ymin.clamp(0.0, 1.0),
946 xmax: seg_xmax.clamp(0.0, 1.0),
947 ymax: seg_ymax.clamp(0.0, 1.0),
948 segmentation: mask,
949 }
950}
951
952#[allow(clippy::too_many_arguments)]
968fn proto_segmentations_i8_i8(
969 detect: &[crate::DetectBox],
970 coeff_all: &[i8],
971 coeff_quant: &edgefirst_tensor::Quantization,
972 protos: &[i8],
973 proto_quant: &edgefirst_tensor::Quantization,
974 proto_h: usize,
975 proto_w: usize,
976 num_protos: usize,
977 lx0: f32,
978 inv_lw: f32,
979 ly0: f32,
980 inv_lh: f32,
981 layout: edgefirst_decoder::ProtoLayout,
982) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
983 use edgefirst_tensor::QuantMode;
984
985 let _span = tracing::trace_span!(
986 "mask_i8_fastpath",
987 n = detect.len(),
988 proto_h,
989 proto_w,
990 num_protos,
991 ?layout,
992 )
993 .entered();
994
995 let zp_c: i32 = match coeff_quant.mode() {
996 QuantMode::PerTensor { zero_point, .. } => zero_point,
997 QuantMode::PerTensorSymmetric { .. } => 0,
998 _ => {
999 return Err(crate::Error::NotSupported(
1000 "per-channel coeff quantization not supported on proto-res i8 path".into(),
1001 ))
1002 }
1003 };
1004 let zp_p: i32 = match proto_quant.mode() {
1005 QuantMode::PerTensor { zero_point, .. } => zero_point,
1006 QuantMode::PerTensorSymmetric { .. } => 0,
1007 _ => {
1008 return Err(crate::Error::NotSupported(
1009 "per-channel proto quantization not supported on proto-res i8 path".into(),
1010 ))
1011 }
1012 };
1013
1014 let hw = proto_h * proto_w;
1015
1016 let proto_sums: Vec<i32> = if zp_c != 0 {
1018 match layout {
1019 edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
1020 .map(|px_idx| {
1021 let base = px_idx * num_protos;
1022 protos[base..base + num_protos]
1023 .iter()
1024 .map(|&v| v as i32)
1025 .sum()
1026 })
1027 .collect(),
1028 edgefirst_decoder::ProtoLayout::Nchw => {
1029 let mut sums = vec![0i32; hw];
1030 for c in 0..num_protos {
1031 let plane = &protos[c * hw..];
1032 for (px, s) in sums.iter_mut().enumerate() {
1033 *s += plane[px] as i32;
1034 }
1035 }
1036 sums
1037 }
1038 }
1039 } else {
1040 Vec::new()
1041 };
1042
1043 #[cfg(target_arch = "aarch64")]
1044 let use_dotprod = std::arch::is_aarch64_feature_detected!("dotprod");
1045
1046 detect
1047 .par_iter()
1048 .enumerate()
1049 .map(|(i, det)| {
1050 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1051 let (x0, y0, x1, y1, roi_w, roi_h) = bbox_to_proto_roi(det, proto_w, proto_h);
1052
1053 let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
1055 let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
1056
1057 let mut mask_buf = vec![0u8; roi_h * roi_w];
1058
1059 match layout {
1060 edgefirst_decoder::ProtoLayout::Nhwc => {
1061 let stride_y = proto_w * num_protos;
1062 #[cfg(target_arch = "aarch64")]
1063 {
1064 if use_dotprod {
1065 for ly in 0..roi_h {
1066 let py = y0 + ly;
1067 let row_base = py * stride_y + x0 * num_protos;
1068 for lx in 0..roi_w {
1069 let pix_base = row_base + lx * num_protos;
1070 let proto_px = &protos[pix_base..pix_base + num_protos];
1071 let raw_dot = unsafe {
1072 dot_i8_neon_dotprod(
1073 coeff.as_ptr(),
1074 proto_px.as_ptr(),
1075 num_protos,
1076 )
1077 };
1078 let correction = if zp_c != 0 {
1079 zp_c * proto_sums[py * proto_w + x0 + lx]
1080 } else {
1081 0
1082 };
1083 let logit = raw_dot - correction - bias;
1084 if logit > 0 {
1085 mask_buf[ly * roi_w + lx] = 255;
1086 }
1087 }
1088 }
1089 } else {
1090 for ly in 0..roi_h {
1091 let py = y0 + ly;
1092 let row_base = py * stride_y + x0 * num_protos;
1093 for lx in 0..roi_w {
1094 let pix_base = row_base + lx * num_protos;
1095 let proto_px = &protos[pix_base..pix_base + num_protos];
1096 let raw_dot = unsafe {
1097 dot_i8_neon_base(
1098 coeff.as_ptr(),
1099 proto_px.as_ptr(),
1100 num_protos,
1101 )
1102 };
1103 let correction = if zp_c != 0 {
1104 zp_c * proto_sums[py * proto_w + x0 + lx]
1105 } else {
1106 0
1107 };
1108 let logit = raw_dot - correction - bias;
1109 if logit > 0 {
1110 mask_buf[ly * roi_w + lx] = 255;
1111 }
1112 }
1113 }
1114 }
1115 }
1116 #[cfg(not(target_arch = "aarch64"))]
1117 {
1118 for ly in 0..roi_h {
1119 let py = y0 + ly;
1120 let row_base = py * stride_y + x0 * num_protos;
1121 for lx in 0..roi_w {
1122 let pix_base = row_base + lx * num_protos;
1123 let proto_px = &protos[pix_base..pix_base + num_protos];
1124 let raw_dot = dot_i8_scalar(coeff, proto_px, num_protos);
1125 let correction = if zp_c != 0 {
1126 zp_c * proto_sums[py * proto_w + x0 + lx]
1127 } else {
1128 0
1129 };
1130 let logit = raw_dot - correction - bias;
1131 if logit > 0 {
1132 mask_buf[ly * roi_w + lx] = 255;
1133 }
1134 }
1135 }
1136 }
1137 }
1138 edgefirst_decoder::ProtoLayout::Nchw => {
1139 let mut accum = vec![0i32; roi_h * roi_w];
1143 for c in 0..num_protos {
1144 let plane = &protos[c * hw..];
1145 let coeff_c = coeff[c] as i32;
1146 for ly in 0..roi_h {
1147 let py = y0 + ly;
1148 let row_start = py * proto_w + x0;
1149 let out_row_start = ly * roi_w;
1150 for lx in 0..roi_w {
1151 accum[out_row_start + lx] += coeff_c * plane[row_start + lx] as i32;
1152 }
1153 }
1154 }
1155 for ly in 0..roi_h {
1157 let py = y0 + ly;
1158 for lx in 0..roi_w {
1159 let idx = ly * roi_w + lx;
1160 let correction = if zp_c != 0 {
1161 zp_c * proto_sums[py * proto_w + x0 + lx]
1162 } else {
1163 0
1164 };
1165 let logit = accum[idx] - correction - bias;
1166 if logit > 0 {
1167 mask_buf[idx] = 255;
1168 }
1169 }
1170 }
1171 }
1172 }
1173
1174 let mask = ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1175 .expect("mask_buf length matches roi_h * roi_w");
1176 Ok(seg_from_roi(
1177 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
1178 ))
1179 })
1180 .collect()
1181}
1182
1183#[allow(clippy::too_many_arguments)]
1185fn proto_segmentations_i16_i8(
1186 detect: &[crate::DetectBox],
1187 coeff_all: &[i16],
1188 coeff_quant: &edgefirst_tensor::Quantization,
1189 protos: &[i8],
1190 proto_quant: &edgefirst_tensor::Quantization,
1191 proto_h: usize,
1192 proto_w: usize,
1193 num_protos: usize,
1194 lx0: f32,
1195 inv_lw: f32,
1196 ly0: f32,
1197 inv_lh: f32,
1198 layout: edgefirst_decoder::ProtoLayout,
1199) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1200 use edgefirst_tensor::QuantMode;
1201
1202 let _span = tracing::trace_span!(
1203 "mask_i16_i8_fastpath",
1204 n = detect.len(),
1205 proto_h,
1206 proto_w,
1207 num_protos,
1208 ?layout,
1209 )
1210 .entered();
1211
1212 let zp_c: i32 = match coeff_quant.mode() {
1213 QuantMode::PerTensor { zero_point, .. } => zero_point,
1214 QuantMode::PerTensorSymmetric { .. } => 0,
1215 _ => {
1216 return Err(crate::Error::NotSupported(
1217 "per-channel coeff quantization not supported on proto-res i16 path".into(),
1218 ))
1219 }
1220 };
1221 let zp_p: i32 = match proto_quant.mode() {
1222 QuantMode::PerTensor { zero_point, .. } => zero_point,
1223 QuantMode::PerTensorSymmetric { .. } => 0,
1224 _ => {
1225 return Err(crate::Error::NotSupported(
1226 "per-channel proto quantization not supported on proto-res i8 path".into(),
1227 ))
1228 }
1229 };
1230
1231 let hw = proto_h * proto_w;
1232
1233 let proto_sums: Vec<i32> = if zp_c != 0 {
1235 match layout {
1236 edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
1237 .map(|px_idx| {
1238 let base = px_idx * num_protos;
1239 protos[base..base + num_protos]
1240 .iter()
1241 .map(|&v| v as i32)
1242 .sum()
1243 })
1244 .collect(),
1245 edgefirst_decoder::ProtoLayout::Nchw => {
1246 let mut sums = vec![0i32; hw];
1247 for c in 0..num_protos {
1248 let plane = &protos[c * hw..];
1249 for (px, s) in sums.iter_mut().enumerate() {
1250 *s += plane[px] as i32;
1251 }
1252 }
1253 sums
1254 }
1255 }
1256 } else {
1257 Vec::new()
1258 };
1259
1260 detect
1261 .par_iter()
1262 .enumerate()
1263 .map(|(i, det)| {
1264 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1265 let (x0, y0, x1, y1, roi_w, roi_h) = bbox_to_proto_roi(det, proto_w, proto_h);
1266
1267 let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
1269 let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
1270
1271 let mut mask_buf = vec![0u8; roi_h * roi_w];
1272
1273 match layout {
1274 edgefirst_decoder::ProtoLayout::Nhwc => {
1275 let stride_y = proto_w * num_protos;
1276 #[cfg(target_arch = "aarch64")]
1277 {
1278 for ly in 0..roi_h {
1279 let py = y0 + ly;
1280 let row_base = py * stride_y + x0 * num_protos;
1281 for lx in 0..roi_w {
1282 let pix_base = row_base + lx * num_protos;
1283 let proto_px = &protos[pix_base..pix_base + num_protos];
1284 let raw_dot = unsafe {
1285 dot_i16_i8_neon(coeff.as_ptr(), proto_px.as_ptr(), num_protos)
1286 };
1287 let correction = if zp_c != 0 {
1288 zp_c * proto_sums[py * proto_w + x0 + lx]
1289 } else {
1290 0
1291 };
1292 let logit = raw_dot - correction - bias;
1293 if logit > 0 {
1294 mask_buf[ly * roi_w + lx] = 255;
1295 }
1296 }
1297 }
1298 }
1299 #[cfg(not(target_arch = "aarch64"))]
1300 {
1301 for ly in 0..roi_h {
1302 let py = y0 + ly;
1303 let row_base = py * stride_y + x0 * num_protos;
1304 for lx in 0..roi_w {
1305 let pix_base = row_base + lx * num_protos;
1306 let proto_px = &protos[pix_base..pix_base + num_protos];
1307 let raw_dot = dot_i16_i8_scalar(coeff, proto_px, num_protos);
1308 let correction = if zp_c != 0 {
1309 zp_c * proto_sums[py * proto_w + x0 + lx]
1310 } else {
1311 0
1312 };
1313 let logit = raw_dot - correction - bias;
1314 if logit > 0 {
1315 mask_buf[ly * roi_w + lx] = 255;
1316 }
1317 }
1318 }
1319 }
1320 }
1321 edgefirst_decoder::ProtoLayout::Nchw => {
1322 let mut accum = vec![0i32; roi_h * roi_w];
1326 for c in 0..num_protos {
1327 let plane = &protos[c * hw..];
1328 let coeff_c = coeff[c] as i32;
1329 for ly in 0..roi_h {
1330 let py = y0 + ly;
1331 let row_start = py * proto_w + x0;
1332 let out_row_start = ly * roi_w;
1333 for lx in 0..roi_w {
1334 accum[out_row_start + lx] += coeff_c * plane[row_start + lx] as i32;
1335 }
1336 }
1337 }
1338 for ly in 0..roi_h {
1340 let py = y0 + ly;
1341 for lx in 0..roi_w {
1342 let idx = ly * roi_w + lx;
1343 let correction = if zp_c != 0 {
1344 zp_c * proto_sums[py * proto_w + x0 + lx]
1345 } else {
1346 0
1347 };
1348 let logit = accum[idx] - correction - bias;
1349 if logit > 0 {
1350 mask_buf[idx] = 255;
1351 }
1352 }
1353 }
1354 }
1355 }
1356
1357 let mask = ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1358 .expect("mask_buf length matches roi_h * roi_w");
1359 Ok(seg_from_roi(
1360 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
1361 ))
1362 })
1363 .collect()
1364}
1365
1366#[allow(clippy::too_many_arguments)]
1377fn fused_dot_sign_f32_slice(
1378 protos: &[f32],
1379 coeff: &[f32],
1380 _proto_h: usize,
1381 proto_w: usize,
1382 y0: usize,
1383 x0: usize,
1384 roi_h: usize,
1385 roi_w: usize,
1386 num_protos: usize,
1387) -> ndarray::Array3<u8> {
1388 let stride_y = proto_w * num_protos;
1389 let mut mask_buf = vec![0u8; roi_h * roi_w];
1390 for y in 0..roi_h {
1391 let row_base = (y0 + y) * stride_y + x0 * num_protos;
1392 let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1393 for (x, out_px) in out_row.iter_mut().enumerate() {
1394 let base = row_base + x * num_protos;
1395 let mut acc = 0.0_f32;
1396 let mut k = 0;
1397 let chunks = num_protos / 4;
1398 for _ in 0..chunks {
1399 acc += coeff[k] * protos[base + k]
1400 + coeff[k + 1] * protos[base + k + 1]
1401 + coeff[k + 2] * protos[base + k + 2]
1402 + coeff[k + 3] * protos[base + k + 3];
1403 k += 4;
1404 }
1405 while k < num_protos {
1406 acc += coeff[k] * protos[base + k];
1407 k += 1;
1408 }
1409 if acc > 0.0 {
1410 *out_px = 255;
1411 }
1412 }
1413 }
1414 ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1415 .expect("mask_buf length matches roi_h * roi_w")
1416}
1417
1418#[allow(clippy::too_many_arguments)]
1427fn fused_dot_sign_f16_slice(
1428 protos: &[half::f16],
1429 coeff: &[f32],
1430 _proto_h: usize,
1431 proto_w: usize,
1432 y0: usize,
1433 x0: usize,
1434 roi_h: usize,
1435 roi_w: usize,
1436 num_protos: usize,
1437) -> ndarray::Array3<u8> {
1438 #[cfg(all(
1439 target_arch = "x86_64",
1440 target_feature = "f16c",
1441 target_feature = "fma"
1442 ))]
1443 {
1444 unsafe {
1446 fused_dot_sign_f16_slice_f16c(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
1447 }
1448 }
1449 #[cfg(not(all(
1450 target_arch = "x86_64",
1451 target_feature = "f16c",
1452 target_feature = "fma"
1453 )))]
1454 {
1455 fused_dot_sign_f16_slice_scalar(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
1456 }
1457}
1458
1459#[allow(clippy::too_many_arguments)]
1461fn fused_dot_sign_f16_slice_scalar(
1462 protos: &[half::f16],
1463 coeff: &[f32],
1464 proto_w: usize,
1465 y0: usize,
1466 x0: usize,
1467 roi_h: usize,
1468 roi_w: usize,
1469 num_protos: usize,
1470) -> ndarray::Array3<u8> {
1471 let stride_y = proto_w * num_protos;
1472 let mut mask_buf = vec![0u8; roi_h * roi_w];
1473 for y in 0..roi_h {
1474 let row_base = (y0 + y) * stride_y + x0 * num_protos;
1475 let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1476 for (x, out_px) in out_row.iter_mut().enumerate() {
1477 let base = row_base + x * num_protos;
1478 let mut acc = 0.0_f32;
1479 let mut k = 0;
1480 let chunks = num_protos / 4;
1481 for _ in 0..chunks {
1482 acc += coeff[k] * protos[base + k].to_f32()
1483 + coeff[k + 1] * protos[base + k + 1].to_f32()
1484 + coeff[k + 2] * protos[base + k + 2].to_f32()
1485 + coeff[k + 3] * protos[base + k + 3].to_f32();
1486 k += 4;
1487 }
1488 while k < num_protos {
1489 acc += coeff[k] * protos[base + k].to_f32();
1490 k += 1;
1491 }
1492 if acc > 0.0 {
1493 *out_px = 255;
1494 }
1495 }
1496 }
1497 ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1498 .expect("mask_buf length matches roi_h * roi_w")
1499}
1500
1501#[cfg(all(
1511 target_arch = "x86_64",
1512 target_feature = "f16c",
1513 target_feature = "fma"
1514))]
1515#[allow(clippy::too_many_arguments)]
1516#[target_feature(enable = "f16c,fma,avx")]
1517unsafe fn fused_dot_sign_f16_slice_f16c(
1518 protos: &[half::f16],
1519 coeff: &[f32],
1520 proto_w: usize,
1521 y0: usize,
1522 x0: usize,
1523 roi_h: usize,
1524 roi_w: usize,
1525 num_protos: usize,
1526) -> ndarray::Array3<u8> {
1527 use core::arch::x86_64::{
1528 _mm256_castps256_ps128, _mm256_cvtph_ps, _mm256_extractf128_ps, _mm256_fmadd_ps,
1529 _mm256_loadu_ps, _mm256_setzero_ps, _mm_add_ps, _mm_cvtss_f32, _mm_hadd_ps,
1530 _mm_loadu_si128,
1531 };
1532
1533 let stride_y = proto_w * num_protos;
1534 let chunks8 = num_protos / 8;
1535 let mut mask_buf = vec![0u8; roi_h * roi_w];
1536
1537 for y in 0..roi_h {
1538 let row_base = (y0 + y) * stride_y + x0 * num_protos;
1539 let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1540 for (x, out_px) in out_row.iter_mut().enumerate() {
1541 let base = row_base + x * num_protos;
1542 let mut acc_v = _mm256_setzero_ps();
1543 let mut k = 0;
1544 for _ in 0..chunks8 {
1545 let p_ptr = protos
1546 .as_ptr()
1547 .add(base + k)
1548 .cast::<core::arch::x86_64::__m128i>();
1549 let raw = _mm_loadu_si128(p_ptr);
1550 let widened = _mm256_cvtph_ps(raw);
1551 let coeffs_v = _mm256_loadu_ps(coeff.as_ptr().add(k));
1552 acc_v = _mm256_fmadd_ps(coeffs_v, widened, acc_v);
1553 k += 8;
1554 }
1555 let lo = _mm256_castps256_ps128(acc_v);
1557 let hi = _mm256_extractf128_ps::<1>(acc_v);
1558 let sum4 = _mm_add_ps(lo, hi);
1559 let sum2 = _mm_hadd_ps(sum4, sum4);
1560 let sum1 = _mm_hadd_ps(sum2, sum2);
1561 let mut acc = _mm_cvtss_f32(sum1);
1562
1563 while k < num_protos {
1565 acc += coeff[k] * protos[base + k].to_f32();
1566 k += 1;
1567 }
1568
1569 if acc > 0.0 {
1570 *out_px = 255;
1571 }
1572 }
1573 }
1574 ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1575 .expect("mask_buf length matches roi_h * roi_w")
1576}
1577
1578#[allow(clippy::too_many_arguments)]
1582fn fused_dequant_dot_sign_i8_slice(
1583 protos: &[i8],
1584 coeff: &[f32],
1585 quant: &edgefirst_tensor::Quantization,
1586 _proto_h: usize,
1587 proto_w: usize,
1588 y0: usize,
1589 x0: usize,
1590 roi_h: usize,
1591 roi_w: usize,
1592 num_protos: usize,
1593) -> crate::Result<ndarray::Array3<u8>> {
1594 use edgefirst_tensor::QuantMode;
1595 let stride_y = proto_w * num_protos;
1596
1597 let mut stack_scratch = [0.0_f32; 64];
1599 let mut heap_scratch: Vec<f32>;
1600 let scaled_coeff: &mut [f32] = if num_protos <= stack_scratch.len() {
1601 &mut stack_scratch[..num_protos]
1602 } else {
1603 heap_scratch = vec![0.0_f32; num_protos];
1604 heap_scratch.as_mut_slice()
1605 };
1606 let zp_offset: f32;
1607 match quant.mode() {
1608 QuantMode::PerTensorSymmetric { scale } => {
1609 for k in 0..num_protos {
1610 scaled_coeff[k] = coeff[k] * scale;
1611 }
1612 zp_offset = 0.0;
1613 }
1614 QuantMode::PerTensor { scale, zero_point } => {
1615 for k in 0..num_protos {
1616 scaled_coeff[k] = coeff[k] * scale;
1617 }
1618 zp_offset = zero_point as f32 * scaled_coeff.iter().take(num_protos).sum::<f32>();
1619 }
1620 QuantMode::PerChannelSymmetric { scales, axis } => {
1621 if axis != 2 {
1622 return Err(crate::Error::NotSupported(format!(
1623 "per-channel quantization on axis {axis} not supported \
1624 (only channel axis 2 is implemented on this kernel)"
1625 )));
1626 }
1627 for k in 0..num_protos {
1628 scaled_coeff[k] = coeff[k] * scales[k];
1629 }
1630 zp_offset = 0.0;
1631 }
1632 QuantMode::PerChannel {
1633 scales,
1634 zero_points,
1635 axis,
1636 } => {
1637 if axis != 2 {
1638 return Err(crate::Error::NotSupported(format!(
1639 "per-channel quantization on axis {axis} not supported \
1640 (only channel axis 2 is implemented on this kernel)"
1641 )));
1642 }
1643 for k in 0..num_protos {
1644 scaled_coeff[k] = coeff[k] * scales[k];
1645 }
1646 zp_offset = (0..num_protos)
1647 .map(|k| scaled_coeff[k] * zero_points[k] as f32)
1648 .sum();
1649 }
1650 }
1651
1652 let mut mask_buf = vec![0u8; roi_h * roi_w];
1653 for y in 0..roi_h {
1654 let row_base = (y0 + y) * stride_y + (x0) * num_protos;
1655 let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1656 for (x, out_px) in out_row.iter_mut().enumerate() {
1657 let base = row_base + x * num_protos;
1658 let mut acc = 0.0_f32;
1659 let mut k = 0;
1660 let chunks = num_protos / 4;
1661 for _ in 0..chunks {
1662 let p0 = protos[base + k] as f32;
1663 let p1 = protos[base + k + 1] as f32;
1664 let p2 = protos[base + k + 2] as f32;
1665 let p3 = protos[base + k + 3] as f32;
1666 acc += scaled_coeff[k] * p0
1667 + scaled_coeff[k + 1] * p1
1668 + scaled_coeff[k + 2] * p2
1669 + scaled_coeff[k + 3] * p3;
1670 k += 4;
1671 }
1672 while k < num_protos {
1673 acc += scaled_coeff[k] * protos[base + k] as f32;
1674 k += 1;
1675 }
1676 if acc > zp_offset {
1677 *out_px = 255;
1678 }
1679 }
1680 }
1681 Ok(ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1682 .expect("mask_buf length matches roi_h * roi_w"))
1683}
1684
1685#[allow(clippy::too_many_arguments)]
1686fn scaled_segmentations_f32_slice(
1687 detect: &[crate::DetectBox],
1688 coeff_all: &[f32],
1689 protos: &[f32],
1690 proto_h: usize,
1691 proto_w: usize,
1692 num_protos: usize,
1693 letterbox: Option<[f32; 4]>,
1694 width: u32,
1695 height: u32,
1696) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1697 scaled_run(
1698 detect,
1699 coeff_all,
1700 protos,
1701 proto_h,
1702 proto_w,
1703 num_protos,
1704 letterbox,
1705 width,
1706 height,
1707 1.0,
1708 |p, _| *p,
1709 )
1710}
1711
1712#[allow(clippy::too_many_arguments)]
1713fn scaled_segmentations_f16_slice(
1714 detect: &[crate::DetectBox],
1715 coeff_all: &[f32],
1716 protos: &[half::f16],
1717 proto_h: usize,
1718 proto_w: usize,
1719 num_protos: usize,
1720 letterbox: Option<[f32; 4]>,
1721 width: u32,
1722 height: u32,
1723) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1724 scaled_run(
1725 detect,
1726 coeff_all,
1727 protos,
1728 proto_h,
1729 proto_w,
1730 num_protos,
1731 letterbox,
1732 width,
1733 height,
1734 1.0,
1735 |p: &half::f16, _| p.to_f32(),
1736 )
1737}
1738
1739#[allow(clippy::too_many_arguments)]
1740fn scaled_segmentations_i8_slice(
1741 detect: &[crate::DetectBox],
1742 coeff_all: &[f32],
1743 protos: &[i8],
1744 proto_h: usize,
1745 proto_w: usize,
1746 num_protos: usize,
1747 quant: &edgefirst_tensor::Quantization,
1748 letterbox: Option<[f32; 4]>,
1749 width: u32,
1750 height: u32,
1751) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1752 use edgefirst_tensor::QuantMode;
1753 let (scale, zp) = match quant.mode() {
1757 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
1758 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
1759 QuantMode::PerChannel { axis, .. } | QuantMode::PerChannelSymmetric { axis, .. } => {
1760 return Err(crate::Error::NotSupported(format!(
1761 "per-channel quantization (axis={axis}) on scaled seg path \
1762 not yet supported"
1763 )));
1764 }
1765 };
1766 scaled_run(
1767 detect,
1768 coeff_all,
1769 protos,
1770 proto_h,
1771 proto_w,
1772 num_protos,
1773 letterbox,
1774 width,
1775 height,
1776 scale,
1777 move |p: &i8, _| *p as f32 - zp,
1778 )
1779}
1780
1781#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
1797#[inline(always)]
1798fn dot_i8_scalar(coeff: &[i8], proto: &[i8], n: usize) -> i32 {
1799 let mut acc: i32 = 0;
1800 let chunks = n / 4;
1801 let mut k = 0;
1802 for _ in 0..chunks {
1803 acc += coeff[k] as i32 * proto[k] as i32
1804 + coeff[k + 1] as i32 * proto[k + 1] as i32
1805 + coeff[k + 2] as i32 * proto[k + 2] as i32
1806 + coeff[k + 3] as i32 * proto[k + 3] as i32;
1807 k += 4;
1808 }
1809 while k < n {
1810 acc += coeff[k] as i32 * proto[k] as i32;
1811 k += 1;
1812 }
1813 acc
1814}
1815
1816#[cfg(target_arch = "aarch64")]
1818#[inline(always)]
1819unsafe fn dot_i8_neon_base(coeff: *const i8, proto: *const i8, n: usize) -> i32 {
1820 use std::arch::aarch64::*;
1821 let mut acc = vdupq_n_s32(0);
1822 let full_chunks = n / 16;
1823 let mut offset = 0usize;
1824 for _ in 0..full_chunks {
1825 let c = vld1q_s8(coeff.add(offset));
1826 let p = vld1q_s8(proto.add(offset));
1827 let lo = vmull_s8(vget_low_s8(c), vget_low_s8(p));
1829 let hi = vmull_high_s8(c, p);
1830 acc = vpadalq_s16(acc, lo);
1831 acc = vpadalq_s16(acc, hi);
1832 offset += 16;
1833 }
1834 let remainder = n - offset;
1836 if remainder >= 8 {
1837 let c = vld1_s8(coeff.add(offset));
1838 let p = vld1_s8(proto.add(offset));
1839 let prod = vmull_s8(c, p);
1840 acc = vpadalq_s16(acc, prod);
1841 offset += 8;
1842 }
1843 let mut scalar_acc = vaddvq_s32(acc);
1844 while offset < n {
1845 scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1846 offset += 1;
1847 }
1848 scalar_acc
1849}
1850
1851#[cfg(target_arch = "aarch64")]
1855#[inline(always)]
1856unsafe fn dot_i8_neon_dotprod(coeff: *const i8, proto: *const i8, n: usize) -> i32 {
1857 use std::arch::aarch64::*;
1858 let mut acc = vdupq_n_s32(0);
1859 let full_chunks = n / 16;
1860 let mut offset = 0usize;
1861 for _ in 0..full_chunks {
1862 let c = vld1q_s8(coeff.add(offset));
1863 let p = vld1q_s8(proto.add(offset));
1864 let result: int32x4_t;
1868 core::arch::asm!(
1869 ".arch_extension dotprod",
1870 "sdot {acc:v}.4s, {a:v}.16b, {b:v}.16b",
1871 acc = inout(vreg) acc => result,
1872 a = in(vreg) c,
1873 b = in(vreg) p,
1874 options(pure, nomem, nostack),
1875 );
1876 acc = result;
1877 offset += 16;
1878 }
1879 let mut scalar_acc = vaddvq_s32(acc);
1880 while offset < n {
1882 scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1883 offset += 1;
1884 }
1885 scalar_acc
1886}
1887
1888#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
1890#[inline(always)]
1891fn dot_i16_i8_scalar(coeff: &[i16], proto: &[i8], n: usize) -> i32 {
1892 let mut acc: i32 = 0;
1893 let chunks = n / 4;
1894 let mut k = 0;
1895 for _ in 0..chunks {
1896 acc += coeff[k] as i32 * proto[k] as i32
1897 + coeff[k + 1] as i32 * proto[k + 1] as i32
1898 + coeff[k + 2] as i32 * proto[k + 2] as i32
1899 + coeff[k + 3] as i32 * proto[k + 3] as i32;
1900 k += 4;
1901 }
1902 while k < n {
1903 acc += coeff[k] as i32 * proto[k] as i32;
1904 k += 1;
1905 }
1906 acc
1907}
1908
1909#[cfg(target_arch = "aarch64")]
1912#[inline(always)]
1913unsafe fn dot_i16_i8_neon(coeff: *const i16, proto: *const i8, n: usize) -> i32 {
1914 use std::arch::aarch64::*;
1915 let mut acc = vdupq_n_s32(0);
1916 let full_chunks = n / 8;
1917 let mut offset = 0usize;
1918 for _ in 0..full_chunks {
1919 let c = vld1q_s16(coeff.add(offset));
1920 let p_raw = vld1_s8(proto.add(offset));
1921 let p = vmovl_s8(p_raw);
1922 acc = vmlal_s16(acc, vget_low_s16(c), vget_low_s16(p));
1923 acc = vmlal_high_s16(acc, c, p);
1924 offset += 8;
1925 }
1926 let mut scalar_acc = vaddvq_s32(acc);
1927 while offset < n {
1928 scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1929 offset += 1;
1930 }
1931 scalar_acc
1932}
1933
1934#[cfg(target_arch = "aarch64")]
1937#[inline(always)]
1938#[allow(clippy::too_many_arguments)]
1939fn compute_logits_dotprod(
1940 logits: &mut [i32],
1941 coeff: &[i8],
1942 protos: &[i8],
1943 proto_sums: &[i32],
1944 proto_w: usize,
1945 proto_x0: usize,
1946 proto_y0: usize,
1947 roi_w: usize,
1948 roi_h: usize,
1949 stride_y: usize,
1950 num_protos: usize,
1951 zp_c: i32,
1952 bias: i32,
1953) {
1954 for ly_idx in 0..roi_h {
1955 let py = proto_y0 + ly_idx;
1956 let row_base = py * stride_y + proto_x0 * num_protos;
1957 for lx_idx in 0..roi_w {
1958 let pix_base = row_base + lx_idx * num_protos;
1959 let proto_px = &protos[pix_base..pix_base + num_protos];
1960 let raw_dot =
1961 unsafe { dot_i8_neon_dotprod(coeff.as_ptr(), proto_px.as_ptr(), num_protos) };
1962 let correction = if zp_c != 0 {
1963 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
1964 } else {
1965 0
1966 };
1967 logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
1968 }
1969 }
1970}
1971
1972#[cfg(target_arch = "aarch64")]
1975#[inline(always)]
1976#[allow(clippy::too_many_arguments)]
1977fn compute_logits_base(
1978 logits: &mut [i32],
1979 coeff: &[i8],
1980 protos: &[i8],
1981 proto_sums: &[i32],
1982 proto_w: usize,
1983 proto_x0: usize,
1984 proto_y0: usize,
1985 roi_w: usize,
1986 roi_h: usize,
1987 stride_y: usize,
1988 num_protos: usize,
1989 zp_c: i32,
1990 bias: i32,
1991) {
1992 for ly_idx in 0..roi_h {
1993 let py = proto_y0 + ly_idx;
1994 let row_base = py * stride_y + proto_x0 * num_protos;
1995 for lx_idx in 0..roi_w {
1996 let pix_base = row_base + lx_idx * num_protos;
1997 let proto_px = &protos[pix_base..pix_base + num_protos];
1998 let raw_dot =
1999 unsafe { dot_i8_neon_base(coeff.as_ptr(), proto_px.as_ptr(), num_protos) };
2000 let correction = if zp_c != 0 {
2001 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2002 } else {
2003 0
2004 };
2005 logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2006 }
2007 }
2008}
2009
2010#[allow(clippy::too_many_arguments)]
2011fn scaled_segmentations_i8_i8(
2012 detect: &[crate::DetectBox],
2013 coeff_all: &[i8],
2014 coeff_quant: &edgefirst_tensor::Quantization,
2015 protos: &[i8],
2016 proto_quant: &edgefirst_tensor::Quantization,
2017 proto_h: usize,
2018 proto_w: usize,
2019 num_protos: usize,
2020 letterbox: Option<[f32; 4]>,
2021 width: u32,
2022 height: u32,
2023 layout: edgefirst_decoder::ProtoLayout,
2024) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
2025 use edgefirst_tensor::QuantMode;
2026
2027 let _span = tracing::trace_span!(
2028 "mask_i8_fastpath",
2029 n = detect.len(),
2030 proto_h,
2031 proto_w,
2032 num_protos,
2033 width,
2034 height,
2035 ?layout,
2036 )
2037 .entered();
2038
2039 let zp_c: i32 = match coeff_quant.mode() {
2040 QuantMode::PerTensor { zero_point, .. } => zero_point,
2041 QuantMode::PerTensorSymmetric { .. } => 0,
2042 _ => {
2043 return Err(crate::Error::NotSupported(
2044 "per-channel coeff quantization not supported".into(),
2045 ))
2046 }
2047 };
2048 let zp_p: i32 = match proto_quant.mode() {
2049 QuantMode::PerTensor { zero_point, .. } => zero_point,
2050 QuantMode::PerTensorSymmetric { .. } => 0,
2051 _ => {
2052 return Err(crate::Error::NotSupported(
2053 "per-channel proto quantization not supported".into(),
2054 ))
2055 }
2056 };
2057
2058 let (lx0, lw, ly0, lh) = match letterbox {
2059 Some([lx0, ly0, lx1, ly1]) => {
2060 let lw = (lx1 - lx0).max(f32::EPSILON);
2061 let lh = (ly1 - ly0).max(f32::EPSILON);
2062 (lx0, lw, ly0, lh)
2063 }
2064 None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
2065 };
2066 let out_w = width as usize;
2067 let out_h = height as usize;
2068 let hw = proto_h * proto_w;
2069
2070 let proto_sums: Vec<i32> = if zp_c != 0 {
2072 match layout {
2073 edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
2074 .map(|px_idx| {
2075 let base = px_idx * num_protos;
2076 let mut s: i32 = 0;
2077 for k in 0..num_protos {
2078 s += protos[base + k] as i32;
2079 }
2080 s
2081 })
2082 .collect(),
2083 edgefirst_decoder::ProtoLayout::Nchw => {
2084 let mut sums = vec![0i32; hw];
2085 for c in 0..num_protos {
2086 let plane = &protos[c * hw..];
2087 for (px, s) in sums.iter_mut().enumerate() {
2088 *s += plane[px] as i32;
2089 }
2090 }
2091 sums
2092 }
2093 }
2094 } else {
2095 Vec::new()
2096 };
2097
2098 #[cfg(target_arch = "aarch64")]
2100 let use_dotprod = std::arch::is_aarch64_feature_detected!("dotprod");
2101
2102 let stride_y = proto_w * num_protos;
2104
2105 detect
2106 .par_iter()
2107 .enumerate()
2108 .map(|(i, det)| {
2109 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
2110 let bbox = det.bbox.to_canonical();
2111 let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
2112 let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
2113 let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
2114 let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
2115 let px0 = (xmin * out_w as f32).round() as usize;
2116 let py0 = (ymin * out_h as f32).round() as usize;
2117 let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
2118 let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
2119 let bbox_w = px1.saturating_sub(px0).max(1);
2120 let bbox_h = py1.saturating_sub(py0).max(1);
2121
2122 let sample_x_at = |px: f32| -> f32 {
2124 let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
2125 model_x_norm * proto_w as f32 - 0.5
2126 };
2127 let sample_y_at = |py: f32| -> f32 {
2128 let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
2129 model_y_norm * proto_h as f32 - 0.5
2130 };
2131 let s_x_min = sample_x_at(px0 as f32);
2132 let s_x_max = sample_x_at((px1 as f32) - 1.0);
2133 let s_y_min = sample_y_at(py0 as f32);
2134 let s_y_max = sample_y_at((py1 as f32) - 1.0);
2135 let proto_x0 = (s_x_min.floor() as isize)
2136 .max(0)
2137 .min(proto_w.saturating_sub(1) as isize) as usize;
2138 let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2139 let proto_y0 = (s_y_min.floor() as isize)
2140 .max(0)
2141 .min(proto_h.saturating_sub(1) as isize) as usize;
2142 let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2143 let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2144 let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2145
2146 let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
2148 let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
2149
2150 let mut logits = vec![0_i32; roi_h * roi_w];
2152 match layout {
2153 edgefirst_decoder::ProtoLayout::Nhwc => {
2154 #[cfg(target_arch = "aarch64")]
2155 {
2156 if use_dotprod {
2157 compute_logits_dotprod(
2158 &mut logits,
2159 coeff,
2160 protos,
2161 &proto_sums,
2162 proto_w,
2163 proto_x0,
2164 proto_y0,
2165 roi_w,
2166 roi_h,
2167 stride_y,
2168 num_protos,
2169 zp_c,
2170 bias,
2171 );
2172 } else {
2173 compute_logits_base(
2174 &mut logits,
2175 coeff,
2176 protos,
2177 &proto_sums,
2178 proto_w,
2179 proto_x0,
2180 proto_y0,
2181 roi_w,
2182 roi_h,
2183 stride_y,
2184 num_protos,
2185 zp_c,
2186 bias,
2187 );
2188 }
2189 }
2190 #[cfg(not(target_arch = "aarch64"))]
2191 {
2192 for ly_idx in 0..roi_h {
2193 let py = proto_y0 + ly_idx;
2194 let row_base = py * stride_y + proto_x0 * num_protos;
2195 for lx_idx in 0..roi_w {
2196 let pix_base = row_base + lx_idx * num_protos;
2197 let proto_px = &protos[pix_base..pix_base + num_protos];
2198 let raw_dot = dot_i8_scalar(coeff, proto_px, num_protos);
2199 let correction = if zp_c != 0 {
2200 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2201 } else {
2202 0
2203 };
2204 logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2205 }
2206 }
2207 }
2208 }
2209 edgefirst_decoder::ProtoLayout::Nchw => {
2210 for c in 0..num_protos {
2212 let plane = &protos[c * hw..];
2213 let coeff_c = coeff[c] as i32;
2214 for ly_idx in 0..roi_h {
2215 let py = proto_y0 + ly_idx;
2216 let row_start = py * proto_w + proto_x0;
2217 let out_row_start = ly_idx * roi_w;
2218 for lx_idx in 0..roi_w {
2219 logits[out_row_start + lx_idx] +=
2220 coeff_c * plane[row_start + lx_idx] as i32;
2221 }
2222 }
2223 }
2224 for ly_idx in 0..roi_h {
2226 let py = proto_y0 + ly_idx;
2227 for lx_idx in 0..roi_w {
2228 let idx = ly_idx * roi_w + lx_idx;
2229 let correction = if zp_c != 0 {
2230 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2231 } else {
2232 0
2233 };
2234 logits[idx] -= correction + bias;
2235 }
2236 }
2237 }
2238 }
2239
2240 let roi_last_x = roi_w.saturating_sub(1);
2243 let roi_last_y = roi_h.saturating_sub(1);
2244
2245 const FRAC_BITS: i32 = 10;
2247 const FRAC_SCALE: i32 = 1 << FRAC_BITS; let x_coords: Vec<(usize, usize, i32)> = (0..bbox_w)
2249 .map(|xi| {
2250 let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2251 let x_floor = sample_x.floor();
2252 let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as usize;
2253 let x_hi = (x_lo + 1).min(roi_w - 1);
2254 let x_frac = ((sample_x - x_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2255 (x_lo, x_hi, x_frac)
2256 })
2257 .collect();
2258
2259 let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2260 for yi in 0..bbox_h {
2261 let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2262 let y_floor = sample_y.floor();
2263 let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2264 let y_hi = (y_lo + 1).min(roi_h - 1);
2265 let y_frac = ((sample_y - y_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2266 let y_frac_inv = FRAC_SCALE - y_frac;
2267 let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2268 let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2269 let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2270
2271 for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2272 let tl = row_lo[x_lo];
2273 let tr = row_lo[x_hi];
2274 let bl = row_hi[x_lo];
2275 let br = row_hi[x_hi];
2276
2277 if (tl & tr & bl & br) < 0 {
2281 continue;
2283 }
2284 if tl > 0 && tr > 0 && bl > 0 && br > 0 {
2285 out_row[xi] = 255;
2287 continue;
2288 }
2289
2290 let x_frac_inv = FRAC_SCALE - x_frac;
2292 let l0 = tl as i64 * x_frac_inv as i64 + tr as i64 * x_frac as i64;
2293 let l1 = bl as i64 * x_frac_inv as i64 + br as i64 * x_frac as i64;
2294 let logit = l0 * y_frac_inv as i64 + l1 * y_frac as i64;
2295 out_row[xi] = if logit > 0 { 255 } else { 0 };
2296 }
2297 }
2298
2299 let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2300 .expect("tile_buf length matches bbox_h * bbox_w");
2301 Ok(edgefirst_decoder::Segmentation {
2302 xmin,
2303 ymin,
2304 xmax,
2305 ymax,
2306 segmentation: tile,
2307 })
2308 })
2309 .collect()
2310}
2311
2312#[allow(clippy::too_many_arguments)]
2313fn scaled_segmentations_i16_i8(
2314 detect: &[crate::DetectBox],
2315 coeff_all: &[i16],
2316 coeff_quant: &edgefirst_tensor::Quantization,
2317 protos: &[i8],
2318 proto_quant: &edgefirst_tensor::Quantization,
2319 proto_h: usize,
2320 proto_w: usize,
2321 num_protos: usize,
2322 letterbox: Option<[f32; 4]>,
2323 width: u32,
2324 height: u32,
2325 layout: edgefirst_decoder::ProtoLayout,
2326) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
2327 use edgefirst_tensor::QuantMode;
2328
2329 let _span = tracing::trace_span!(
2330 "mask_i16_i8_fastpath",
2331 n = detect.len(),
2332 proto_h,
2333 proto_w,
2334 num_protos,
2335 width,
2336 height,
2337 ?layout,
2338 )
2339 .entered();
2340
2341 let zp_c: i32 = match coeff_quant.mode() {
2342 QuantMode::PerTensor { zero_point, .. } => zero_point,
2343 QuantMode::PerTensorSymmetric { .. } => 0,
2344 _ => {
2345 return Err(crate::Error::NotSupported(
2346 "per-channel coeff quantization not supported".into(),
2347 ))
2348 }
2349 };
2350 let zp_p: i32 = match proto_quant.mode() {
2351 QuantMode::PerTensor { zero_point, .. } => zero_point,
2352 QuantMode::PerTensorSymmetric { .. } => 0,
2353 _ => {
2354 return Err(crate::Error::NotSupported(
2355 "per-channel proto quantization not supported".into(),
2356 ))
2357 }
2358 };
2359
2360 let (lx0, lw, ly0, lh) = match letterbox {
2361 Some([lx0, ly0, lx1, ly1]) => {
2362 let lw = (lx1 - lx0).max(f32::EPSILON);
2363 let lh = (ly1 - ly0).max(f32::EPSILON);
2364 (lx0, lw, ly0, lh)
2365 }
2366 None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
2367 };
2368 let out_w = width as usize;
2369 let out_h = height as usize;
2370 let hw = proto_h * proto_w;
2371
2372 let proto_sums: Vec<i32> = if zp_c != 0 {
2374 match layout {
2375 edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
2376 .map(|px_idx| {
2377 let base = px_idx * num_protos;
2378 let mut s: i32 = 0;
2379 for k in 0..num_protos {
2380 s += protos[base + k] as i32;
2381 }
2382 s
2383 })
2384 .collect(),
2385 edgefirst_decoder::ProtoLayout::Nchw => {
2386 let mut sums = vec![0i32; hw];
2387 for c in 0..num_protos {
2388 let plane = &protos[c * hw..];
2389 for (px, s) in sums.iter_mut().enumerate() {
2390 *s += plane[px] as i32;
2391 }
2392 }
2393 sums
2394 }
2395 }
2396 } else {
2397 Vec::new()
2398 };
2399
2400 let stride_y = proto_w * num_protos;
2402
2403 detect
2404 .par_iter()
2405 .enumerate()
2406 .map(|(i, det)| {
2407 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
2408 let bbox = det.bbox.to_canonical();
2409 let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
2410 let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
2411 let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
2412 let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
2413 let px0 = (xmin * out_w as f32).round() as usize;
2414 let py0 = (ymin * out_h as f32).round() as usize;
2415 let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
2416 let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
2417 let bbox_w = px1.saturating_sub(px0).max(1);
2418 let bbox_h = py1.saturating_sub(py0).max(1);
2419
2420 let sample_x_at = |px: f32| -> f32 {
2422 let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
2423 model_x_norm * proto_w as f32 - 0.5
2424 };
2425 let sample_y_at = |py: f32| -> f32 {
2426 let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
2427 model_y_norm * proto_h as f32 - 0.5
2428 };
2429 let s_x_min = sample_x_at(px0 as f32);
2430 let s_x_max = sample_x_at((px1 as f32) - 1.0);
2431 let s_y_min = sample_y_at(py0 as f32);
2432 let s_y_max = sample_y_at((py1 as f32) - 1.0);
2433 let proto_x0 = (s_x_min.floor() as isize)
2434 .max(0)
2435 .min(proto_w.saturating_sub(1) as isize) as usize;
2436 let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2437 let proto_y0 = (s_y_min.floor() as isize)
2438 .max(0)
2439 .min(proto_h.saturating_sub(1) as isize) as usize;
2440 let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2441 let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2442 let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2443
2444 let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
2446 let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
2447
2448 let mut logits = vec![0_i32; roi_h * roi_w];
2450 match layout {
2451 edgefirst_decoder::ProtoLayout::Nhwc => {
2452 #[cfg(target_arch = "aarch64")]
2453 {
2454 for ly_idx in 0..roi_h {
2455 let py = proto_y0 + ly_idx;
2456 let row_base = py * stride_y + proto_x0 * num_protos;
2457 for lx_idx in 0..roi_w {
2458 let pix_base = row_base + lx_idx * num_protos;
2459 let proto_px = &protos[pix_base..pix_base + num_protos];
2460 let raw_dot = unsafe {
2461 dot_i16_i8_neon(coeff.as_ptr(), proto_px.as_ptr(), num_protos)
2462 };
2463 let correction = if zp_c != 0 {
2464 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2465 } else {
2466 0
2467 };
2468 logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2469 }
2470 }
2471 }
2472 #[cfg(not(target_arch = "aarch64"))]
2473 {
2474 for ly_idx in 0..roi_h {
2475 let py = proto_y0 + ly_idx;
2476 let row_base = py * stride_y + proto_x0 * num_protos;
2477 for lx_idx in 0..roi_w {
2478 let pix_base = row_base + lx_idx * num_protos;
2479 let proto_px = &protos[pix_base..pix_base + num_protos];
2480 let raw_dot = dot_i16_i8_scalar(coeff, proto_px, num_protos);
2481 let correction = if zp_c != 0 {
2482 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2483 } else {
2484 0
2485 };
2486 logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2487 }
2488 }
2489 }
2490 }
2491 edgefirst_decoder::ProtoLayout::Nchw => {
2492 for c in 0..num_protos {
2494 let plane = &protos[c * hw..];
2495 let coeff_c = coeff[c] as i32;
2496 for ly_idx in 0..roi_h {
2497 let py = proto_y0 + ly_idx;
2498 let row_start = py * proto_w + proto_x0;
2499 let out_row_start = ly_idx * roi_w;
2500 for lx_idx in 0..roi_w {
2501 logits[out_row_start + lx_idx] +=
2502 coeff_c * plane[row_start + lx_idx] as i32;
2503 }
2504 }
2505 }
2506 for ly_idx in 0..roi_h {
2508 let py = proto_y0 + ly_idx;
2509 for lx_idx in 0..roi_w {
2510 let idx = ly_idx * roi_w + lx_idx;
2511 let correction = if zp_c != 0 {
2512 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2513 } else {
2514 0
2515 };
2516 logits[idx] -= correction + bias;
2517 }
2518 }
2519 }
2520 }
2521
2522 let roi_last_x = roi_w.saturating_sub(1);
2525 let roi_last_y = roi_h.saturating_sub(1);
2526
2527 const FRAC_BITS: i32 = 10;
2529 const FRAC_SCALE: i32 = 1 << FRAC_BITS; let x_coords: Vec<(usize, usize, i32)> = (0..bbox_w)
2531 .map(|xi| {
2532 let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2533 let x_floor = sample_x.floor();
2534 let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as usize;
2535 let x_hi = (x_lo + 1).min(roi_w - 1);
2536 let x_frac = ((sample_x - x_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2537 (x_lo, x_hi, x_frac)
2538 })
2539 .collect();
2540
2541 let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2542 for yi in 0..bbox_h {
2543 let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2544 let y_floor = sample_y.floor();
2545 let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2546 let y_hi = (y_lo + 1).min(roi_h - 1);
2547 let y_frac = ((sample_y - y_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2548 let y_frac_inv = FRAC_SCALE - y_frac;
2549 let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2550 let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2551 let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2552
2553 for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2554 let tl = row_lo[x_lo];
2555 let tr = row_lo[x_hi];
2556 let bl = row_hi[x_lo];
2557 let br = row_hi[x_hi];
2558
2559 if (tl & tr & bl & br) < 0 {
2563 continue;
2565 }
2566 if tl > 0 && tr > 0 && bl > 0 && br > 0 {
2567 out_row[xi] = 255;
2569 continue;
2570 }
2571
2572 let x_frac_inv = FRAC_SCALE - x_frac;
2574 let l0 = tl as i64 * x_frac_inv as i64 + tr as i64 * x_frac as i64;
2575 let l1 = bl as i64 * x_frac_inv as i64 + br as i64 * x_frac as i64;
2576 let logit = l0 * y_frac_inv as i64 + l1 * y_frac as i64;
2577 out_row[xi] = if logit > 0 { 255 } else { 0 };
2578 }
2579 }
2580
2581 let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2582 .expect("tile_buf length matches bbox_h * bbox_w");
2583 Ok(edgefirst_decoder::Segmentation {
2584 xmin,
2585 ymin,
2586 xmax,
2587 ymax,
2588 segmentation: tile,
2589 })
2590 })
2591 .collect()
2592}
2593
2594#[allow(clippy::too_many_arguments)]
2595fn scaled_run<P: Copy + Sync>(
2596 detect: &[crate::DetectBox],
2597 coeff_all: &[f32],
2598 protos: &[P],
2599 proto_h: usize,
2600 proto_w: usize,
2601 num_protos: usize,
2602 letterbox: Option<[f32; 4]>,
2603 width: u32,
2604 height: u32,
2605 acc_scale: f32,
2606 load_f32: impl Fn(&P, f32) -> f32 + Copy + Sync,
2607) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
2608 let (lx0, lw, ly0, lh) = match letterbox {
2609 Some([lx0, ly0, lx1, ly1]) => {
2610 let lw = (lx1 - lx0).max(f32::EPSILON);
2611 let lh = (ly1 - ly0).max(f32::EPSILON);
2612 (lx0, lw, ly0, lh)
2613 }
2614 None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
2615 };
2616 let out_w = width as usize;
2617 let out_h = height as usize;
2618 let stride_y = proto_w * num_protos;
2619
2620 detect
2642 .par_iter()
2643 .enumerate()
2644 .map(|(i, det)| {
2645 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
2646 let bbox = det.bbox.to_canonical();
2647 let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
2648 let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
2649 let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
2650 let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
2651 let px0 = (xmin * out_w as f32).round() as usize;
2652 let py0 = (ymin * out_h as f32).round() as usize;
2653 let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
2654 let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
2655 let bbox_w = px1.saturating_sub(px0).max(1);
2656 let bbox_h = py1.saturating_sub(py0).max(1);
2657
2658 let sample_x_at = |px: f32| -> f32 {
2663 let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
2664 model_x_norm * proto_w as f32 - 0.5
2665 };
2666 let sample_y_at = |py: f32| -> f32 {
2667 let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
2668 model_y_norm * proto_h as f32 - 0.5
2669 };
2670 let s_x_min = sample_x_at(px0 as f32);
2671 let s_x_max = sample_x_at((px1 as f32) - 1.0);
2672 let s_y_min = sample_y_at(py0 as f32);
2673 let s_y_max = sample_y_at((py1 as f32) - 1.0);
2674 let proto_x0 = (s_x_min.floor() as isize)
2678 .max(0)
2679 .min(proto_w.saturating_sub(1) as isize) as usize;
2680 let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2681 let proto_y0 = (s_y_min.floor() as isize)
2682 .max(0)
2683 .min(proto_h.saturating_sub(1) as isize) as usize;
2684 let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2685 let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2686 let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2687
2688 if !acc_scale.is_finite() || acc_scale <= 0.0 {
2697 return Err(crate::Error::NotSupported(format!(
2698 "acc_scale must be finite and positive for sign-threshold optimization (got {acc_scale})"
2699 )));
2700 }
2701 let _ = acc_scale; let mut logits = vec![0.0_f32; roi_h * roi_w];
2703 for ly_idx in 0..roi_h {
2704 let py = proto_y0 + ly_idx;
2705 let row_base = py * stride_y + proto_x0 * num_protos;
2706 for lx_idx in 0..roi_w {
2707 let pix_base = row_base + lx_idx * num_protos;
2708 let mut acc = 0.0_f32;
2709 let mut k = 0;
2711 let chunks = num_protos / 4;
2712 for _ in 0..chunks {
2713 acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0)
2714 + coeff[k + 1] * load_f32(&protos[pix_base + k + 1], 0.0)
2715 + coeff[k + 2] * load_f32(&protos[pix_base + k + 2], 0.0)
2716 + coeff[k + 3] * load_f32(&protos[pix_base + k + 3], 0.0);
2717 k += 4;
2718 }
2719 while k < num_protos {
2720 acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0);
2721 k += 1;
2722 }
2723 logits[ly_idx * roi_w + lx_idx] = acc;
2724 }
2725 }
2726
2727 let roi_last_x = roi_w.saturating_sub(1);
2738 let roi_last_y = roi_h.saturating_sub(1);
2739
2740 let x_coords: Vec<(u32, u32, f32)> = (0..bbox_w)
2742 .map(|xi| {
2743 let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2744 let x_floor = sample_x.floor();
2745 let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as u32;
2746 let x_hi = (x_lo as usize + 1).min(roi_w - 1) as u32;
2747 let x_frac = (sample_x - x_floor).clamp(0.0, 1.0);
2748 (x_lo, x_hi, x_frac)
2749 })
2750 .collect();
2751
2752 let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2755 for yi in 0..bbox_h {
2756 let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2757 let y_floor = sample_y.floor();
2758 let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2759 let y_hi = (y_lo + 1).min(roi_h - 1);
2760 let y_frac = (sample_y - y_floor).clamp(0.0, 1.0);
2761 let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2762 let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2763 let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2764 for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2765 let (xl, xh) = (x_lo as usize, x_hi as usize);
2766 let l0 = row_lo[xl] + (row_lo[xh] - row_lo[xl]) * x_frac;
2767 let l1 = row_hi[xl] + (row_hi[xh] - row_hi[xl]) * x_frac;
2768 let logit = l0 + (l1 - l0) * y_frac;
2769 out_row[xi] = if logit > 0.0 { 255 } else { 0 };
2770 }
2771 }
2772 let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2774 .expect("tile_buf length matches bbox_h * bbox_w");
2775 Ok(edgefirst_decoder::Segmentation {
2776 xmin,
2777 ymin,
2778 xmax,
2779 ymax,
2780 segmentation: tile,
2781 })
2782 })
2783 .collect()
2784}
2785
2786#[cfg(test)]
2787mod tests {
2788 use super::CPUProcessor;
2789 use edgefirst_decoder::{BoundingBox, DetectBox, ProtoData, ProtoLayout};
2790 use edgefirst_tensor::{Quantization, Tensor, TensorDyn};
2791
2792 const PROTO_H: usize = 4;
2793 const PROTO_W: usize = 4;
2794 const NUM_PROTOS: usize = 8;
2795
2796 fn det(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> DetectBox {
2797 DetectBox {
2798 bbox: BoundingBox {
2799 xmin,
2800 ymin,
2801 xmax,
2802 ymax,
2803 },
2804 score: 0.9,
2805 label: 0,
2806 }
2807 }
2808
2809 fn make_i8_quant(shape: &[usize], data: &[i8], scale: f32, zp: i32) -> TensorDyn {
2810 let t = Tensor::<i8>::from_slice(data, shape).unwrap();
2811 let t = t
2812 .with_quantization(Quantization::per_tensor(scale, zp))
2813 .unwrap();
2814 TensorDyn::I8(t)
2815 }
2816
2817 fn make_i16_quant(shape: &[usize], data: &[i16], scale: f32, zp: i32) -> TensorDyn {
2818 let t = Tensor::<i16>::from_slice(data, shape).unwrap();
2819 let t = t
2820 .with_quantization(Quantization::per_tensor(scale, zp))
2821 .unwrap();
2822 TensorDyn::I16(t)
2823 }
2824
2825 fn make_i16_raw(shape: &[usize], data: &[i16]) -> TensorDyn {
2826 let t = Tensor::<i16>::from_slice(data, shape).unwrap();
2827 TensorDyn::I16(t)
2828 }
2829
2830 fn make_f32(shape: &[usize], data: &[f32]) -> TensorDyn {
2831 let t = Tensor::<f32>::from_slice(data, shape).unwrap();
2832 TensorDyn::F32(t)
2833 }
2834
2835 fn gen_protos_i8(h: usize, w: usize, k: usize) -> Vec<i8> {
2836 (0..h * w * k).map(|i| (i % 127) as i8).collect()
2837 }
2838
2839 fn gen_coeffs_i16(n: usize, k: usize) -> Vec<i16> {
2840 (0..n * k)
2841 .map(|i| ((i as i32 % 201) - 100) as i16)
2842 .collect()
2843 }
2844
2845 fn gen_coeffs_i8(n: usize, k: usize) -> Vec<i8> {
2846 (0..n * k).map(|i| ((i as i32 % 201) - 100) as i8).collect()
2847 }
2848
2849 #[test]
2852 fn materialize_proto_i16_i8_quant_produces_masks() {
2853 let cpu = CPUProcessor::new();
2854 let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
2855 let protos = make_i8_quant(
2856 &[PROTO_H, PROTO_W, NUM_PROTOS],
2857 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2858 0.02,
2859 0,
2860 );
2861 let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 0);
2862 let proto_data = ProtoData {
2863 mask_coefficients: coeffs,
2864 protos,
2865 layout: ProtoLayout::Nhwc,
2866 };
2867 let result = cpu.materialize_segmentations(&detect, &proto_data, None);
2868 assert!(result.is_ok(), "materialize failed: {:?}", result.err());
2869 let segs = result.unwrap();
2870 assert_eq!(segs.len(), 1);
2871 let seg = &segs[0];
2872 assert!(seg.segmentation.shape()[0] > 0);
2873 assert!(seg.segmentation.shape()[1] > 0);
2874 }
2875
2876 #[test]
2879 fn materialize_proto_i16_no_quant_falls_back_to_f32() {
2880 let cpu = CPUProcessor::new();
2881 let detect = vec![det(0.2, 0.2, 0.8, 0.8)];
2882 let protos = make_i8_quant(
2883 &[PROTO_H, PROTO_W, NUM_PROTOS],
2884 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2885 0.02,
2886 0,
2887 );
2888 let coeffs = make_i16_raw(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS));
2891 let proto_data = ProtoData {
2892 mask_coefficients: coeffs,
2893 protos,
2894 layout: ProtoLayout::Nhwc,
2895 };
2896 let result = cpu.materialize_segmentations(&detect, &proto_data, None);
2897 assert!(
2898 result.is_ok(),
2899 "missing coeff quant should fall back to f32 path, got: {:?}",
2900 result.err()
2901 );
2902 assert_eq!(result.unwrap().len(), 1);
2903 }
2904
2905 #[test]
2908 fn materialize_scaled_i16_i8_quant_produces_masks() {
2909 let cpu = CPUProcessor::new();
2910 let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
2911 let protos = make_i8_quant(
2912 &[PROTO_H, PROTO_W, NUM_PROTOS],
2913 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2914 0.02,
2915 0,
2916 );
2917 let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 0);
2918 let proto_data = ProtoData {
2919 mask_coefficients: coeffs,
2920 protos,
2921 layout: ProtoLayout::Nhwc,
2922 };
2923 let result = cpu.materialize_scaled_segmentations(&detect, &proto_data, None, 64, 64);
2924 assert!(
2925 result.is_ok(),
2926 "materialize_scaled failed: {:?}",
2927 result.err()
2928 );
2929 let segs = result.unwrap();
2930 assert_eq!(segs.len(), 1);
2931 let seg = &segs[0];
2932 assert!(seg.segmentation.shape()[0] > 0);
2933 assert!(seg.segmentation.shape()[1] > 0);
2934 }
2935
2936 #[test]
2939 fn materialize_scaled_i16_no_quant_falls_back_to_f32() {
2940 let cpu = CPUProcessor::new();
2941 let detect = vec![det(0.2, 0.2, 0.8, 0.8)];
2942 let protos = make_i8_quant(
2943 &[PROTO_H, PROTO_W, NUM_PROTOS],
2944 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2945 0.02,
2946 0,
2947 );
2948 let coeffs = make_i16_raw(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS));
2949 let proto_data = ProtoData {
2950 mask_coefficients: coeffs,
2951 protos,
2952 layout: ProtoLayout::Nhwc,
2953 };
2954 let result = cpu.materialize_scaled_segmentations(&detect, &proto_data, None, 64, 64);
2955 assert!(
2956 result.is_ok(),
2957 "missing coeff quant should fall back to f32 path, got: {:?}",
2958 result.err()
2959 );
2960 assert_eq!(result.unwrap().len(), 1);
2961 }
2962
2963 #[test]
2966 fn materialize_proto_i16_i8_matches_f32_reference() {
2967 let cpu = CPUProcessor::new();
2968 let detect = vec![det(0.1, 0.1, 0.9, 0.9), det(0.3, 0.3, 0.7, 0.7)];
2969 let n_det = detect.len();
2970 let scale_c = 0.01_f32;
2971 let scale_p = 0.02_f32;
2972 let raw_protos = gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS);
2973 let raw_coeffs = gen_coeffs_i16(n_det, NUM_PROTOS);
2974
2975 let protos_f32: Vec<f32> = raw_protos.iter().map(|&v| v as f32 * scale_p).collect();
2977 let coeffs_f32: Vec<f32> = raw_coeffs.iter().map(|&v| v as f32 * scale_c).collect();
2978 let proto_data_f32 = ProtoData {
2979 mask_coefficients: make_f32(&[n_det, NUM_PROTOS], &coeffs_f32),
2980 protos: make_f32(&[PROTO_H, PROTO_W, NUM_PROTOS], &protos_f32),
2981 layout: ProtoLayout::Nhwc,
2982 };
2983
2984 let proto_data_int = ProtoData {
2985 mask_coefficients: make_i16_quant(&[n_det, NUM_PROTOS], &raw_coeffs, scale_c, 0),
2986 protos: make_i8_quant(&[PROTO_H, PROTO_W, NUM_PROTOS], &raw_protos, scale_p, 0),
2987 layout: ProtoLayout::Nhwc,
2988 };
2989
2990 let segs_f32 = cpu
2991 .materialize_segmentations(&detect, &proto_data_f32, None)
2992 .unwrap();
2993 let segs_int = cpu
2994 .materialize_segmentations(&detect, &proto_data_int, None)
2995 .unwrap();
2996
2997 assert_eq!(segs_f32.len(), segs_int.len());
2998 for (sf, si) in segs_f32.iter().zip(segs_int.iter()) {
2999 assert_eq!(sf.segmentation.shape(), si.segmentation.shape());
3000 let total = sf.segmentation.len();
3001 let mismatches = sf
3002 .segmentation
3003 .iter()
3004 .zip(si.segmentation.iter())
3005 .filter(|(a, b)| a != b)
3006 .count();
3007 let pct = mismatches as f64 / total as f64 * 100.0;
3008 assert!(
3009 pct < 5.0,
3010 "mask mismatch {mismatches}/{total} ({pct:.1}%) exceeds 5% threshold"
3011 );
3012 }
3013 }
3014
3015 #[test]
3018 fn materialize_proto_i16_multiple_detections() {
3019 let cpu = CPUProcessor::new();
3020 let detect = vec![
3021 det(0.0, 0.0, 0.5, 0.5),
3022 det(0.5, 0.5, 1.0, 1.0),
3023 det(0.1, 0.1, 0.3, 0.3),
3024 ];
3025 let protos = make_i8_quant(
3026 &[PROTO_H, PROTO_W, NUM_PROTOS],
3027 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3028 0.02,
3029 0,
3030 );
3031 let coeffs = make_i16_quant(&[3, NUM_PROTOS], &gen_coeffs_i16(3, NUM_PROTOS), 0.01, 0);
3032 let proto_data = ProtoData {
3033 mask_coefficients: coeffs,
3034 protos,
3035 layout: ProtoLayout::Nhwc,
3036 };
3037 let segs = cpu
3038 .materialize_segmentations(&detect, &proto_data, None)
3039 .unwrap();
3040 assert_eq!(segs.len(), 3);
3041 }
3042
3043 #[test]
3046 fn materialize_proto_i16_empty_detections() {
3047 let cpu = CPUProcessor::new();
3048 let detect: Vec<DetectBox> = vec![];
3049 let protos = make_i8_quant(
3050 &[PROTO_H, PROTO_W, NUM_PROTOS],
3051 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3052 0.02,
3053 0,
3054 );
3055 let coeffs = make_i16_quant(&[0, NUM_PROTOS], &[], 0.01, 0);
3056 let proto_data = ProtoData {
3057 mask_coefficients: coeffs,
3058 protos,
3059 layout: ProtoLayout::Nhwc,
3060 };
3061 let segs = cpu
3062 .materialize_segmentations(&detect, &proto_data, None)
3063 .unwrap();
3064 assert!(segs.is_empty());
3065 }
3066
3067 #[test]
3070 fn materialize_scaled_i16_i8_matches_f32_reference() {
3071 let cpu = CPUProcessor::new();
3072 let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3073 let scale_c = 0.01_f32;
3074 let scale_p = 0.02_f32;
3075 let raw_protos = gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS);
3076 let raw_coeffs = gen_coeffs_i16(1, NUM_PROTOS);
3077
3078 let protos_f32: Vec<f32> = raw_protos.iter().map(|&v| v as f32 * scale_p).collect();
3079 let coeffs_f32: Vec<f32> = raw_coeffs.iter().map(|&v| v as f32 * scale_c).collect();
3080 let proto_data_f32 = ProtoData {
3081 mask_coefficients: make_f32(&[1, NUM_PROTOS], &coeffs_f32),
3082 protos: make_f32(&[PROTO_H, PROTO_W, NUM_PROTOS], &protos_f32),
3083 layout: ProtoLayout::Nhwc,
3084 };
3085 let proto_data_int = ProtoData {
3086 mask_coefficients: make_i16_quant(&[1, NUM_PROTOS], &raw_coeffs, scale_c, 0),
3087 protos: make_i8_quant(&[PROTO_H, PROTO_W, NUM_PROTOS], &raw_protos, scale_p, 0),
3088 layout: ProtoLayout::Nhwc,
3089 };
3090
3091 let (w, h) = (64_u32, 64_u32);
3092 let segs_f32 = cpu
3093 .materialize_scaled_segmentations(&detect, &proto_data_f32, None, w, h)
3094 .unwrap();
3095 let segs_int = cpu
3096 .materialize_scaled_segmentations(&detect, &proto_data_int, None, w, h)
3097 .unwrap();
3098
3099 assert_eq!(segs_f32.len(), segs_int.len());
3100 for (sf, si) in segs_f32.iter().zip(segs_int.iter()) {
3101 assert_eq!(sf.segmentation.shape(), si.segmentation.shape());
3102 let total = sf.segmentation.len();
3103 let mismatches = sf
3104 .segmentation
3105 .iter()
3106 .zip(si.segmentation.iter())
3107 .filter(|(a, b)| a != b)
3108 .count();
3109 let pct = mismatches as f64 / total as f64 * 100.0;
3110 assert!(
3111 pct < 5.0,
3112 "scaled mask mismatch {mismatches}/{total} ({pct:.1}%) exceeds 5% threshold"
3113 );
3114 }
3115 }
3116
3117 #[test]
3120 fn materialize_proto_i8_i8_regression() {
3121 let cpu = CPUProcessor::new();
3122 let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3123 let protos = make_i8_quant(
3124 &[PROTO_H, PROTO_W, NUM_PROTOS],
3125 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3126 0.02,
3127 0,
3128 );
3129 let coeffs = make_i8_quant(&[1, NUM_PROTOS], &gen_coeffs_i8(1, NUM_PROTOS), 0.01, 0);
3130 let proto_data = ProtoData {
3131 mask_coefficients: coeffs,
3132 protos,
3133 layout: ProtoLayout::Nhwc,
3134 };
3135 let result = cpu.materialize_segmentations(&detect, &proto_data, None);
3136 assert!(result.is_ok(), "i8×i8 regression: {:?}", result.err());
3137 assert_eq!(result.unwrap().len(), 1);
3138 }
3139
3140 #[test]
3143 fn materialize_proto_i16_nonzero_zp() {
3144 let cpu = CPUProcessor::new();
3145 let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3146 let protos = make_i8_quant(
3147 &[PROTO_H, PROTO_W, NUM_PROTOS],
3148 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3149 0.02,
3150 -10,
3151 );
3152 let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 5);
3153 let proto_data = ProtoData {
3154 mask_coefficients: coeffs,
3155 protos,
3156 layout: ProtoLayout::Nhwc,
3157 };
3158 let result = cpu.materialize_segmentations(&detect, &proto_data, None);
3159 assert!(result.is_ok(), "nonzero zp failed: {:?}", result.err());
3160 assert_eq!(result.unwrap().len(), 1);
3161 }
3162
3163 #[test]
3166 fn materialize_scaled_i16_nonzero_zp() {
3167 let cpu = CPUProcessor::new();
3168 let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3169 let protos = make_i8_quant(
3170 &[PROTO_H, PROTO_W, NUM_PROTOS],
3171 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3172 0.02,
3173 -10,
3174 );
3175 let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 5);
3176 let proto_data = ProtoData {
3177 mask_coefficients: coeffs,
3178 protos,
3179 layout: ProtoLayout::Nhwc,
3180 };
3181 let result = cpu.materialize_scaled_segmentations(&detect, &proto_data, None, 64, 64);
3182 assert!(
3183 result.is_ok(),
3184 "scaled nonzero zp failed: {:?}",
3185 result.err()
3186 );
3187 assert_eq!(result.unwrap().len(), 1);
3188 }
3189}