use super::anchor_nms::iou;
use crate::{
error::{VisionError, VisionResult},
fpn::top_down::FeatureMap,
handle::LcgRng,
};
fn silu_inplace(x: &mut [f32]) {
for v in x.iter_mut() {
*v *= 1.0 / (1.0 + (-*v).exp());
}
}
#[inline]
fn softplus(x: f32) -> f32 {
x.max(0.0) + (-(x.abs())).exp().ln_1p()
}
#[derive(Debug, Clone)]
pub struct Conv2d {
weight: Vec<f32>,
bias: Vec<f32>,
c_in: usize,
c_out: usize,
k: usize,
stride: usize,
pad: usize,
}
impl Conv2d {
fn new(
c_in: usize,
c_out: usize,
k: usize,
stride: usize,
pad: usize,
rng: &mut LcgRng,
) -> Self {
let fan_in = c_in * k * k;
let scale = (2.0 / fan_in as f32).sqrt();
let mut weight = vec![0.0f32; c_out * fan_in];
rng.fill_normal(&mut weight);
for w in &mut weight {
*w *= scale;
}
Self {
weight,
bias: vec![0.0f32; c_out],
c_in,
c_out,
k,
stride,
pad,
}
}
#[must_use]
#[inline]
pub fn out_channels(&self) -> usize {
self.c_out
}
pub fn forward(&self, x: &FeatureMap) -> VisionResult<FeatureMap> {
if x.channels != self.c_in {
return Err(VisionError::DimensionMismatch {
expected: self.c_in,
got: x.channels,
});
}
let (h, w) = (x.height, x.width);
if h + 2 * self.pad < self.k || w + 2 * self.pad < self.k {
return Err(VisionError::InvalidImageSize {
height: h,
width: w,
channels: x.channels,
});
}
let h_out = (h + 2 * self.pad - self.k) / self.stride + 1;
let w_out = (w + 2 * self.pad - self.k) / self.stride + 1;
let mut out = vec![0.0f32; self.c_out * h_out * w_out];
let k = self.k;
for oc in 0..self.c_out {
let oc_w = oc * self.c_in * k * k;
for oh in 0..h_out {
for ow in 0..w_out {
let mut acc = self.bias[oc];
for ic in 0..self.c_in {
let in_base = ic * h * w;
let w_base = oc_w + ic * k * k;
for ki in 0..k {
let ih = oh * self.stride + ki;
if ih < self.pad || ih >= h + self.pad {
continue;
}
let ih = ih - self.pad;
for kj in 0..k {
let iw = ow * self.stride + kj;
if iw < self.pad || iw >= w + self.pad {
continue;
}
let iw = iw - self.pad;
acc += self.weight[w_base + ki * k + kj]
* out_in(x, in_base, ih, w, iw);
}
}
}
out[(oc * h_out + oh) * w_out + ow] = acc;
}
}
}
FeatureMap::new(out, self.c_out, h_out, w_out)
}
}
#[inline]
fn out_in(x: &FeatureMap, in_base: usize, ih: usize, w: usize, iw: usize) -> f32 {
x.data[in_base + ih * w + iw]
}
#[derive(Debug, Clone)]
pub struct DwConv2d {
weight: Vec<f32>,
bias: Vec<f32>,
c: usize,
k: usize,
stride: usize,
pad: usize,
}
impl DwConv2d {
fn new(c: usize, k: usize, stride: usize, pad: usize, rng: &mut LcgRng) -> Self {
let fan_in = k * k;
let scale = (2.0 / fan_in as f32).sqrt();
let mut weight = vec![0.0f32; c * fan_in];
rng.fill_normal(&mut weight);
for w in &mut weight {
*w *= scale;
}
Self {
weight,
bias: vec![0.0f32; c],
c,
k,
stride,
pad,
}
}
pub fn forward(&self, x: &FeatureMap) -> VisionResult<FeatureMap> {
if x.channels != self.c {
return Err(VisionError::DimensionMismatch {
expected: self.c,
got: x.channels,
});
}
let (h, w) = (x.height, x.width);
if h + 2 * self.pad < self.k || w + 2 * self.pad < self.k {
return Err(VisionError::InvalidImageSize {
height: h,
width: w,
channels: x.channels,
});
}
let h_out = (h + 2 * self.pad - self.k) / self.stride + 1;
let w_out = (w + 2 * self.pad - self.k) / self.stride + 1;
let k = self.k;
let mut out = vec![0.0f32; self.c * h_out * w_out];
for ch in 0..self.c {
let in_base = ch * h * w;
let w_base = ch * k * k;
let bias = self.bias[ch];
for oh in 0..h_out {
for ow in 0..w_out {
let mut acc = bias;
for ki in 0..k {
let ih = oh * self.stride + ki;
if ih < self.pad || ih >= h + self.pad {
continue;
}
let ih = ih - self.pad;
for kj in 0..k {
let iw = ow * self.stride + kj;
if iw < self.pad || iw >= w + self.pad {
continue;
}
let iw = iw - self.pad;
acc +=
self.weight[w_base + ki * k + kj] * x.data[in_base + ih * w + iw];
}
}
out[(ch * h_out + oh) * w_out + ow] = acc;
}
}
}
FeatureMap::new(out, self.c, h_out, w_out)
}
}
fn concat_channels(a: &FeatureMap, b: &FeatureMap) -> VisionResult<FeatureMap> {
if a.height != b.height || a.width != b.width {
return Err(VisionError::ShapeMismatch {
lhs: vec![a.channels, a.height, a.width],
rhs: vec![b.channels, b.height, b.width],
});
}
let mut data = Vec::with_capacity(a.data.len() + b.data.len());
data.extend_from_slice(&a.data);
data.extend_from_slice(&b.data);
Ok(FeatureMap {
data,
channels: a.channels + b.channels,
height: a.height,
width: a.width,
})
}
fn add_inplace(dst: &mut FeatureMap, src: &FeatureMap) -> VisionResult<()> {
if dst.channels != src.channels || dst.height != src.height || dst.width != src.width {
return Err(VisionError::ShapeMismatch {
lhs: vec![dst.channels, dst.height, dst.width],
rhs: vec![src.channels, src.height, src.width],
});
}
for (a, b) in dst.data.iter_mut().zip(src.data.iter()) {
*a += *b;
}
Ok(())
}
fn upsample2x(x: &FeatureMap) -> FeatureMap {
let (c, h, w) = (x.channels, x.height, x.width);
let (h2, w2) = (h * 2, w * 2);
let mut out = vec![0.0f32; c * h2 * w2];
for ch in 0..c {
for i in 0..h {
for j in 0..w {
let v = x.data[(ch * h + i) * w + j];
let oi = i * 2;
let oj = j * 2;
out[(ch * h2 + oi) * w2 + oj] = v;
out[(ch * h2 + oi) * w2 + oj + 1] = v;
out[(ch * h2 + oi + 1) * w2 + oj] = v;
out[(ch * h2 + oi + 1) * w2 + oj + 1] = v;
}
}
}
FeatureMap {
data: out,
channels: c,
height: h2,
width: w2,
}
}
pub struct Bottleneck {
dw: DwConv2d,
pw: Conv2d,
}
impl Bottleneck {
fn new(channels: usize, dw_kernel: usize, rng: &mut LcgRng) -> Self {
let pad = (dw_kernel - 1) / 2;
Self {
dw: DwConv2d::new(channels, dw_kernel, 1, pad, rng),
pw: Conv2d::new(channels, channels, 1, 1, 0, rng),
}
}
fn forward(&self, x: &FeatureMap) -> VisionResult<FeatureMap> {
let mut y = self.dw.forward(x)?;
silu_inplace(&mut y.data);
let mut y = self.pw.forward(&y)?;
silu_inplace(&mut y.data);
add_inplace(&mut y, x)?; Ok(y)
}
}
pub struct CspLayer {
main_conv: Conv2d,
short_conv: Conv2d,
blocks: Vec<Bottleneck>,
final_conv: Conv2d,
}
impl CspLayer {
fn new(
in_channels: usize,
out_channels: usize,
n_blocks: usize,
dw_kernel: usize,
rng: &mut LcgRng,
) -> Self {
let mid = out_channels / 2;
let main_conv = Conv2d::new(in_channels, mid, 1, 1, 0, rng);
let short_conv = Conv2d::new(in_channels, mid, 1, 1, 0, rng);
let blocks = (0..n_blocks)
.map(|_| Bottleneck::new(mid, dw_kernel, rng))
.collect();
let final_conv = Conv2d::new(2 * mid, out_channels, 1, 1, 0, rng);
Self {
main_conv,
short_conv,
blocks,
final_conv,
}
}
fn forward(&self, x: &FeatureMap) -> VisionResult<FeatureMap> {
let mut short = self.short_conv.forward(x)?;
silu_inplace(&mut short.data);
let mut main = self.main_conv.forward(x)?;
silu_inplace(&mut main.data);
for b in &self.blocks {
main = b.forward(&main)?;
}
let cat = concat_channels(&main, &short)?;
let mut out = self.final_conv.forward(&cat)?;
silu_inplace(&mut out.data);
Ok(out)
}
}
struct BackboneStage {
downsample: Conv2d,
csp: CspLayer,
}
pub struct CspNeXtBackbone {
stem: Conv2d,
stages: Vec<BackboneStage>,
}
impl CspNeXtBackbone {
fn new(cfg: &RtmDetConfig, rng: &mut LcgRng) -> Self {
let stem = Conv2d::new(cfg.in_chans, cfg.stem_channels, 3, 2, 1, rng);
let mut stages = Vec::with_capacity(cfg.stage_channels.len());
let mut prev = cfg.stem_channels;
for &c in &cfg.stage_channels {
let downsample = Conv2d::new(prev, c, 3, 2, 1, rng);
let csp = CspLayer::new(c, c, cfg.n_bottlenecks, cfg.dw_kernel, rng);
stages.push(BackboneStage { downsample, csp });
prev = c;
}
Self { stem, stages }
}
pub fn forward(&self, image: &FeatureMap) -> VisionResult<Vec<FeatureMap>> {
let mut x = self.stem.forward(image)?;
silu_inplace(&mut x.data);
let mut feats = Vec::with_capacity(self.stages.len());
for stage in &self.stages {
let mut d = stage.downsample.forward(&x)?;
silu_inplace(&mut d.data);
x = stage.csp.forward(&d)?;
feats.push(x.clone());
}
Ok(feats)
}
}
pub struct Pafpn {
lateral: Vec<Conv2d>,
top_down: Vec<Conv2d>,
downsample: Vec<Conv2d>,
bottom_up: Vec<Conv2d>,
n_levels: usize,
}
impl Pafpn {
fn new(in_channels: &[usize], out_channels: usize, rng: &mut LcgRng) -> Self {
let n_levels = in_channels.len();
let lateral = in_channels
.iter()
.map(|&c| Conv2d::new(c, out_channels, 1, 1, 0, rng))
.collect();
let top_down = (0..n_levels)
.map(|_| Conv2d::new(out_channels, out_channels, 3, 1, 1, rng))
.collect();
let downsample = (0..n_levels.saturating_sub(1))
.map(|_| Conv2d::new(out_channels, out_channels, 3, 2, 1, rng))
.collect();
let bottom_up = (0..n_levels.saturating_sub(1))
.map(|_| Conv2d::new(out_channels, out_channels, 3, 1, 1, rng))
.collect();
Self {
lateral,
top_down,
downsample,
bottom_up,
n_levels,
}
}
pub fn forward(&self, feats: Vec<FeatureMap>) -> VisionResult<Vec<FeatureMap>> {
if feats.len() != self.n_levels {
return Err(VisionError::DimensionMismatch {
expected: self.n_levels,
got: feats.len(),
});
}
let l = self.n_levels;
let mut lat: Vec<FeatureMap> = Vec::with_capacity(l);
for (f, conv) in feats.iter().zip(self.lateral.iter()) {
lat.push(conv.forward(f)?);
}
for level in (0..l.saturating_sub(1)).rev() {
let up = upsample2x(&lat[level + 1]);
add_inplace(&mut lat[level], &up)?;
let mut fused = self.top_down[level].forward(&lat[level])?;
silu_inplace(&mut fused.data);
lat[level] = fused;
}
if l > 0 {
let mut fused = self.top_down[l - 1].forward(&lat[l - 1])?;
silu_inplace(&mut fused.data);
lat[l - 1] = fused;
}
let mut outs: Vec<FeatureMap> = Vec::with_capacity(l);
outs.push(lat[0].clone());
for level in 1..l {
let mut down = self.downsample[level - 1].forward(&outs[level - 1])?;
silu_inplace(&mut down.data);
let mut merged = lat[level].clone();
add_inplace(&mut merged, &down)?;
let mut fused = self.bottom_up[level - 1].forward(&merged)?;
silu_inplace(&mut fused.data);
outs.push(fused);
}
Ok(outs)
}
}
pub struct DecoupledHead {
cls_conv: Conv2d,
cls_pred: Conv2d,
reg_conv: Conv2d,
reg_pred: Conv2d,
}
impl DecoupledHead {
fn new(channels: usize, n_classes: usize, rng: &mut LcgRng) -> Self {
Self {
cls_conv: Conv2d::new(channels, channels, 3, 1, 1, rng),
cls_pred: Conv2d::new(channels, n_classes, 1, 1, 0, rng),
reg_conv: Conv2d::new(channels, channels, 3, 1, 1, rng),
reg_pred: Conv2d::new(channels, 4, 1, 1, 0, rng),
}
}
pub fn forward_level(&self, x: &FeatureMap) -> VisionResult<(FeatureMap, FeatureMap)> {
let mut c = self.cls_conv.forward(x)?;
silu_inplace(&mut c.data);
let cls = self.cls_pred.forward(&c)?;
let mut r = self.reg_conv.forward(x)?;
silu_inplace(&mut r.data);
let reg = self.reg_pred.forward(&r)?;
Ok((cls, reg))
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RtmDetConfig {
pub in_chans: usize,
pub img_size: usize,
pub stem_channels: usize,
pub stage_channels: Vec<usize>,
pub n_bottlenecks: usize,
pub dw_kernel: usize,
pub neck_channels: usize,
pub n_classes: usize,
}
impl RtmDetConfig {
pub fn new(
in_chans: usize,
img_size: usize,
stem_channels: usize,
stage_channels: Vec<usize>,
n_bottlenecks: usize,
dw_kernel: usize,
neck_channels: usize,
n_classes: usize,
) -> VisionResult<Self> {
if in_chans == 0 || img_size == 0 {
return Err(VisionError::InvalidImageSize {
height: img_size,
width: img_size,
channels: in_chans,
});
}
if stage_channels.is_empty() {
return Err(VisionError::EmptyInput("rtmdet stage_channels"));
}
if n_classes == 0 {
return Err(VisionError::InvalidNumClasses(n_classes));
}
if dw_kernel == 0 || dw_kernel % 2 == 0 {
return Err(VisionError::InvalidPatchSize {
patch_size: dw_kernel,
img_size,
});
}
if stem_channels == 0 || neck_channels == 0 {
return Err(VisionError::DimensionMismatch {
expected: 1,
got: 0,
});
}
for &c in &stage_channels {
if c == 0 || c % 2 != 0 {
return Err(VisionError::DimensionMismatch {
expected: 2,
got: c,
});
}
}
Ok(Self {
in_chans,
img_size,
stem_channels,
stage_channels,
n_bottlenecks,
dw_kernel,
neck_channels,
n_classes,
})
}
#[must_use]
pub fn tiny() -> Self {
Self {
in_chans: 3,
img_size: 32,
stem_channels: 8,
stage_channels: vec![8, 16, 16],
n_bottlenecks: 1,
dw_kernel: 5,
neck_channels: 8,
n_classes: 4,
}
}
#[must_use]
#[inline]
pub fn n_levels(&self) -> usize {
self.stage_channels.len()
}
}
#[derive(Debug, Clone)]
pub struct RtmDetOutput {
pub cls_scores: Vec<FeatureMap>,
pub bbox_preds: Vec<FeatureMap>,
pub strides: Vec<usize>,
}
pub struct RtmDet {
cfg: RtmDetConfig,
backbone: CspNeXtBackbone,
neck: Pafpn,
head: DecoupledHead,
}
impl RtmDet {
pub fn new(cfg: RtmDetConfig, rng: &mut LcgRng) -> VisionResult<Self> {
let cfg = RtmDetConfig::new(
cfg.in_chans,
cfg.img_size,
cfg.stem_channels,
cfg.stage_channels.clone(),
cfg.n_bottlenecks,
cfg.dw_kernel,
cfg.neck_channels,
cfg.n_classes,
)?;
let backbone = CspNeXtBackbone::new(&cfg, rng);
let neck = Pafpn::new(&cfg.stage_channels, cfg.neck_channels, rng);
let head = DecoupledHead::new(cfg.neck_channels, cfg.n_classes, rng);
Ok(Self {
cfg,
backbone,
neck,
head,
})
}
#[must_use]
#[inline]
pub fn config(&self) -> &RtmDetConfig {
&self.cfg
}
pub fn backbone_features(&self, image: &[f32]) -> VisionResult<Vec<FeatureMap>> {
let img = self.make_image(image)?;
self.backbone.forward(&img)
}
pub fn neck_features(&self, image: &[f32]) -> VisionResult<Vec<FeatureMap>> {
let feats = self.backbone_features(image)?;
self.neck.forward(feats)
}
pub fn forward(&self, image: &[f32]) -> VisionResult<RtmDetOutput> {
let neck = self.neck_features(image)?;
let mut cls_scores = Vec::with_capacity(neck.len());
let mut bbox_preds = Vec::with_capacity(neck.len());
let mut strides = Vec::with_capacity(neck.len());
for level in &neck {
let (cls, reg) = self.head.forward_level(level)?;
if cls
.data
.iter()
.chain(reg.data.iter())
.any(|v| !v.is_finite())
{
return Err(VisionError::NonFinite("rtmdet head output"));
}
strides.push(self.cfg.img_size / level.height.max(1));
cls_scores.push(cls);
bbox_preds.push(reg);
}
Ok(RtmDetOutput {
cls_scores,
bbox_preds,
strides,
})
}
fn make_image(&self, image: &[f32]) -> VisionResult<FeatureMap> {
FeatureMap::new(
image.to_vec(),
self.cfg.in_chans,
self.cfg.img_size,
self.cfg.img_size,
)
}
}
pub fn decode_level(
cls: &FeatureMap,
reg: &FeatureMap,
stride: usize,
) -> VisionResult<(Vec<f32>, Vec<f32>, Vec<usize>)> {
if cls.height != reg.height || cls.width != reg.width {
return Err(VisionError::ShapeMismatch {
lhs: vec![cls.channels, cls.height, cls.width],
rhs: vec![reg.channels, reg.height, reg.width],
});
}
if reg.channels != 4 {
return Err(VisionError::DimensionMismatch {
expected: 4,
got: reg.channels,
});
}
let (h, w, n_cls) = (cls.height, cls.width, cls.channels);
let s = stride as f32;
let n = h * w;
let mut boxes = vec![0.0f32; n * 4];
let mut scores = vec![0.0f32; n];
let mut labels = vec![0usize; n];
for i in 0..h {
for j in 0..w {
let loc = i * w + j;
let cx = (j as f32 + 0.5) * s;
let cy = (i as f32 + 0.5) * s;
let l = softplus(reg.at(0, i, j)) * s;
let t = softplus(reg.at(1, i, j)) * s;
let r = softplus(reg.at(2, i, j)) * s;
let b = softplus(reg.at(3, i, j)) * s;
boxes[loc * 4] = cx - l;
boxes[loc * 4 + 1] = cy - t;
boxes[loc * 4 + 2] = cx + r;
boxes[loc * 4 + 3] = cy + b;
let mut best = f32::NEG_INFINITY;
let mut best_c = 0usize;
for c in 0..n_cls {
let p = 1.0 / (1.0 + (-cls.at(c, i, j)).exp());
if p > best {
best = p;
best_c = c;
}
}
scores[loc] = best;
labels[loc] = best_c;
}
}
Ok((boxes, scores, labels))
}
pub fn simota_cost(
pred_cls: &[f32],
pred_boxes: &[f32],
gt_labels: &[usize],
gt_boxes: &[f32],
n_classes: usize,
lambda_iou: f32,
) -> VisionResult<Vec<f32>> {
if n_classes == 0 {
return Err(VisionError::InvalidNumClasses(n_classes));
}
if !lambda_iou.is_finite() {
return Err(VisionError::NonFinite("simota lambda_iou"));
}
let n_pred = pred_boxes.len() / 4;
let n_gt = gt_labels.len();
if n_pred == 0 {
return Err(VisionError::EmptyInput("simota predictions"));
}
if n_gt == 0 {
return Err(VisionError::EmptyInput("simota targets"));
}
if pred_boxes.len() != n_pred * 4 {
return Err(VisionError::DimensionMismatch {
expected: n_pred * 4,
got: pred_boxes.len(),
});
}
if pred_cls.len() != n_pred * n_classes {
return Err(VisionError::DimensionMismatch {
expected: n_pred * n_classes,
got: pred_cls.len(),
});
}
if gt_boxes.len() != n_gt * 4 {
return Err(VisionError::DimensionMismatch {
expected: n_gt * 4,
got: gt_boxes.len(),
});
}
const EPS: f32 = 1e-7;
let mut cost = vec![0.0f32; n_gt * n_pred];
for g in 0..n_gt {
let cls = gt_labels[g];
if cls >= n_classes {
return Err(VisionError::DimensionMismatch {
expected: n_classes,
got: cls,
});
}
let gbox = [
gt_boxes[g * 4],
gt_boxes[g * 4 + 1],
gt_boxes[g * 4 + 2],
gt_boxes[g * 4 + 3],
];
for p in 0..n_pred {
let prob = pred_cls[p * n_classes + cls].clamp(EPS, 1.0);
let cls_cost = -prob.ln();
let pbox = [
pred_boxes[p * 4],
pred_boxes[p * 4 + 1],
pred_boxes[p * 4 + 2],
pred_boxes[p * 4 + 3],
];
let iou_val = iou(&pbox, &gbox);
let iou_cost = -(iou_val + EPS).ln();
cost[g * n_pred + p] = cls_cost + lambda_iou * iou_cost;
}
}
if cost.iter().any(|v| !v.is_finite()) {
return Err(VisionError::NonFinite("simota cost"));
}
Ok(cost)
}
#[cfg(test)]
mod tests {
use super::*;
fn random_image(cfg: &RtmDetConfig, seed: u64) -> Vec<f32> {
let mut rng = LcgRng::new(seed);
let mut img = vec![0.0f32; cfg.in_chans * cfg.img_size * cfg.img_size];
rng.fill_normal(&mut img);
img
}
#[test]
fn config_tiny_valid() {
let cfg = RtmDetConfig::tiny();
assert_eq!(cfg.n_levels(), 3);
}
#[test]
fn config_odd_stage_channel_errors() {
let r = RtmDetConfig::new(3, 32, 8, vec![8, 15], 1, 5, 8, 4);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
}
#[test]
fn config_even_dw_kernel_errors() {
let r = RtmDetConfig::new(3, 32, 8, vec![8, 16], 1, 4, 8, 4);
assert!(matches!(r, Err(VisionError::InvalidPatchSize { .. })));
}
#[test]
fn config_zero_classes_errors() {
let r = RtmDetConfig::new(3, 32, 8, vec![8, 16], 1, 5, 8, 0);
assert!(matches!(r, Err(VisionError::InvalidNumClasses(0))));
}
#[test]
fn conv2d_stride2_halves_spatial() {
let mut rng = LcgRng::new(1);
let conv = Conv2d::new(3, 4, 3, 2, 1, &mut rng);
let x = FeatureMap::new(vec![0.5f32; 3 * 16 * 16], 3, 16, 16).expect("ok");
let y = conv.forward(&x).expect("ok");
assert_eq!((y.channels, y.height, y.width), (4, 8, 8));
}
#[test]
fn dwconv_identity_kernel_is_input() {
let mut rng = LcgRng::new(2);
let mut dw = DwConv2d::new(2, 3, 1, 1, &mut rng);
for v in dw.weight.iter_mut() {
*v = 0.0;
}
for ch in 0..2 {
dw.weight[ch * 9 + 4] = 1.0;
}
let mut data = vec![0.0f32; 2 * 4 * 4];
let mut r2 = LcgRng::new(3);
r2.fill_normal(&mut data);
let x = FeatureMap::new(data.clone(), 2, 4, 4).expect("ok");
let y = dw.forward(&x).expect("ok");
for (a, b) in y.data.iter().zip(data.iter()) {
assert!((a - b).abs() < 1e-5, "identity dw mismatch {a} vs {b}");
}
}
#[test]
fn backbone_multiscale_halving() {
let cfg = RtmDetConfig::tiny();
let mut rng = LcgRng::new(10);
let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
let img = random_image(&cfg, 11);
let feats = det.backbone_features(&img).expect("ok");
assert_eq!(feats.len(), 3, "one feature per stage");
let spatials: Vec<usize> = feats.iter().map(|f| f.height).collect();
assert_eq!(spatials, vec![8, 4, 2]);
for w in feats.windows(2) {
assert_eq!(w[0].height, w[1].height * 2, "each stage halves spatial");
assert_eq!(w[0].width, w[1].width * 2);
}
let chans: Vec<usize> = feats.iter().map(|f| f.channels).collect();
assert_eq!(chans, cfg.stage_channels);
assert!(feats.iter().all(|f| f.data.iter().all(|v| v.is_finite())));
}
#[test]
fn pafpn_uniform_channels_same_scales() {
let cfg = RtmDetConfig::tiny();
let mut rng = LcgRng::new(12);
let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
let img = random_image(&cfg, 13);
let neck = det.neck_features(&img).expect("ok");
assert_eq!(neck.len(), cfg.n_levels(), "neck preserves #scales");
for fm in &neck {
assert_eq!(fm.channels, cfg.neck_channels, "uniform fused channels");
}
let spatials: Vec<usize> = neck.iter().map(|f| f.height).collect();
assert_eq!(spatials, vec![8, 4, 2]);
assert!(neck.iter().all(|f| f.data.iter().all(|v| v.is_finite())));
}
#[test]
fn decoupled_head_shapes() {
let cfg = RtmDetConfig::tiny();
let mut rng = LcgRng::new(14);
let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
let img = random_image(&cfg, 15);
let out = det.forward(&img).expect("ok");
assert_eq!(out.cls_scores.len(), 3);
assert_eq!(out.bbox_preds.len(), 3);
for (cls, reg) in out.cls_scores.iter().zip(out.bbox_preds.iter()) {
assert_eq!(cls.channels, cfg.n_classes, "cls has n_classes channels");
assert_eq!(reg.channels, 4, "reg has 4 channels");
assert_eq!(cls.height, reg.height);
assert_eq!(cls.data.len(), cfg.n_classes * cls.height * cls.width);
assert_eq!(reg.data.len(), 4 * reg.height * reg.width);
}
assert_eq!(out.strides, vec![4, 8, 16]);
}
#[test]
fn forward_all_finite() {
let cfg = RtmDetConfig::tiny();
let mut rng = LcgRng::new(16);
let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
let img = random_image(&cfg, 17);
let out = det.forward(&img).expect("ok");
for fm in out.cls_scores.iter().chain(out.bbox_preds.iter()) {
assert!(fm.data.iter().all(|v| v.is_finite()));
}
}
#[test]
fn varying_input_changes_detections() {
let cfg = RtmDetConfig::tiny();
let mut rng = LcgRng::new(18);
let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
let img_a = random_image(&cfg, 19);
let img_b = random_image(&cfg, 20);
let out_a = det.forward(&img_a).expect("ok");
let out_b = det.forward(&img_b).expect("ok");
let diff: f32 = out_a.cls_scores[0]
.data
.iter()
.zip(out_b.cls_scores[0].data.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
diff > 1e-4,
"detections should change with input, diff={diff}"
);
}
#[test]
fn decode_level_produces_valid_boxes() {
let cfg = RtmDetConfig::tiny();
let mut rng = LcgRng::new(21);
let det = RtmDet::new(cfg.clone(), &mut rng).expect("ok");
let img = random_image(&cfg, 22);
let out = det.forward(&img).expect("ok");
let (boxes, scores, labels) =
decode_level(&out.cls_scores[0], &out.bbox_preds[0], out.strides[0]).expect("ok");
let n = out.cls_scores[0].height * out.cls_scores[0].width;
assert_eq!(boxes.len(), n * 4);
assert_eq!(scores.len(), n);
assert_eq!(labels.len(), n);
for loc in 0..n {
assert!(boxes[loc * 4 + 2] > boxes[loc * 4], "x2 must exceed x1");
assert!(boxes[loc * 4 + 3] > boxes[loc * 4 + 1], "y2 must exceed y1");
assert!((0.0..=1.0).contains(&scores[loc]), "score in [0,1]");
assert!(labels[loc] < cfg.n_classes);
}
assert!(boxes.iter().all(|v| v.is_finite()));
}
#[test]
fn simota_lower_cost_for_better_match() {
let n_classes = 2;
let pred_cls = vec![
0.9f32, 0.1, 0.1, 0.9, ];
let pred_boxes = vec![
0.0f32, 0.0, 10.0, 10.0, 20.0, 20.0, 30.0, 30.0, ];
let gt_labels = vec![0usize];
let gt_boxes = vec![0.0f32, 0.0, 10.0, 10.0];
let cost = simota_cost(
&pred_cls,
&pred_boxes,
>_labels,
>_boxes,
n_classes,
3.0,
)
.expect("ok");
assert_eq!(cost.len(), 2, "[n_gt × n_pred]");
assert!(cost.iter().all(|v| v.is_finite()), "cost must be finite");
assert!(
cost[0] < cost[1],
"better cls+iou match must have lower cost: {} vs {}",
cost[0],
cost[1]
);
}
#[test]
fn simota_cost_monotonic_in_iou() {
let n_classes = 1;
let pred_cls = vec![0.8f32, 0.8, 0.8];
let pred_boxes = vec![
0.0f32, 0.0, 10.0, 10.0, 5.0, 0.0, 15.0, 10.0, 50.0, 50.0, 60.0, 60.0, ];
let gt_labels = vec![0usize];
let gt_boxes = vec![0.0f32, 0.0, 10.0, 10.0];
let cost = simota_cost(
&pred_cls,
&pred_boxes,
>_labels,
>_boxes,
n_classes,
2.0,
)
.expect("ok");
assert!(cost[0] < cost[1], "higher IoU → lower cost");
assert!(cost[1] < cost[2], "higher IoU → lower cost");
}
#[test]
fn simota_errors_on_bad_shapes() {
let r = simota_cost(
&[0.5f32],
&[0.0, 0.0, 1.0, 1.0],
&[0],
&[0.0, 0.0, 1.0, 1.0],
2,
1.0,
);
assert!(matches!(r, Err(VisionError::DimensionMismatch { .. })));
let r2 = simota_cost(&[], &[], &[0], &[0.0, 0.0, 1.0, 1.0], 1, 1.0);
assert!(matches!(r2, Err(VisionError::EmptyInput(_))));
}
#[test]
fn deterministic_same_seed() {
let cfg = RtmDetConfig::tiny();
let img = random_image(&cfg, 30);
let mut ra = LcgRng::new(99);
let mut rb = LcgRng::new(99);
let da = RtmDet::new(cfg.clone(), &mut ra).expect("ok");
let db = RtmDet::new(cfg, &mut rb).expect("ok");
let oa = da.forward(&img).expect("ok");
let ob = db.forward(&img).expect("ok");
for (a, b) in oa.cls_scores.iter().zip(ob.cls_scores.iter()) {
assert_eq!(a.data, b.data, "same seed → identical output");
}
}
}