Skip to main content

oxicuda_vision/detection/
rtmdet.rs

1//! RTMDet — a compact, faithful CPU reference of the one-stage detector from
2//! Lyu et al. 2022, *"RTMDet: An Empirical Study of Designing Real-Time Object
3//! Detectors"*.
4//!
5//! The model is built from three classic pieces, all implemented here with real
6//! convolutional arithmetic (no shape-only stubs):
7//!
8//! 1. **CSPNeXt backbone** — a stem followed by a stack of stages. Each stage
9//!    downsamples by a stride-2 convolution and then applies a **CSP layer**:
10//!    the channels are *split* into two branches, one branch passes through a
11//!    stack of **bottlenecks built from large-kernel depthwise convolutions**,
12//!    and the two branches are *concatenated* and fused by a 1×1 convolution.
13//!    Successive stages halve the spatial resolution, producing a multi-scale
14//!    feature hierarchy.
15//! 2. **PAFPN neck** — a Path-Aggregation Feature Pyramid Network performing a
16//!    *top-down* pass (upsample coarse features and add to finer ones) followed
17//!    by a *bottom-up* pass (downsample fine features and add to coarser ones),
18//!    emitting one fused map per input scale with a uniform channel count.
19//! 3. **Decoupled head** — a *shared* head with **separate classification and
20//!    regression branches**, producing per-location class logits and 4 box
21//!    regression values at every scale.
22//!
23//! It also provides the **SimOTA-lite** dynamic soft-label assignment cost
24//! (`cost = cls_cost + λ · iou_cost`) used to match predictions to ground truth.
25//!
26//! ## Tensor layout
27//! All feature maps use the channel-major [`FeatureMap`] layout
28//! (`[channels, height, width]`, row-major).
29
30use super::anchor_nms::iou;
31use crate::{
32    error::{VisionError, VisionResult},
33    fpn::top_down::FeatureMap,
34    handle::LcgRng,
35};
36
37// ─── Activations ───────────────────────────────────────────────────────────────
38
39/// SiLU / swish activation applied in place: `x ← x · sigmoid(x)`.
40fn silu_inplace(x: &mut [f32]) {
41    for v in x.iter_mut() {
42        *v *= 1.0 / (1.0 + (-*v).exp());
43    }
44}
45
46/// Numerically-stable softplus: `log(1 + exp(x))`.
47#[inline]
48fn softplus(x: f32) -> f32 {
49    x.max(0.0) + (-(x.abs())).exp().ln_1p()
50}
51
52// ─── Conv2d ────────────────────────────────────────────────────────────────────
53
54/// A dense 2-D convolution (groups = 1) with `[c_out, c_in, k, k]` weights.
55#[derive(Debug, Clone)]
56pub struct Conv2d {
57    weight: Vec<f32>,
58    bias: Vec<f32>,
59    c_in: usize,
60    c_out: usize,
61    k: usize,
62    stride: usize,
63    pad: usize,
64}
65
66impl Conv2d {
67    /// He-initialised convolution.
68    fn new(
69        c_in: usize,
70        c_out: usize,
71        k: usize,
72        stride: usize,
73        pad: usize,
74        rng: &mut LcgRng,
75    ) -> Self {
76        let fan_in = c_in * k * k;
77        let scale = (2.0 / fan_in as f32).sqrt();
78        let mut weight = vec![0.0f32; c_out * fan_in];
79        rng.fill_normal(&mut weight);
80        for w in &mut weight {
81            *w *= scale;
82        }
83        Self {
84            weight,
85            bias: vec![0.0f32; c_out],
86            c_in,
87            c_out,
88            k,
89            stride,
90            pad,
91        }
92    }
93
94    /// Output channel count.
95    #[must_use]
96    #[inline]
97    pub fn out_channels(&self) -> usize {
98        self.c_out
99    }
100
101    /// Forward pass over a [`FeatureMap`].
102    ///
103    /// # Errors
104    /// - [`VisionError::DimensionMismatch`] if `x.channels != c_in`.
105    /// - [`VisionError::InvalidImageSize`] if the kernel does not fit the
106    ///   padded input.
107    pub fn forward(&self, x: &FeatureMap) -> VisionResult<FeatureMap> {
108        if x.channels != self.c_in {
109            return Err(VisionError::DimensionMismatch {
110                expected: self.c_in,
111                got: x.channels,
112            });
113        }
114        let (h, w) = (x.height, x.width);
115        if h + 2 * self.pad < self.k || w + 2 * self.pad < self.k {
116            return Err(VisionError::InvalidImageSize {
117                height: h,
118                width: w,
119                channels: x.channels,
120            });
121        }
122        let h_out = (h + 2 * self.pad - self.k) / self.stride + 1;
123        let w_out = (w + 2 * self.pad - self.k) / self.stride + 1;
124        let mut out = vec![0.0f32; self.c_out * h_out * w_out];
125        let k = self.k;
126        for oc in 0..self.c_out {
127            let oc_w = oc * self.c_in * k * k;
128            for oh in 0..h_out {
129                for ow in 0..w_out {
130                    let mut acc = self.bias[oc];
131                    for ic in 0..self.c_in {
132                        let in_base = ic * h * w;
133                        let w_base = oc_w + ic * k * k;
134                        for ki in 0..k {
135                            let ih = oh * self.stride + ki;
136                            if ih < self.pad || ih >= h + self.pad {
137                                continue;
138                            }
139                            let ih = ih - self.pad;
140                            for kj in 0..k {
141                                let iw = ow * self.stride + kj;
142                                if iw < self.pad || iw >= w + self.pad {
143                                    continue;
144                                }
145                                let iw = iw - self.pad;
146                                acc += self.weight[w_base + ki * k + kj]
147                                    * out_in(x, in_base, ih, w, iw);
148                            }
149                        }
150                    }
151                    out[(oc * h_out + oh) * w_out + ow] = acc;
152                }
153            }
154        }
155        FeatureMap::new(out, self.c_out, h_out, w_out)
156    }
157}
158
159/// Read `x.data[in_base + ih*w + iw]`. Helper to keep the inner conv loop tidy.
160#[inline]
161fn out_in(x: &FeatureMap, in_base: usize, ih: usize, w: usize, iw: usize) -> f32 {
162    x.data[in_base + ih * w + iw]
163}
164
165// ─── DwConv2d ──────────────────────────────────────────────────────────────────
166
167/// A depthwise 2-D convolution (groups = channels) with `[c, k, k]` weights.
168///
169/// Used for the **large-kernel** depthwise convolutions inside the CSPNeXt
170/// bottleneck.
171#[derive(Debug, Clone)]
172pub struct DwConv2d {
173    weight: Vec<f32>,
174    bias: Vec<f32>,
175    c: usize,
176    k: usize,
177    stride: usize,
178    pad: usize,
179}
180
181impl DwConv2d {
182    fn new(c: usize, k: usize, stride: usize, pad: usize, rng: &mut LcgRng) -> Self {
183        let fan_in = k * k;
184        let scale = (2.0 / fan_in as f32).sqrt();
185        let mut weight = vec![0.0f32; c * fan_in];
186        rng.fill_normal(&mut weight);
187        for w in &mut weight {
188            *w *= scale;
189        }
190        Self {
191            weight,
192            bias: vec![0.0f32; c],
193            c,
194            k,
195            stride,
196            pad,
197        }
198    }
199
200    /// Depthwise forward pass.
201    ///
202    /// # Errors
203    /// - [`VisionError::DimensionMismatch`] if `x.channels != c`.
204    /// - [`VisionError::InvalidImageSize`] if the kernel does not fit.
205    pub fn forward(&self, x: &FeatureMap) -> VisionResult<FeatureMap> {
206        if x.channels != self.c {
207            return Err(VisionError::DimensionMismatch {
208                expected: self.c,
209                got: x.channels,
210            });
211        }
212        let (h, w) = (x.height, x.width);
213        if h + 2 * self.pad < self.k || w + 2 * self.pad < self.k {
214            return Err(VisionError::InvalidImageSize {
215                height: h,
216                width: w,
217                channels: x.channels,
218            });
219        }
220        let h_out = (h + 2 * self.pad - self.k) / self.stride + 1;
221        let w_out = (w + 2 * self.pad - self.k) / self.stride + 1;
222        let k = self.k;
223        let mut out = vec![0.0f32; self.c * h_out * w_out];
224        for ch in 0..self.c {
225            let in_base = ch * h * w;
226            let w_base = ch * k * k;
227            let bias = self.bias[ch];
228            for oh in 0..h_out {
229                for ow in 0..w_out {
230                    let mut acc = bias;
231                    for ki in 0..k {
232                        let ih = oh * self.stride + ki;
233                        if ih < self.pad || ih >= h + self.pad {
234                            continue;
235                        }
236                        let ih = ih - self.pad;
237                        for kj in 0..k {
238                            let iw = ow * self.stride + kj;
239                            if iw < self.pad || iw >= w + self.pad {
240                                continue;
241                            }
242                            let iw = iw - self.pad;
243                            acc +=
244                                self.weight[w_base + ki * k + kj] * x.data[in_base + ih * w + iw];
245                        }
246                    }
247                    out[(ch * h_out + oh) * w_out + ow] = acc;
248                }
249            }
250        }
251        FeatureMap::new(out, self.c, h_out, w_out)
252    }
253}
254
255// ─── Feature-map helpers ───────────────────────────────────────────────────────
256
257/// Concatenate two feature maps along the channel axis (same spatial size).
258fn concat_channels(a: &FeatureMap, b: &FeatureMap) -> VisionResult<FeatureMap> {
259    if a.height != b.height || a.width != b.width {
260        return Err(VisionError::ShapeMismatch {
261            lhs: vec![a.channels, a.height, a.width],
262            rhs: vec![b.channels, b.height, b.width],
263        });
264    }
265    let mut data = Vec::with_capacity(a.data.len() + b.data.len());
266    data.extend_from_slice(&a.data);
267    data.extend_from_slice(&b.data);
268    Ok(FeatureMap {
269        data,
270        channels: a.channels + b.channels,
271        height: a.height,
272        width: a.width,
273    })
274}
275
276/// Element-wise `dst += src` (shapes must match exactly).
277fn add_inplace(dst: &mut FeatureMap, src: &FeatureMap) -> VisionResult<()> {
278    if dst.channels != src.channels || dst.height != src.height || dst.width != src.width {
279        return Err(VisionError::ShapeMismatch {
280            lhs: vec![dst.channels, dst.height, dst.width],
281            rhs: vec![src.channels, src.height, src.width],
282        });
283    }
284    for (a, b) in dst.data.iter_mut().zip(src.data.iter()) {
285        *a += *b;
286    }
287    Ok(())
288}
289
290/// Nearest-neighbour 2× upsampling.
291fn upsample2x(x: &FeatureMap) -> FeatureMap {
292    let (c, h, w) = (x.channels, x.height, x.width);
293    let (h2, w2) = (h * 2, w * 2);
294    let mut out = vec![0.0f32; c * h2 * w2];
295    for ch in 0..c {
296        for i in 0..h {
297            for j in 0..w {
298                let v = x.data[(ch * h + i) * w + j];
299                let oi = i * 2;
300                let oj = j * 2;
301                out[(ch * h2 + oi) * w2 + oj] = v;
302                out[(ch * h2 + oi) * w2 + oj + 1] = v;
303                out[(ch * h2 + oi + 1) * w2 + oj] = v;
304                out[(ch * h2 + oi + 1) * w2 + oj + 1] = v;
305            }
306        }
307    }
308    FeatureMap {
309        data: out,
310        channels: c,
311        height: h2,
312        width: w2,
313    }
314}
315
316// ─── Bottleneck ────────────────────────────────────────────────────────────────
317
318/// CSPNeXt bottleneck: large-kernel depthwise conv → 1×1 pointwise conv, with a
319/// residual connection. Both convolutions are followed by SiLU.
320pub struct Bottleneck {
321    dw: DwConv2d,
322    pw: Conv2d,
323}
324
325impl Bottleneck {
326    fn new(channels: usize, dw_kernel: usize, rng: &mut LcgRng) -> Self {
327        let pad = (dw_kernel - 1) / 2;
328        Self {
329            dw: DwConv2d::new(channels, dw_kernel, 1, pad, rng),
330            pw: Conv2d::new(channels, channels, 1, 1, 0, rng),
331        }
332    }
333
334    fn forward(&self, x: &FeatureMap) -> VisionResult<FeatureMap> {
335        let mut y = self.dw.forward(x)?;
336        silu_inplace(&mut y.data);
337        let mut y = self.pw.forward(&y)?;
338        silu_inplace(&mut y.data);
339        add_inplace(&mut y, x)?; // residual
340        Ok(y)
341    }
342}
343
344// ─── CSP layer ─────────────────────────────────────────────────────────────────
345
346/// A Cross-Stage-Partial layer: split into a "main" and a "short" branch, run a
347/// stack of bottlenecks on the main branch, concatenate, and fuse with a 1×1
348/// convolution. `out_channels` must be even (`mid = out_channels / 2`).
349pub struct CspLayer {
350    main_conv: Conv2d,
351    short_conv: Conv2d,
352    blocks: Vec<Bottleneck>,
353    final_conv: Conv2d,
354}
355
356impl CspLayer {
357    fn new(
358        in_channels: usize,
359        out_channels: usize,
360        n_blocks: usize,
361        dw_kernel: usize,
362        rng: &mut LcgRng,
363    ) -> Self {
364        let mid = out_channels / 2;
365        let main_conv = Conv2d::new(in_channels, mid, 1, 1, 0, rng);
366        let short_conv = Conv2d::new(in_channels, mid, 1, 1, 0, rng);
367        let blocks = (0..n_blocks)
368            .map(|_| Bottleneck::new(mid, dw_kernel, rng))
369            .collect();
370        // After concat: 2 * mid = out_channels.
371        let final_conv = Conv2d::new(2 * mid, out_channels, 1, 1, 0, rng);
372        Self {
373            main_conv,
374            short_conv,
375            blocks,
376            final_conv,
377        }
378    }
379
380    fn forward(&self, x: &FeatureMap) -> VisionResult<FeatureMap> {
381        let mut short = self.short_conv.forward(x)?;
382        silu_inplace(&mut short.data);
383        let mut main = self.main_conv.forward(x)?;
384        silu_inplace(&mut main.data);
385        for b in &self.blocks {
386            main = b.forward(&main)?;
387        }
388        let cat = concat_channels(&main, &short)?;
389        let mut out = self.final_conv.forward(&cat)?;
390        silu_inplace(&mut out.data);
391        Ok(out)
392    }
393}
394
395// ─── Backbone ──────────────────────────────────────────────────────────────────
396
397struct BackboneStage {
398    downsample: Conv2d,
399    csp: CspLayer,
400}
401
402/// CSPNeXt-style backbone: a stride-2 stem followed by `stage_channels.len()`
403/// downsampling stages, each emitting one multi-scale feature map.
404pub struct CspNeXtBackbone {
405    stem: Conv2d,
406    stages: Vec<BackboneStage>,
407}
408
409impl CspNeXtBackbone {
410    fn new(cfg: &RtmDetConfig, rng: &mut LcgRng) -> Self {
411        let stem = Conv2d::new(cfg.in_chans, cfg.stem_channels, 3, 2, 1, rng);
412        let mut stages = Vec::with_capacity(cfg.stage_channels.len());
413        let mut prev = cfg.stem_channels;
414        for &c in &cfg.stage_channels {
415            let downsample = Conv2d::new(prev, c, 3, 2, 1, rng);
416            let csp = CspLayer::new(c, c, cfg.n_bottlenecks, cfg.dw_kernel, rng);
417            stages.push(BackboneStage { downsample, csp });
418            prev = c;
419        }
420        Self { stem, stages }
421    }
422
423    /// Forward pass returning one [`FeatureMap`] per stage (finest first).
424    ///
425    /// # Errors
426    /// Propagates convolution shape errors.
427    pub fn forward(&self, image: &FeatureMap) -> VisionResult<Vec<FeatureMap>> {
428        let mut x = self.stem.forward(image)?;
429        silu_inplace(&mut x.data);
430        let mut feats = Vec::with_capacity(self.stages.len());
431        for stage in &self.stages {
432            let mut d = stage.downsample.forward(&x)?;
433            silu_inplace(&mut d.data);
434            x = stage.csp.forward(&d)?;
435            feats.push(x.clone());
436        }
437        Ok(feats)
438    }
439}
440
441// ─── PAFPN neck ────────────────────────────────────────────────────────────────
442
443/// Path-Aggregation Feature Pyramid Network.
444pub struct Pafpn {
445    lateral: Vec<Conv2d>,
446    top_down: Vec<Conv2d>,
447    downsample: Vec<Conv2d>,
448    bottom_up: Vec<Conv2d>,
449    n_levels: usize,
450}
451
452impl Pafpn {
453    fn new(in_channels: &[usize], out_channels: usize, rng: &mut LcgRng) -> Self {
454        let n_levels = in_channels.len();
455        let lateral = in_channels
456            .iter()
457            .map(|&c| Conv2d::new(c, out_channels, 1, 1, 0, rng))
458            .collect();
459        let top_down = (0..n_levels)
460            .map(|_| Conv2d::new(out_channels, out_channels, 3, 1, 1, rng))
461            .collect();
462        let downsample = (0..n_levels.saturating_sub(1))
463            .map(|_| Conv2d::new(out_channels, out_channels, 3, 2, 1, rng))
464            .collect();
465        let bottom_up = (0..n_levels.saturating_sub(1))
466            .map(|_| Conv2d::new(out_channels, out_channels, 3, 1, 1, rng))
467            .collect();
468        Self {
469            lateral,
470            top_down,
471            downsample,
472            bottom_up,
473            n_levels,
474        }
475    }
476
477    /// Forward pass. `feats` must be ordered finest→coarsest and contain
478    /// `n_levels` maps whose spatial sizes halve between adjacent levels.
479    ///
480    /// # Errors
481    /// - [`VisionError::DimensionMismatch`] if `feats.len() != n_levels`.
482    /// - [`VisionError::ShapeMismatch`] if adjacent levels are not in a 2:1
483    ///   spatial ratio.
484    pub fn forward(&self, feats: Vec<FeatureMap>) -> VisionResult<Vec<FeatureMap>> {
485        if feats.len() != self.n_levels {
486            return Err(VisionError::DimensionMismatch {
487                expected: self.n_levels,
488                got: feats.len(),
489            });
490        }
491        let l = self.n_levels;
492
493        // Lateral 1×1 to unify channels.
494        let mut lat: Vec<FeatureMap> = Vec::with_capacity(l);
495        for (f, conv) in feats.iter().zip(self.lateral.iter()) {
496            lat.push(conv.forward(f)?);
497        }
498
499        // Top-down: coarse → fine.
500        for level in (0..l.saturating_sub(1)).rev() {
501            let up = upsample2x(&lat[level + 1]);
502            add_inplace(&mut lat[level], &up)?;
503            let mut fused = self.top_down[level].forward(&lat[level])?;
504            silu_inplace(&mut fused.data);
505            lat[level] = fused;
506        }
507        // Smooth the coarsest level too for symmetry.
508        if l > 0 {
509            let mut fused = self.top_down[l - 1].forward(&lat[l - 1])?;
510            silu_inplace(&mut fused.data);
511            lat[l - 1] = fused;
512        }
513
514        // Bottom-up: fine → coarse.
515        let mut outs: Vec<FeatureMap> = Vec::with_capacity(l);
516        outs.push(lat[0].clone());
517        for level in 1..l {
518            let mut down = self.downsample[level - 1].forward(&outs[level - 1])?;
519            silu_inplace(&mut down.data);
520            let mut merged = lat[level].clone();
521            add_inplace(&mut merged, &down)?;
522            let mut fused = self.bottom_up[level - 1].forward(&merged)?;
523            silu_inplace(&mut fused.data);
524            outs.push(fused);
525        }
526        Ok(outs)
527    }
528}
529
530// ─── Decoupled head ────────────────────────────────────────────────────────────
531
532/// Shared decoupled head with independent classification and regression
533/// branches.
534pub struct DecoupledHead {
535    cls_conv: Conv2d,
536    cls_pred: Conv2d,
537    reg_conv: Conv2d,
538    reg_pred: Conv2d,
539}
540
541impl DecoupledHead {
542    fn new(channels: usize, n_classes: usize, rng: &mut LcgRng) -> Self {
543        Self {
544            cls_conv: Conv2d::new(channels, channels, 3, 1, 1, rng),
545            cls_pred: Conv2d::new(channels, n_classes, 1, 1, 0, rng),
546            reg_conv: Conv2d::new(channels, channels, 3, 1, 1, rng),
547            reg_pred: Conv2d::new(channels, 4, 1, 1, 0, rng),
548        }
549    }
550
551    /// Apply the head to one feature map, returning `(cls_logits, reg)` where
552    /// `cls_logits` has `n_classes` channels and `reg` has 4 channels.
553    ///
554    /// # Errors
555    /// Propagates convolution shape errors.
556    pub fn forward_level(&self, x: &FeatureMap) -> VisionResult<(FeatureMap, FeatureMap)> {
557        let mut c = self.cls_conv.forward(x)?;
558        silu_inplace(&mut c.data);
559        let cls = self.cls_pred.forward(&c)?;
560
561        let mut r = self.reg_conv.forward(x)?;
562        silu_inplace(&mut r.data);
563        let reg = self.reg_pred.forward(&r)?;
564        Ok((cls, reg))
565    }
566}
567
568// ─── Config ────────────────────────────────────────────────────────────────────
569
570/// RTMDet hyper-parameters.
571#[derive(Debug, Clone, PartialEq)]
572pub struct RtmDetConfig {
573    /// Input image channels (e.g. 3).
574    pub in_chans: usize,
575    /// Square input spatial size.
576    pub img_size: usize,
577    /// Stem output channels.
578    pub stem_channels: usize,
579    /// Output channels of each backbone stage (each value must be even).
580    pub stage_channels: Vec<usize>,
581    /// Bottlenecks per CSP layer.
582    pub n_bottlenecks: usize,
583    /// Large depthwise kernel size (odd).
584    pub dw_kernel: usize,
585    /// Uniform PAFPN / head channel count.
586    pub neck_channels: usize,
587    /// Number of object classes.
588    pub n_classes: usize,
589}
590
591impl RtmDetConfig {
592    /// Create and validate a configuration.
593    ///
594    /// # Errors
595    /// - [`VisionError::InvalidImageSize`] if `in_chans == 0` or `img_size == 0`.
596    /// - [`VisionError::EmptyInput`] if `stage_channels` is empty.
597    /// - [`VisionError::InvalidNumClasses`] if `n_classes == 0`.
598    /// - [`VisionError::InvalidPatchSize`] if `dw_kernel` is even or zero.
599    /// - [`VisionError::DimensionMismatch`] if any stage channel count is odd
600    ///   or zero, or `neck_channels`/`stem_channels` is zero.
601    pub fn new(
602        in_chans: usize,
603        img_size: usize,
604        stem_channels: usize,
605        stage_channels: Vec<usize>,
606        n_bottlenecks: usize,
607        dw_kernel: usize,
608        neck_channels: usize,
609        n_classes: usize,
610    ) -> VisionResult<Self> {
611        if in_chans == 0 || img_size == 0 {
612            return Err(VisionError::InvalidImageSize {
613                height: img_size,
614                width: img_size,
615                channels: in_chans,
616            });
617        }
618        if stage_channels.is_empty() {
619            return Err(VisionError::EmptyInput("rtmdet stage_channels"));
620        }
621        if n_classes == 0 {
622            return Err(VisionError::InvalidNumClasses(n_classes));
623        }
624        if dw_kernel == 0 || dw_kernel % 2 == 0 {
625            return Err(VisionError::InvalidPatchSize {
626                patch_size: dw_kernel,
627                img_size,
628            });
629        }
630        if stem_channels == 0 || neck_channels == 0 {
631            return Err(VisionError::DimensionMismatch {
632                expected: 1,
633                got: 0,
634            });
635        }
636        for &c in &stage_channels {
637            if c == 0 || c % 2 != 0 {
638                return Err(VisionError::DimensionMismatch {
639                    expected: 2,
640                    got: c,
641                });
642            }
643        }
644        Ok(Self {
645            in_chans,
646            img_size,
647            stem_channels,
648            stage_channels,
649            n_bottlenecks,
650            dw_kernel,
651            neck_channels,
652            n_classes,
653        })
654    }
655
656    /// A tiny configuration for unit tests:
657    /// 3×32×32 input, stem 8, stages `[8, 16, 16]`, 1 bottleneck, 5×5 depthwise,
658    /// neck 8, 4 classes.
659    #[must_use]
660    pub fn tiny() -> Self {
661        Self {
662            in_chans: 3,
663            img_size: 32,
664            stem_channels: 8,
665            stage_channels: vec![8, 16, 16],
666            n_bottlenecks: 1,
667            dw_kernel: 5,
668            neck_channels: 8,
669            n_classes: 4,
670        }
671    }
672
673    /// Number of pyramid levels (= number of backbone stages).
674    #[must_use]
675    #[inline]
676    pub fn n_levels(&self) -> usize {
677        self.stage_channels.len()
678    }
679}
680
681// ─── Output ────────────────────────────────────────────────────────────────────
682
683/// Multi-scale detector output.
684#[derive(Debug, Clone)]
685pub struct RtmDetOutput {
686    /// Per-level class logits, each `[n_classes, H, W]`.
687    pub cls_scores: Vec<FeatureMap>,
688    /// Per-level raw box regression, each `[4, H, W]`.
689    pub bbox_preds: Vec<FeatureMap>,
690    /// Per-level stride (image pixels / feature pixel).
691    pub strides: Vec<usize>,
692}
693
694// ─── RtmDet ────────────────────────────────────────────────────────────────────
695
696/// The full RTMDet detector.
697pub struct RtmDet {
698    cfg: RtmDetConfig,
699    backbone: CspNeXtBackbone,
700    neck: Pafpn,
701    head: DecoupledHead,
702}
703
704impl RtmDet {
705    /// Build the detector with randomly-initialised weights.
706    ///
707    /// # Errors
708    /// Propagates configuration validation.
709    pub fn new(cfg: RtmDetConfig, rng: &mut LcgRng) -> VisionResult<Self> {
710        let cfg = RtmDetConfig::new(
711            cfg.in_chans,
712            cfg.img_size,
713            cfg.stem_channels,
714            cfg.stage_channels.clone(),
715            cfg.n_bottlenecks,
716            cfg.dw_kernel,
717            cfg.neck_channels,
718            cfg.n_classes,
719        )?;
720        let backbone = CspNeXtBackbone::new(&cfg, rng);
721        let neck = Pafpn::new(&cfg.stage_channels, cfg.neck_channels, rng);
722        let head = DecoupledHead::new(cfg.neck_channels, cfg.n_classes, rng);
723        Ok(Self {
724            cfg,
725            backbone,
726            neck,
727            head,
728        })
729    }
730
731    /// Read-only access to the configuration.
732    #[must_use]
733    #[inline]
734    pub fn config(&self) -> &RtmDetConfig {
735        &self.cfg
736    }
737
738    /// Run only the backbone, returning the multi-scale feature hierarchy.
739    ///
740    /// # Errors
741    /// - [`VisionError::DimensionMismatch`] if `image.len() != in_chans·H·W`.
742    pub fn backbone_features(&self, image: &[f32]) -> VisionResult<Vec<FeatureMap>> {
743        let img = self.make_image(image)?;
744        self.backbone.forward(&img)
745    }
746
747    /// Run backbone + neck, returning the fused pyramid.
748    ///
749    /// # Errors
750    /// Propagates backbone / neck errors.
751    pub fn neck_features(&self, image: &[f32]) -> VisionResult<Vec<FeatureMap>> {
752        let feats = self.backbone_features(image)?;
753        self.neck.forward(feats)
754    }
755
756    /// Full forward pass producing per-level class + box predictions.
757    ///
758    /// # Errors
759    /// Propagates backbone / neck / head errors, or [`VisionError::NonFinite`]
760    /// if any prediction is non-finite.
761    pub fn forward(&self, image: &[f32]) -> VisionResult<RtmDetOutput> {
762        let neck = self.neck_features(image)?;
763        let mut cls_scores = Vec::with_capacity(neck.len());
764        let mut bbox_preds = Vec::with_capacity(neck.len());
765        let mut strides = Vec::with_capacity(neck.len());
766        for level in &neck {
767            let (cls, reg) = self.head.forward_level(level)?;
768            if cls
769                .data
770                .iter()
771                .chain(reg.data.iter())
772                .any(|v| !v.is_finite())
773            {
774                return Err(VisionError::NonFinite("rtmdet head output"));
775            }
776            strides.push(self.cfg.img_size / level.height.max(1));
777            cls_scores.push(cls);
778            bbox_preds.push(reg);
779        }
780        Ok(RtmDetOutput {
781            cls_scores,
782            bbox_preds,
783            strides,
784        })
785    }
786
787    fn make_image(&self, image: &[f32]) -> VisionResult<FeatureMap> {
788        FeatureMap::new(
789            image.to_vec(),
790            self.cfg.in_chans,
791            self.cfg.img_size,
792            self.cfg.img_size,
793        )
794    }
795}
796
797// ─── Box decoding ──────────────────────────────────────────────────────────────
798
799/// Decode one level's raw predictions into image-space detections.
800///
801/// Each location `(i, j)` is treated as an anchor point at the cell centre
802/// `((j+0.5)·stride, (i+0.5)·stride)`. The 4 regression channels are decoded as
803/// non-negative left/top/right/bottom distances (via softplus, scaled by the
804/// stride) to give an axis-aligned box `[x1, y1, x2, y2]`. The score is the
805/// maximum per-class sigmoid probability and the label its arg-max.
806///
807/// Returns `(boxes [n·4], scores [n], labels [n])` with `n = H · W`.
808///
809/// # Errors
810/// - [`VisionError::DimensionMismatch`] if `cls.width != reg.width` /
811///   `cls.height != reg.height` or `reg.channels != 4`.
812pub fn decode_level(
813    cls: &FeatureMap,
814    reg: &FeatureMap,
815    stride: usize,
816) -> VisionResult<(Vec<f32>, Vec<f32>, Vec<usize>)> {
817    if cls.height != reg.height || cls.width != reg.width {
818        return Err(VisionError::ShapeMismatch {
819            lhs: vec![cls.channels, cls.height, cls.width],
820            rhs: vec![reg.channels, reg.height, reg.width],
821        });
822    }
823    if reg.channels != 4 {
824        return Err(VisionError::DimensionMismatch {
825            expected: 4,
826            got: reg.channels,
827        });
828    }
829    let (h, w, n_cls) = (cls.height, cls.width, cls.channels);
830    let s = stride as f32;
831    let n = h * w;
832    let mut boxes = vec![0.0f32; n * 4];
833    let mut scores = vec![0.0f32; n];
834    let mut labels = vec![0usize; n];
835    for i in 0..h {
836        for j in 0..w {
837            let loc = i * w + j;
838            let cx = (j as f32 + 0.5) * s;
839            let cy = (i as f32 + 0.5) * s;
840            let l = softplus(reg.at(0, i, j)) * s;
841            let t = softplus(reg.at(1, i, j)) * s;
842            let r = softplus(reg.at(2, i, j)) * s;
843            let b = softplus(reg.at(3, i, j)) * s;
844            boxes[loc * 4] = cx - l;
845            boxes[loc * 4 + 1] = cy - t;
846            boxes[loc * 4 + 2] = cx + r;
847            boxes[loc * 4 + 3] = cy + b;
848
849            let mut best = f32::NEG_INFINITY;
850            let mut best_c = 0usize;
851            for c in 0..n_cls {
852                let p = 1.0 / (1.0 + (-cls.at(c, i, j)).exp());
853                if p > best {
854                    best = p;
855                    best_c = c;
856                }
857            }
858            scores[loc] = best;
859            labels[loc] = best_c;
860        }
861    }
862    Ok((boxes, scores, labels))
863}
864
865// ─── SimOTA-lite cost ──────────────────────────────────────────────────────────
866
867/// SimOTA-lite dynamic soft-label assignment cost.
868///
869/// For every (ground-truth, prediction) pair the cost is
870///
871/// ```text
872/// cost = cls_cost + λ · iou_cost
873/// cls_cost = −log( p_pred(gt_class) )
874/// iou_cost = −log( IoU(pred_box, gt_box) )
875/// ```
876///
877/// A prediction that is *confident in the correct class* and *well-localised*
878/// (high IoU) therefore receives a **low** cost — exactly the property exploited
879/// by the dynamic assignment.
880///
881/// # Parameters
882/// - `pred_cls`: `[n_pred · n_classes]` per-class **probabilities** in `[0, 1]`.
883/// - `pred_boxes`: `[n_pred · 4]` predicted boxes `[x1, y1, x2, y2]`.
884/// - `gt_labels`: `[n_gt]` ground-truth class indices.
885/// - `gt_boxes`: `[n_gt · 4]` ground-truth boxes `[x1, y1, x2, y2]`.
886/// - `n_classes`: number of classes.
887/// - `lambda_iou`: weight on the IoU cost term.
888///
889/// # Returns
890/// Flat `[n_gt · n_pred]` cost matrix (row-major, `cost[g · n_pred + p]`).
891///
892/// # Errors
893/// - [`VisionError::EmptyInput`] if there are no predictions or no targets.
894/// - [`VisionError::DimensionMismatch`] on inconsistent input lengths or an
895///   out-of-range ground-truth label.
896/// - [`VisionError::NonFinite`] if `lambda_iou` is not finite or a cost is
897///   non-finite.
898pub fn simota_cost(
899    pred_cls: &[f32],
900    pred_boxes: &[f32],
901    gt_labels: &[usize],
902    gt_boxes: &[f32],
903    n_classes: usize,
904    lambda_iou: f32,
905) -> VisionResult<Vec<f32>> {
906    if n_classes == 0 {
907        return Err(VisionError::InvalidNumClasses(n_classes));
908    }
909    if !lambda_iou.is_finite() {
910        return Err(VisionError::NonFinite("simota lambda_iou"));
911    }
912    let n_pred = pred_boxes.len() / 4;
913    let n_gt = gt_labels.len();
914    if n_pred == 0 {
915        return Err(VisionError::EmptyInput("simota predictions"));
916    }
917    if n_gt == 0 {
918        return Err(VisionError::EmptyInput("simota targets"));
919    }
920    if pred_boxes.len() != n_pred * 4 {
921        return Err(VisionError::DimensionMismatch {
922            expected: n_pred * 4,
923            got: pred_boxes.len(),
924        });
925    }
926    if pred_cls.len() != n_pred * n_classes {
927        return Err(VisionError::DimensionMismatch {
928            expected: n_pred * n_classes,
929            got: pred_cls.len(),
930        });
931    }
932    if gt_boxes.len() != n_gt * 4 {
933        return Err(VisionError::DimensionMismatch {
934            expected: n_gt * 4,
935            got: gt_boxes.len(),
936        });
937    }
938
939    const EPS: f32 = 1e-7;
940    let mut cost = vec![0.0f32; n_gt * n_pred];
941    for g in 0..n_gt {
942        let cls = gt_labels[g];
943        if cls >= n_classes {
944            return Err(VisionError::DimensionMismatch {
945                expected: n_classes,
946                got: cls,
947            });
948        }
949        let gbox = [
950            gt_boxes[g * 4],
951            gt_boxes[g * 4 + 1],
952            gt_boxes[g * 4 + 2],
953            gt_boxes[g * 4 + 3],
954        ];
955        for p in 0..n_pred {
956            let prob = pred_cls[p * n_classes + cls].clamp(EPS, 1.0);
957            let cls_cost = -prob.ln();
958            let pbox = [
959                pred_boxes[p * 4],
960                pred_boxes[p * 4 + 1],
961                pred_boxes[p * 4 + 2],
962                pred_boxes[p * 4 + 3],
963            ];
964            let iou_val = iou(&pbox, &gbox);
965            let iou_cost = -(iou_val + EPS).ln();
966            cost[g * n_pred + p] = cls_cost + lambda_iou * iou_cost;
967        }
968    }
969    if cost.iter().any(|v| !v.is_finite()) {
970        return Err(VisionError::NonFinite("simota cost"));
971    }
972    Ok(cost)
973}
974
975// ─── Tests ───────────────────────────────────────────────────────────────────
976
977#[cfg(test)]
978mod tests {
979    use super::*;
980
981    fn random_image(cfg: &RtmDetConfig, seed: u64) -> Vec<f32> {
982        let mut rng = LcgRng::new(seed);
983        let mut img = vec![0.0f32; cfg.in_chans * cfg.img_size * cfg.img_size];
984        rng.fill_normal(&mut img);
985        img
986    }
987
988    // ── Config validation ─────────────────────────────────────────────────────
989
990    #[test]
991    fn config_tiny_valid() {
992        let cfg = RtmDetConfig::tiny();
993        assert_eq!(cfg.n_levels(), 3);
994    }
995
996    #[test]
997    fn config_odd_stage_channel_errors() {
998        let r = RtmDetConfig::new(3, 32, 8, vec![8, 15], 1, 5, 8, 4);
999        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
1000    }
1001
1002    #[test]
1003    fn config_even_dw_kernel_errors() {
1004        let r = RtmDetConfig::new(3, 32, 8, vec![8, 16], 1, 4, 8, 4);
1005        assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
1006    }
1007
1008    #[test]
1009    fn config_zero_classes_errors() {
1010        let r = RtmDetConfig::new(3, 32, 8, vec![8, 16], 1, 5, 8, 0);
1011        assert!(matches!(r, Err(VisionError::InvalidNumClasses(0))));
1012    }
1013
1014    // ── Conv sanity ───────────────────────────────────────────────────────────
1015
1016    #[test]
1017    fn conv2d_stride2_halves_spatial() {
1018        let mut rng = LcgRng::new(1);
1019        let conv = Conv2d::new(3, 4, 3, 2, 1, &mut rng);
1020        let x = FeatureMap::new(vec![0.5f32; 3 * 16 * 16], 3, 16, 16).expect("ok");
1021        let y = conv.forward(&x).expect("ok");
1022        assert_eq!((y.channels, y.height, y.width), (4, 8, 8));
1023    }
1024
1025    #[test]
1026    fn dwconv_identity_kernel_is_input() {
1027        // Centre-delta 3×3 depthwise kernel with zero bias = identity.
1028        let mut rng = LcgRng::new(2);
1029        let mut dw = DwConv2d::new(2, 3, 1, 1, &mut rng);
1030        for v in dw.weight.iter_mut() {
1031            *v = 0.0;
1032        }
1033        // centre of a 3×3 kernel is index 4.
1034        for ch in 0..2 {
1035            dw.weight[ch * 9 + 4] = 1.0;
1036        }
1037        let mut data = vec![0.0f32; 2 * 4 * 4];
1038        let mut r2 = LcgRng::new(3);
1039        r2.fill_normal(&mut data);
1040        let x = FeatureMap::new(data.clone(), 2, 4, 4).expect("ok");
1041        let y = dw.forward(&x).expect("ok");
1042        for (a, b) in y.data.iter().zip(data.iter()) {
1043            assert!((a - b).abs() < 1e-5, "identity dw mismatch {a} vs {b}");
1044        }
1045    }
1046
1047    // ── Backbone: multi-scale halving ─────────────────────────────────────────
1048
1049    #[test]
1050    fn backbone_multiscale_halving() {
1051        let cfg = RtmDetConfig::tiny();
1052        let mut rng = LcgRng::new(10);
1053        let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
1054        let img = random_image(&cfg, 11);
1055        let feats = det.backbone_features(&img).expect("ok");
1056        assert_eq!(feats.len(), 3, "one feature per stage");
1057        // Stage spatials: 8, 4, 2 — each halves the previous.
1058        let spatials: Vec<usize> = feats.iter().map(|f| f.height).collect();
1059        assert_eq!(spatials, vec![8, 4, 2]);
1060        for w in feats.windows(2) {
1061            assert_eq!(w[0].height, w[1].height * 2, "each stage halves spatial");
1062            assert_eq!(w[0].width, w[1].width * 2);
1063        }
1064        // Channels match the configured stage widths.
1065        let chans: Vec<usize> = feats.iter().map(|f| f.channels).collect();
1066        assert_eq!(chans, cfg.stage_channels);
1067        assert!(feats.iter().all(|f| f.data.iter().all(|v| v.is_finite())));
1068    }
1069
1070    // ── Neck: same #scales, fused channels ────────────────────────────────────
1071
1072    #[test]
1073    fn pafpn_uniform_channels_same_scales() {
1074        let cfg = RtmDetConfig::tiny();
1075        let mut rng = LcgRng::new(12);
1076        let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
1077        let img = random_image(&cfg, 13);
1078        let neck = det.neck_features(&img).expect("ok");
1079        assert_eq!(neck.len(), cfg.n_levels(), "neck preserves #scales");
1080        for fm in &neck {
1081            assert_eq!(fm.channels, cfg.neck_channels, "uniform fused channels");
1082        }
1083        // Spatial sizes preserved per level (8, 4, 2).
1084        let spatials: Vec<usize> = neck.iter().map(|f| f.height).collect();
1085        assert_eq!(spatials, vec![8, 4, 2]);
1086        assert!(neck.iter().all(|f| f.data.iter().all(|v| v.is_finite())));
1087    }
1088
1089    // ── Decoupled head shapes ─────────────────────────────────────────────────
1090
1091    #[test]
1092    fn decoupled_head_shapes() {
1093        let cfg = RtmDetConfig::tiny();
1094        let mut rng = LcgRng::new(14);
1095        let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
1096        let img = random_image(&cfg, 15);
1097        let out = det.forward(&img).expect("ok");
1098        assert_eq!(out.cls_scores.len(), 3);
1099        assert_eq!(out.bbox_preds.len(), 3);
1100        for (cls, reg) in out.cls_scores.iter().zip(out.bbox_preds.iter()) {
1101            assert_eq!(cls.channels, cfg.n_classes, "cls has n_classes channels");
1102            assert_eq!(reg.channels, 4, "reg has 4 channels");
1103            assert_eq!(cls.height, reg.height);
1104            assert_eq!(cls.data.len(), cfg.n_classes * cls.height * cls.width);
1105            assert_eq!(reg.data.len(), 4 * reg.height * reg.width);
1106        }
1107        assert_eq!(out.strides, vec![4, 8, 16]);
1108    }
1109
1110    #[test]
1111    fn forward_all_finite() {
1112        let cfg = RtmDetConfig::tiny();
1113        let mut rng = LcgRng::new(16);
1114        let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
1115        let img = random_image(&cfg, 17);
1116        let out = det.forward(&img).expect("ok");
1117        for fm in out.cls_scores.iter().chain(out.bbox_preds.iter()) {
1118            assert!(fm.data.iter().all(|v| v.is_finite()));
1119        }
1120    }
1121
1122    // ── Varying input changes detections ──────────────────────────────────────
1123
1124    #[test]
1125    fn varying_input_changes_detections() {
1126        let cfg = RtmDetConfig::tiny();
1127        let mut rng = LcgRng::new(18);
1128        let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
1129        let img_a = random_image(&cfg, 19);
1130        let img_b = random_image(&cfg, 20);
1131        let out_a = det.forward(&img_a).expect("ok");
1132        let out_b = det.forward(&img_b).expect("ok");
1133        // The finest-level class scores must differ for different inputs.
1134        let diff: f32 = out_a.cls_scores[0]
1135            .data
1136            .iter()
1137            .zip(out_b.cls_scores[0].data.iter())
1138            .map(|(a, b)| (a - b).abs())
1139            .sum();
1140        assert!(
1141            diff > 1e-4,
1142            "detections should change with input, diff={diff}"
1143        );
1144    }
1145
1146    #[test]
1147    fn decode_level_produces_valid_boxes() {
1148        let cfg = RtmDetConfig::tiny();
1149        let mut rng = LcgRng::new(21);
1150        let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
1151        let img = random_image(&cfg, 22);
1152        let out = det.forward(&img).expect("ok");
1153        let (boxes, scores, labels) =
1154            decode_level(&out.cls_scores[0], &out.bbox_preds[0], out.strides[0]).expect("ok");
1155        let n = out.cls_scores[0].height * out.cls_scores[0].width;
1156        assert_eq!(boxes.len(), n * 4);
1157        assert_eq!(scores.len(), n);
1158        assert_eq!(labels.len(), n);
1159        for loc in 0..n {
1160            // x2 > x1 and y2 > y1 (softplus distances are positive).
1161            assert!(boxes[loc * 4 + 2] > boxes[loc * 4], "x2 must exceed x1");
1162            assert!(boxes[loc * 4 + 3] > boxes[loc * 4 + 1], "y2 must exceed y1");
1163            assert!((0.0..=1.0).contains(&scores[loc]), "score in [0,1]");
1164            assert!(labels[loc] < cfg.n_classes);
1165        }
1166        assert!(boxes.iter().all(|v| v.is_finite()));
1167    }
1168
1169    // ── SimOTA-lite cost ──────────────────────────────────────────────────────
1170
1171    #[test]
1172    fn simota_lower_cost_for_better_match() {
1173        // 2 predictions, 1 ground-truth (class 0, box [0,0,10,10]).
1174        let n_classes = 2;
1175        // pred 0: confident on class 0, perfectly localised → should be cheap.
1176        // pred 1: confident on the wrong class, far away → should be expensive.
1177        let pred_cls = vec![
1178            0.9f32, 0.1, // pred 0
1179            0.1, 0.9, // pred 1
1180        ];
1181        let pred_boxes = vec![
1182            0.0f32, 0.0, 10.0, 10.0, // pred 0 (IoU = 1)
1183            20.0, 20.0, 30.0, 30.0, // pred 1 (IoU = 0)
1184        ];
1185        let gt_labels = vec![0usize];
1186        let gt_boxes = vec![0.0f32, 0.0, 10.0, 10.0];
1187
1188        let cost = simota_cost(
1189            &pred_cls,
1190            &pred_boxes,
1191            &gt_labels,
1192            &gt_boxes,
1193            n_classes,
1194            3.0,
1195        )
1196        .expect("ok");
1197        assert_eq!(cost.len(), 2, "[n_gt × n_pred]");
1198        assert!(cost.iter().all(|v| v.is_finite()), "cost must be finite");
1199        // cost[gt 0, pred 0] must be much lower than cost[gt 0, pred 1].
1200        assert!(
1201            cost[0] < cost[1],
1202            "better cls+iou match must have lower cost: {} vs {}",
1203            cost[0],
1204            cost[1]
1205        );
1206    }
1207
1208    #[test]
1209    fn simota_cost_monotonic_in_iou() {
1210        // Same class confidence, varying IoU → cost decreases with IoU.
1211        let n_classes = 1;
1212        let pred_cls = vec![0.8f32, 0.8, 0.8];
1213        let pred_boxes = vec![
1214            0.0f32, 0.0, 10.0, 10.0, // IoU 1.0
1215            5.0, 0.0, 15.0, 10.0, // IoU 1/3
1216            50.0, 50.0, 60.0, 60.0, // IoU 0
1217        ];
1218        let gt_labels = vec![0usize];
1219        let gt_boxes = vec![0.0f32, 0.0, 10.0, 10.0];
1220        let cost = simota_cost(
1221            &pred_cls,
1222            &pred_boxes,
1223            &gt_labels,
1224            &gt_boxes,
1225            n_classes,
1226            2.0,
1227        )
1228        .expect("ok");
1229        assert!(cost[0] < cost[1], "higher IoU → lower cost");
1230        assert!(cost[1] < cost[2], "higher IoU → lower cost");
1231    }
1232
1233    #[test]
1234    fn simota_errors_on_bad_shapes() {
1235        // pred_cls length wrong.
1236        let r = simota_cost(
1237            &[0.5f32],
1238            &[0.0, 0.0, 1.0, 1.0],
1239            &[0],
1240            &[0.0, 0.0, 1.0, 1.0],
1241            2,
1242            1.0,
1243        );
1244        assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
1245        // empty predictions.
1246        let r2 = simota_cost(&[], &[], &[0], &[0.0, 0.0, 1.0, 1.0], 1, 1.0);
1247        assert!(matches!(r2, Err(VisionError::EmptyInput(_))));
1248    }
1249
1250    // ── Determinism ───────────────────────────────────────────────────────────
1251
1252    #[test]
1253    fn deterministic_same_seed() {
1254        let cfg = RtmDetConfig::tiny();
1255        let img = random_image(&cfg, 30);
1256        let mut ra = LcgRng::new(99);
1257        let mut rb = LcgRng::new(99);
1258        let da = RtmDet::new(cfg.clone(), &mut ra).expect("ok");
1259        let db = RtmDet::new(cfg, &mut rb).expect("ok");
1260        let oa = da.forward(&img).expect("ok");
1261        let ob = db.forward(&img).expect("ok");
1262        for (a, b) in oa.cls_scores.iter().zip(ob.cls_scores.iter()) {
1263            assert_eq!(a.data, b.data, "same seed → identical output");
1264        }
1265    }
1266}