1use super::anchor_nms::iou;
31use crate::{
32 error::{VisionError, VisionResult},
33 fpn::top_down::FeatureMap,
34 handle::LcgRng,
35};
36
37fn silu_inplace(x: &mut [f32]) {
41 for v in x.iter_mut() {
42 *v *= 1.0 / (1.0 + (-*v).exp());
43 }
44}
45
46#[inline]
48fn softplus(x: f32) -> f32 {
49 x.max(0.0) + (-(x.abs())).exp().ln_1p()
50}
51
52#[derive(Debug, Clone)]
56pub struct Conv2d {
57 weight: Vec<f32>,
58 bias: Vec<f32>,
59 c_in: usize,
60 c_out: usize,
61 k: usize,
62 stride: usize,
63 pad: usize,
64}
65
66impl Conv2d {
67 fn new(
69 c_in: usize,
70 c_out: usize,
71 k: usize,
72 stride: usize,
73 pad: usize,
74 rng: &mut LcgRng,
75 ) -> Self {
76 let fan_in = c_in * k * k;
77 let scale = (2.0 / fan_in as f32).sqrt();
78 let mut weight = vec![0.0f32; c_out * fan_in];
79 rng.fill_normal(&mut weight);
80 for w in &mut weight {
81 *w *= scale;
82 }
83 Self {
84 weight,
85 bias: vec![0.0f32; c_out],
86 c_in,
87 c_out,
88 k,
89 stride,
90 pad,
91 }
92 }
93
94 #[must_use]
96 #[inline]
97 pub fn out_channels(&self) -> usize {
98 self.c_out
99 }
100
101 pub fn forward(&self, x: &FeatureMap) -> VisionResult<FeatureMap> {
108 if x.channels != self.c_in {
109 return Err(VisionError::DimensionMismatch {
110 expected: self.c_in,
111 got: x.channels,
112 });
113 }
114 let (h, w) = (x.height, x.width);
115 if h + 2 * self.pad < self.k || w + 2 * self.pad < self.k {
116 return Err(VisionError::InvalidImageSize {
117 height: h,
118 width: w,
119 channels: x.channels,
120 });
121 }
122 let h_out = (h + 2 * self.pad - self.k) / self.stride + 1;
123 let w_out = (w + 2 * self.pad - self.k) / self.stride + 1;
124 let mut out = vec![0.0f32; self.c_out * h_out * w_out];
125 let k = self.k;
126 for oc in 0..self.c_out {
127 let oc_w = oc * self.c_in * k * k;
128 for oh in 0..h_out {
129 for ow in 0..w_out {
130 let mut acc = self.bias[oc];
131 for ic in 0..self.c_in {
132 let in_base = ic * h * w;
133 let w_base = oc_w + ic * k * k;
134 for ki in 0..k {
135 let ih = oh * self.stride + ki;
136 if ih < self.pad || ih >= h + self.pad {
137 continue;
138 }
139 let ih = ih - self.pad;
140 for kj in 0..k {
141 let iw = ow * self.stride + kj;
142 if iw < self.pad || iw >= w + self.pad {
143 continue;
144 }
145 let iw = iw - self.pad;
146 acc += self.weight[w_base + ki * k + kj]
147 * out_in(x, in_base, ih, w, iw);
148 }
149 }
150 }
151 out[(oc * h_out + oh) * w_out + ow] = acc;
152 }
153 }
154 }
155 FeatureMap::new(out, self.c_out, h_out, w_out)
156 }
157}
158
159#[inline]
161fn out_in(x: &FeatureMap, in_base: usize, ih: usize, w: usize, iw: usize) -> f32 {
162 x.data[in_base + ih * w + iw]
163}
164
165#[derive(Debug, Clone)]
172pub struct DwConv2d {
173 weight: Vec<f32>,
174 bias: Vec<f32>,
175 c: usize,
176 k: usize,
177 stride: usize,
178 pad: usize,
179}
180
181impl DwConv2d {
182 fn new(c: usize, k: usize, stride: usize, pad: usize, rng: &mut LcgRng) -> Self {
183 let fan_in = k * k;
184 let scale = (2.0 / fan_in as f32).sqrt();
185 let mut weight = vec![0.0f32; c * fan_in];
186 rng.fill_normal(&mut weight);
187 for w in &mut weight {
188 *w *= scale;
189 }
190 Self {
191 weight,
192 bias: vec![0.0f32; c],
193 c,
194 k,
195 stride,
196 pad,
197 }
198 }
199
200 pub fn forward(&self, x: &FeatureMap) -> VisionResult<FeatureMap> {
206 if x.channels != self.c {
207 return Err(VisionError::DimensionMismatch {
208 expected: self.c,
209 got: x.channels,
210 });
211 }
212 let (h, w) = (x.height, x.width);
213 if h + 2 * self.pad < self.k || w + 2 * self.pad < self.k {
214 return Err(VisionError::InvalidImageSize {
215 height: h,
216 width: w,
217 channels: x.channels,
218 });
219 }
220 let h_out = (h + 2 * self.pad - self.k) / self.stride + 1;
221 let w_out = (w + 2 * self.pad - self.k) / self.stride + 1;
222 let k = self.k;
223 let mut out = vec![0.0f32; self.c * h_out * w_out];
224 for ch in 0..self.c {
225 let in_base = ch * h * w;
226 let w_base = ch * k * k;
227 let bias = self.bias[ch];
228 for oh in 0..h_out {
229 for ow in 0..w_out {
230 let mut acc = bias;
231 for ki in 0..k {
232 let ih = oh * self.stride + ki;
233 if ih < self.pad || ih >= h + self.pad {
234 continue;
235 }
236 let ih = ih - self.pad;
237 for kj in 0..k {
238 let iw = ow * self.stride + kj;
239 if iw < self.pad || iw >= w + self.pad {
240 continue;
241 }
242 let iw = iw - self.pad;
243 acc +=
244 self.weight[w_base + ki * k + kj] * x.data[in_base + ih * w + iw];
245 }
246 }
247 out[(ch * h_out + oh) * w_out + ow] = acc;
248 }
249 }
250 }
251 FeatureMap::new(out, self.c, h_out, w_out)
252 }
253}
254
255fn concat_channels(a: &FeatureMap, b: &FeatureMap) -> VisionResult<FeatureMap> {
259 if a.height != b.height || a.width != b.width {
260 return Err(VisionError::ShapeMismatch {
261 lhs: vec![a.channels, a.height, a.width],
262 rhs: vec![b.channels, b.height, b.width],
263 });
264 }
265 let mut data = Vec::with_capacity(a.data.len() + b.data.len());
266 data.extend_from_slice(&a.data);
267 data.extend_from_slice(&b.data);
268 Ok(FeatureMap {
269 data,
270 channels: a.channels + b.channels,
271 height: a.height,
272 width: a.width,
273 })
274}
275
276fn add_inplace(dst: &mut FeatureMap, src: &FeatureMap) -> VisionResult<()> {
278 if dst.channels != src.channels || dst.height != src.height || dst.width != src.width {
279 return Err(VisionError::ShapeMismatch {
280 lhs: vec![dst.channels, dst.height, dst.width],
281 rhs: vec![src.channels, src.height, src.width],
282 });
283 }
284 for (a, b) in dst.data.iter_mut().zip(src.data.iter()) {
285 *a += *b;
286 }
287 Ok(())
288}
289
290fn upsample2x(x: &FeatureMap) -> FeatureMap {
292 let (c, h, w) = (x.channels, x.height, x.width);
293 let (h2, w2) = (h * 2, w * 2);
294 let mut out = vec![0.0f32; c * h2 * w2];
295 for ch in 0..c {
296 for i in 0..h {
297 for j in 0..w {
298 let v = x.data[(ch * h + i) * w + j];
299 let oi = i * 2;
300 let oj = j * 2;
301 out[(ch * h2 + oi) * w2 + oj] = v;
302 out[(ch * h2 + oi) * w2 + oj + 1] = v;
303 out[(ch * h2 + oi + 1) * w2 + oj] = v;
304 out[(ch * h2 + oi + 1) * w2 + oj + 1] = v;
305 }
306 }
307 }
308 FeatureMap {
309 data: out,
310 channels: c,
311 height: h2,
312 width: w2,
313 }
314}
315
316pub struct Bottleneck {
321 dw: DwConv2d,
322 pw: Conv2d,
323}
324
325impl Bottleneck {
326 fn new(channels: usize, dw_kernel: usize, rng: &mut LcgRng) -> Self {
327 let pad = (dw_kernel - 1) / 2;
328 Self {
329 dw: DwConv2d::new(channels, dw_kernel, 1, pad, rng),
330 pw: Conv2d::new(channels, channels, 1, 1, 0, rng),
331 }
332 }
333
334 fn forward(&self, x: &FeatureMap) -> VisionResult<FeatureMap> {
335 let mut y = self.dw.forward(x)?;
336 silu_inplace(&mut y.data);
337 let mut y = self.pw.forward(&y)?;
338 silu_inplace(&mut y.data);
339 add_inplace(&mut y, x)?; Ok(y)
341 }
342}
343
344pub struct CspLayer {
350 main_conv: Conv2d,
351 short_conv: Conv2d,
352 blocks: Vec<Bottleneck>,
353 final_conv: Conv2d,
354}
355
356impl CspLayer {
357 fn new(
358 in_channels: usize,
359 out_channels: usize,
360 n_blocks: usize,
361 dw_kernel: usize,
362 rng: &mut LcgRng,
363 ) -> Self {
364 let mid = out_channels / 2;
365 let main_conv = Conv2d::new(in_channels, mid, 1, 1, 0, rng);
366 let short_conv = Conv2d::new(in_channels, mid, 1, 1, 0, rng);
367 let blocks = (0..n_blocks)
368 .map(|_| Bottleneck::new(mid, dw_kernel, rng))
369 .collect();
370 let final_conv = Conv2d::new(2 * mid, out_channels, 1, 1, 0, rng);
372 Self {
373 main_conv,
374 short_conv,
375 blocks,
376 final_conv,
377 }
378 }
379
380 fn forward(&self, x: &FeatureMap) -> VisionResult<FeatureMap> {
381 let mut short = self.short_conv.forward(x)?;
382 silu_inplace(&mut short.data);
383 let mut main = self.main_conv.forward(x)?;
384 silu_inplace(&mut main.data);
385 for b in &self.blocks {
386 main = b.forward(&main)?;
387 }
388 let cat = concat_channels(&main, &short)?;
389 let mut out = self.final_conv.forward(&cat)?;
390 silu_inplace(&mut out.data);
391 Ok(out)
392 }
393}
394
395struct BackboneStage {
398 downsample: Conv2d,
399 csp: CspLayer,
400}
401
402pub struct CspNeXtBackbone {
405 stem: Conv2d,
406 stages: Vec<BackboneStage>,
407}
408
409impl CspNeXtBackbone {
410 fn new(cfg: &RtmDetConfig, rng: &mut LcgRng) -> Self {
411 let stem = Conv2d::new(cfg.in_chans, cfg.stem_channels, 3, 2, 1, rng);
412 let mut stages = Vec::with_capacity(cfg.stage_channels.len());
413 let mut prev = cfg.stem_channels;
414 for &c in &cfg.stage_channels {
415 let downsample = Conv2d::new(prev, c, 3, 2, 1, rng);
416 let csp = CspLayer::new(c, c, cfg.n_bottlenecks, cfg.dw_kernel, rng);
417 stages.push(BackboneStage { downsample, csp });
418 prev = c;
419 }
420 Self { stem, stages }
421 }
422
423 pub fn forward(&self, image: &FeatureMap) -> VisionResult<Vec<FeatureMap>> {
428 let mut x = self.stem.forward(image)?;
429 silu_inplace(&mut x.data);
430 let mut feats = Vec::with_capacity(self.stages.len());
431 for stage in &self.stages {
432 let mut d = stage.downsample.forward(&x)?;
433 silu_inplace(&mut d.data);
434 x = stage.csp.forward(&d)?;
435 feats.push(x.clone());
436 }
437 Ok(feats)
438 }
439}
440
441pub struct Pafpn {
445 lateral: Vec<Conv2d>,
446 top_down: Vec<Conv2d>,
447 downsample: Vec<Conv2d>,
448 bottom_up: Vec<Conv2d>,
449 n_levels: usize,
450}
451
452impl Pafpn {
453 fn new(in_channels: &[usize], out_channels: usize, rng: &mut LcgRng) -> Self {
454 let n_levels = in_channels.len();
455 let lateral = in_channels
456 .iter()
457 .map(|&c| Conv2d::new(c, out_channels, 1, 1, 0, rng))
458 .collect();
459 let top_down = (0..n_levels)
460 .map(|_| Conv2d::new(out_channels, out_channels, 3, 1, 1, rng))
461 .collect();
462 let downsample = (0..n_levels.saturating_sub(1))
463 .map(|_| Conv2d::new(out_channels, out_channels, 3, 2, 1, rng))
464 .collect();
465 let bottom_up = (0..n_levels.saturating_sub(1))
466 .map(|_| Conv2d::new(out_channels, out_channels, 3, 1, 1, rng))
467 .collect();
468 Self {
469 lateral,
470 top_down,
471 downsample,
472 bottom_up,
473 n_levels,
474 }
475 }
476
477 pub fn forward(&self, feats: Vec<FeatureMap>) -> VisionResult<Vec<FeatureMap>> {
485 if feats.len() != self.n_levels {
486 return Err(VisionError::DimensionMismatch {
487 expected: self.n_levels,
488 got: feats.len(),
489 });
490 }
491 let l = self.n_levels;
492
493 let mut lat: Vec<FeatureMap> = Vec::with_capacity(l);
495 for (f, conv) in feats.iter().zip(self.lateral.iter()) {
496 lat.push(conv.forward(f)?);
497 }
498
499 for level in (0..l.saturating_sub(1)).rev() {
501 let up = upsample2x(&lat[level + 1]);
502 add_inplace(&mut lat[level], &up)?;
503 let mut fused = self.top_down[level].forward(&lat[level])?;
504 silu_inplace(&mut fused.data);
505 lat[level] = fused;
506 }
507 if l > 0 {
509 let mut fused = self.top_down[l - 1].forward(&lat[l - 1])?;
510 silu_inplace(&mut fused.data);
511 lat[l - 1] = fused;
512 }
513
514 let mut outs: Vec<FeatureMap> = Vec::with_capacity(l);
516 outs.push(lat[0].clone());
517 for level in 1..l {
518 let mut down = self.downsample[level - 1].forward(&outs[level - 1])?;
519 silu_inplace(&mut down.data);
520 let mut merged = lat[level].clone();
521 add_inplace(&mut merged, &down)?;
522 let mut fused = self.bottom_up[level - 1].forward(&merged)?;
523 silu_inplace(&mut fused.data);
524 outs.push(fused);
525 }
526 Ok(outs)
527 }
528}
529
530pub struct DecoupledHead {
535 cls_conv: Conv2d,
536 cls_pred: Conv2d,
537 reg_conv: Conv2d,
538 reg_pred: Conv2d,
539}
540
541impl DecoupledHead {
542 fn new(channels: usize, n_classes: usize, rng: &mut LcgRng) -> Self {
543 Self {
544 cls_conv: Conv2d::new(channels, channels, 3, 1, 1, rng),
545 cls_pred: Conv2d::new(channels, n_classes, 1, 1, 0, rng),
546 reg_conv: Conv2d::new(channels, channels, 3, 1, 1, rng),
547 reg_pred: Conv2d::new(channels, 4, 1, 1, 0, rng),
548 }
549 }
550
551 pub fn forward_level(&self, x: &FeatureMap) -> VisionResult<(FeatureMap, FeatureMap)> {
557 let mut c = self.cls_conv.forward(x)?;
558 silu_inplace(&mut c.data);
559 let cls = self.cls_pred.forward(&c)?;
560
561 let mut r = self.reg_conv.forward(x)?;
562 silu_inplace(&mut r.data);
563 let reg = self.reg_pred.forward(&r)?;
564 Ok((cls, reg))
565 }
566}
567
568#[derive(Debug, Clone, PartialEq)]
572pub struct RtmDetConfig {
573 pub in_chans: usize,
575 pub img_size: usize,
577 pub stem_channels: usize,
579 pub stage_channels: Vec<usize>,
581 pub n_bottlenecks: usize,
583 pub dw_kernel: usize,
585 pub neck_channels: usize,
587 pub n_classes: usize,
589}
590
591impl RtmDetConfig {
592 pub fn new(
602 in_chans: usize,
603 img_size: usize,
604 stem_channels: usize,
605 stage_channels: Vec<usize>,
606 n_bottlenecks: usize,
607 dw_kernel: usize,
608 neck_channels: usize,
609 n_classes: usize,
610 ) -> VisionResult<Self> {
611 if in_chans == 0 || img_size == 0 {
612 return Err(VisionError::InvalidImageSize {
613 height: img_size,
614 width: img_size,
615 channels: in_chans,
616 });
617 }
618 if stage_channels.is_empty() {
619 return Err(VisionError::EmptyInput("rtmdet stage_channels"));
620 }
621 if n_classes == 0 {
622 return Err(VisionError::InvalidNumClasses(n_classes));
623 }
624 if dw_kernel == 0 || dw_kernel % 2 == 0 {
625 return Err(VisionError::InvalidPatchSize {
626 patch_size: dw_kernel,
627 img_size,
628 });
629 }
630 if stem_channels == 0 || neck_channels == 0 {
631 return Err(VisionError::DimensionMismatch {
632 expected: 1,
633 got: 0,
634 });
635 }
636 for &c in &stage_channels {
637 if c == 0 || c % 2 != 0 {
638 return Err(VisionError::DimensionMismatch {
639 expected: 2,
640 got: c,
641 });
642 }
643 }
644 Ok(Self {
645 in_chans,
646 img_size,
647 stem_channels,
648 stage_channels,
649 n_bottlenecks,
650 dw_kernel,
651 neck_channels,
652 n_classes,
653 })
654 }
655
656 #[must_use]
660 pub fn tiny() -> Self {
661 Self {
662 in_chans: 3,
663 img_size: 32,
664 stem_channels: 8,
665 stage_channels: vec![8, 16, 16],
666 n_bottlenecks: 1,
667 dw_kernel: 5,
668 neck_channels: 8,
669 n_classes: 4,
670 }
671 }
672
673 #[must_use]
675 #[inline]
676 pub fn n_levels(&self) -> usize {
677 self.stage_channels.len()
678 }
679}
680
681#[derive(Debug, Clone)]
685pub struct RtmDetOutput {
686 pub cls_scores: Vec<FeatureMap>,
688 pub bbox_preds: Vec<FeatureMap>,
690 pub strides: Vec<usize>,
692}
693
694pub struct RtmDet {
698 cfg: RtmDetConfig,
699 backbone: CspNeXtBackbone,
700 neck: Pafpn,
701 head: DecoupledHead,
702}
703
704impl RtmDet {
705 pub fn new(cfg: RtmDetConfig, rng: &mut LcgRng) -> VisionResult<Self> {
710 let cfg = RtmDetConfig::new(
711 cfg.in_chans,
712 cfg.img_size,
713 cfg.stem_channels,
714 cfg.stage_channels.clone(),
715 cfg.n_bottlenecks,
716 cfg.dw_kernel,
717 cfg.neck_channels,
718 cfg.n_classes,
719 )?;
720 let backbone = CspNeXtBackbone::new(&cfg, rng);
721 let neck = Pafpn::new(&cfg.stage_channels, cfg.neck_channels, rng);
722 let head = DecoupledHead::new(cfg.neck_channels, cfg.n_classes, rng);
723 Ok(Self {
724 cfg,
725 backbone,
726 neck,
727 head,
728 })
729 }
730
731 #[must_use]
733 #[inline]
734 pub fn config(&self) -> &RtmDetConfig {
735 &self.cfg
736 }
737
738 pub fn backbone_features(&self, image: &[f32]) -> VisionResult<Vec<FeatureMap>> {
743 let img = self.make_image(image)?;
744 self.backbone.forward(&img)
745 }
746
747 pub fn neck_features(&self, image: &[f32]) -> VisionResult<Vec<FeatureMap>> {
752 let feats = self.backbone_features(image)?;
753 self.neck.forward(feats)
754 }
755
756 pub fn forward(&self, image: &[f32]) -> VisionResult<RtmDetOutput> {
762 let neck = self.neck_features(image)?;
763 let mut cls_scores = Vec::with_capacity(neck.len());
764 let mut bbox_preds = Vec::with_capacity(neck.len());
765 let mut strides = Vec::with_capacity(neck.len());
766 for level in &neck {
767 let (cls, reg) = self.head.forward_level(level)?;
768 if cls
769 .data
770 .iter()
771 .chain(reg.data.iter())
772 .any(|v| !v.is_finite())
773 {
774 return Err(VisionError::NonFinite("rtmdet head output"));
775 }
776 strides.push(self.cfg.img_size / level.height.max(1));
777 cls_scores.push(cls);
778 bbox_preds.push(reg);
779 }
780 Ok(RtmDetOutput {
781 cls_scores,
782 bbox_preds,
783 strides,
784 })
785 }
786
787 fn make_image(&self, image: &[f32]) -> VisionResult<FeatureMap> {
788 FeatureMap::new(
789 image.to_vec(),
790 self.cfg.in_chans,
791 self.cfg.img_size,
792 self.cfg.img_size,
793 )
794 }
795}
796
797pub fn decode_level(
813 cls: &FeatureMap,
814 reg: &FeatureMap,
815 stride: usize,
816) -> VisionResult<(Vec<f32>, Vec<f32>, Vec<usize>)> {
817 if cls.height != reg.height || cls.width != reg.width {
818 return Err(VisionError::ShapeMismatch {
819 lhs: vec![cls.channels, cls.height, cls.width],
820 rhs: vec![reg.channels, reg.height, reg.width],
821 });
822 }
823 if reg.channels != 4 {
824 return Err(VisionError::DimensionMismatch {
825 expected: 4,
826 got: reg.channels,
827 });
828 }
829 let (h, w, n_cls) = (cls.height, cls.width, cls.channels);
830 let s = stride as f32;
831 let n = h * w;
832 let mut boxes = vec![0.0f32; n * 4];
833 let mut scores = vec![0.0f32; n];
834 let mut labels = vec![0usize; n];
835 for i in 0..h {
836 for j in 0..w {
837 let loc = i * w + j;
838 let cx = (j as f32 + 0.5) * s;
839 let cy = (i as f32 + 0.5) * s;
840 let l = softplus(reg.at(0, i, j)) * s;
841 let t = softplus(reg.at(1, i, j)) * s;
842 let r = softplus(reg.at(2, i, j)) * s;
843 let b = softplus(reg.at(3, i, j)) * s;
844 boxes[loc * 4] = cx - l;
845 boxes[loc * 4 + 1] = cy - t;
846 boxes[loc * 4 + 2] = cx + r;
847 boxes[loc * 4 + 3] = cy + b;
848
849 let mut best = f32::NEG_INFINITY;
850 let mut best_c = 0usize;
851 for c in 0..n_cls {
852 let p = 1.0 / (1.0 + (-cls.at(c, i, j)).exp());
853 if p > best {
854 best = p;
855 best_c = c;
856 }
857 }
858 scores[loc] = best;
859 labels[loc] = best_c;
860 }
861 }
862 Ok((boxes, scores, labels))
863}
864
865pub fn simota_cost(
899 pred_cls: &[f32],
900 pred_boxes: &[f32],
901 gt_labels: &[usize],
902 gt_boxes: &[f32],
903 n_classes: usize,
904 lambda_iou: f32,
905) -> VisionResult<Vec<f32>> {
906 if n_classes == 0 {
907 return Err(VisionError::InvalidNumClasses(n_classes));
908 }
909 if !lambda_iou.is_finite() {
910 return Err(VisionError::NonFinite("simota lambda_iou"));
911 }
912 let n_pred = pred_boxes.len() / 4;
913 let n_gt = gt_labels.len();
914 if n_pred == 0 {
915 return Err(VisionError::EmptyInput("simota predictions"));
916 }
917 if n_gt == 0 {
918 return Err(VisionError::EmptyInput("simota targets"));
919 }
920 if pred_boxes.len() != n_pred * 4 {
921 return Err(VisionError::DimensionMismatch {
922 expected: n_pred * 4,
923 got: pred_boxes.len(),
924 });
925 }
926 if pred_cls.len() != n_pred * n_classes {
927 return Err(VisionError::DimensionMismatch {
928 expected: n_pred * n_classes,
929 got: pred_cls.len(),
930 });
931 }
932 if gt_boxes.len() != n_gt * 4 {
933 return Err(VisionError::DimensionMismatch {
934 expected: n_gt * 4,
935 got: gt_boxes.len(),
936 });
937 }
938
939 const EPS: f32 = 1e-7;
940 let mut cost = vec![0.0f32; n_gt * n_pred];
941 for g in 0..n_gt {
942 let cls = gt_labels[g];
943 if cls >= n_classes {
944 return Err(VisionError::DimensionMismatch {
945 expected: n_classes,
946 got: cls,
947 });
948 }
949 let gbox = [
950 gt_boxes[g * 4],
951 gt_boxes[g * 4 + 1],
952 gt_boxes[g * 4 + 2],
953 gt_boxes[g * 4 + 3],
954 ];
955 for p in 0..n_pred {
956 let prob = pred_cls[p * n_classes + cls].clamp(EPS, 1.0);
957 let cls_cost = -prob.ln();
958 let pbox = [
959 pred_boxes[p * 4],
960 pred_boxes[p * 4 + 1],
961 pred_boxes[p * 4 + 2],
962 pred_boxes[p * 4 + 3],
963 ];
964 let iou_val = iou(&pbox, &gbox);
965 let iou_cost = -(iou_val + EPS).ln();
966 cost[g * n_pred + p] = cls_cost + lambda_iou * iou_cost;
967 }
968 }
969 if cost.iter().any(|v| !v.is_finite()) {
970 return Err(VisionError::NonFinite("simota cost"));
971 }
972 Ok(cost)
973}
974
975#[cfg(test)]
978mod tests {
979 use super::*;
980
981 fn random_image(cfg: &RtmDetConfig, seed: u64) -> Vec<f32> {
982 let mut rng = LcgRng::new(seed);
983 let mut img = vec![0.0f32; cfg.in_chans * cfg.img_size * cfg.img_size];
984 rng.fill_normal(&mut img);
985 img
986 }
987
988 #[test]
991 fn config_tiny_valid() {
992 let cfg = RtmDetConfig::tiny();
993 assert_eq!(cfg.n_levels(), 3);
994 }
995
996 #[test]
997 fn config_odd_stage_channel_errors() {
998 let r = RtmDetConfig::new(3, 32, 8, vec![8, 15], 1, 5, 8, 4);
999 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
1000 }
1001
1002 #[test]
1003 fn config_even_dw_kernel_errors() {
1004 let r = RtmDetConfig::new(3, 32, 8, vec![8, 16], 1, 4, 8, 4);
1005 assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
1006 }
1007
1008 #[test]
1009 fn config_zero_classes_errors() {
1010 let r = RtmDetConfig::new(3, 32, 8, vec![8, 16], 1, 5, 8, 0);
1011 assert!(matches!(r, Err(VisionError::InvalidNumClasses(0))));
1012 }
1013
1014 #[test]
1017 fn conv2d_stride2_halves_spatial() {
1018 let mut rng = LcgRng::new(1);
1019 let conv = Conv2d::new(3, 4, 3, 2, 1, &mut rng);
1020 let x = FeatureMap::new(vec![0.5f32; 3 * 16 * 16], 3, 16, 16).expect("ok");
1021 let y = conv.forward(&x).expect("ok");
1022 assert_eq!((y.channels, y.height, y.width), (4, 8, 8));
1023 }
1024
1025 #[test]
1026 fn dwconv_identity_kernel_is_input() {
1027 let mut rng = LcgRng::new(2);
1029 let mut dw = DwConv2d::new(2, 3, 1, 1, &mut rng);
1030 for v in dw.weight.iter_mut() {
1031 *v = 0.0;
1032 }
1033 for ch in 0..2 {
1035 dw.weight[ch * 9 + 4] = 1.0;
1036 }
1037 let mut data = vec![0.0f32; 2 * 4 * 4];
1038 let mut r2 = LcgRng::new(3);
1039 r2.fill_normal(&mut data);
1040 let x = FeatureMap::new(data.clone(), 2, 4, 4).expect("ok");
1041 let y = dw.forward(&x).expect("ok");
1042 for (a, b) in y.data.iter().zip(data.iter()) {
1043 assert!((a - b).abs() < 1e-5, "identity dw mismatch {a} vs {b}");
1044 }
1045 }
1046
1047 #[test]
1050 fn backbone_multiscale_halving() {
1051 let cfg = RtmDetConfig::tiny();
1052 let mut rng = LcgRng::new(10);
1053 let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
1054 let img = random_image(&cfg, 11);
1055 let feats = det.backbone_features(&img).expect("ok");
1056 assert_eq!(feats.len(), 3, "one feature per stage");
1057 let spatials: Vec<usize> = feats.iter().map(|f| f.height).collect();
1059 assert_eq!(spatials, vec![8, 4, 2]);
1060 for w in feats.windows(2) {
1061 assert_eq!(w[0].height, w[1].height * 2, "each stage halves spatial");
1062 assert_eq!(w[0].width, w[1].width * 2);
1063 }
1064 let chans: Vec<usize> = feats.iter().map(|f| f.channels).collect();
1066 assert_eq!(chans, cfg.stage_channels);
1067 assert!(feats.iter().all(|f| f.data.iter().all(|v| v.is_finite())));
1068 }
1069
1070 #[test]
1073 fn pafpn_uniform_channels_same_scales() {
1074 let cfg = RtmDetConfig::tiny();
1075 let mut rng = LcgRng::new(12);
1076 let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
1077 let img = random_image(&cfg, 13);
1078 let neck = det.neck_features(&img).expect("ok");
1079 assert_eq!(neck.len(), cfg.n_levels(), "neck preserves #scales");
1080 for fm in &neck {
1081 assert_eq!(fm.channels, cfg.neck_channels, "uniform fused channels");
1082 }
1083 let spatials: Vec<usize> = neck.iter().map(|f| f.height).collect();
1085 assert_eq!(spatials, vec![8, 4, 2]);
1086 assert!(neck.iter().all(|f| f.data.iter().all(|v| v.is_finite())));
1087 }
1088
1089 #[test]
1092 fn decoupled_head_shapes() {
1093 let cfg = RtmDetConfig::tiny();
1094 let mut rng = LcgRng::new(14);
1095 let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
1096 let img = random_image(&cfg, 15);
1097 let out = det.forward(&img).expect("ok");
1098 assert_eq!(out.cls_scores.len(), 3);
1099 assert_eq!(out.bbox_preds.len(), 3);
1100 for (cls, reg) in out.cls_scores.iter().zip(out.bbox_preds.iter()) {
1101 assert_eq!(cls.channels, cfg.n_classes, "cls has n_classes channels");
1102 assert_eq!(reg.channels, 4, "reg has 4 channels");
1103 assert_eq!(cls.height, reg.height);
1104 assert_eq!(cls.data.len(), cfg.n_classes * cls.height * cls.width);
1105 assert_eq!(reg.data.len(), 4 * reg.height * reg.width);
1106 }
1107 assert_eq!(out.strides, vec![4, 8, 16]);
1108 }
1109
1110 #[test]
1111 fn forward_all_finite() {
1112 let cfg = RtmDetConfig::tiny();
1113 let mut rng = LcgRng::new(16);
1114 let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
1115 let img = random_image(&cfg, 17);
1116 let out = det.forward(&img).expect("ok");
1117 for fm in out.cls_scores.iter().chain(out.bbox_preds.iter()) {
1118 assert!(fm.data.iter().all(|v| v.is_finite()));
1119 }
1120 }
1121
1122 #[test]
1125 fn varying_input_changes_detections() {
1126 let cfg = RtmDetConfig::tiny();
1127 let mut rng = LcgRng::new(18);
1128 let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
1129 let img_a = random_image(&cfg, 19);
1130 let img_b = random_image(&cfg, 20);
1131 let out_a = det.forward(&img_a).expect("ok");
1132 let out_b = det.forward(&img_b).expect("ok");
1133 let diff: f32 = out_a.cls_scores[0]
1135 .data
1136 .iter()
1137 .zip(out_b.cls_scores[0].data.iter())
1138 .map(|(a, b)| (a - b).abs())
1139 .sum();
1140 assert!(
1141 diff > 1e-4,
1142 "detections should change with input, diff={diff}"
1143 );
1144 }
1145
1146 #[test]
1147 fn decode_level_produces_valid_boxes() {
1148 let cfg = RtmDetConfig::tiny();
1149 let mut rng = LcgRng::new(21);
1150 let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
1151 let img = random_image(&cfg, 22);
1152 let out = det.forward(&img).expect("ok");
1153 let (boxes, scores, labels) =
1154 decode_level(&out.cls_scores[0], &out.bbox_preds[0], out.strides[0]).expect("ok");
1155 let n = out.cls_scores[0].height * out.cls_scores[0].width;
1156 assert_eq!(boxes.len(), n * 4);
1157 assert_eq!(scores.len(), n);
1158 assert_eq!(labels.len(), n);
1159 for loc in 0..n {
1160 assert!(boxes[loc * 4 + 2] > boxes[loc * 4], "x2 must exceed x1");
1162 assert!(boxes[loc * 4 + 3] > boxes[loc * 4 + 1], "y2 must exceed y1");
1163 assert!((0.0..=1.0).contains(&scores[loc]), "score in [0,1]");
1164 assert!(labels[loc] < cfg.n_classes);
1165 }
1166 assert!(boxes.iter().all(|v| v.is_finite()));
1167 }
1168
1169 #[test]
1172 fn simota_lower_cost_for_better_match() {
1173 let n_classes = 2;
1175 let pred_cls = vec![
1178 0.9f32, 0.1, 0.1, 0.9, ];
1181 let pred_boxes = vec![
1182 0.0f32, 0.0, 10.0, 10.0, 20.0, 20.0, 30.0, 30.0, ];
1185 let gt_labels = vec![0usize];
1186 let gt_boxes = vec![0.0f32, 0.0, 10.0, 10.0];
1187
1188 let cost = simota_cost(
1189 &pred_cls,
1190 &pred_boxes,
1191 >_labels,
1192 >_boxes,
1193 n_classes,
1194 3.0,
1195 )
1196 .expect("ok");
1197 assert_eq!(cost.len(), 2, "[n_gt × n_pred]");
1198 assert!(cost.iter().all(|v| v.is_finite()), "cost must be finite");
1199 assert!(
1201 cost[0] < cost[1],
1202 "better cls+iou match must have lower cost: {} vs {}",
1203 cost[0],
1204 cost[1]
1205 );
1206 }
1207
1208 #[test]
1209 fn simota_cost_monotonic_in_iou() {
1210 let n_classes = 1;
1212 let pred_cls = vec![0.8f32, 0.8, 0.8];
1213 let pred_boxes = vec![
1214 0.0f32, 0.0, 10.0, 10.0, 5.0, 0.0, 15.0, 10.0, 50.0, 50.0, 60.0, 60.0, ];
1218 let gt_labels = vec![0usize];
1219 let gt_boxes = vec![0.0f32, 0.0, 10.0, 10.0];
1220 let cost = simota_cost(
1221 &pred_cls,
1222 &pred_boxes,
1223 >_labels,
1224 >_boxes,
1225 n_classes,
1226 2.0,
1227 )
1228 .expect("ok");
1229 assert!(cost[0] < cost[1], "higher IoU → lower cost");
1230 assert!(cost[1] < cost[2], "higher IoU → lower cost");
1231 }
1232
1233 #[test]
1234 fn simota_errors_on_bad_shapes() {
1235 let r = simota_cost(
1237 &[0.5f32],
1238 &[0.0, 0.0, 1.0, 1.0],
1239 &[0],
1240 &[0.0, 0.0, 1.0, 1.0],
1241 2,
1242 1.0,
1243 );
1244 assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
1245 let r2 = simota_cost(&[], &[], &[0], &[0.0, 0.0, 1.0, 1.0], 1, 1.0);
1247 assert!(matches!(r2, Err(VisionError::EmptyInput(_))));
1248 }
1249
1250 #[test]
1253 fn deterministic_same_seed() {
1254 let cfg = RtmDetConfig::tiny();
1255 let img = random_image(&cfg, 30);
1256 let mut ra = LcgRng::new(99);
1257 let mut rb = LcgRng::new(99);
1258 let da = RtmDet::new(cfg.clone(), &mut ra).expect("ok");
1259 let db = RtmDet::new(cfg, &mut rb).expect("ok");
1260 let oa = da.forward(&img).expect("ok");
1261 let ob = db.forward(&img).expect("ok");
1262 for (a, b) in oa.cls_scores.iter().zip(ob.cls_scores.iter()) {
1263 assert_eq!(a.data, b.data, "same seed → identical output");
1264 }
1265 }
1266}