1use super::CPUProcessor;
5use crate::Result;
6use edgefirst_decoder::{DetectBox, Segmentation};
7use ndarray::Axis;
8use rayon::prelude::*;
9
10impl CPUProcessor {
11 #[allow(clippy::too_many_arguments)]
12 pub(super) fn render_modelpack_segmentation(
13 &mut self,
14 dst_w: usize,
15 dst_h: usize,
16 dst_rs: usize,
17 dst_c: usize,
18 dst_slice: &mut [u8],
19 segmentation: &Segmentation,
20 opacity: f32,
21 ) -> Result<()> {
22 use ndarray_stats::QuantileExt;
23
24 let seg = &segmentation.segmentation;
25 let [seg_height, seg_width, seg_classes] = *seg.shape() else {
26 unreachable!("Array3 did not have [usize; 3] as shape");
27 };
28 let start_y = (dst_h as f32 * segmentation.ymin).round();
29 let end_y = (dst_h as f32 * segmentation.ymax).round();
30 let start_x = (dst_w as f32 * segmentation.xmin).round();
31 let end_x = (dst_w as f32 * segmentation.xmax).round();
32
33 let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
34 let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
35
36 let start_x_u = (start_x as usize).min(dst_w);
37 let start_y_u = (start_y as usize).min(dst_h);
38 let end_x_u = (end_x as usize).min(dst_w);
39 let end_y_u = (end_y as usize).min(dst_h);
40
41 let argmax = seg.map_axis(Axis(2), |r| r.argmax().unwrap());
42 let get_value_at_nearest = |x: f32, y: f32| -> usize {
43 let x = x.round() as usize;
44 let y = y.round() as usize;
45 argmax
46 .get([y.min(seg_height - 1), x.min(seg_width - 1)])
47 .copied()
48 .unwrap_or(0)
49 };
50
51 for y in start_y_u..end_y_u {
52 for x in start_x_u..end_x_u {
53 let seg_x = (x as f32 - start_x) * scale_x;
54 let seg_y = (y as f32 - start_y) * scale_y;
55 let label = get_value_at_nearest(seg_x, seg_y);
56
57 if label == seg_classes - 1 {
58 continue;
59 }
60
61 let color = self.colors[label % self.colors.len()];
62
63 let alpha = if opacity == 1.0 {
64 color[3] as u16
65 } else {
66 (color[3] as f32 * opacity).round() as u16
67 };
68
69 let dst_index = (y * dst_rs) + (x * dst_c);
70 for c in 0..3 {
71 dst_slice[dst_index + c] = ((color[c] as u16 * alpha
72 + dst_slice[dst_index + c] as u16 * (255 - alpha))
73 / 255) as u8;
74 }
75 }
76 }
77
78 Ok(())
79 }
80
81 #[allow(clippy::too_many_arguments)]
82 pub(super) fn render_yolo_segmentation(
83 &mut self,
84 dst_w: usize,
85 dst_h: usize,
86 dst_rs: usize,
87 dst_c: usize,
88 dst_slice: &mut [u8],
89 segmentation: &Segmentation,
90 class: usize,
91 opacity: f32,
92 ) -> Result<()> {
93 let seg = &segmentation.segmentation;
94 let [seg_height, seg_width, classes] = *seg.shape() else {
95 unreachable!("Array3 did not have [usize;3] as shape");
96 };
97 debug_assert_eq!(classes, 1);
98
99 let start_y = (dst_h as f32 * segmentation.ymin).round();
100 let end_y = (dst_h as f32 * segmentation.ymax).round();
101 let start_x = (dst_w as f32 * segmentation.xmin).round();
102 let end_x = (dst_w as f32 * segmentation.xmax).round();
103
104 let scale_x = (seg_width as f32 - 1.0) / ((end_x - start_x) - 1.0);
105 let scale_y = (seg_height as f32 - 1.0) / ((end_y - start_y) - 1.0);
106
107 let start_x_u = (start_x as usize).min(dst_w);
108 let start_y_u = (start_y as usize).min(dst_h);
109 let end_x_u = (end_x as usize).min(dst_w);
110 let end_y_u = (end_y as usize).min(dst_h);
111
112 for y in start_y_u..end_y_u {
113 for x in start_x_u..end_x_u {
114 let seg_x = ((x as f32 - start_x) * scale_x) as usize;
115 let seg_y = ((y as f32 - start_y) * scale_y) as usize;
116 let val = *seg.get([seg_y, seg_x, 0]).unwrap_or(&0);
117
118 if val < 127 {
119 continue;
120 }
121
122 let color = self.colors[class % self.colors.len()];
123
124 let alpha = if opacity == 1.0 {
125 color[3] as u16
126 } else {
127 (color[3] as f32 * opacity).round() as u16
128 };
129
130 let dst_index = (y * dst_rs) + (x * dst_c);
131 for c in 0..3 {
132 dst_slice[dst_index + c] = ((color[c] as u16 * alpha
133 + dst_slice[dst_index + c] as u16 * (255 - alpha))
134 / 255) as u8;
135 }
136 }
137 }
138
139 Ok(())
140 }
141
142 #[allow(clippy::too_many_arguments)]
143 pub(super) fn render_box(
144 &mut self,
145 dst_w: usize,
146 dst_h: usize,
147 dst_rs: usize,
148 dst_c: usize,
149 dst_slice: &mut [u8],
150 detect: &[DetectBox],
151 color_mode: crate::ColorMode,
152 ) -> Result<()> {
153 const LINE_THICKNESS: usize = 3;
154
155 for (idx, d) in detect.iter().enumerate() {
156 use edgefirst_decoder::BoundingBox;
157
158 let color_index = color_mode.index(idx, d.label);
159 let [r, g, b, _] = self.colors[color_index % self.colors.len()];
160 let bbox = d.bbox.to_canonical();
161 let bbox = BoundingBox {
162 xmin: bbox.xmin.clamp(0.0, 1.0),
163 ymin: bbox.ymin.clamp(0.0, 1.0),
164 xmax: bbox.xmax.clamp(0.0, 1.0),
165 ymax: bbox.ymax.clamp(0.0, 1.0),
166 };
167 let inner = [
168 ((dst_w - 1) as f32 * bbox.xmin - 0.5).round() as usize,
169 ((dst_h - 1) as f32 * bbox.ymin - 0.5).round() as usize,
170 ((dst_w - 1) as f32 * bbox.xmax + 0.5).round() as usize,
171 ((dst_h - 1) as f32 * bbox.ymax + 0.5).round() as usize,
172 ];
173
174 let outer = [
175 inner[0].saturating_sub(LINE_THICKNESS),
176 inner[1].saturating_sub(LINE_THICKNESS),
177 (inner[2] + LINE_THICKNESS).min(dst_w),
178 (inner[3] + LINE_THICKNESS).min(dst_h),
179 ];
180
181 for y in outer[1] + 1..=inner[1] {
183 for x in outer[0] + 1..outer[2] {
184 let index = (y * dst_rs) + (x * dst_c);
185 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
186 }
187 }
188
189 for y in inner[1]..inner[3] {
191 for x in outer[0] + 1..=inner[0] {
192 let index = (y * dst_rs) + (x * dst_c);
193 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
194 }
195
196 for x in inner[2]..outer[2] {
197 let index = (y * dst_rs) + (x * dst_c);
198 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
199 }
200 }
201
202 for y in inner[3]..outer[3] {
204 for x in outer[0] + 1..outer[2] {
205 let index = (y * dst_rs) + (x * dst_c);
206 dst_slice[index..(index + 3)].copy_from_slice(&[r, g, b]);
207 }
208 }
209 }
210 Ok(())
211 }
212
213 pub fn materialize_segmentations(
224 &self,
225 detect: &[crate::DetectBox],
226 proto_data: &crate::ProtoData,
227 letterbox: Option<[f32; 4]>,
228 ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
229 use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
230
231 let _span = tracing::trace_span!(
232 "image.materialize_masks",
233 mode = "proto",
234 n_detections = detect.len(),
235 )
236 .entered();
237
238 if detect.is_empty() {
239 return Ok(Vec::new());
240 }
241 let proto_shape = proto_data.protos.shape();
242 if proto_shape.len() != 3 {
243 return Err(crate::Error::InvalidShape(format!(
244 "protos tensor must be rank-3, got {proto_shape:?}"
245 )));
246 }
247 let (proto_h, proto_w, num_protos) = match proto_data.layout {
249 edgefirst_decoder::ProtoLayout::Nhwc => {
250 (proto_shape[0], proto_shape[1], proto_shape[2])
251 }
252 edgefirst_decoder::ProtoLayout::Nchw => {
253 (proto_shape[1], proto_shape[2], proto_shape[0])
254 }
255 };
256 let coeff_shape = proto_data.mask_coefficients.shape();
257 if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
258 return Err(crate::Error::InvalidShape(format!(
259 "mask_coefficients shape {coeff_shape:?} incompatible with protos \
260 {proto_shape:?} (expected [N, {num_protos}])"
261 )));
262 }
263 if coeff_shape[0] == 0 {
264 return Ok(Vec::new());
265 }
266 if coeff_shape[0] != detect.len() {
267 return Err(crate::Error::Internal(format!(
268 "mask_coefficients rows {} != detection count {}",
269 coeff_shape[0],
270 detect.len()
271 )));
272 }
273
274 let (lx0, inv_lw, ly0, inv_lh) = match letterbox {
276 Some([lx0, ly0, lx1, ly1]) => {
277 let lw = lx1 - lx0;
278 let lh = ly1 - ly0;
279 (
280 lx0,
281 if lw > 0.0 { 1.0 / lw } else { 1.0 },
282 ly0,
283 if lh > 0.0 { 1.0 / lh } else { 1.0 },
284 )
285 }
286 None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
287 };
288
289 if proto_data.mask_coefficients.dtype() == DType::I8
296 && proto_data.protos.dtype() == DType::I8
297 {
298 let coeff_t = proto_data
299 .mask_coefficients
300 .as_i8()
301 .expect("I8 coefficients");
302 let coeff_m = coeff_t.map()?;
303 let coeff_quant = coeff_t.quantization().ok_or_else(|| {
304 crate::Error::InvalidShape(
305 "I8 mask_coefficients require quantization metadata".into(),
306 )
307 })?;
308 let proto_t = proto_data.protos.as_i8().expect("I8 protos");
309 let proto_m = proto_t.map()?;
310 let proto_quant = proto_t.quantization().ok_or_else(|| {
311 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
312 })?;
313 match proto_segmentations_i8_i8(
314 detect,
315 coeff_m.as_slice(),
316 coeff_quant,
317 proto_m.as_slice(),
318 proto_quant,
319 proto_h,
320 proto_w,
321 num_protos,
322 lx0,
323 inv_lw,
324 ly0,
325 inv_lh,
326 proto_data.layout,
327 ) {
328 Ok(result) => return Ok(result),
329 Err(crate::Error::NotSupported(_)) => {
330 }
333 Err(e) => return Err(e),
334 }
335 }
336
337 if proto_data.mask_coefficients.dtype() == DType::I16
339 && proto_data.protos.dtype() == DType::I8
340 {
341 let coeff_t = proto_data
342 .mask_coefficients
343 .as_i16()
344 .expect("I16 coefficients");
345 let coeff_m = coeff_t.map()?;
346 if let Some(coeff_quant) = coeff_t.quantization() {
349 let proto_t = proto_data.protos.as_i8().expect("I8 protos");
350 let proto_m = proto_t.map()?;
351 let proto_quant = proto_t.quantization().ok_or_else(|| {
352 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
353 })?;
354 match proto_segmentations_i16_i8(
355 detect,
356 coeff_m.as_slice(),
357 coeff_quant,
358 proto_m.as_slice(),
359 proto_quant,
360 proto_h,
361 proto_w,
362 num_protos,
363 lx0,
364 inv_lw,
365 ly0,
366 inv_lh,
367 proto_data.layout,
368 ) {
369 Ok(result) => return Ok(result),
370 Err(crate::Error::NotSupported(_)) => {
371 }
373 Err(e) => return Err(e),
374 }
375 }
376 }
377
378 if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw
383 && proto_data.protos.dtype() != DType::I8
384 {
385 return Err(crate::Error::NotSupported(
386 "NCHW proto layout with non-I8 protos is not supported in the f32 fallback path"
387 .into(),
388 ));
389 }
390 let coeff_f32_storage: Vec<f32>;
391 let coeff_f32_slice: &[f32] = match proto_data.mask_coefficients.dtype() {
392 DType::F32 => {
393 let t = proto_data
394 .mask_coefficients
395 .as_f32()
396 .expect("dtype matched F32");
397 let m = t.map()?;
398 coeff_f32_storage = m.as_slice().to_vec();
399 &coeff_f32_storage[..]
400 }
401 DType::F16 => {
402 let t = proto_data
403 .mask_coefficients
404 .as_f16()
405 .expect("dtype matched F16");
406 let m = t.map()?;
407 coeff_f32_storage = m.as_slice().iter().map(|v| v.to_f32()).collect();
408 &coeff_f32_storage[..]
409 }
410 DType::I8 => {
411 let t = proto_data
412 .mask_coefficients
413 .as_i8()
414 .expect("dtype matched I8");
415 let m = t.map()?;
416 coeff_f32_storage = if let Some(q) = t.quantization() {
417 use edgefirst_tensor::QuantMode;
418 let (scale, zp) = match q.mode() {
419 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
420 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
421 other => {
422 return Err(crate::Error::NotSupported(format!(
423 "I8 mask_coefficients quantization mode {other:?} not supported"
424 )));
425 }
426 };
427 m.as_slice()
428 .iter()
429 .map(|&v| (v as f32 - zp) * scale)
430 .collect()
431 } else {
432 m.as_slice().iter().map(|&v| v as f32).collect()
433 };
434 &coeff_f32_storage[..]
435 }
436 DType::I16 => {
437 let t = proto_data
438 .mask_coefficients
439 .as_i16()
440 .expect("dtype matched I16");
441 let m = t.map()?;
442 coeff_f32_storage = if let Some(q) = t.quantization() {
443 use edgefirst_tensor::QuantMode;
444 let (scale, zp) = match q.mode() {
445 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
446 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
447 other => {
448 return Err(crate::Error::NotSupported(format!(
449 "I16 mask_coefficients quantization mode {other:?} not supported"
450 )));
451 }
452 };
453 m.as_slice()
454 .iter()
455 .map(|&v| (v as f32 - zp) * scale)
456 .collect()
457 } else {
458 m.as_slice().iter().map(|&v| v as f32).collect()
459 };
460 &coeff_f32_storage[..]
461 }
462 other => {
463 return Err(crate::Error::InvalidShape(format!(
464 "mask_coefficients dtype {other:?} not supported; expected F32, F16, I8, or I16"
465 )));
466 }
467 };
468
469 match proto_data.protos.dtype() {
475 DType::I8 => {
476 let t = proto_data.protos.as_i8().expect("dtype matched I8");
477 let quant = t.quantization().ok_or_else(|| {
478 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
479 })?;
480 let m = t.map()?;
481 let src_slice = m.as_slice();
482 let transposed_storage =
483 if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw {
484 let hw = proto_h * proto_w;
485 let mut nhwc = vec![0i8; hw * num_protos];
486 for c in 0..num_protos {
487 let plane = &src_slice[c * hw..(c + 1) * hw];
488 for px in 0..hw {
489 nhwc[px * num_protos + c] = plane[px];
490 }
491 }
492 Some(nhwc)
493 } else {
494 None
495 };
496 let protos_slice = transposed_storage.as_deref().unwrap_or(src_slice);
497 detect
498 .par_iter()
499 .enumerate()
500 .map(|(i, det)| {
501 let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
502 let (x0, y0, x1, y1, roi_w, roi_h) =
503 bbox_to_proto_roi(det, proto_w, proto_h);
504 let mask = fused_dequant_dot_sign_i8_slice(
505 protos_slice,
506 coeff,
507 quant,
508 proto_h,
509 proto_w,
510 y0,
511 x0,
512 roi_h,
513 roi_w,
514 num_protos,
515 )?;
516 Ok(seg_from_roi(
517 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
518 ))
519 })
520 .collect()
521 }
522 DType::F32 => {
523 let t = proto_data.protos.as_f32().expect("dtype matched F32");
524 let m = t.map()?;
525 let protos_slice = m.as_slice();
526 detect
527 .par_iter()
528 .enumerate()
529 .map(|(i, det)| {
530 let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
531 let (x0, y0, x1, y1, roi_w, roi_h) =
532 bbox_to_proto_roi(det, proto_w, proto_h);
533 let mask = fused_dot_sign_f32_slice(
534 protos_slice,
535 coeff,
536 proto_h,
537 proto_w,
538 y0,
539 x0,
540 roi_h,
541 roi_w,
542 num_protos,
543 );
544 Ok(seg_from_roi(
545 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
546 ))
547 })
548 .collect()
549 }
550 DType::F16 => {
551 let t = proto_data.protos.as_f16().expect("dtype matched F16");
552 let m = t.map()?;
553 let protos_slice = m.as_slice();
554 detect
555 .par_iter()
556 .enumerate()
557 .map(|(i, det)| {
558 let coeff = &coeff_f32_slice[i * num_protos..(i + 1) * num_protos];
559 let (x0, y0, x1, y1, roi_w, roi_h) =
560 bbox_to_proto_roi(det, proto_w, proto_h);
561 let mask = fused_dot_sign_f16_slice(
562 protos_slice,
563 coeff,
564 proto_h,
565 proto_w,
566 y0,
567 x0,
568 roi_h,
569 roi_w,
570 num_protos,
571 );
572 Ok(seg_from_roi(
573 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
574 ))
575 })
576 .collect()
577 }
578 other => Err(crate::Error::InvalidShape(format!(
579 "proto tensor dtype {other:?} not supported"
580 ))),
581 }
582 }
583
584 pub fn materialize_scaled_segmentations(
597 &self,
598 detect: &[crate::DetectBox],
599 proto_data: &crate::ProtoData,
600 letterbox: Option<[f32; 4]>,
601 width: u32,
602 height: u32,
603 ) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
604 use edgefirst_tensor::{DType, TensorMapTrait, TensorTrait};
605
606 let _span = tracing::trace_span!(
607 "image.materialize_masks",
608 mode = "scaled",
609 n_detections = detect.len(),
610 width,
611 height,
612 )
613 .entered();
614
615 if detect.is_empty() {
616 return Ok(Vec::new());
617 }
618 if width == 0 || height == 0 {
619 return Err(crate::Error::InvalidShape(
620 "Scaled mask width/height must be positive".into(),
621 ));
622 }
623 let proto_shape = proto_data.protos.shape();
624 if proto_shape.len() != 3 {
625 return Err(crate::Error::InvalidShape(format!(
626 "protos tensor must be rank-3, got {proto_shape:?}"
627 )));
628 }
629 let (proto_h, proto_w, num_protos) = match proto_data.layout {
631 edgefirst_decoder::ProtoLayout::Nhwc => {
632 (proto_shape[0], proto_shape[1], proto_shape[2])
633 }
634 edgefirst_decoder::ProtoLayout::Nchw => {
635 (proto_shape[1], proto_shape[2], proto_shape[0])
636 }
637 };
638 let coeff_shape = proto_data.mask_coefficients.shape();
639 if coeff_shape.len() != 2 || coeff_shape[1] != num_protos {
640 return Err(crate::Error::InvalidShape(format!(
641 "mask_coefficients shape {coeff_shape:?} incompatible with protos \
642 {proto_shape:?}"
643 )));
644 }
645 if coeff_shape[0] == 0 {
646 return Ok(Vec::new());
647 }
648 if coeff_shape[0] != detect.len() {
649 return Err(crate::Error::Internal(format!(
650 "mask_coefficients rows {} != detection count {}",
651 coeff_shape[0],
652 detect.len()
653 )));
654 }
655
656 if proto_data.mask_coefficients.dtype() == DType::I8
662 && proto_data.protos.dtype() == DType::I8
663 {
664 let coeff_t = proto_data
665 .mask_coefficients
666 .as_i8()
667 .expect("I8 coefficients");
668 let coeff_m = coeff_t.map()?;
669 let coeff_quant = coeff_t.quantization().ok_or_else(|| {
670 crate::Error::InvalidShape(
671 "I8 mask_coefficients require quantization metadata".into(),
672 )
673 })?;
674 let proto_t = proto_data.protos.as_i8().expect("I8 protos");
675 let proto_m = proto_t.map()?;
676 let proto_quant = proto_t.quantization().ok_or_else(|| {
677 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
678 })?;
679 match scaled_segmentations_i8_i8(
680 detect,
681 coeff_m.as_slice(),
682 coeff_quant,
683 proto_m.as_slice(),
684 proto_quant,
685 proto_h,
686 proto_w,
687 num_protos,
688 letterbox,
689 width,
690 height,
691 proto_data.layout,
692 ) {
693 Ok(result) => return Ok(result),
694 Err(crate::Error::NotSupported(_)) => {
695 }
698 Err(e) => return Err(e),
699 }
700 }
701
702 if proto_data.mask_coefficients.dtype() == DType::I16
704 && proto_data.protos.dtype() == DType::I8
705 {
706 let coeff_t = proto_data
707 .mask_coefficients
708 .as_i16()
709 .expect("I16 coefficients");
710 let coeff_m = coeff_t.map()?;
711 if let Some(coeff_quant) = coeff_t.quantization() {
714 let proto_t = proto_data.protos.as_i8().expect("I8 protos");
715 let proto_m = proto_t.map()?;
716 let proto_quant = proto_t.quantization().ok_or_else(|| {
717 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
718 })?;
719 match scaled_segmentations_i16_i8(
720 detect,
721 coeff_m.as_slice(),
722 coeff_quant,
723 proto_m.as_slice(),
724 proto_quant,
725 proto_h,
726 proto_w,
727 num_protos,
728 letterbox,
729 width,
730 height,
731 proto_data.layout,
732 ) {
733 Ok(result) => return Ok(result),
734 Err(crate::Error::NotSupported(_)) => {}
735 Err(e) => return Err(e),
736 }
737 }
738 }
739
740 if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw
742 && proto_data.protos.dtype() != DType::I8
743 {
744 return Err(crate::Error::NotSupported(
745 "NCHW proto layout with non-I8 protos is not supported in the f32 fallback path"
746 .into(),
747 ));
748 }
749 let coeff_f32: Vec<f32> = match proto_data.mask_coefficients.dtype() {
750 DType::F32 => {
751 let t = proto_data.mask_coefficients.as_f32().expect("F32");
752 let m = t.map()?;
753 m.as_slice().to_vec()
754 }
755 DType::F16 => {
756 let t = proto_data.mask_coefficients.as_f16().expect("F16");
757 let m = t.map()?;
758 m.as_slice().iter().map(|v| v.to_f32()).collect()
759 }
760 DType::I8 => {
761 let t = proto_data.mask_coefficients.as_i8().expect("I8");
763 let m = t.map()?;
764 let q = t.quantization().ok_or_else(|| {
765 crate::Error::InvalidShape(
766 "I8 mask_coefficients require quantization metadata".into(),
767 )
768 })?;
769 use edgefirst_tensor::QuantMode;
770 let (scale, zp) = match q.mode() {
771 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
772 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
773 _ => {
774 return Err(crate::Error::NotSupported(
775 "per-channel mask_coefficients not supported".into(),
776 ))
777 }
778 };
779 m.as_slice()
780 .iter()
781 .map(|&v| (v as f32 - zp) * scale)
782 .collect()
783 }
784 DType::I16 => {
785 let t = proto_data.mask_coefficients.as_i16().expect("I16");
786 let m = t.map()?;
787 if let Some(q) = t.quantization() {
788 use edgefirst_tensor::QuantMode;
789 let (scale, zp) = match q.mode() {
790 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
791 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
792 other => {
793 return Err(crate::Error::NotSupported(format!(
794 "I16 mask_coefficients quantization mode {other:?} not supported"
795 )))
796 }
797 };
798 m.as_slice()
799 .iter()
800 .map(|&v| (v as f32 - zp) * scale)
801 .collect()
802 } else {
803 m.as_slice().iter().map(|&v| v as f32).collect()
804 }
805 }
806 other => {
807 return Err(crate::Error::InvalidShape(format!(
808 "mask_coefficients dtype {other:?} not supported"
809 )));
810 }
811 };
812
813 match proto_data.protos.dtype() {
814 DType::F32 => {
815 let t = proto_data.protos.as_f32().expect("F32");
816 let m = t.map()?;
817 scaled_segmentations_f32_slice(
818 detect,
819 &coeff_f32,
820 m.as_slice(),
821 proto_h,
822 proto_w,
823 num_protos,
824 letterbox,
825 width,
826 height,
827 )
828 }
829 DType::F16 => {
830 let t = proto_data.protos.as_f16().expect("F16");
831 let m = t.map()?;
832 scaled_segmentations_f16_slice(
833 detect,
834 &coeff_f32,
835 m.as_slice(),
836 proto_h,
837 proto_w,
838 num_protos,
839 letterbox,
840 width,
841 height,
842 )
843 }
844 DType::I8 => {
845 let t = proto_data.protos.as_i8().expect("I8");
846 let m = t.map()?;
847 let quant = t.quantization().ok_or_else(|| {
848 crate::Error::InvalidShape("I8 protos require quantization metadata".into())
849 })?;
850 let src_slice = m.as_slice();
851 let transposed_storage =
852 if proto_data.layout == edgefirst_decoder::ProtoLayout::Nchw {
853 let hw = proto_h * proto_w;
854 let mut nhwc = vec![0i8; hw * num_protos];
855 for c in 0..num_protos {
856 let plane = &src_slice[c * hw..(c + 1) * hw];
857 for px in 0..hw {
858 nhwc[px * num_protos + c] = plane[px];
859 }
860 }
861 Some(nhwc)
862 } else {
863 None
864 };
865 let protos_slice = transposed_storage.as_deref().unwrap_or(src_slice);
866 scaled_segmentations_i8_slice(
867 detect,
868 &coeff_f32,
869 protos_slice,
870 proto_h,
871 proto_w,
872 num_protos,
873 quant,
874 letterbox,
875 width,
876 height,
877 )
878 }
879 other => Err(crate::Error::InvalidShape(format!(
880 "proto tensor dtype {other:?} not supported"
881 ))),
882 }
883 }
884}
885
886fn bbox_to_proto_roi(
903 det: &DetectBox,
904 proto_w: usize,
905 proto_h: usize,
906) -> (usize, usize, usize, usize, usize, usize) {
907 let bbox = det.bbox.to_canonical();
908 let xmin = bbox.xmin.clamp(0.0, 1.0);
909 let ymin = bbox.ymin.clamp(0.0, 1.0);
910 let xmax = bbox.xmax.clamp(0.0, 1.0);
911 let ymax = bbox.ymax.clamp(0.0, 1.0);
912 let x0 = ((xmin * proto_w as f32) as usize).min(proto_w.saturating_sub(1));
913 let y0 = ((ymin * proto_h as f32) as usize).min(proto_h.saturating_sub(1));
914 let x1 = ((xmax * proto_w as f32).ceil() as usize).min(proto_w);
915 let y1 = ((ymax * proto_h as f32).ceil() as usize).min(proto_h);
916 let roi_w = x1.saturating_sub(x0).max(1);
917 let roi_h = y1.saturating_sub(y0).max(1);
918 (x0, y0, x1, y1, roi_w, roi_h)
919}
920
921#[allow(clippy::too_many_arguments)]
925fn seg_from_roi(
926 mask: ndarray::Array3<u8>,
927 x0: usize,
928 y0: usize,
929 x1: usize,
930 y1: usize,
931 proto_w: usize,
932 proto_h: usize,
933 lx0: f32,
934 inv_lw: f32,
935 ly0: f32,
936 inv_lh: f32,
937) -> edgefirst_decoder::Segmentation {
938 let seg_xmin = ((x0 as f32 / proto_w as f32) - lx0) * inv_lw;
939 let seg_ymin = ((y0 as f32 / proto_h as f32) - ly0) * inv_lh;
940 let seg_xmax = ((x1 as f32 / proto_w as f32) - lx0) * inv_lw;
941 let seg_ymax = ((y1 as f32 / proto_h as f32) - ly0) * inv_lh;
942 edgefirst_decoder::Segmentation {
943 xmin: seg_xmin.clamp(0.0, 1.0),
944 ymin: seg_ymin.clamp(0.0, 1.0),
945 xmax: seg_xmax.clamp(0.0, 1.0),
946 ymax: seg_ymax.clamp(0.0, 1.0),
947 segmentation: mask,
948 }
949}
950
951#[allow(clippy::too_many_arguments)]
967fn proto_segmentations_i8_i8(
968 detect: &[crate::DetectBox],
969 coeff_all: &[i8],
970 coeff_quant: &edgefirst_tensor::Quantization,
971 protos: &[i8],
972 proto_quant: &edgefirst_tensor::Quantization,
973 proto_h: usize,
974 proto_w: usize,
975 num_protos: usize,
976 lx0: f32,
977 inv_lw: f32,
978 ly0: f32,
979 inv_lh: f32,
980 layout: edgefirst_decoder::ProtoLayout,
981) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
982 use edgefirst_tensor::QuantMode;
983
984 let _span = tracing::trace_span!(
985 "image.materialize_masks.kernel_i8",
986 n = detect.len(),
987 proto_h,
988 proto_w,
989 num_protos,
990 ?layout,
991 )
992 .entered();
993
994 let zp_c: i32 = match coeff_quant.mode() {
995 QuantMode::PerTensor { zero_point, .. } => zero_point,
996 QuantMode::PerTensorSymmetric { .. } => 0,
997 _ => {
998 return Err(crate::Error::NotSupported(
999 "per-channel coeff quantization not supported on proto-res i8 path".into(),
1000 ))
1001 }
1002 };
1003 let zp_p: i32 = match proto_quant.mode() {
1004 QuantMode::PerTensor { zero_point, .. } => zero_point,
1005 QuantMode::PerTensorSymmetric { .. } => 0,
1006 _ => {
1007 return Err(crate::Error::NotSupported(
1008 "per-channel proto quantization not supported on proto-res i8 path".into(),
1009 ))
1010 }
1011 };
1012
1013 let hw = proto_h * proto_w;
1014
1015 let proto_sums: Vec<i32> = if zp_c != 0 {
1017 match layout {
1018 edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
1019 .map(|px_idx| {
1020 let base = px_idx * num_protos;
1021 protos[base..base + num_protos]
1022 .iter()
1023 .map(|&v| v as i32)
1024 .sum()
1025 })
1026 .collect(),
1027 edgefirst_decoder::ProtoLayout::Nchw => {
1028 let mut sums = vec![0i32; hw];
1029 for c in 0..num_protos {
1030 let plane = &protos[c * hw..];
1031 for (px, s) in sums.iter_mut().enumerate() {
1032 *s += plane[px] as i32;
1033 }
1034 }
1035 sums
1036 }
1037 }
1038 } else {
1039 Vec::new()
1040 };
1041
1042 #[cfg(target_arch = "aarch64")]
1043 let use_dotprod = std::arch::is_aarch64_feature_detected!("dotprod");
1044
1045 detect
1046 .par_iter()
1047 .enumerate()
1048 .map(|(i, det)| {
1049 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1050 let (x0, y0, x1, y1, roi_w, roi_h) = bbox_to_proto_roi(det, proto_w, proto_h);
1051
1052 let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
1054 let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
1055
1056 let mut mask_buf = vec![0u8; roi_h * roi_w];
1057
1058 match layout {
1059 edgefirst_decoder::ProtoLayout::Nhwc => {
1060 let stride_y = proto_w * num_protos;
1061 #[cfg(target_arch = "aarch64")]
1062 {
1063 if use_dotprod {
1064 for ly in 0..roi_h {
1065 let py = y0 + ly;
1066 let row_base = py * stride_y + x0 * num_protos;
1067 for lx in 0..roi_w {
1068 let pix_base = row_base + lx * num_protos;
1069 let proto_px = &protos[pix_base..pix_base + num_protos];
1070 let raw_dot = unsafe {
1071 dot_i8_neon_dotprod(
1072 coeff.as_ptr(),
1073 proto_px.as_ptr(),
1074 num_protos,
1075 )
1076 };
1077 let correction = if zp_c != 0 {
1078 zp_c * proto_sums[py * proto_w + x0 + lx]
1079 } else {
1080 0
1081 };
1082 let logit = raw_dot - correction - bias;
1083 if logit > 0 {
1084 mask_buf[ly * roi_w + lx] = 255;
1085 }
1086 }
1087 }
1088 } else {
1089 for ly in 0..roi_h {
1090 let py = y0 + ly;
1091 let row_base = py * stride_y + x0 * num_protos;
1092 for lx in 0..roi_w {
1093 let pix_base = row_base + lx * num_protos;
1094 let proto_px = &protos[pix_base..pix_base + num_protos];
1095 let raw_dot = unsafe {
1096 dot_i8_neon_base(
1097 coeff.as_ptr(),
1098 proto_px.as_ptr(),
1099 num_protos,
1100 )
1101 };
1102 let correction = if zp_c != 0 {
1103 zp_c * proto_sums[py * proto_w + x0 + lx]
1104 } else {
1105 0
1106 };
1107 let logit = raw_dot - correction - bias;
1108 if logit > 0 {
1109 mask_buf[ly * roi_w + lx] = 255;
1110 }
1111 }
1112 }
1113 }
1114 }
1115 #[cfg(not(target_arch = "aarch64"))]
1116 {
1117 for ly in 0..roi_h {
1118 let py = y0 + ly;
1119 let row_base = py * stride_y + x0 * num_protos;
1120 for lx in 0..roi_w {
1121 let pix_base = row_base + lx * num_protos;
1122 let proto_px = &protos[pix_base..pix_base + num_protos];
1123 let raw_dot = dot_i8_scalar(coeff, proto_px, num_protos);
1124 let correction = if zp_c != 0 {
1125 zp_c * proto_sums[py * proto_w + x0 + lx]
1126 } else {
1127 0
1128 };
1129 let logit = raw_dot - correction - bias;
1130 if logit > 0 {
1131 mask_buf[ly * roi_w + lx] = 255;
1132 }
1133 }
1134 }
1135 }
1136 }
1137 edgefirst_decoder::ProtoLayout::Nchw => {
1138 let mut accum = vec![0i32; roi_h * roi_w];
1142 for c in 0..num_protos {
1143 let plane = &protos[c * hw..];
1144 let coeff_c = coeff[c] as i32;
1145 for ly in 0..roi_h {
1146 let py = y0 + ly;
1147 let row_start = py * proto_w + x0;
1148 let out_row_start = ly * roi_w;
1149 for lx in 0..roi_w {
1150 accum[out_row_start + lx] += coeff_c * plane[row_start + lx] as i32;
1151 }
1152 }
1153 }
1154 for ly in 0..roi_h {
1156 let py = y0 + ly;
1157 for lx in 0..roi_w {
1158 let idx = ly * roi_w + lx;
1159 let correction = if zp_c != 0 {
1160 zp_c * proto_sums[py * proto_w + x0 + lx]
1161 } else {
1162 0
1163 };
1164 let logit = accum[idx] - correction - bias;
1165 if logit > 0 {
1166 mask_buf[idx] = 255;
1167 }
1168 }
1169 }
1170 }
1171 }
1172
1173 let mask = ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1174 .expect("mask_buf length matches roi_h * roi_w");
1175 Ok(seg_from_roi(
1176 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
1177 ))
1178 })
1179 .collect()
1180}
1181
1182#[allow(clippy::too_many_arguments)]
1184fn proto_segmentations_i16_i8(
1185 detect: &[crate::DetectBox],
1186 coeff_all: &[i16],
1187 coeff_quant: &edgefirst_tensor::Quantization,
1188 protos: &[i8],
1189 proto_quant: &edgefirst_tensor::Quantization,
1190 proto_h: usize,
1191 proto_w: usize,
1192 num_protos: usize,
1193 lx0: f32,
1194 inv_lw: f32,
1195 ly0: f32,
1196 inv_lh: f32,
1197 layout: edgefirst_decoder::ProtoLayout,
1198) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1199 use edgefirst_tensor::QuantMode;
1200
1201 let _span = tracing::trace_span!(
1202 "image.materialize_masks.kernel_i16xi8",
1203 n = detect.len(),
1204 proto_h,
1205 proto_w,
1206 num_protos,
1207 ?layout,
1208 )
1209 .entered();
1210
1211 let zp_c: i32 = match coeff_quant.mode() {
1212 QuantMode::PerTensor { zero_point, .. } => zero_point,
1213 QuantMode::PerTensorSymmetric { .. } => 0,
1214 _ => {
1215 return Err(crate::Error::NotSupported(
1216 "per-channel coeff quantization not supported on proto-res i16 path".into(),
1217 ))
1218 }
1219 };
1220 let zp_p: i32 = match proto_quant.mode() {
1221 QuantMode::PerTensor { zero_point, .. } => zero_point,
1222 QuantMode::PerTensorSymmetric { .. } => 0,
1223 _ => {
1224 return Err(crate::Error::NotSupported(
1225 "per-channel proto quantization not supported on proto-res i8 path".into(),
1226 ))
1227 }
1228 };
1229
1230 let hw = proto_h * proto_w;
1231
1232 let proto_sums: Vec<i32> = if zp_c != 0 {
1234 match layout {
1235 edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
1236 .map(|px_idx| {
1237 let base = px_idx * num_protos;
1238 protos[base..base + num_protos]
1239 .iter()
1240 .map(|&v| v as i32)
1241 .sum()
1242 })
1243 .collect(),
1244 edgefirst_decoder::ProtoLayout::Nchw => {
1245 let mut sums = vec![0i32; hw];
1246 for c in 0..num_protos {
1247 let plane = &protos[c * hw..];
1248 for (px, s) in sums.iter_mut().enumerate() {
1249 *s += plane[px] as i32;
1250 }
1251 }
1252 sums
1253 }
1254 }
1255 } else {
1256 Vec::new()
1257 };
1258
1259 detect
1260 .par_iter()
1261 .enumerate()
1262 .map(|(i, det)| {
1263 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
1264 let (x0, y0, x1, y1, roi_w, roi_h) = bbox_to_proto_roi(det, proto_w, proto_h);
1265
1266 let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
1268 let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
1269
1270 let mut mask_buf = vec![0u8; roi_h * roi_w];
1271
1272 match layout {
1273 edgefirst_decoder::ProtoLayout::Nhwc => {
1274 let stride_y = proto_w * num_protos;
1275 #[cfg(target_arch = "aarch64")]
1276 {
1277 for ly in 0..roi_h {
1278 let py = y0 + ly;
1279 let row_base = py * stride_y + x0 * num_protos;
1280 for lx in 0..roi_w {
1281 let pix_base = row_base + lx * num_protos;
1282 let proto_px = &protos[pix_base..pix_base + num_protos];
1283 let raw_dot = unsafe {
1284 dot_i16_i8_neon(coeff.as_ptr(), proto_px.as_ptr(), num_protos)
1285 };
1286 let correction = if zp_c != 0 {
1287 zp_c * proto_sums[py * proto_w + x0 + lx]
1288 } else {
1289 0
1290 };
1291 let logit = raw_dot - correction - bias;
1292 if logit > 0 {
1293 mask_buf[ly * roi_w + lx] = 255;
1294 }
1295 }
1296 }
1297 }
1298 #[cfg(not(target_arch = "aarch64"))]
1299 {
1300 for ly in 0..roi_h {
1301 let py = y0 + ly;
1302 let row_base = py * stride_y + x0 * num_protos;
1303 for lx in 0..roi_w {
1304 let pix_base = row_base + lx * num_protos;
1305 let proto_px = &protos[pix_base..pix_base + num_protos];
1306 let raw_dot = dot_i16_i8_scalar(coeff, proto_px, num_protos);
1307 let correction = if zp_c != 0 {
1308 zp_c * proto_sums[py * proto_w + x0 + lx]
1309 } else {
1310 0
1311 };
1312 let logit = raw_dot - correction - bias;
1313 if logit > 0 {
1314 mask_buf[ly * roi_w + lx] = 255;
1315 }
1316 }
1317 }
1318 }
1319 }
1320 edgefirst_decoder::ProtoLayout::Nchw => {
1321 let mut accum = vec![0i32; roi_h * roi_w];
1325 for c in 0..num_protos {
1326 let plane = &protos[c * hw..];
1327 let coeff_c = coeff[c] as i32;
1328 for ly in 0..roi_h {
1329 let py = y0 + ly;
1330 let row_start = py * proto_w + x0;
1331 let out_row_start = ly * roi_w;
1332 for lx in 0..roi_w {
1333 accum[out_row_start + lx] += coeff_c * plane[row_start + lx] as i32;
1334 }
1335 }
1336 }
1337 for ly in 0..roi_h {
1339 let py = y0 + ly;
1340 for lx in 0..roi_w {
1341 let idx = ly * roi_w + lx;
1342 let correction = if zp_c != 0 {
1343 zp_c * proto_sums[py * proto_w + x0 + lx]
1344 } else {
1345 0
1346 };
1347 let logit = accum[idx] - correction - bias;
1348 if logit > 0 {
1349 mask_buf[idx] = 255;
1350 }
1351 }
1352 }
1353 }
1354 }
1355
1356 let mask = ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1357 .expect("mask_buf length matches roi_h * roi_w");
1358 Ok(seg_from_roi(
1359 mask, x0, y0, x1, y1, proto_w, proto_h, lx0, inv_lw, ly0, inv_lh,
1360 ))
1361 })
1362 .collect()
1363}
1364
1365#[allow(clippy::too_many_arguments)]
1376fn fused_dot_sign_f32_slice(
1377 protos: &[f32],
1378 coeff: &[f32],
1379 _proto_h: usize,
1380 proto_w: usize,
1381 y0: usize,
1382 x0: usize,
1383 roi_h: usize,
1384 roi_w: usize,
1385 num_protos: usize,
1386) -> ndarray::Array3<u8> {
1387 let stride_y = proto_w * num_protos;
1388 let mut mask_buf = vec![0u8; roi_h * roi_w];
1389 for y in 0..roi_h {
1390 let row_base = (y0 + y) * stride_y + x0 * num_protos;
1391 let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1392 for (x, out_px) in out_row.iter_mut().enumerate() {
1393 let base = row_base + x * num_protos;
1394 let mut acc = 0.0_f32;
1395 let mut k = 0;
1396 let chunks = num_protos / 4;
1397 for _ in 0..chunks {
1398 acc += coeff[k] * protos[base + k]
1399 + coeff[k + 1] * protos[base + k + 1]
1400 + coeff[k + 2] * protos[base + k + 2]
1401 + coeff[k + 3] * protos[base + k + 3];
1402 k += 4;
1403 }
1404 while k < num_protos {
1405 acc += coeff[k] * protos[base + k];
1406 k += 1;
1407 }
1408 if acc > 0.0 {
1409 *out_px = 255;
1410 }
1411 }
1412 }
1413 ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1414 .expect("mask_buf length matches roi_h * roi_w")
1415}
1416
1417#[allow(clippy::too_many_arguments)]
1426fn fused_dot_sign_f16_slice(
1427 protos: &[half::f16],
1428 coeff: &[f32],
1429 _proto_h: usize,
1430 proto_w: usize,
1431 y0: usize,
1432 x0: usize,
1433 roi_h: usize,
1434 roi_w: usize,
1435 num_protos: usize,
1436) -> ndarray::Array3<u8> {
1437 #[cfg(all(
1438 target_arch = "x86_64",
1439 target_feature = "f16c",
1440 target_feature = "fma"
1441 ))]
1442 {
1443 unsafe {
1445 fused_dot_sign_f16_slice_f16c(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
1446 }
1447 }
1448 #[cfg(not(all(
1449 target_arch = "x86_64",
1450 target_feature = "f16c",
1451 target_feature = "fma"
1452 )))]
1453 {
1454 fused_dot_sign_f16_slice_scalar(protos, coeff, proto_w, y0, x0, roi_h, roi_w, num_protos)
1455 }
1456}
1457
1458#[allow(clippy::too_many_arguments)]
1460fn fused_dot_sign_f16_slice_scalar(
1461 protos: &[half::f16],
1462 coeff: &[f32],
1463 proto_w: usize,
1464 y0: usize,
1465 x0: usize,
1466 roi_h: usize,
1467 roi_w: usize,
1468 num_protos: usize,
1469) -> ndarray::Array3<u8> {
1470 let stride_y = proto_w * num_protos;
1471 let mut mask_buf = vec![0u8; roi_h * roi_w];
1472 for y in 0..roi_h {
1473 let row_base = (y0 + y) * stride_y + x0 * num_protos;
1474 let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1475 for (x, out_px) in out_row.iter_mut().enumerate() {
1476 let base = row_base + x * num_protos;
1477 let mut acc = 0.0_f32;
1478 let mut k = 0;
1479 let chunks = num_protos / 4;
1480 for _ in 0..chunks {
1481 acc += coeff[k] * protos[base + k].to_f32()
1482 + coeff[k + 1] * protos[base + k + 1].to_f32()
1483 + coeff[k + 2] * protos[base + k + 2].to_f32()
1484 + coeff[k + 3] * protos[base + k + 3].to_f32();
1485 k += 4;
1486 }
1487 while k < num_protos {
1488 acc += coeff[k] * protos[base + k].to_f32();
1489 k += 1;
1490 }
1491 if acc > 0.0 {
1492 *out_px = 255;
1493 }
1494 }
1495 }
1496 ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1497 .expect("mask_buf length matches roi_h * roi_w")
1498}
1499
1500#[cfg(all(
1510 target_arch = "x86_64",
1511 target_feature = "f16c",
1512 target_feature = "fma"
1513))]
1514#[allow(clippy::too_many_arguments)]
1515#[target_feature(enable = "f16c,fma,avx")]
1516unsafe fn fused_dot_sign_f16_slice_f16c(
1517 protos: &[half::f16],
1518 coeff: &[f32],
1519 proto_w: usize,
1520 y0: usize,
1521 x0: usize,
1522 roi_h: usize,
1523 roi_w: usize,
1524 num_protos: usize,
1525) -> ndarray::Array3<u8> {
1526 use core::arch::x86_64::{
1527 _mm256_castps256_ps128, _mm256_cvtph_ps, _mm256_extractf128_ps, _mm256_fmadd_ps,
1528 _mm256_loadu_ps, _mm256_setzero_ps, _mm_add_ps, _mm_cvtss_f32, _mm_hadd_ps,
1529 _mm_loadu_si128,
1530 };
1531
1532 let stride_y = proto_w * num_protos;
1533 let chunks8 = num_protos / 8;
1534 let mut mask_buf = vec![0u8; roi_h * roi_w];
1535
1536 for y in 0..roi_h {
1537 let row_base = (y0 + y) * stride_y + x0 * num_protos;
1538 let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1539 for (x, out_px) in out_row.iter_mut().enumerate() {
1540 let base = row_base + x * num_protos;
1541 let mut acc_v = _mm256_setzero_ps();
1542 let mut k = 0;
1543 for _ in 0..chunks8 {
1544 let p_ptr = protos
1545 .as_ptr()
1546 .add(base + k)
1547 .cast::<core::arch::x86_64::__m128i>();
1548 let raw = _mm_loadu_si128(p_ptr);
1549 let widened = _mm256_cvtph_ps(raw);
1550 let coeffs_v = _mm256_loadu_ps(coeff.as_ptr().add(k));
1551 acc_v = _mm256_fmadd_ps(coeffs_v, widened, acc_v);
1552 k += 8;
1553 }
1554 let lo = _mm256_castps256_ps128(acc_v);
1556 let hi = _mm256_extractf128_ps::<1>(acc_v);
1557 let sum4 = _mm_add_ps(lo, hi);
1558 let sum2 = _mm_hadd_ps(sum4, sum4);
1559 let sum1 = _mm_hadd_ps(sum2, sum2);
1560 let mut acc = _mm_cvtss_f32(sum1);
1561
1562 while k < num_protos {
1564 acc += coeff[k] * protos[base + k].to_f32();
1565 k += 1;
1566 }
1567
1568 if acc > 0.0 {
1569 *out_px = 255;
1570 }
1571 }
1572 }
1573 ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1574 .expect("mask_buf length matches roi_h * roi_w")
1575}
1576
1577#[allow(clippy::too_many_arguments)]
1581fn fused_dequant_dot_sign_i8_slice(
1582 protos: &[i8],
1583 coeff: &[f32],
1584 quant: &edgefirst_tensor::Quantization,
1585 _proto_h: usize,
1586 proto_w: usize,
1587 y0: usize,
1588 x0: usize,
1589 roi_h: usize,
1590 roi_w: usize,
1591 num_protos: usize,
1592) -> crate::Result<ndarray::Array3<u8>> {
1593 use edgefirst_tensor::QuantMode;
1594 let stride_y = proto_w * num_protos;
1595
1596 let mut stack_scratch = [0.0_f32; 64];
1598 let mut heap_scratch: Vec<f32>;
1599 let scaled_coeff: &mut [f32] = if num_protos <= stack_scratch.len() {
1600 &mut stack_scratch[..num_protos]
1601 } else {
1602 heap_scratch = vec![0.0_f32; num_protos];
1603 heap_scratch.as_mut_slice()
1604 };
1605 let zp_offset: f32;
1606 match quant.mode() {
1607 QuantMode::PerTensorSymmetric { scale } => {
1608 for k in 0..num_protos {
1609 scaled_coeff[k] = coeff[k] * scale;
1610 }
1611 zp_offset = 0.0;
1612 }
1613 QuantMode::PerTensor { scale, zero_point } => {
1614 for k in 0..num_protos {
1615 scaled_coeff[k] = coeff[k] * scale;
1616 }
1617 zp_offset = zero_point as f32 * scaled_coeff.iter().take(num_protos).sum::<f32>();
1618 }
1619 QuantMode::PerChannelSymmetric { scales, axis } => {
1620 if axis != 2 {
1621 return Err(crate::Error::NotSupported(format!(
1622 "per-channel quantization on axis {axis} not supported \
1623 (only channel axis 2 is implemented on this kernel)"
1624 )));
1625 }
1626 for k in 0..num_protos {
1627 scaled_coeff[k] = coeff[k] * scales[k];
1628 }
1629 zp_offset = 0.0;
1630 }
1631 QuantMode::PerChannel {
1632 scales,
1633 zero_points,
1634 axis,
1635 } => {
1636 if axis != 2 {
1637 return Err(crate::Error::NotSupported(format!(
1638 "per-channel quantization on axis {axis} not supported \
1639 (only channel axis 2 is implemented on this kernel)"
1640 )));
1641 }
1642 for k in 0..num_protos {
1643 scaled_coeff[k] = coeff[k] * scales[k];
1644 }
1645 zp_offset = (0..num_protos)
1646 .map(|k| scaled_coeff[k] * zero_points[k] as f32)
1647 .sum();
1648 }
1649 }
1650
1651 let mut mask_buf = vec![0u8; roi_h * roi_w];
1652 for y in 0..roi_h {
1653 let row_base = (y0 + y) * stride_y + (x0) * num_protos;
1654 let out_row = &mut mask_buf[y * roi_w..(y + 1) * roi_w];
1655 for (x, out_px) in out_row.iter_mut().enumerate() {
1656 let base = row_base + x * num_protos;
1657 let mut acc = 0.0_f32;
1658 let mut k = 0;
1659 let chunks = num_protos / 4;
1660 for _ in 0..chunks {
1661 let p0 = protos[base + k] as f32;
1662 let p1 = protos[base + k + 1] as f32;
1663 let p2 = protos[base + k + 2] as f32;
1664 let p3 = protos[base + k + 3] as f32;
1665 acc += scaled_coeff[k] * p0
1666 + scaled_coeff[k + 1] * p1
1667 + scaled_coeff[k + 2] * p2
1668 + scaled_coeff[k + 3] * p3;
1669 k += 4;
1670 }
1671 while k < num_protos {
1672 acc += scaled_coeff[k] * protos[base + k] as f32;
1673 k += 1;
1674 }
1675 if acc > zp_offset {
1676 *out_px = 255;
1677 }
1678 }
1679 }
1680 Ok(ndarray::Array3::from_shape_vec((roi_h, roi_w, 1), mask_buf)
1681 .expect("mask_buf length matches roi_h * roi_w"))
1682}
1683
1684#[allow(clippy::too_many_arguments)]
1685fn scaled_segmentations_f32_slice(
1686 detect: &[crate::DetectBox],
1687 coeff_all: &[f32],
1688 protos: &[f32],
1689 proto_h: usize,
1690 proto_w: usize,
1691 num_protos: usize,
1692 letterbox: Option<[f32; 4]>,
1693 width: u32,
1694 height: u32,
1695) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1696 scaled_run(
1697 detect,
1698 coeff_all,
1699 protos,
1700 proto_h,
1701 proto_w,
1702 num_protos,
1703 letterbox,
1704 width,
1705 height,
1706 1.0,
1707 |p, _| *p,
1708 )
1709}
1710
1711#[allow(clippy::too_many_arguments)]
1712fn scaled_segmentations_f16_slice(
1713 detect: &[crate::DetectBox],
1714 coeff_all: &[f32],
1715 protos: &[half::f16],
1716 proto_h: usize,
1717 proto_w: usize,
1718 num_protos: usize,
1719 letterbox: Option<[f32; 4]>,
1720 width: u32,
1721 height: u32,
1722) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1723 scaled_run(
1724 detect,
1725 coeff_all,
1726 protos,
1727 proto_h,
1728 proto_w,
1729 num_protos,
1730 letterbox,
1731 width,
1732 height,
1733 1.0,
1734 |p: &half::f16, _| p.to_f32(),
1735 )
1736}
1737
1738#[allow(clippy::too_many_arguments)]
1739fn scaled_segmentations_i8_slice(
1740 detect: &[crate::DetectBox],
1741 coeff_all: &[f32],
1742 protos: &[i8],
1743 proto_h: usize,
1744 proto_w: usize,
1745 num_protos: usize,
1746 quant: &edgefirst_tensor::Quantization,
1747 letterbox: Option<[f32; 4]>,
1748 width: u32,
1749 height: u32,
1750) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
1751 use edgefirst_tensor::QuantMode;
1752 let (scale, zp) = match quant.mode() {
1756 QuantMode::PerTensor { scale, zero_point } => (scale, zero_point as f32),
1757 QuantMode::PerTensorSymmetric { scale } => (scale, 0.0),
1758 QuantMode::PerChannel { axis, .. } | QuantMode::PerChannelSymmetric { axis, .. } => {
1759 return Err(crate::Error::NotSupported(format!(
1760 "per-channel quantization (axis={axis}) on scaled seg path \
1761 not yet supported"
1762 )));
1763 }
1764 };
1765 scaled_run(
1766 detect,
1767 coeff_all,
1768 protos,
1769 proto_h,
1770 proto_w,
1771 num_protos,
1772 letterbox,
1773 width,
1774 height,
1775 scale,
1776 move |p: &i8, _| *p as f32 - zp,
1777 )
1778}
1779
1780#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
1796#[inline(always)]
1797fn dot_i8_scalar(coeff: &[i8], proto: &[i8], n: usize) -> i32 {
1798 let mut acc: i32 = 0;
1799 let chunks = n / 4;
1800 let mut k = 0;
1801 for _ in 0..chunks {
1802 acc += coeff[k] as i32 * proto[k] as i32
1803 + coeff[k + 1] as i32 * proto[k + 1] as i32
1804 + coeff[k + 2] as i32 * proto[k + 2] as i32
1805 + coeff[k + 3] as i32 * proto[k + 3] as i32;
1806 k += 4;
1807 }
1808 while k < n {
1809 acc += coeff[k] as i32 * proto[k] as i32;
1810 k += 1;
1811 }
1812 acc
1813}
1814
1815#[cfg(target_arch = "aarch64")]
1817#[inline(always)]
1818unsafe fn dot_i8_neon_base(coeff: *const i8, proto: *const i8, n: usize) -> i32 {
1819 use std::arch::aarch64::*;
1820 let mut acc = vdupq_n_s32(0);
1821 let full_chunks = n / 16;
1822 let mut offset = 0usize;
1823 for _ in 0..full_chunks {
1824 let c = vld1q_s8(coeff.add(offset));
1825 let p = vld1q_s8(proto.add(offset));
1826 let lo = vmull_s8(vget_low_s8(c), vget_low_s8(p));
1828 let hi = vmull_high_s8(c, p);
1829 acc = vpadalq_s16(acc, lo);
1830 acc = vpadalq_s16(acc, hi);
1831 offset += 16;
1832 }
1833 let remainder = n - offset;
1835 if remainder >= 8 {
1836 let c = vld1_s8(coeff.add(offset));
1837 let p = vld1_s8(proto.add(offset));
1838 let prod = vmull_s8(c, p);
1839 acc = vpadalq_s16(acc, prod);
1840 offset += 8;
1841 }
1842 let mut scalar_acc = vaddvq_s32(acc);
1843 while offset < n {
1844 scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1845 offset += 1;
1846 }
1847 scalar_acc
1848}
1849
1850#[cfg(target_arch = "aarch64")]
1854#[inline(always)]
1855unsafe fn dot_i8_neon_dotprod(coeff: *const i8, proto: *const i8, n: usize) -> i32 {
1856 use std::arch::aarch64::*;
1857 let mut acc = vdupq_n_s32(0);
1858 let full_chunks = n / 16;
1859 let mut offset = 0usize;
1860 for _ in 0..full_chunks {
1861 let c = vld1q_s8(coeff.add(offset));
1862 let p = vld1q_s8(proto.add(offset));
1863 let result: int32x4_t;
1867 core::arch::asm!(
1868 ".arch_extension dotprod",
1869 "sdot {acc:v}.4s, {a:v}.16b, {b:v}.16b",
1870 acc = inout(vreg) acc => result,
1871 a = in(vreg) c,
1872 b = in(vreg) p,
1873 options(pure, nomem, nostack),
1874 );
1875 acc = result;
1876 offset += 16;
1877 }
1878 let mut scalar_acc = vaddvq_s32(acc);
1879 while offset < n {
1881 scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1882 offset += 1;
1883 }
1884 scalar_acc
1885}
1886
1887#[cfg_attr(target_arch = "aarch64", allow(dead_code))]
1889#[inline(always)]
1890fn dot_i16_i8_scalar(coeff: &[i16], proto: &[i8], n: usize) -> i32 {
1891 let mut acc: i32 = 0;
1892 let chunks = n / 4;
1893 let mut k = 0;
1894 for _ in 0..chunks {
1895 acc += coeff[k] as i32 * proto[k] as i32
1896 + coeff[k + 1] as i32 * proto[k + 1] as i32
1897 + coeff[k + 2] as i32 * proto[k + 2] as i32
1898 + coeff[k + 3] as i32 * proto[k + 3] as i32;
1899 k += 4;
1900 }
1901 while k < n {
1902 acc += coeff[k] as i32 * proto[k] as i32;
1903 k += 1;
1904 }
1905 acc
1906}
1907
1908#[cfg(target_arch = "aarch64")]
1911#[inline(always)]
1912unsafe fn dot_i16_i8_neon(coeff: *const i16, proto: *const i8, n: usize) -> i32 {
1913 use std::arch::aarch64::*;
1914 let mut acc = vdupq_n_s32(0);
1915 let full_chunks = n / 8;
1916 let mut offset = 0usize;
1917 for _ in 0..full_chunks {
1918 let c = vld1q_s16(coeff.add(offset));
1919 let p_raw = vld1_s8(proto.add(offset));
1920 let p = vmovl_s8(p_raw);
1921 acc = vmlal_s16(acc, vget_low_s16(c), vget_low_s16(p));
1922 acc = vmlal_high_s16(acc, c, p);
1923 offset += 8;
1924 }
1925 let mut scalar_acc = vaddvq_s32(acc);
1926 while offset < n {
1927 scalar_acc += *coeff.add(offset) as i32 * *proto.add(offset) as i32;
1928 offset += 1;
1929 }
1930 scalar_acc
1931}
1932
1933#[cfg(target_arch = "aarch64")]
1936#[inline(always)]
1937#[allow(clippy::too_many_arguments)]
1938fn compute_logits_dotprod(
1939 logits: &mut [i32],
1940 coeff: &[i8],
1941 protos: &[i8],
1942 proto_sums: &[i32],
1943 proto_w: usize,
1944 proto_x0: usize,
1945 proto_y0: usize,
1946 roi_w: usize,
1947 roi_h: usize,
1948 stride_y: usize,
1949 num_protos: usize,
1950 zp_c: i32,
1951 bias: i32,
1952) {
1953 for ly_idx in 0..roi_h {
1954 let py = proto_y0 + ly_idx;
1955 let row_base = py * stride_y + proto_x0 * num_protos;
1956 for lx_idx in 0..roi_w {
1957 let pix_base = row_base + lx_idx * num_protos;
1958 let proto_px = &protos[pix_base..pix_base + num_protos];
1959 let raw_dot =
1960 unsafe { dot_i8_neon_dotprod(coeff.as_ptr(), proto_px.as_ptr(), num_protos) };
1961 let correction = if zp_c != 0 {
1962 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
1963 } else {
1964 0
1965 };
1966 logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
1967 }
1968 }
1969}
1970
1971#[cfg(target_arch = "aarch64")]
1974#[inline(always)]
1975#[allow(clippy::too_many_arguments)]
1976fn compute_logits_base(
1977 logits: &mut [i32],
1978 coeff: &[i8],
1979 protos: &[i8],
1980 proto_sums: &[i32],
1981 proto_w: usize,
1982 proto_x0: usize,
1983 proto_y0: usize,
1984 roi_w: usize,
1985 roi_h: usize,
1986 stride_y: usize,
1987 num_protos: usize,
1988 zp_c: i32,
1989 bias: i32,
1990) {
1991 for ly_idx in 0..roi_h {
1992 let py = proto_y0 + ly_idx;
1993 let row_base = py * stride_y + proto_x0 * num_protos;
1994 for lx_idx in 0..roi_w {
1995 let pix_base = row_base + lx_idx * num_protos;
1996 let proto_px = &protos[pix_base..pix_base + num_protos];
1997 let raw_dot =
1998 unsafe { dot_i8_neon_base(coeff.as_ptr(), proto_px.as_ptr(), num_protos) };
1999 let correction = if zp_c != 0 {
2000 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2001 } else {
2002 0
2003 };
2004 logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2005 }
2006 }
2007}
2008
2009#[allow(clippy::too_many_arguments)]
2010fn scaled_segmentations_i8_i8(
2011 detect: &[crate::DetectBox],
2012 coeff_all: &[i8],
2013 coeff_quant: &edgefirst_tensor::Quantization,
2014 protos: &[i8],
2015 proto_quant: &edgefirst_tensor::Quantization,
2016 proto_h: usize,
2017 proto_w: usize,
2018 num_protos: usize,
2019 letterbox: Option<[f32; 4]>,
2020 width: u32,
2021 height: u32,
2022 layout: edgefirst_decoder::ProtoLayout,
2023) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
2024 use edgefirst_tensor::QuantMode;
2025
2026 let _span = tracing::trace_span!(
2027 "image.materialize_masks.kernel_i8_scaled",
2028 n = detect.len(),
2029 proto_h,
2030 proto_w,
2031 num_protos,
2032 width,
2033 height,
2034 ?layout,
2035 )
2036 .entered();
2037
2038 let zp_c: i32 = match coeff_quant.mode() {
2039 QuantMode::PerTensor { zero_point, .. } => zero_point,
2040 QuantMode::PerTensorSymmetric { .. } => 0,
2041 _ => {
2042 return Err(crate::Error::NotSupported(
2043 "per-channel coeff quantization not supported".into(),
2044 ))
2045 }
2046 };
2047 let zp_p: i32 = match proto_quant.mode() {
2048 QuantMode::PerTensor { zero_point, .. } => zero_point,
2049 QuantMode::PerTensorSymmetric { .. } => 0,
2050 _ => {
2051 return Err(crate::Error::NotSupported(
2052 "per-channel proto quantization not supported".into(),
2053 ))
2054 }
2055 };
2056
2057 let (lx0, lw, ly0, lh) = match letterbox {
2058 Some([lx0, ly0, lx1, ly1]) => {
2059 let lw = (lx1 - lx0).max(f32::EPSILON);
2060 let lh = (ly1 - ly0).max(f32::EPSILON);
2061 (lx0, lw, ly0, lh)
2062 }
2063 None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
2064 };
2065 let out_w = width as usize;
2066 let out_h = height as usize;
2067 let hw = proto_h * proto_w;
2068
2069 let proto_sums: Vec<i32> = if zp_c != 0 {
2071 match layout {
2072 edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
2073 .map(|px_idx| {
2074 let base = px_idx * num_protos;
2075 let mut s: i32 = 0;
2076 for k in 0..num_protos {
2077 s += protos[base + k] as i32;
2078 }
2079 s
2080 })
2081 .collect(),
2082 edgefirst_decoder::ProtoLayout::Nchw => {
2083 let mut sums = vec![0i32; hw];
2084 for c in 0..num_protos {
2085 let plane = &protos[c * hw..];
2086 for (px, s) in sums.iter_mut().enumerate() {
2087 *s += plane[px] as i32;
2088 }
2089 }
2090 sums
2091 }
2092 }
2093 } else {
2094 Vec::new()
2095 };
2096
2097 #[cfg(target_arch = "aarch64")]
2099 let use_dotprod = std::arch::is_aarch64_feature_detected!("dotprod");
2100
2101 let stride_y = proto_w * num_protos;
2103
2104 detect
2105 .par_iter()
2106 .enumerate()
2107 .map(|(i, det)| {
2108 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
2109 let bbox = det.bbox.to_canonical();
2110 let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
2111 let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
2112 let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
2113 let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
2114 let px0 = (xmin * out_w as f32).round() as usize;
2115 let py0 = (ymin * out_h as f32).round() as usize;
2116 let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
2117 let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
2118 let bbox_w = px1.saturating_sub(px0).max(1);
2119 let bbox_h = py1.saturating_sub(py0).max(1);
2120
2121 let sample_x_at = |px: f32| -> f32 {
2123 let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
2124 model_x_norm * proto_w as f32 - 0.5
2125 };
2126 let sample_y_at = |py: f32| -> f32 {
2127 let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
2128 model_y_norm * proto_h as f32 - 0.5
2129 };
2130 let s_x_min = sample_x_at(px0 as f32);
2131 let s_x_max = sample_x_at((px1 as f32) - 1.0);
2132 let s_y_min = sample_y_at(py0 as f32);
2133 let s_y_max = sample_y_at((py1 as f32) - 1.0);
2134 let proto_x0 = (s_x_min.floor() as isize)
2135 .max(0)
2136 .min(proto_w.saturating_sub(1) as isize) as usize;
2137 let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2138 let proto_y0 = (s_y_min.floor() as isize)
2139 .max(0)
2140 .min(proto_h.saturating_sub(1) as isize) as usize;
2141 let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2142 let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2143 let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2144
2145 let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
2147 let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
2148
2149 let mut logits = vec![0_i32; roi_h * roi_w];
2151 match layout {
2152 edgefirst_decoder::ProtoLayout::Nhwc => {
2153 #[cfg(target_arch = "aarch64")]
2154 {
2155 if use_dotprod {
2156 compute_logits_dotprod(
2157 &mut logits,
2158 coeff,
2159 protos,
2160 &proto_sums,
2161 proto_w,
2162 proto_x0,
2163 proto_y0,
2164 roi_w,
2165 roi_h,
2166 stride_y,
2167 num_protos,
2168 zp_c,
2169 bias,
2170 );
2171 } else {
2172 compute_logits_base(
2173 &mut logits,
2174 coeff,
2175 protos,
2176 &proto_sums,
2177 proto_w,
2178 proto_x0,
2179 proto_y0,
2180 roi_w,
2181 roi_h,
2182 stride_y,
2183 num_protos,
2184 zp_c,
2185 bias,
2186 );
2187 }
2188 }
2189 #[cfg(not(target_arch = "aarch64"))]
2190 {
2191 for ly_idx in 0..roi_h {
2192 let py = proto_y0 + ly_idx;
2193 let row_base = py * stride_y + proto_x0 * num_protos;
2194 for lx_idx in 0..roi_w {
2195 let pix_base = row_base + lx_idx * num_protos;
2196 let proto_px = &protos[pix_base..pix_base + num_protos];
2197 let raw_dot = dot_i8_scalar(coeff, proto_px, num_protos);
2198 let correction = if zp_c != 0 {
2199 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2200 } else {
2201 0
2202 };
2203 logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2204 }
2205 }
2206 }
2207 }
2208 edgefirst_decoder::ProtoLayout::Nchw => {
2209 for c in 0..num_protos {
2211 let plane = &protos[c * hw..];
2212 let coeff_c = coeff[c] as i32;
2213 for ly_idx in 0..roi_h {
2214 let py = proto_y0 + ly_idx;
2215 let row_start = py * proto_w + proto_x0;
2216 let out_row_start = ly_idx * roi_w;
2217 for lx_idx in 0..roi_w {
2218 logits[out_row_start + lx_idx] +=
2219 coeff_c * plane[row_start + lx_idx] as i32;
2220 }
2221 }
2222 }
2223 for ly_idx in 0..roi_h {
2225 let py = proto_y0 + ly_idx;
2226 for lx_idx in 0..roi_w {
2227 let idx = ly_idx * roi_w + lx_idx;
2228 let correction = if zp_c != 0 {
2229 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2230 } else {
2231 0
2232 };
2233 logits[idx] -= correction + bias;
2234 }
2235 }
2236 }
2237 }
2238
2239 let roi_last_x = roi_w.saturating_sub(1);
2242 let roi_last_y = roi_h.saturating_sub(1);
2243
2244 const FRAC_BITS: i32 = 10;
2246 const FRAC_SCALE: i32 = 1 << FRAC_BITS; let x_coords: Vec<(usize, usize, i32)> = (0..bbox_w)
2248 .map(|xi| {
2249 let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2250 let x_floor = sample_x.floor();
2251 let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as usize;
2252 let x_hi = (x_lo + 1).min(roi_w - 1);
2253 let x_frac = ((sample_x - x_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2254 (x_lo, x_hi, x_frac)
2255 })
2256 .collect();
2257
2258 let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2259 for yi in 0..bbox_h {
2260 let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2261 let y_floor = sample_y.floor();
2262 let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2263 let y_hi = (y_lo + 1).min(roi_h - 1);
2264 let y_frac = ((sample_y - y_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2265 let y_frac_inv = FRAC_SCALE - y_frac;
2266 let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2267 let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2268 let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2269
2270 for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2271 let tl = row_lo[x_lo];
2272 let tr = row_lo[x_hi];
2273 let bl = row_hi[x_lo];
2274 let br = row_hi[x_hi];
2275
2276 if (tl & tr & bl & br) < 0 {
2280 continue;
2282 }
2283 if tl > 0 && tr > 0 && bl > 0 && br > 0 {
2284 out_row[xi] = 255;
2286 continue;
2287 }
2288
2289 let x_frac_inv = FRAC_SCALE - x_frac;
2291 let l0 = tl as i64 * x_frac_inv as i64 + tr as i64 * x_frac as i64;
2292 let l1 = bl as i64 * x_frac_inv as i64 + br as i64 * x_frac as i64;
2293 let logit = l0 * y_frac_inv as i64 + l1 * y_frac as i64;
2294 out_row[xi] = if logit > 0 { 255 } else { 0 };
2295 }
2296 }
2297
2298 let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2299 .expect("tile_buf length matches bbox_h * bbox_w");
2300 Ok(edgefirst_decoder::Segmentation {
2301 xmin,
2302 ymin,
2303 xmax,
2304 ymax,
2305 segmentation: tile,
2306 })
2307 })
2308 .collect()
2309}
2310
2311#[allow(clippy::too_many_arguments)]
2312fn scaled_segmentations_i16_i8(
2313 detect: &[crate::DetectBox],
2314 coeff_all: &[i16],
2315 coeff_quant: &edgefirst_tensor::Quantization,
2316 protos: &[i8],
2317 proto_quant: &edgefirst_tensor::Quantization,
2318 proto_h: usize,
2319 proto_w: usize,
2320 num_protos: usize,
2321 letterbox: Option<[f32; 4]>,
2322 width: u32,
2323 height: u32,
2324 layout: edgefirst_decoder::ProtoLayout,
2325) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
2326 use edgefirst_tensor::QuantMode;
2327
2328 let _span = tracing::trace_span!(
2329 "image.materialize_masks.kernel_i16xi8_scaled",
2330 n = detect.len(),
2331 proto_h,
2332 proto_w,
2333 num_protos,
2334 width,
2335 height,
2336 ?layout,
2337 )
2338 .entered();
2339
2340 let zp_c: i32 = match coeff_quant.mode() {
2341 QuantMode::PerTensor { zero_point, .. } => zero_point,
2342 QuantMode::PerTensorSymmetric { .. } => 0,
2343 _ => {
2344 return Err(crate::Error::NotSupported(
2345 "per-channel coeff quantization not supported".into(),
2346 ))
2347 }
2348 };
2349 let zp_p: i32 = match proto_quant.mode() {
2350 QuantMode::PerTensor { zero_point, .. } => zero_point,
2351 QuantMode::PerTensorSymmetric { .. } => 0,
2352 _ => {
2353 return Err(crate::Error::NotSupported(
2354 "per-channel proto quantization not supported".into(),
2355 ))
2356 }
2357 };
2358
2359 let (lx0, lw, ly0, lh) = match letterbox {
2360 Some([lx0, ly0, lx1, ly1]) => {
2361 let lw = (lx1 - lx0).max(f32::EPSILON);
2362 let lh = (ly1 - ly0).max(f32::EPSILON);
2363 (lx0, lw, ly0, lh)
2364 }
2365 None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
2366 };
2367 let out_w = width as usize;
2368 let out_h = height as usize;
2369 let hw = proto_h * proto_w;
2370
2371 let proto_sums: Vec<i32> = if zp_c != 0 {
2373 match layout {
2374 edgefirst_decoder::ProtoLayout::Nhwc => (0..hw)
2375 .map(|px_idx| {
2376 let base = px_idx * num_protos;
2377 let mut s: i32 = 0;
2378 for k in 0..num_protos {
2379 s += protos[base + k] as i32;
2380 }
2381 s
2382 })
2383 .collect(),
2384 edgefirst_decoder::ProtoLayout::Nchw => {
2385 let mut sums = vec![0i32; hw];
2386 for c in 0..num_protos {
2387 let plane = &protos[c * hw..];
2388 for (px, s) in sums.iter_mut().enumerate() {
2389 *s += plane[px] as i32;
2390 }
2391 }
2392 sums
2393 }
2394 }
2395 } else {
2396 Vec::new()
2397 };
2398
2399 let stride_y = proto_w * num_protos;
2401
2402 detect
2403 .par_iter()
2404 .enumerate()
2405 .map(|(i, det)| {
2406 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
2407 let bbox = det.bbox.to_canonical();
2408 let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
2409 let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
2410 let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
2411 let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
2412 let px0 = (xmin * out_w as f32).round() as usize;
2413 let py0 = (ymin * out_h as f32).round() as usize;
2414 let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
2415 let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
2416 let bbox_w = px1.saturating_sub(px0).max(1);
2417 let bbox_h = py1.saturating_sub(py0).max(1);
2418
2419 let sample_x_at = |px: f32| -> f32 {
2421 let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
2422 model_x_norm * proto_w as f32 - 0.5
2423 };
2424 let sample_y_at = |py: f32| -> f32 {
2425 let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
2426 model_y_norm * proto_h as f32 - 0.5
2427 };
2428 let s_x_min = sample_x_at(px0 as f32);
2429 let s_x_max = sample_x_at((px1 as f32) - 1.0);
2430 let s_y_min = sample_y_at(py0 as f32);
2431 let s_y_max = sample_y_at((py1 as f32) - 1.0);
2432 let proto_x0 = (s_x_min.floor() as isize)
2433 .max(0)
2434 .min(proto_w.saturating_sub(1) as isize) as usize;
2435 let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2436 let proto_y0 = (s_y_min.floor() as isize)
2437 .max(0)
2438 .min(proto_h.saturating_sub(1) as isize) as usize;
2439 let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2440 let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2441 let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2442
2443 let coeff_sum: i32 = coeff.iter().map(|&c| c as i32).sum();
2445 let bias = zp_p * coeff_sum - (num_protos as i32) * zp_c * zp_p;
2446
2447 let mut logits = vec![0_i32; roi_h * roi_w];
2449 match layout {
2450 edgefirst_decoder::ProtoLayout::Nhwc => {
2451 #[cfg(target_arch = "aarch64")]
2452 {
2453 for ly_idx in 0..roi_h {
2454 let py = proto_y0 + ly_idx;
2455 let row_base = py * stride_y + proto_x0 * num_protos;
2456 for lx_idx in 0..roi_w {
2457 let pix_base = row_base + lx_idx * num_protos;
2458 let proto_px = &protos[pix_base..pix_base + num_protos];
2459 let raw_dot = unsafe {
2460 dot_i16_i8_neon(coeff.as_ptr(), proto_px.as_ptr(), num_protos)
2461 };
2462 let correction = if zp_c != 0 {
2463 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2464 } else {
2465 0
2466 };
2467 logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2468 }
2469 }
2470 }
2471 #[cfg(not(target_arch = "aarch64"))]
2472 {
2473 for ly_idx in 0..roi_h {
2474 let py = proto_y0 + ly_idx;
2475 let row_base = py * stride_y + proto_x0 * num_protos;
2476 for lx_idx in 0..roi_w {
2477 let pix_base = row_base + lx_idx * num_protos;
2478 let proto_px = &protos[pix_base..pix_base + num_protos];
2479 let raw_dot = dot_i16_i8_scalar(coeff, proto_px, num_protos);
2480 let correction = if zp_c != 0 {
2481 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2482 } else {
2483 0
2484 };
2485 logits[ly_idx * roi_w + lx_idx] = raw_dot - correction - bias;
2486 }
2487 }
2488 }
2489 }
2490 edgefirst_decoder::ProtoLayout::Nchw => {
2491 for c in 0..num_protos {
2493 let plane = &protos[c * hw..];
2494 let coeff_c = coeff[c] as i32;
2495 for ly_idx in 0..roi_h {
2496 let py = proto_y0 + ly_idx;
2497 let row_start = py * proto_w + proto_x0;
2498 let out_row_start = ly_idx * roi_w;
2499 for lx_idx in 0..roi_w {
2500 logits[out_row_start + lx_idx] +=
2501 coeff_c * plane[row_start + lx_idx] as i32;
2502 }
2503 }
2504 }
2505 for ly_idx in 0..roi_h {
2507 let py = proto_y0 + ly_idx;
2508 for lx_idx in 0..roi_w {
2509 let idx = ly_idx * roi_w + lx_idx;
2510 let correction = if zp_c != 0 {
2511 zp_c * proto_sums[py * proto_w + proto_x0 + lx_idx]
2512 } else {
2513 0
2514 };
2515 logits[idx] -= correction + bias;
2516 }
2517 }
2518 }
2519 }
2520
2521 let roi_last_x = roi_w.saturating_sub(1);
2524 let roi_last_y = roi_h.saturating_sub(1);
2525
2526 const FRAC_BITS: i32 = 10;
2528 const FRAC_SCALE: i32 = 1 << FRAC_BITS; let x_coords: Vec<(usize, usize, i32)> = (0..bbox_w)
2530 .map(|xi| {
2531 let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2532 let x_floor = sample_x.floor();
2533 let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as usize;
2534 let x_hi = (x_lo + 1).min(roi_w - 1);
2535 let x_frac = ((sample_x - x_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2536 (x_lo, x_hi, x_frac)
2537 })
2538 .collect();
2539
2540 let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2541 for yi in 0..bbox_h {
2542 let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2543 let y_floor = sample_y.floor();
2544 let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2545 let y_hi = (y_lo + 1).min(roi_h - 1);
2546 let y_frac = ((sample_y - y_floor).clamp(0.0, 1.0) * FRAC_SCALE as f32) as i32;
2547 let y_frac_inv = FRAC_SCALE - y_frac;
2548 let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2549 let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2550 let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2551
2552 for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2553 let tl = row_lo[x_lo];
2554 let tr = row_lo[x_hi];
2555 let bl = row_hi[x_lo];
2556 let br = row_hi[x_hi];
2557
2558 if (tl & tr & bl & br) < 0 {
2562 continue;
2564 }
2565 if tl > 0 && tr > 0 && bl > 0 && br > 0 {
2566 out_row[xi] = 255;
2568 continue;
2569 }
2570
2571 let x_frac_inv = FRAC_SCALE - x_frac;
2573 let l0 = tl as i64 * x_frac_inv as i64 + tr as i64 * x_frac as i64;
2574 let l1 = bl as i64 * x_frac_inv as i64 + br as i64 * x_frac as i64;
2575 let logit = l0 * y_frac_inv as i64 + l1 * y_frac as i64;
2576 out_row[xi] = if logit > 0 { 255 } else { 0 };
2577 }
2578 }
2579
2580 let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2581 .expect("tile_buf length matches bbox_h * bbox_w");
2582 Ok(edgefirst_decoder::Segmentation {
2583 xmin,
2584 ymin,
2585 xmax,
2586 ymax,
2587 segmentation: tile,
2588 })
2589 })
2590 .collect()
2591}
2592
2593#[allow(clippy::too_many_arguments)]
2594fn scaled_run<P: Copy + Sync>(
2595 detect: &[crate::DetectBox],
2596 coeff_all: &[f32],
2597 protos: &[P],
2598 proto_h: usize,
2599 proto_w: usize,
2600 num_protos: usize,
2601 letterbox: Option<[f32; 4]>,
2602 width: u32,
2603 height: u32,
2604 acc_scale: f32,
2605 load_f32: impl Fn(&P, f32) -> f32 + Copy + Sync,
2606) -> crate::Result<Vec<edgefirst_decoder::Segmentation>> {
2607 let (lx0, lw, ly0, lh) = match letterbox {
2608 Some([lx0, ly0, lx1, ly1]) => {
2609 let lw = (lx1 - lx0).max(f32::EPSILON);
2610 let lh = (ly1 - ly0).max(f32::EPSILON);
2611 (lx0, lw, ly0, lh)
2612 }
2613 None => (0.0_f32, 1.0_f32, 0.0_f32, 1.0_f32),
2614 };
2615 let out_w = width as usize;
2616 let out_h = height as usize;
2617 let stride_y = proto_w * num_protos;
2618
2619 detect
2641 .par_iter()
2642 .enumerate()
2643 .map(|(i, det)| {
2644 let coeff = &coeff_all[i * num_protos..(i + 1) * num_protos];
2645 let bbox = det.bbox.to_canonical();
2646 let xmin = ((bbox.xmin - lx0) / lw).clamp(0.0, 1.0);
2647 let ymin = ((bbox.ymin - ly0) / lh).clamp(0.0, 1.0);
2648 let xmax = ((bbox.xmax - lx0) / lw).clamp(0.0, 1.0);
2649 let ymax = ((bbox.ymax - ly0) / lh).clamp(0.0, 1.0);
2650 let px0 = (xmin * out_w as f32).round() as usize;
2651 let py0 = (ymin * out_h as f32).round() as usize;
2652 let px1 = ((xmax * out_w as f32).round() as usize).min(out_w);
2653 let py1 = ((ymax * out_h as f32).round() as usize).min(out_h);
2654 let bbox_w = px1.saturating_sub(px0).max(1);
2655 let bbox_h = py1.saturating_sub(py0).max(1);
2656
2657 let sample_x_at = |px: f32| -> f32 {
2662 let model_x_norm = lx0 + (px + 0.5) / out_w as f32 * lw;
2663 model_x_norm * proto_w as f32 - 0.5
2664 };
2665 let sample_y_at = |py: f32| -> f32 {
2666 let model_y_norm = ly0 + (py + 0.5) / out_h as f32 * lh;
2667 model_y_norm * proto_h as f32 - 0.5
2668 };
2669 let s_x_min = sample_x_at(px0 as f32);
2670 let s_x_max = sample_x_at((px1 as f32) - 1.0);
2671 let s_y_min = sample_y_at(py0 as f32);
2672 let s_y_max = sample_y_at((py1 as f32) - 1.0);
2673 let proto_x0 = (s_x_min.floor() as isize)
2677 .max(0)
2678 .min(proto_w.saturating_sub(1) as isize) as usize;
2679 let proto_x1 = ((s_x_max.ceil() as isize) + 1).max(0).min(proto_w as isize) as usize;
2680 let proto_y0 = (s_y_min.floor() as isize)
2681 .max(0)
2682 .min(proto_h.saturating_sub(1) as isize) as usize;
2683 let proto_y1 = ((s_y_max.ceil() as isize) + 1).max(0).min(proto_h as isize) as usize;
2684 let roi_w = proto_x1.saturating_sub(proto_x0).max(1);
2685 let roi_h = proto_y1.saturating_sub(proto_y0).max(1);
2686
2687 if !acc_scale.is_finite() || acc_scale <= 0.0 {
2696 return Err(crate::Error::NotSupported(format!(
2697 "acc_scale must be finite and positive for sign-threshold optimization (got {acc_scale})"
2698 )));
2699 }
2700 let _ = acc_scale; let mut logits = vec![0.0_f32; roi_h * roi_w];
2702 for ly_idx in 0..roi_h {
2703 let py = proto_y0 + ly_idx;
2704 let row_base = py * stride_y + proto_x0 * num_protos;
2705 for lx_idx in 0..roi_w {
2706 let pix_base = row_base + lx_idx * num_protos;
2707 let mut acc = 0.0_f32;
2708 let mut k = 0;
2710 let chunks = num_protos / 4;
2711 for _ in 0..chunks {
2712 acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0)
2713 + coeff[k + 1] * load_f32(&protos[pix_base + k + 1], 0.0)
2714 + coeff[k + 2] * load_f32(&protos[pix_base + k + 2], 0.0)
2715 + coeff[k + 3] * load_f32(&protos[pix_base + k + 3], 0.0);
2716 k += 4;
2717 }
2718 while k < num_protos {
2719 acc += coeff[k] * load_f32(&protos[pix_base + k], 0.0);
2720 k += 1;
2721 }
2722 logits[ly_idx * roi_w + lx_idx] = acc;
2723 }
2724 }
2725
2726 let roi_last_x = roi_w.saturating_sub(1);
2737 let roi_last_y = roi_h.saturating_sub(1);
2738
2739 let x_coords: Vec<(u32, u32, f32)> = (0..bbox_w)
2741 .map(|xi| {
2742 let sample_x = sample_x_at((px0 + xi) as f32) - proto_x0 as f32;
2743 let x_floor = sample_x.floor();
2744 let x_lo = (x_floor as isize).max(0).min(roi_last_x as isize) as u32;
2745 let x_hi = (x_lo as usize + 1).min(roi_w - 1) as u32;
2746 let x_frac = (sample_x - x_floor).clamp(0.0, 1.0);
2747 (x_lo, x_hi, x_frac)
2748 })
2749 .collect();
2750
2751 let mut tile_buf = vec![0u8; bbox_h * bbox_w];
2754 for yi in 0..bbox_h {
2755 let sample_y = sample_y_at((py0 + yi) as f32) - proto_y0 as f32;
2756 let y_floor = sample_y.floor();
2757 let y_lo = (y_floor as isize).max(0).min(roi_last_y as isize) as usize;
2758 let y_hi = (y_lo + 1).min(roi_h - 1);
2759 let y_frac = (sample_y - y_floor).clamp(0.0, 1.0);
2760 let row_lo = &logits[y_lo * roi_w..y_lo * roi_w + roi_w];
2761 let row_hi = &logits[y_hi * roi_w..y_hi * roi_w + roi_w];
2762 let out_row = &mut tile_buf[yi * bbox_w..(yi + 1) * bbox_w];
2763 for (xi, &(x_lo, x_hi, x_frac)) in x_coords.iter().enumerate() {
2764 let (xl, xh) = (x_lo as usize, x_hi as usize);
2765 let l0 = row_lo[xl] + (row_lo[xh] - row_lo[xl]) * x_frac;
2766 let l1 = row_hi[xl] + (row_hi[xh] - row_hi[xl]) * x_frac;
2767 let logit = l0 + (l1 - l0) * y_frac;
2768 out_row[xi] = if logit > 0.0 { 255 } else { 0 };
2769 }
2770 }
2771 let tile = ndarray::Array3::from_shape_vec((bbox_h, bbox_w, 1), tile_buf)
2773 .expect("tile_buf length matches bbox_h * bbox_w");
2774 Ok(edgefirst_decoder::Segmentation {
2775 xmin,
2776 ymin,
2777 xmax,
2778 ymax,
2779 segmentation: tile,
2780 })
2781 })
2782 .collect()
2783}
2784
2785#[cfg(test)]
2786mod tests {
2787 use super::CPUProcessor;
2788 use edgefirst_decoder::{BoundingBox, DetectBox, ProtoData, ProtoLayout};
2789 use edgefirst_tensor::{Quantization, Tensor, TensorDyn};
2790
2791 const PROTO_H: usize = 4;
2792 const PROTO_W: usize = 4;
2793 const NUM_PROTOS: usize = 8;
2794
2795 fn det(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> DetectBox {
2796 DetectBox {
2797 bbox: BoundingBox {
2798 xmin,
2799 ymin,
2800 xmax,
2801 ymax,
2802 },
2803 score: 0.9,
2804 label: 0,
2805 }
2806 }
2807
2808 fn make_i8_quant(shape: &[usize], data: &[i8], scale: f32, zp: i32) -> TensorDyn {
2809 let t = Tensor::<i8>::from_slice(data, shape).unwrap();
2810 let t = t
2811 .with_quantization(Quantization::per_tensor(scale, zp))
2812 .unwrap();
2813 TensorDyn::I8(t)
2814 }
2815
2816 fn make_i16_quant(shape: &[usize], data: &[i16], scale: f32, zp: i32) -> TensorDyn {
2817 let t = Tensor::<i16>::from_slice(data, shape).unwrap();
2818 let t = t
2819 .with_quantization(Quantization::per_tensor(scale, zp))
2820 .unwrap();
2821 TensorDyn::I16(t)
2822 }
2823
2824 fn make_i16_raw(shape: &[usize], data: &[i16]) -> TensorDyn {
2825 let t = Tensor::<i16>::from_slice(data, shape).unwrap();
2826 TensorDyn::I16(t)
2827 }
2828
2829 fn make_f32(shape: &[usize], data: &[f32]) -> TensorDyn {
2830 let t = Tensor::<f32>::from_slice(data, shape).unwrap();
2831 TensorDyn::F32(t)
2832 }
2833
2834 fn gen_protos_i8(h: usize, w: usize, k: usize) -> Vec<i8> {
2835 (0..h * w * k).map(|i| (i % 127) as i8).collect()
2836 }
2837
2838 fn gen_coeffs_i16(n: usize, k: usize) -> Vec<i16> {
2839 (0..n * k)
2840 .map(|i| ((i as i32 % 201) - 100) as i16)
2841 .collect()
2842 }
2843
2844 fn gen_coeffs_i8(n: usize, k: usize) -> Vec<i8> {
2845 (0..n * k).map(|i| ((i as i32 % 201) - 100) as i8).collect()
2846 }
2847
2848 #[test]
2851 fn materialize_proto_i16_i8_quant_produces_masks() {
2852 let cpu = CPUProcessor::new();
2853 let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
2854 let protos = make_i8_quant(
2855 &[PROTO_H, PROTO_W, NUM_PROTOS],
2856 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2857 0.02,
2858 0,
2859 );
2860 let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 0);
2861 let proto_data = ProtoData {
2862 mask_coefficients: coeffs,
2863 protos,
2864 layout: ProtoLayout::Nhwc,
2865 };
2866 let result = cpu.materialize_segmentations(&detect, &proto_data, None);
2867 assert!(result.is_ok(), "materialize failed: {:?}", result.err());
2868 let segs = result.unwrap();
2869 assert_eq!(segs.len(), 1);
2870 let seg = &segs[0];
2871 assert!(seg.segmentation.shape()[0] > 0);
2872 assert!(seg.segmentation.shape()[1] > 0);
2873 }
2874
2875 #[test]
2878 fn materialize_proto_i16_no_quant_falls_back_to_f32() {
2879 let cpu = CPUProcessor::new();
2880 let detect = vec![det(0.2, 0.2, 0.8, 0.8)];
2881 let protos = make_i8_quant(
2882 &[PROTO_H, PROTO_W, NUM_PROTOS],
2883 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2884 0.02,
2885 0,
2886 );
2887 let coeffs = make_i16_raw(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS));
2890 let proto_data = ProtoData {
2891 mask_coefficients: coeffs,
2892 protos,
2893 layout: ProtoLayout::Nhwc,
2894 };
2895 let result = cpu.materialize_segmentations(&detect, &proto_data, None);
2896 assert!(
2897 result.is_ok(),
2898 "missing coeff quant should fall back to f32 path, got: {:?}",
2899 result.err()
2900 );
2901 assert_eq!(result.unwrap().len(), 1);
2902 }
2903
2904 #[test]
2907 fn materialize_scaled_i16_i8_quant_produces_masks() {
2908 let cpu = CPUProcessor::new();
2909 let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
2910 let protos = make_i8_quant(
2911 &[PROTO_H, PROTO_W, NUM_PROTOS],
2912 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2913 0.02,
2914 0,
2915 );
2916 let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 0);
2917 let proto_data = ProtoData {
2918 mask_coefficients: coeffs,
2919 protos,
2920 layout: ProtoLayout::Nhwc,
2921 };
2922 let result = cpu.materialize_scaled_segmentations(&detect, &proto_data, None, 64, 64);
2923 assert!(
2924 result.is_ok(),
2925 "materialize_scaled failed: {:?}",
2926 result.err()
2927 );
2928 let segs = result.unwrap();
2929 assert_eq!(segs.len(), 1);
2930 let seg = &segs[0];
2931 assert!(seg.segmentation.shape()[0] > 0);
2932 assert!(seg.segmentation.shape()[1] > 0);
2933 }
2934
2935 #[test]
2938 fn materialize_scaled_i16_no_quant_falls_back_to_f32() {
2939 let cpu = CPUProcessor::new();
2940 let detect = vec![det(0.2, 0.2, 0.8, 0.8)];
2941 let protos = make_i8_quant(
2942 &[PROTO_H, PROTO_W, NUM_PROTOS],
2943 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
2944 0.02,
2945 0,
2946 );
2947 let coeffs = make_i16_raw(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS));
2948 let proto_data = ProtoData {
2949 mask_coefficients: coeffs,
2950 protos,
2951 layout: ProtoLayout::Nhwc,
2952 };
2953 let result = cpu.materialize_scaled_segmentations(&detect, &proto_data, None, 64, 64);
2954 assert!(
2955 result.is_ok(),
2956 "missing coeff quant should fall back to f32 path, got: {:?}",
2957 result.err()
2958 );
2959 assert_eq!(result.unwrap().len(), 1);
2960 }
2961
2962 #[test]
2965 fn materialize_proto_i16_i8_matches_f32_reference() {
2966 let cpu = CPUProcessor::new();
2967 let detect = vec![det(0.1, 0.1, 0.9, 0.9), det(0.3, 0.3, 0.7, 0.7)];
2968 let n_det = detect.len();
2969 let scale_c = 0.01_f32;
2970 let scale_p = 0.02_f32;
2971 let raw_protos = gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS);
2972 let raw_coeffs = gen_coeffs_i16(n_det, NUM_PROTOS);
2973
2974 let protos_f32: Vec<f32> = raw_protos.iter().map(|&v| v as f32 * scale_p).collect();
2976 let coeffs_f32: Vec<f32> = raw_coeffs.iter().map(|&v| v as f32 * scale_c).collect();
2977 let proto_data_f32 = ProtoData {
2978 mask_coefficients: make_f32(&[n_det, NUM_PROTOS], &coeffs_f32),
2979 protos: make_f32(&[PROTO_H, PROTO_W, NUM_PROTOS], &protos_f32),
2980 layout: ProtoLayout::Nhwc,
2981 };
2982
2983 let proto_data_int = ProtoData {
2984 mask_coefficients: make_i16_quant(&[n_det, NUM_PROTOS], &raw_coeffs, scale_c, 0),
2985 protos: make_i8_quant(&[PROTO_H, PROTO_W, NUM_PROTOS], &raw_protos, scale_p, 0),
2986 layout: ProtoLayout::Nhwc,
2987 };
2988
2989 let segs_f32 = cpu
2990 .materialize_segmentations(&detect, &proto_data_f32, None)
2991 .unwrap();
2992 let segs_int = cpu
2993 .materialize_segmentations(&detect, &proto_data_int, None)
2994 .unwrap();
2995
2996 assert_eq!(segs_f32.len(), segs_int.len());
2997 for (sf, si) in segs_f32.iter().zip(segs_int.iter()) {
2998 assert_eq!(sf.segmentation.shape(), si.segmentation.shape());
2999 let total = sf.segmentation.len();
3000 let mismatches = sf
3001 .segmentation
3002 .iter()
3003 .zip(si.segmentation.iter())
3004 .filter(|(a, b)| a != b)
3005 .count();
3006 let pct = mismatches as f64 / total as f64 * 100.0;
3007 assert!(
3008 pct < 5.0,
3009 "mask mismatch {mismatches}/{total} ({pct:.1}%) exceeds 5% threshold"
3010 );
3011 }
3012 }
3013
3014 #[test]
3017 fn materialize_proto_i16_multiple_detections() {
3018 let cpu = CPUProcessor::new();
3019 let detect = vec![
3020 det(0.0, 0.0, 0.5, 0.5),
3021 det(0.5, 0.5, 1.0, 1.0),
3022 det(0.1, 0.1, 0.3, 0.3),
3023 ];
3024 let protos = make_i8_quant(
3025 &[PROTO_H, PROTO_W, NUM_PROTOS],
3026 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3027 0.02,
3028 0,
3029 );
3030 let coeffs = make_i16_quant(&[3, NUM_PROTOS], &gen_coeffs_i16(3, NUM_PROTOS), 0.01, 0);
3031 let proto_data = ProtoData {
3032 mask_coefficients: coeffs,
3033 protos,
3034 layout: ProtoLayout::Nhwc,
3035 };
3036 let segs = cpu
3037 .materialize_segmentations(&detect, &proto_data, None)
3038 .unwrap();
3039 assert_eq!(segs.len(), 3);
3040 }
3041
3042 #[test]
3045 fn materialize_proto_i16_empty_detections() {
3046 let cpu = CPUProcessor::new();
3047 let detect: Vec<DetectBox> = vec![];
3048 let protos = make_i8_quant(
3049 &[PROTO_H, PROTO_W, NUM_PROTOS],
3050 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3051 0.02,
3052 0,
3053 );
3054 let coeffs = make_i16_quant(&[0, NUM_PROTOS], &[], 0.01, 0);
3055 let proto_data = ProtoData {
3056 mask_coefficients: coeffs,
3057 protos,
3058 layout: ProtoLayout::Nhwc,
3059 };
3060 let segs = cpu
3061 .materialize_segmentations(&detect, &proto_data, None)
3062 .unwrap();
3063 assert!(segs.is_empty());
3064 }
3065
3066 #[test]
3069 fn materialize_scaled_i16_i8_matches_f32_reference() {
3070 let cpu = CPUProcessor::new();
3071 let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3072 let scale_c = 0.01_f32;
3073 let scale_p = 0.02_f32;
3074 let raw_protos = gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS);
3075 let raw_coeffs = gen_coeffs_i16(1, NUM_PROTOS);
3076
3077 let protos_f32: Vec<f32> = raw_protos.iter().map(|&v| v as f32 * scale_p).collect();
3078 let coeffs_f32: Vec<f32> = raw_coeffs.iter().map(|&v| v as f32 * scale_c).collect();
3079 let proto_data_f32 = ProtoData {
3080 mask_coefficients: make_f32(&[1, NUM_PROTOS], &coeffs_f32),
3081 protos: make_f32(&[PROTO_H, PROTO_W, NUM_PROTOS], &protos_f32),
3082 layout: ProtoLayout::Nhwc,
3083 };
3084 let proto_data_int = ProtoData {
3085 mask_coefficients: make_i16_quant(&[1, NUM_PROTOS], &raw_coeffs, scale_c, 0),
3086 protos: make_i8_quant(&[PROTO_H, PROTO_W, NUM_PROTOS], &raw_protos, scale_p, 0),
3087 layout: ProtoLayout::Nhwc,
3088 };
3089
3090 let (w, h) = (64_u32, 64_u32);
3091 let segs_f32 = cpu
3092 .materialize_scaled_segmentations(&detect, &proto_data_f32, None, w, h)
3093 .unwrap();
3094 let segs_int = cpu
3095 .materialize_scaled_segmentations(&detect, &proto_data_int, None, w, h)
3096 .unwrap();
3097
3098 assert_eq!(segs_f32.len(), segs_int.len());
3099 for (sf, si) in segs_f32.iter().zip(segs_int.iter()) {
3100 assert_eq!(sf.segmentation.shape(), si.segmentation.shape());
3101 let total = sf.segmentation.len();
3102 let mismatches = sf
3103 .segmentation
3104 .iter()
3105 .zip(si.segmentation.iter())
3106 .filter(|(a, b)| a != b)
3107 .count();
3108 let pct = mismatches as f64 / total as f64 * 100.0;
3109 assert!(
3110 pct < 5.0,
3111 "scaled mask mismatch {mismatches}/{total} ({pct:.1}%) exceeds 5% threshold"
3112 );
3113 }
3114 }
3115
3116 #[test]
3119 fn materialize_proto_i8_i8_regression() {
3120 let cpu = CPUProcessor::new();
3121 let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3122 let protos = make_i8_quant(
3123 &[PROTO_H, PROTO_W, NUM_PROTOS],
3124 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3125 0.02,
3126 0,
3127 );
3128 let coeffs = make_i8_quant(&[1, NUM_PROTOS], &gen_coeffs_i8(1, NUM_PROTOS), 0.01, 0);
3129 let proto_data = ProtoData {
3130 mask_coefficients: coeffs,
3131 protos,
3132 layout: ProtoLayout::Nhwc,
3133 };
3134 let result = cpu.materialize_segmentations(&detect, &proto_data, None);
3135 assert!(result.is_ok(), "i8×i8 regression: {:?}", result.err());
3136 assert_eq!(result.unwrap().len(), 1);
3137 }
3138
3139 #[test]
3142 fn materialize_proto_i16_nonzero_zp() {
3143 let cpu = CPUProcessor::new();
3144 let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3145 let protos = make_i8_quant(
3146 &[PROTO_H, PROTO_W, NUM_PROTOS],
3147 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3148 0.02,
3149 -10,
3150 );
3151 let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 5);
3152 let proto_data = ProtoData {
3153 mask_coefficients: coeffs,
3154 protos,
3155 layout: ProtoLayout::Nhwc,
3156 };
3157 let result = cpu.materialize_segmentations(&detect, &proto_data, None);
3158 assert!(result.is_ok(), "nonzero zp failed: {:?}", result.err());
3159 assert_eq!(result.unwrap().len(), 1);
3160 }
3161
3162 #[test]
3165 fn materialize_scaled_i16_nonzero_zp() {
3166 let cpu = CPUProcessor::new();
3167 let detect = vec![det(0.1, 0.1, 0.9, 0.9)];
3168 let protos = make_i8_quant(
3169 &[PROTO_H, PROTO_W, NUM_PROTOS],
3170 &gen_protos_i8(PROTO_H, PROTO_W, NUM_PROTOS),
3171 0.02,
3172 -10,
3173 );
3174 let coeffs = make_i16_quant(&[1, NUM_PROTOS], &gen_coeffs_i16(1, NUM_PROTOS), 0.01, 5);
3175 let proto_data = ProtoData {
3176 mask_coefficients: coeffs,
3177 protos,
3178 layout: ProtoLayout::Nhwc,
3179 };
3180 let result = cpu.materialize_scaled_segmentations(&detect, &proto_data, None, 64, 64);
3181 assert!(
3182 result.is_ok(),
3183 "scaled nonzero zp failed: {:?}",
3184 result.err()
3185 );
3186 assert_eq!(result.unwrap().len(), 1);
3187 }
3188}