Skip to main content

oxicuda_vision/detection/
detr_decoder.rs

1//! DETR (DEtection TRansformer) decoder.
2//!
3//! Implements the DETR decoder as described in "End-to-End Object Detection with
4//! Transformers" (Carion et al., 2020).  Each decoder layer applies:
5//!
6//! 1. **Self-attention** over the object query embeddings (pre-norm).
7//! 2. **Cross-attention** from object queries to encoder memory (pre-norm).
8//! 3. **Feed-forward network** (two-layer MLP with GELU, pre-norm).
9//!
10//! The decoder stacks `depth` such layers sequentially.
11
12use crate::{
13    error::{VisionError, VisionResult},
14    handle::LcgRng,
15};
16
17// ─── DetrConfig ───────────────────────────────────────────────────────────────
18
19/// DETR decoder hyper-parameters.
20#[derive(Debug, Clone)]
21pub struct DetrConfig {
22    /// Number of object query vectors.
23    pub n_queries: usize,
24    /// Embedding dimension for all tokens (queries and encoder features).
25    pub embed_dim: usize,
26    /// Number of attention heads (must divide `embed_dim`).
27    pub n_heads: usize,
28    /// Number of decoder layers.
29    pub depth: usize,
30    /// MLP expansion factor: `mlp_dim = mlp_ratio * embed_dim`.
31    pub mlp_ratio: usize,
32}
33
34impl DetrConfig {
35    /// Construct a validated `DetrConfig`.
36    ///
37    /// # Errors
38    /// - `InvalidEmbedDim` if `embed_dim == 0`.
39    /// - `InvalidNumHeads` if `n_heads == 0`.
40    /// - `HeadDimMismatch` if `embed_dim % n_heads != 0`.
41    /// - `DimensionMismatch` if `n_queries == 0`, `depth == 0`, or `mlp_ratio == 0`.
42    pub fn new(
43        n_queries: usize,
44        embed_dim: usize,
45        n_heads: usize,
46        depth: usize,
47        mlp_ratio: usize,
48    ) -> VisionResult<Self> {
49        if embed_dim == 0 {
50            return Err(VisionError::InvalidEmbedDim(embed_dim));
51        }
52        if n_heads == 0 {
53            return Err(VisionError::InvalidNumHeads(n_heads));
54        }
55        if embed_dim % n_heads != 0 {
56            return Err(VisionError::HeadDimMismatch { n_heads, embed_dim });
57        }
58        if n_queries == 0 {
59            return Err(VisionError::DimensionMismatch {
60                expected: 1,
61                got: 0,
62            });
63        }
64        if depth == 0 {
65            return Err(VisionError::DimensionMismatch {
66                expected: 1,
67                got: 0,
68            });
69        }
70        if mlp_ratio == 0 {
71            return Err(VisionError::DimensionMismatch {
72                expected: 1,
73                got: 0,
74            });
75        }
76        Ok(Self {
77            n_queries,
78            embed_dim,
79            n_heads,
80            depth,
81            mlp_ratio,
82        })
83    }
84
85    /// A tiny configuration for unit tests.
86    ///
87    /// `n_queries=4, embed_dim=32, n_heads=4, depth=1, mlp_ratio=4`.
88    pub fn tiny() -> Self {
89        Self {
90            n_queries: 4,
91            embed_dim: 32,
92            n_heads: 4,
93            depth: 1,
94            mlp_ratio: 4,
95        }
96    }
97
98    /// MLP hidden dimension.
99    #[inline]
100    pub fn mlp_dim(&self) -> usize {
101        self.mlp_ratio * self.embed_dim
102    }
103
104    /// Per-head dimension.
105    #[inline]
106    pub fn head_dim(&self) -> usize {
107        self.embed_dim / self.n_heads
108    }
109}
110
111// ─── DetrDecoderLayerWeights ──────────────────────────────────────────────────
112
113/// All learnable weights for a single DETR decoder layer.
114pub struct DetrDecoderLayerWeights {
115    // ── Self-attention (queries attend to queries) ────────────────────────────
116    /// Fused QKV projection: `[3 × embed_dim × embed_dim]`.
117    pub self_qkv_weight: Vec<f32>,
118    /// Fused QKV bias: `[3 × embed_dim]`.
119    pub self_qkv_bias: Vec<f32>,
120    /// Output projection: `[embed_dim × embed_dim]`.
121    pub self_out_weight: Vec<f32>,
122    /// Output projection bias: `[embed_dim]`.
123    pub self_out_bias: Vec<f32>,
124
125    // ── Cross-attention (queries attend to encoder memory) ────────────────────
126    /// Query projection: `[embed_dim × embed_dim]`.
127    pub cross_q_weight: Vec<f32>,
128    /// Query projection bias: `[embed_dim]`.
129    pub cross_q_bias: Vec<f32>,
130    /// Fused Key+Value projection from encoder: `[2 × embed_dim × embed_dim]`.
131    pub cross_kv_weight: Vec<f32>,
132    /// Fused KV bias: `[2 × embed_dim]`.
133    pub cross_kv_bias: Vec<f32>,
134    /// Cross-attention output projection: `[embed_dim × embed_dim]`.
135    pub cross_out_weight: Vec<f32>,
136    /// Cross-attention output bias: `[embed_dim]`.
137    pub cross_out_bias: Vec<f32>,
138
139    // ── Feed-forward network ─────────────────────────────────────────────────
140    /// FFN first layer: `[mlp_dim × embed_dim]`.
141    pub ffn1_weight: Vec<f32>,
142    /// FFN first layer bias: `[mlp_dim]`.
143    pub ffn1_bias: Vec<f32>,
144    /// FFN second layer: `[embed_dim × mlp_dim]`.
145    pub ffn2_weight: Vec<f32>,
146    /// FFN second layer bias: `[embed_dim]`.
147    pub ffn2_bias: Vec<f32>,
148
149    // ── Layer normalisation (three norms per layer) ───────────────────────────
150    /// LN after self-attention: scale `[embed_dim]`.
151    pub ln1_weight: Vec<f32>,
152    /// LN after self-attention: bias `[embed_dim]`.
153    pub ln1_bias: Vec<f32>,
154    /// LN before cross-attention: scale `[embed_dim]`.
155    pub ln2_weight: Vec<f32>,
156    /// LN before cross-attention: bias `[embed_dim]`.
157    pub ln2_bias: Vec<f32>,
158    /// LN before FFN: scale `[embed_dim]`.
159    pub ln3_weight: Vec<f32>,
160    /// LN before FFN: bias `[embed_dim]`.
161    pub ln3_bias: Vec<f32>,
162}
163
164impl DetrDecoderLayerWeights {
165    /// Xavier-style default initialisation.
166    ///
167    /// Attention/FFN weights: N(0, 1/√embed_dim); biases: zeros;
168    /// LayerNorm weights: ones; biases: zeros.
169    pub fn default_init(cfg: &DetrConfig, rng: &mut LcgRng) -> Self {
170        let e = cfg.embed_dim;
171        let mlp = cfg.mlp_dim();
172        let scale = 1.0_f32 / (e as f32).sqrt();
173
174        let fill_scaled = |rng: &mut LcgRng, n: usize| -> Vec<f32> {
175            let mut v = vec![0.0f32; n];
176            rng.fill_normal(&mut v);
177            for x in &mut v {
178                *x *= scale;
179            }
180            v
181        };
182
183        // Self-attention
184        let self_qkv_weight = fill_scaled(rng, 3 * e * e);
185        let self_qkv_bias = vec![0.0f32; 3 * e];
186        let self_out_weight = fill_scaled(rng, e * e);
187        let self_out_bias = vec![0.0f32; e];
188
189        // Cross-attention
190        let cross_q_weight = fill_scaled(rng, e * e);
191        let cross_q_bias = vec![0.0f32; e];
192        let cross_kv_weight = fill_scaled(rng, 2 * e * e);
193        let cross_kv_bias = vec![0.0f32; 2 * e];
194        let cross_out_weight = fill_scaled(rng, e * e);
195        let cross_out_bias = vec![0.0f32; e];
196
197        // FFN
198        let ffn1_weight = fill_scaled(rng, mlp * e);
199        let ffn1_bias = vec![0.0f32; mlp];
200        let ffn2_weight = fill_scaled(rng, e * mlp);
201        let ffn2_bias = vec![0.0f32; e];
202
203        // Layer norms
204        let ln1_weight = vec![1.0f32; e];
205        let ln1_bias = vec![0.0f32; e];
206        let ln2_weight = vec![1.0f32; e];
207        let ln2_bias = vec![0.0f32; e];
208        let ln3_weight = vec![1.0f32; e];
209        let ln3_bias = vec![0.0f32; e];
210
211        Self {
212            self_qkv_weight,
213            self_qkv_bias,
214            self_out_weight,
215            self_out_bias,
216            cross_q_weight,
217            cross_q_bias,
218            cross_kv_weight,
219            cross_kv_bias,
220            cross_out_weight,
221            cross_out_bias,
222            ffn1_weight,
223            ffn1_bias,
224            ffn2_weight,
225            ffn2_bias,
226            ln1_weight,
227            ln1_bias,
228            ln2_weight,
229            ln2_bias,
230            ln3_weight,
231            ln3_bias,
232        }
233    }
234}
235
236// ─── DetrDecoderLayer ─────────────────────────────────────────────────────────
237
238/// A single DETR decoder layer.
239pub struct DetrDecoderLayer {
240    /// Decoder configuration (n_queries, embed_dim, n_heads, …).
241    pub config: DetrConfig,
242    /// Learned weights for this layer.
243    pub weights: DetrDecoderLayerWeights,
244}
245
246impl DetrDecoderLayer {
247    /// Construct a new decoder layer with Xavier-initialised weights.
248    pub fn new(cfg: DetrConfig, rng: &mut LcgRng) -> Self {
249        let weights = DetrDecoderLayerWeights::default_init(&cfg, rng);
250        Self {
251            config: cfg,
252            weights,
253        }
254    }
255
256    /// Forward pass for one decoder layer.
257    ///
258    /// Pre-norm residual scheme:
259    /// ```text
260    /// q1  = self_attn(LN1(queries)) + queries
261    /// q2  = cross_attn(LN2(q1), key=encoder, val=encoder) + q1
262    /// out = FFN(LN3(q2)) + q2
263    /// ```
264    ///
265    /// # Parameters
266    /// - `queries`:       flat `[n_queries × embed_dim]`.
267    /// - `encoder_feats`: flat `[n_enc_tokens × embed_dim]`.
268    /// - `n_enc_tokens`:  number of encoder feature tokens.
269    ///
270    /// # Returns
271    /// Updated queries: flat `[n_queries × embed_dim]`.
272    ///
273    /// # Errors
274    /// - `DimensionMismatch` if input tensor lengths are inconsistent.
275    /// - `NonFinite` if NaN/Inf appear in attention output.
276    pub fn forward(
277        &self,
278        queries: &[f32],
279        encoder_feats: &[f32],
280        n_enc_tokens: usize,
281    ) -> VisionResult<Vec<f32>> {
282        let e = self.config.embed_dim;
283        let nq = self.config.n_queries;
284        let nh = self.config.n_heads;
285        let w = &self.weights;
286
287        // Validate input sizes.
288        let expected_q = nq * e;
289        if queries.len() != expected_q {
290            return Err(VisionError::DimensionMismatch {
291                expected: expected_q,
292                got: queries.len(),
293            });
294        }
295        let expected_enc = n_enc_tokens * e;
296        if encoder_feats.len() != expected_enc {
297            return Err(VisionError::DimensionMismatch {
298                expected: expected_enc,
299                got: encoder_feats.len(),
300            });
301        }
302        if n_enc_tokens == 0 {
303            return Err(VisionError::EmptyInput("encoder features"));
304        }
305
306        // ── Step 1: Self-attention ────────────────────────────────────────────
307        // Pre-norm: LN1(queries)
308        let queries_normed = layer_norm(queries, &w.ln1_weight, &w.ln1_bias, nq, e, 1e-5);
309        // Self-attn: Q=K=V=queries_normed
310        let sa_out = mhsa_self(
311            &queries_normed,
312            nq,
313            e,
314            nh,
315            &w.self_qkv_weight,
316            &w.self_qkv_bias,
317            &w.self_out_weight,
318            &w.self_out_bias,
319        )?;
320        // Residual 1: queries + self_attn_out
321        let q1: Vec<f32> = queries
322            .iter()
323            .zip(sa_out.iter())
324            .map(|(a, b)| a + b)
325            .collect();
326
327        // ── Step 2: Cross-attention ───────────────────────────────────────────
328        // Pre-norm: LN2(q1)
329        let q1_normed = layer_norm(&q1, &w.ln2_weight, &w.ln2_bias, nq, e, 1e-5);
330        // Cross-attn: Q from normed queries, K/V from encoder
331        let ca_out = mhsa_cross(
332            &q1_normed,
333            nq,
334            encoder_feats,
335            n_enc_tokens,
336            e,
337            nh,
338            &w.cross_q_weight,
339            &w.cross_q_bias,
340            &w.cross_kv_weight,
341            &w.cross_kv_bias,
342            &w.cross_out_weight,
343            &w.cross_out_bias,
344        )?;
345        // Residual 2: q1 + cross_attn_out
346        let q2: Vec<f32> = q1.iter().zip(ca_out.iter()).map(|(a, b)| a + b).collect();
347
348        // ── Step 3: FFN ───────────────────────────────────────────────────────
349        // Pre-norm: LN3(q2)
350        let q2_normed = layer_norm(&q2, &w.ln3_weight, &w.ln3_bias, nq, e, 1e-5);
351        let mlp_dim = self.config.mlp_dim();
352        // Linear1 → GELU
353        let ffn_mid = linear(&q2_normed, &w.ffn1_weight, &w.ffn1_bias, e, mlp_dim);
354        let ffn_mid: Vec<f32> = ffn_mid.iter().map(|&v| gelu_approx(v)).collect();
355        // Linear2
356        let ffn_out = linear(&ffn_mid, &w.ffn2_weight, &w.ffn2_bias, mlp_dim, e);
357        // Residual 3: q2 + ffn_out
358        let out: Vec<f32> = q2.iter().zip(ffn_out.iter()).map(|(a, b)| a + b).collect();
359
360        Ok(out)
361    }
362}
363
364// ─── DetrDecoder ─────────────────────────────────────────────────────────────
365
366/// Multi-layer DETR decoder: stacks `config.depth` decoder layers.
367pub struct DetrDecoder {
368    /// Decoder layers in order of application.
369    pub layers: Vec<DetrDecoderLayer>,
370}
371
372impl DetrDecoder {
373    /// Build a new `DetrDecoder` with `cfg.depth` layers, all Xavier-initialised.
374    ///
375    /// # Errors
376    /// - `DimensionMismatch` if `cfg.depth == 0`.
377    /// - Propagates errors from `DetrConfig` validation (via cloning).
378    pub fn new(cfg: DetrConfig, rng: &mut LcgRng) -> VisionResult<Self> {
379        if cfg.depth == 0 {
380            return Err(VisionError::DimensionMismatch {
381                expected: 1,
382                got: 0,
383            });
384        }
385        let depth = cfg.depth;
386        let mut layers = Vec::with_capacity(depth);
387        for _ in 0..depth {
388            layers.push(DetrDecoderLayer::new(cfg.clone(), rng));
389        }
390        Ok(Self { layers })
391    }
392
393    /// Apply all decoder layers in sequence.
394    ///
395    /// # Parameters
396    /// - `queries`:       flat `[n_queries × embed_dim]`.
397    /// - `encoder_feats`: flat `[n_enc_tokens × embed_dim]`.
398    /// - `n_enc_tokens`:  number of encoder memory tokens.
399    ///
400    /// # Returns
401    /// Final queries: flat `[n_queries × embed_dim]`.
402    pub fn forward(
403        &self,
404        queries: &[f32],
405        encoder_feats: &[f32],
406        n_enc_tokens: usize,
407    ) -> VisionResult<Vec<f32>> {
408        let mut current = queries.to_vec();
409        for layer in &self.layers {
410            current = layer.forward(&current, encoder_feats, n_enc_tokens)?;
411        }
412        Ok(current)
413    }
414}
415
416// ─── Internal helpers ─────────────────────────────────────────────────────────
417
418/// Per-row layer normalisation.
419///
420/// For each of `n` rows of length `d`:
421/// ```text
422/// out[i, j] = (x[i, j] - mean_i) / sqrt(var_i + eps) * weight[j] + bias[j]
423/// ```
424fn layer_norm(x: &[f32], weight: &[f32], bias: &[f32], n: usize, d: usize, eps: f32) -> Vec<f32> {
425    let mut out = vec![0.0f32; n * d];
426    for i in 0..n {
427        let row = &x[i * d..(i + 1) * d];
428        let mean: f32 = row.iter().sum::<f32>() / d as f32;
429        let var: f32 = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / d as f32;
430        let inv_std = 1.0 / (var + eps).sqrt();
431        let o = &mut out[i * d..(i + 1) * d];
432        for j in 0..d {
433            o[j] = (row[j] - mean) * inv_std * weight[j] + bias[j];
434        }
435    }
436    out
437}
438
439/// Dense linear transform: `y = x W^T + b`.
440///
441/// - `x`: `[batch × n_in]`.
442/// - `w`: `[n_out × n_in]`.
443/// - `b`: `[n_out]`.
444///
445/// Returns `[batch × n_out]`.
446fn linear(x: &[f32], w: &[f32], b: &[f32], n_in: usize, n_out: usize) -> Vec<f32> {
447    let batch = x.len() / n_in;
448    let mut out = vec![0.0f32; batch * n_out];
449    for bi in 0..batch {
450        let xrow = &x[bi * n_in..(bi + 1) * n_in];
451        let orow = &mut out[bi * n_out..(bi + 1) * n_out];
452        for oi in 0..n_out {
453            let wrow = &w[oi * n_in..(oi + 1) * n_in];
454            let mut acc = b[oi];
455            for k in 0..n_in {
456                acc += xrow[k] * wrow[k];
457            }
458            orow[oi] = acc;
459        }
460    }
461    out
462}
463
464/// GELU activation via tanh approximation.
465///
466/// ```text
467/// GELU(x) ≈ x * 0.5 * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
468/// ```
469#[inline]
470fn gelu_approx(x: f32) -> f32 {
471    const SQRT_2_OVER_PI: f32 = 0.797_884_6;
472    const COEFF: f32 = 0.044_715;
473    let inner = SQRT_2_OVER_PI * (x + COEFF * x * x * x);
474    x * 0.5 * (1.0 + inner.tanh())
475}
476
477/// Row-wise softmax with max subtraction for numerical stability.
478fn softmax_rows(logits: &mut [f32], n_rows: usize, n_cols: usize) {
479    for i in 0..n_rows {
480        let row = &mut logits[i * n_cols..(i + 1) * n_cols];
481        let mx = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
482        let mut sum = 0.0f32;
483        for v in row.iter_mut() {
484            *v = (*v - mx).exp();
485            sum += *v;
486        }
487        let inv = if sum > 0.0 { 1.0 / sum } else { 1.0 };
488        for v in row.iter_mut() {
489            *v *= inv;
490        }
491    }
492}
493
494/// Multi-head **self**-attention: Q, K, V all from the same token sequence.
495///
496/// Uses a fused `[3 * embed_dim × embed_dim]` QKV projection matrix.
497#[allow(clippy::too_many_arguments)]
498fn mhsa_self(
499    tokens: &[f32],
500    n_tokens: usize,
501    embed_dim: usize,
502    n_heads: usize,
503    qkv_weight: &[f32],
504    qkv_bias: &[f32],
505    out_weight: &[f32],
506    out_bias: &[f32],
507) -> VisionResult<Vec<f32>> {
508    let head_dim = embed_dim / n_heads;
509    // Fused QKV projection: [n_tokens × 3*embed_dim]
510    let qkv = linear(tokens, qkv_weight, qkv_bias, embed_dim, 3 * embed_dim);
511
512    // Split into Q, K, V each [n_tokens × embed_dim]
513    let mut q = vec![0.0f32; n_tokens * embed_dim];
514    let mut k = vec![0.0f32; n_tokens * embed_dim];
515    let mut v = vec![0.0f32; n_tokens * embed_dim];
516    for t in 0..n_tokens {
517        let src = &qkv[t * 3 * embed_dim..(t + 1) * 3 * embed_dim];
518        q[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[..embed_dim]);
519        k[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[embed_dim..2 * embed_dim]);
520        v[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[2 * embed_dim..]);
521    }
522
523    compute_attention(
524        &q, n_tokens, &k, n_tokens, &v, embed_dim, n_heads, head_dim, out_weight, out_bias,
525    )
526}
527
528/// Multi-head **cross**-attention: Q from queries, K/V from encoder memory.
529///
530/// `q_weight`: `[embed_dim × embed_dim]`
531/// `kv_weight`: `[2 * embed_dim × embed_dim]` (first half = K, second half = V)
532#[allow(clippy::too_many_arguments)]
533fn mhsa_cross(
534    queries: &[f32],
535    n_queries: usize,
536    encoder: &[f32],
537    n_enc: usize,
538    embed_dim: usize,
539    n_heads: usize,
540    q_weight: &[f32],
541    q_bias: &[f32],
542    kv_weight: &[f32],
543    kv_bias: &[f32],
544    out_weight: &[f32],
545    out_bias: &[f32],
546) -> VisionResult<Vec<f32>> {
547    let head_dim = embed_dim / n_heads;
548
549    // Q projection: [n_queries × embed_dim]
550    let q = linear(queries, q_weight, q_bias, embed_dim, embed_dim);
551
552    // KV fused projection: [n_enc × 2*embed_dim]
553    let kv = linear(encoder, kv_weight, kv_bias, embed_dim, 2 * embed_dim);
554
555    // Split KV into K and V each [n_enc × embed_dim]
556    let mut k = vec![0.0f32; n_enc * embed_dim];
557    let mut v = vec![0.0f32; n_enc * embed_dim];
558    for t in 0..n_enc {
559        let src = &kv[t * 2 * embed_dim..(t + 1) * 2 * embed_dim];
560        k[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[..embed_dim]);
561        v[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[embed_dim..]);
562    }
563
564    compute_attention(
565        &q, n_queries, &k, n_enc, &v, embed_dim, n_heads, head_dim, out_weight, out_bias,
566    )
567}
568
569/// Core scaled dot-product attention computation.
570///
571/// Given already-projected Q `[n_q × embed_dim]`, K `[n_k × embed_dim]`,
572/// V `[n_k × embed_dim]`, computes:
573/// ```text
574/// scores = Q @ K^T / sqrt(head_dim)  [n_q × n_k] per head
575/// attn   = softmax(scores) @ V
576/// out    = concat(attn_heads) @ out_weight + out_bias
577/// ```
578#[allow(clippy::too_many_arguments)]
579fn compute_attention(
580    q: &[f32],
581    n_q: usize,
582    k: &[f32],
583    n_k: usize,
584    v: &[f32],
585    embed_dim: usize,
586    n_heads: usize,
587    head_dim: usize,
588    out_weight: &[f32],
589    out_bias: &[f32],
590) -> VisionResult<Vec<f32>> {
591    let scale = 1.0_f32 / (head_dim as f32).sqrt();
592    let mut concat = vec![0.0f32; n_q * embed_dim];
593    let mut scores = vec![0.0f32; n_q * n_k];
594
595    for h in 0..n_heads {
596        let hd_off = h * head_dim;
597
598        // Compute scores[i, j] = scale * dot(Q[i, h*hd..], K[j, h*hd..])
599        for i in 0..n_q {
600            for j in 0..n_k {
601                let mut dot = 0.0f32;
602                for d in 0..head_dim {
603                    dot += q[i * embed_dim + hd_off + d] * k[j * embed_dim + hd_off + d];
604                }
605                scores[i * n_k + j] = dot * scale;
606            }
607        }
608
609        // Row-wise softmax over keys
610        softmax_rows(&mut scores, n_q, n_k);
611
612        // Weighted value sum: out[i, h*hd + d] = Σ_j scores[i,j] * V[j, h*hd + d]
613        for i in 0..n_q {
614            for d in 0..head_dim {
615                let mut acc = 0.0f32;
616                for j in 0..n_k {
617                    acc += scores[i * n_k + j] * v[j * embed_dim + hd_off + d];
618                }
619                concat[i * embed_dim + hd_off + d] = acc;
620            }
621        }
622    }
623
624    let out = linear(&concat, out_weight, out_bias, embed_dim, embed_dim);
625
626    if out.iter().any(|v| !v.is_finite()) {
627        return Err(VisionError::NonFinite("DETR decoder attention output"));
628    }
629
630    Ok(out)
631}
632
633// ─── Tests ───────────────────────────────────────────────────────────────────
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638
639    fn make_rng() -> LcgRng {
640        LcgRng::new(42)
641    }
642
643    // ── DetrConfig ─────────────────────────────────────────────────────────────
644
645    #[test]
646    fn detr_config_tiny() {
647        let cfg = DetrConfig::tiny();
648        assert_eq!(cfg.n_queries, 4);
649        assert_eq!(cfg.embed_dim, 32);
650        assert_eq!(cfg.n_heads, 4);
651        assert_eq!(cfg.depth, 1);
652        assert_eq!(cfg.mlp_ratio, 4);
653        assert_eq!(cfg.mlp_dim(), 128);
654        assert_eq!(cfg.head_dim(), 8);
655    }
656
657    #[test]
658    fn detr_config_invalid_embed_dim_zero() {
659        let r = DetrConfig::new(4, 0, 4, 1, 4);
660        assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
661    }
662
663    #[test]
664    fn detr_config_invalid_heads_zero() {
665        let r = DetrConfig::new(4, 32, 0, 1, 4);
666        assert!(matches!(r, Err(VisionError::InvalidNumHeads(0))));
667    }
668
669    #[test]
670    fn detr_config_head_dim_mismatch() {
671        let r = DetrConfig::new(4, 32, 3, 1, 4); // 32 % 3 != 0
672        assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
673    }
674
675    #[test]
676    fn detr_config_zero_queries_errors() {
677        let r = DetrConfig::new(0, 32, 4, 1, 4);
678        assert!(r.is_err());
679    }
680
681    // ── Single layer forward ───────────────────────────────────────────────────
682
683    #[test]
684    fn single_layer_forward_shape() {
685        let mut rng = make_rng();
686        let cfg = DetrConfig::tiny();
687        let nq = cfg.n_queries;
688        let e = cfg.embed_dim;
689        let layer = DetrDecoderLayer::new(cfg, &mut rng);
690
691        let queries = vec![0.1f32; nq * e];
692        let encoder = vec![0.2f32; 8 * e]; // 8 encoder tokens
693        let out = layer.forward(&queries, &encoder, 8).expect("forward ok");
694
695        assert_eq!(out.len(), nq * e, "output shape [n_queries × embed_dim]");
696    }
697
698    #[test]
699    fn single_layer_forward_finite() {
700        let mut rng = make_rng();
701        let cfg = DetrConfig::tiny();
702        let nq = cfg.n_queries;
703        let e = cfg.embed_dim;
704        let layer = DetrDecoderLayer::new(cfg, &mut rng);
705
706        let mut queries = vec![0.0f32; nq * e];
707        rng.fill_normal(&mut queries);
708        let mut encoder = vec![0.0f32; 16 * e];
709        rng.fill_normal(&mut encoder);
710
711        let out = layer.forward(&queries, &encoder, 16).expect("forward ok");
712        assert!(out.iter().all(|v| v.is_finite()), "non-finite in output");
713    }
714
715    #[test]
716    fn single_layer_forward_wrong_query_size_errors() {
717        let mut rng = make_rng();
718        let cfg = DetrConfig::tiny();
719        let e = cfg.embed_dim;
720        let layer = DetrDecoderLayer::new(cfg, &mut rng);
721
722        // Provide wrong number of elements for queries
723        let queries = vec![0.0f32; 3 * e]; // should be 4 * e
724        let encoder = vec![0.0f32; 8 * e];
725        let r = layer.forward(&queries, &encoder, 8);
726        assert!(
727            matches!(r, Err(VisionError::DimensionMismatch { .. })),
728            "expected DimensionMismatch"
729        );
730    }
731
732    #[test]
733    fn single_layer_forward_empty_encoder_errors() {
734        let mut rng = make_rng();
735        let cfg = DetrConfig::tiny();
736        let nq = cfg.n_queries;
737        let e = cfg.embed_dim;
738        let layer = DetrDecoderLayer::new(cfg, &mut rng);
739
740        let queries = vec![0.0f32; nq * e];
741        let r = layer.forward(&queries, &[], 0);
742        assert!(r.is_err(), "expected error for empty encoder");
743    }
744
745    // ── Multi-layer decoder ────────────────────────────────────────────────────
746
747    #[test]
748    fn multi_layer_decoder_forward_shape() {
749        let mut rng = make_rng();
750        let cfg = DetrConfig::new(4, 32, 4, 3, 4).expect("valid config");
751        let nq = cfg.n_queries;
752        let e = cfg.embed_dim;
753        let decoder = DetrDecoder::new(cfg, &mut rng).expect("valid decoder");
754
755        let queries = vec![0.1f32; nq * e];
756        let encoder = vec![0.2f32; 12 * e];
757        let out = decoder
758            .forward(&queries, &encoder, 12)
759            .expect("multi-layer ok");
760
761        assert_eq!(out.len(), nq * e, "multi-layer output shape preserved");
762    }
763
764    #[test]
765    fn multi_layer_decoder_forward_finite() {
766        let mut rng = make_rng();
767        let cfg = DetrConfig::new(8, 32, 4, 2, 4).expect("valid config");
768        let nq = cfg.n_queries;
769        let e = cfg.embed_dim;
770        let decoder = DetrDecoder::new(cfg, &mut rng).expect("valid decoder");
771
772        let mut queries = vec![0.0f32; nq * e];
773        rng.fill_normal(&mut queries);
774        let mut encoder = vec![0.0f32; 6 * e];
775        rng.fill_normal(&mut encoder);
776
777        let out = decoder.forward(&queries, &encoder, 6).expect("forward ok");
778        assert!(
779            out.iter().all(|v| v.is_finite()),
780            "non-finite in multi-layer output"
781        );
782    }
783
784    // ── layer_norm ────────────────────────────────────────────────────────────
785
786    #[test]
787    fn layer_norm_constant_row_is_zero() {
788        let x = vec![5.0f32; 32];
789        let w = vec![1.0f32; 32];
790        let b = vec![0.0f32; 32];
791        let out = layer_norm(&x, &w, &b, 1, 32, 1e-5);
792        for v in &out {
793            assert!(v.abs() < 1e-5, "expected near-zero, got {v}");
794        }
795    }
796
797    // ── gelu_approx ───────────────────────────────────────────────────────────
798
799    #[test]
800    fn gelu_zero() {
801        assert!((gelu_approx(0.0) - 0.0).abs() < 1e-6);
802    }
803
804    #[test]
805    fn gelu_large_pos() {
806        assert!((gelu_approx(10.0) - 10.0).abs() < 1e-3);
807    }
808
809    #[test]
810    fn gelu_large_neg() {
811        assert!(gelu_approx(-10.0).abs() < 1e-3);
812    }
813}