1use crate::{
13 error::{VisionError, VisionResult},
14 handle::LcgRng,
15};
16
17#[derive(Debug, Clone)]
21pub struct DetrConfig {
22 pub n_queries: usize,
24 pub embed_dim: usize,
26 pub n_heads: usize,
28 pub depth: usize,
30 pub mlp_ratio: usize,
32}
33
34impl DetrConfig {
35 pub fn new(
43 n_queries: usize,
44 embed_dim: usize,
45 n_heads: usize,
46 depth: usize,
47 mlp_ratio: usize,
48 ) -> VisionResult<Self> {
49 if embed_dim == 0 {
50 return Err(VisionError::InvalidEmbedDim(embed_dim));
51 }
52 if n_heads == 0 {
53 return Err(VisionError::InvalidNumHeads(n_heads));
54 }
55 if embed_dim % n_heads != 0 {
56 return Err(VisionError::HeadDimMismatch { n_heads, embed_dim });
57 }
58 if n_queries == 0 {
59 return Err(VisionError::DimensionMismatch {
60 expected: 1,
61 got: 0,
62 });
63 }
64 if depth == 0 {
65 return Err(VisionError::DimensionMismatch {
66 expected: 1,
67 got: 0,
68 });
69 }
70 if mlp_ratio == 0 {
71 return Err(VisionError::DimensionMismatch {
72 expected: 1,
73 got: 0,
74 });
75 }
76 Ok(Self {
77 n_queries,
78 embed_dim,
79 n_heads,
80 depth,
81 mlp_ratio,
82 })
83 }
84
85 pub fn tiny() -> Self {
89 Self {
90 n_queries: 4,
91 embed_dim: 32,
92 n_heads: 4,
93 depth: 1,
94 mlp_ratio: 4,
95 }
96 }
97
98 #[inline]
100 pub fn mlp_dim(&self) -> usize {
101 self.mlp_ratio * self.embed_dim
102 }
103
104 #[inline]
106 pub fn head_dim(&self) -> usize {
107 self.embed_dim / self.n_heads
108 }
109}
110
111pub struct DetrDecoderLayerWeights {
115 pub self_qkv_weight: Vec<f32>,
118 pub self_qkv_bias: Vec<f32>,
120 pub self_out_weight: Vec<f32>,
122 pub self_out_bias: Vec<f32>,
124
125 pub cross_q_weight: Vec<f32>,
128 pub cross_q_bias: Vec<f32>,
130 pub cross_kv_weight: Vec<f32>,
132 pub cross_kv_bias: Vec<f32>,
134 pub cross_out_weight: Vec<f32>,
136 pub cross_out_bias: Vec<f32>,
138
139 pub ffn1_weight: Vec<f32>,
142 pub ffn1_bias: Vec<f32>,
144 pub ffn2_weight: Vec<f32>,
146 pub ffn2_bias: Vec<f32>,
148
149 pub ln1_weight: Vec<f32>,
152 pub ln1_bias: Vec<f32>,
154 pub ln2_weight: Vec<f32>,
156 pub ln2_bias: Vec<f32>,
158 pub ln3_weight: Vec<f32>,
160 pub ln3_bias: Vec<f32>,
162}
163
164impl DetrDecoderLayerWeights {
165 pub fn default_init(cfg: &DetrConfig, rng: &mut LcgRng) -> Self {
170 let e = cfg.embed_dim;
171 let mlp = cfg.mlp_dim();
172 let scale = 1.0_f32 / (e as f32).sqrt();
173
174 let fill_scaled = |rng: &mut LcgRng, n: usize| -> Vec<f32> {
175 let mut v = vec![0.0f32; n];
176 rng.fill_normal(&mut v);
177 for x in &mut v {
178 *x *= scale;
179 }
180 v
181 };
182
183 let self_qkv_weight = fill_scaled(rng, 3 * e * e);
185 let self_qkv_bias = vec![0.0f32; 3 * e];
186 let self_out_weight = fill_scaled(rng, e * e);
187 let self_out_bias = vec![0.0f32; e];
188
189 let cross_q_weight = fill_scaled(rng, e * e);
191 let cross_q_bias = vec![0.0f32; e];
192 let cross_kv_weight = fill_scaled(rng, 2 * e * e);
193 let cross_kv_bias = vec![0.0f32; 2 * e];
194 let cross_out_weight = fill_scaled(rng, e * e);
195 let cross_out_bias = vec![0.0f32; e];
196
197 let ffn1_weight = fill_scaled(rng, mlp * e);
199 let ffn1_bias = vec![0.0f32; mlp];
200 let ffn2_weight = fill_scaled(rng, e * mlp);
201 let ffn2_bias = vec![0.0f32; e];
202
203 let ln1_weight = vec![1.0f32; e];
205 let ln1_bias = vec![0.0f32; e];
206 let ln2_weight = vec![1.0f32; e];
207 let ln2_bias = vec![0.0f32; e];
208 let ln3_weight = vec![1.0f32; e];
209 let ln3_bias = vec![0.0f32; e];
210
211 Self {
212 self_qkv_weight,
213 self_qkv_bias,
214 self_out_weight,
215 self_out_bias,
216 cross_q_weight,
217 cross_q_bias,
218 cross_kv_weight,
219 cross_kv_bias,
220 cross_out_weight,
221 cross_out_bias,
222 ffn1_weight,
223 ffn1_bias,
224 ffn2_weight,
225 ffn2_bias,
226 ln1_weight,
227 ln1_bias,
228 ln2_weight,
229 ln2_bias,
230 ln3_weight,
231 ln3_bias,
232 }
233 }
234}
235
236pub struct DetrDecoderLayer {
240 pub config: DetrConfig,
242 pub weights: DetrDecoderLayerWeights,
244}
245
246impl DetrDecoderLayer {
247 pub fn new(cfg: DetrConfig, rng: &mut LcgRng) -> Self {
249 let weights = DetrDecoderLayerWeights::default_init(&cfg, rng);
250 Self {
251 config: cfg,
252 weights,
253 }
254 }
255
256 pub fn forward(
277 &self,
278 queries: &[f32],
279 encoder_feats: &[f32],
280 n_enc_tokens: usize,
281 ) -> VisionResult<Vec<f32>> {
282 let e = self.config.embed_dim;
283 let nq = self.config.n_queries;
284 let nh = self.config.n_heads;
285 let w = &self.weights;
286
287 let expected_q = nq * e;
289 if queries.len() != expected_q {
290 return Err(VisionError::DimensionMismatch {
291 expected: expected_q,
292 got: queries.len(),
293 });
294 }
295 let expected_enc = n_enc_tokens * e;
296 if encoder_feats.len() != expected_enc {
297 return Err(VisionError::DimensionMismatch {
298 expected: expected_enc,
299 got: encoder_feats.len(),
300 });
301 }
302 if n_enc_tokens == 0 {
303 return Err(VisionError::EmptyInput("encoder features"));
304 }
305
306 let queries_normed = layer_norm(queries, &w.ln1_weight, &w.ln1_bias, nq, e, 1e-5);
309 let sa_out = mhsa_self(
311 &queries_normed,
312 nq,
313 e,
314 nh,
315 &w.self_qkv_weight,
316 &w.self_qkv_bias,
317 &w.self_out_weight,
318 &w.self_out_bias,
319 )?;
320 let q1: Vec<f32> = queries
322 .iter()
323 .zip(sa_out.iter())
324 .map(|(a, b)| a + b)
325 .collect();
326
327 let q1_normed = layer_norm(&q1, &w.ln2_weight, &w.ln2_bias, nq, e, 1e-5);
330 let ca_out = mhsa_cross(
332 &q1_normed,
333 nq,
334 encoder_feats,
335 n_enc_tokens,
336 e,
337 nh,
338 &w.cross_q_weight,
339 &w.cross_q_bias,
340 &w.cross_kv_weight,
341 &w.cross_kv_bias,
342 &w.cross_out_weight,
343 &w.cross_out_bias,
344 )?;
345 let q2: Vec<f32> = q1.iter().zip(ca_out.iter()).map(|(a, b)| a + b).collect();
347
348 let q2_normed = layer_norm(&q2, &w.ln3_weight, &w.ln3_bias, nq, e, 1e-5);
351 let mlp_dim = self.config.mlp_dim();
352 let ffn_mid = linear(&q2_normed, &w.ffn1_weight, &w.ffn1_bias, e, mlp_dim);
354 let ffn_mid: Vec<f32> = ffn_mid.iter().map(|&v| gelu_approx(v)).collect();
355 let ffn_out = linear(&ffn_mid, &w.ffn2_weight, &w.ffn2_bias, mlp_dim, e);
357 let out: Vec<f32> = q2.iter().zip(ffn_out.iter()).map(|(a, b)| a + b).collect();
359
360 Ok(out)
361 }
362}
363
364pub struct DetrDecoder {
368 pub layers: Vec<DetrDecoderLayer>,
370}
371
372impl DetrDecoder {
373 pub fn new(cfg: DetrConfig, rng: &mut LcgRng) -> VisionResult<Self> {
379 if cfg.depth == 0 {
380 return Err(VisionError::DimensionMismatch {
381 expected: 1,
382 got: 0,
383 });
384 }
385 let depth = cfg.depth;
386 let mut layers = Vec::with_capacity(depth);
387 for _ in 0..depth {
388 layers.push(DetrDecoderLayer::new(cfg.clone(), rng));
389 }
390 Ok(Self { layers })
391 }
392
393 pub fn forward(
403 &self,
404 queries: &[f32],
405 encoder_feats: &[f32],
406 n_enc_tokens: usize,
407 ) -> VisionResult<Vec<f32>> {
408 let mut current = queries.to_vec();
409 for layer in &self.layers {
410 current = layer.forward(¤t, encoder_feats, n_enc_tokens)?;
411 }
412 Ok(current)
413 }
414}
415
416fn layer_norm(x: &[f32], weight: &[f32], bias: &[f32], n: usize, d: usize, eps: f32) -> Vec<f32> {
425 let mut out = vec![0.0f32; n * d];
426 for i in 0..n {
427 let row = &x[i * d..(i + 1) * d];
428 let mean: f32 = row.iter().sum::<f32>() / d as f32;
429 let var: f32 = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / d as f32;
430 let inv_std = 1.0 / (var + eps).sqrt();
431 let o = &mut out[i * d..(i + 1) * d];
432 for j in 0..d {
433 o[j] = (row[j] - mean) * inv_std * weight[j] + bias[j];
434 }
435 }
436 out
437}
438
439fn linear(x: &[f32], w: &[f32], b: &[f32], n_in: usize, n_out: usize) -> Vec<f32> {
447 let batch = x.len() / n_in;
448 let mut out = vec![0.0f32; batch * n_out];
449 for bi in 0..batch {
450 let xrow = &x[bi * n_in..(bi + 1) * n_in];
451 let orow = &mut out[bi * n_out..(bi + 1) * n_out];
452 for oi in 0..n_out {
453 let wrow = &w[oi * n_in..(oi + 1) * n_in];
454 let mut acc = b[oi];
455 for k in 0..n_in {
456 acc += xrow[k] * wrow[k];
457 }
458 orow[oi] = acc;
459 }
460 }
461 out
462}
463
464#[inline]
470fn gelu_approx(x: f32) -> f32 {
471 const SQRT_2_OVER_PI: f32 = 0.797_884_6;
472 const COEFF: f32 = 0.044_715;
473 let inner = SQRT_2_OVER_PI * (x + COEFF * x * x * x);
474 x * 0.5 * (1.0 + inner.tanh())
475}
476
477fn softmax_rows(logits: &mut [f32], n_rows: usize, n_cols: usize) {
479 for i in 0..n_rows {
480 let row = &mut logits[i * n_cols..(i + 1) * n_cols];
481 let mx = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
482 let mut sum = 0.0f32;
483 for v in row.iter_mut() {
484 *v = (*v - mx).exp();
485 sum += *v;
486 }
487 let inv = if sum > 0.0 { 1.0 / sum } else { 1.0 };
488 for v in row.iter_mut() {
489 *v *= inv;
490 }
491 }
492}
493
494#[allow(clippy::too_many_arguments)]
498fn mhsa_self(
499 tokens: &[f32],
500 n_tokens: usize,
501 embed_dim: usize,
502 n_heads: usize,
503 qkv_weight: &[f32],
504 qkv_bias: &[f32],
505 out_weight: &[f32],
506 out_bias: &[f32],
507) -> VisionResult<Vec<f32>> {
508 let head_dim = embed_dim / n_heads;
509 let qkv = linear(tokens, qkv_weight, qkv_bias, embed_dim, 3 * embed_dim);
511
512 let mut q = vec![0.0f32; n_tokens * embed_dim];
514 let mut k = vec![0.0f32; n_tokens * embed_dim];
515 let mut v = vec![0.0f32; n_tokens * embed_dim];
516 for t in 0..n_tokens {
517 let src = &qkv[t * 3 * embed_dim..(t + 1) * 3 * embed_dim];
518 q[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[..embed_dim]);
519 k[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[embed_dim..2 * embed_dim]);
520 v[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[2 * embed_dim..]);
521 }
522
523 compute_attention(
524 &q, n_tokens, &k, n_tokens, &v, embed_dim, n_heads, head_dim, out_weight, out_bias,
525 )
526}
527
528#[allow(clippy::too_many_arguments)]
533fn mhsa_cross(
534 queries: &[f32],
535 n_queries: usize,
536 encoder: &[f32],
537 n_enc: usize,
538 embed_dim: usize,
539 n_heads: usize,
540 q_weight: &[f32],
541 q_bias: &[f32],
542 kv_weight: &[f32],
543 kv_bias: &[f32],
544 out_weight: &[f32],
545 out_bias: &[f32],
546) -> VisionResult<Vec<f32>> {
547 let head_dim = embed_dim / n_heads;
548
549 let q = linear(queries, q_weight, q_bias, embed_dim, embed_dim);
551
552 let kv = linear(encoder, kv_weight, kv_bias, embed_dim, 2 * embed_dim);
554
555 let mut k = vec![0.0f32; n_enc * embed_dim];
557 let mut v = vec![0.0f32; n_enc * embed_dim];
558 for t in 0..n_enc {
559 let src = &kv[t * 2 * embed_dim..(t + 1) * 2 * embed_dim];
560 k[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[..embed_dim]);
561 v[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[embed_dim..]);
562 }
563
564 compute_attention(
565 &q, n_queries, &k, n_enc, &v, embed_dim, n_heads, head_dim, out_weight, out_bias,
566 )
567}
568
569#[allow(clippy::too_many_arguments)]
579fn compute_attention(
580 q: &[f32],
581 n_q: usize,
582 k: &[f32],
583 n_k: usize,
584 v: &[f32],
585 embed_dim: usize,
586 n_heads: usize,
587 head_dim: usize,
588 out_weight: &[f32],
589 out_bias: &[f32],
590) -> VisionResult<Vec<f32>> {
591 let scale = 1.0_f32 / (head_dim as f32).sqrt();
592 let mut concat = vec![0.0f32; n_q * embed_dim];
593 let mut scores = vec![0.0f32; n_q * n_k];
594
595 for h in 0..n_heads {
596 let hd_off = h * head_dim;
597
598 for i in 0..n_q {
600 for j in 0..n_k {
601 let mut dot = 0.0f32;
602 for d in 0..head_dim {
603 dot += q[i * embed_dim + hd_off + d] * k[j * embed_dim + hd_off + d];
604 }
605 scores[i * n_k + j] = dot * scale;
606 }
607 }
608
609 softmax_rows(&mut scores, n_q, n_k);
611
612 for i in 0..n_q {
614 for d in 0..head_dim {
615 let mut acc = 0.0f32;
616 for j in 0..n_k {
617 acc += scores[i * n_k + j] * v[j * embed_dim + hd_off + d];
618 }
619 concat[i * embed_dim + hd_off + d] = acc;
620 }
621 }
622 }
623
624 let out = linear(&concat, out_weight, out_bias, embed_dim, embed_dim);
625
626 if out.iter().any(|v| !v.is_finite()) {
627 return Err(VisionError::NonFinite("DETR decoder attention output"));
628 }
629
630 Ok(out)
631}
632
633#[cfg(test)]
636mod tests {
637 use super::*;
638
639 fn make_rng() -> LcgRng {
640 LcgRng::new(42)
641 }
642
643 #[test]
646 fn detr_config_tiny() {
647 let cfg = DetrConfig::tiny();
648 assert_eq!(cfg.n_queries, 4);
649 assert_eq!(cfg.embed_dim, 32);
650 assert_eq!(cfg.n_heads, 4);
651 assert_eq!(cfg.depth, 1);
652 assert_eq!(cfg.mlp_ratio, 4);
653 assert_eq!(cfg.mlp_dim(), 128);
654 assert_eq!(cfg.head_dim(), 8);
655 }
656
657 #[test]
658 fn detr_config_invalid_embed_dim_zero() {
659 let r = DetrConfig::new(4, 0, 4, 1, 4);
660 assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
661 }
662
663 #[test]
664 fn detr_config_invalid_heads_zero() {
665 let r = DetrConfig::new(4, 32, 0, 1, 4);
666 assert!(matches!(r, Err(VisionError::InvalidNumHeads(0))));
667 }
668
669 #[test]
670 fn detr_config_head_dim_mismatch() {
671 let r = DetrConfig::new(4, 32, 3, 1, 4); assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
673 }
674
675 #[test]
676 fn detr_config_zero_queries_errors() {
677 let r = DetrConfig::new(0, 32, 4, 1, 4);
678 assert!(r.is_err());
679 }
680
681 #[test]
684 fn single_layer_forward_shape() {
685 let mut rng = make_rng();
686 let cfg = DetrConfig::tiny();
687 let nq = cfg.n_queries;
688 let e = cfg.embed_dim;
689 let layer = DetrDecoderLayer::new(cfg, &mut rng);
690
691 let queries = vec![0.1f32; nq * e];
692 let encoder = vec![0.2f32; 8 * e]; let out = layer.forward(&queries, &encoder, 8).expect("forward ok");
694
695 assert_eq!(out.len(), nq * e, "output shape [n_queries × embed_dim]");
696 }
697
698 #[test]
699 fn single_layer_forward_finite() {
700 let mut rng = make_rng();
701 let cfg = DetrConfig::tiny();
702 let nq = cfg.n_queries;
703 let e = cfg.embed_dim;
704 let layer = DetrDecoderLayer::new(cfg, &mut rng);
705
706 let mut queries = vec![0.0f32; nq * e];
707 rng.fill_normal(&mut queries);
708 let mut encoder = vec![0.0f32; 16 * e];
709 rng.fill_normal(&mut encoder);
710
711 let out = layer.forward(&queries, &encoder, 16).expect("forward ok");
712 assert!(out.iter().all(|v| v.is_finite()), "non-finite in output");
713 }
714
715 #[test]
716 fn single_layer_forward_wrong_query_size_errors() {
717 let mut rng = make_rng();
718 let cfg = DetrConfig::tiny();
719 let e = cfg.embed_dim;
720 let layer = DetrDecoderLayer::new(cfg, &mut rng);
721
722 let queries = vec![0.0f32; 3 * e]; let encoder = vec![0.0f32; 8 * e];
725 let r = layer.forward(&queries, &encoder, 8);
726 assert!(
727 matches!(r, Err(VisionError::DimensionMismatch { .. })),
728 "expected DimensionMismatch"
729 );
730 }
731
732 #[test]
733 fn single_layer_forward_empty_encoder_errors() {
734 let mut rng = make_rng();
735 let cfg = DetrConfig::tiny();
736 let nq = cfg.n_queries;
737 let e = cfg.embed_dim;
738 let layer = DetrDecoderLayer::new(cfg, &mut rng);
739
740 let queries = vec![0.0f32; nq * e];
741 let r = layer.forward(&queries, &[], 0);
742 assert!(r.is_err(), "expected error for empty encoder");
743 }
744
745 #[test]
748 fn multi_layer_decoder_forward_shape() {
749 let mut rng = make_rng();
750 let cfg = DetrConfig::new(4, 32, 4, 3, 4).expect("valid config");
751 let nq = cfg.n_queries;
752 let e = cfg.embed_dim;
753 let decoder = DetrDecoder::new(cfg, &mut rng).expect("valid decoder");
754
755 let queries = vec![0.1f32; nq * e];
756 let encoder = vec![0.2f32; 12 * e];
757 let out = decoder
758 .forward(&queries, &encoder, 12)
759 .expect("multi-layer ok");
760
761 assert_eq!(out.len(), nq * e, "multi-layer output shape preserved");
762 }
763
764 #[test]
765 fn multi_layer_decoder_forward_finite() {
766 let mut rng = make_rng();
767 let cfg = DetrConfig::new(8, 32, 4, 2, 4).expect("valid config");
768 let nq = cfg.n_queries;
769 let e = cfg.embed_dim;
770 let decoder = DetrDecoder::new(cfg, &mut rng).expect("valid decoder");
771
772 let mut queries = vec![0.0f32; nq * e];
773 rng.fill_normal(&mut queries);
774 let mut encoder = vec![0.0f32; 6 * e];
775 rng.fill_normal(&mut encoder);
776
777 let out = decoder.forward(&queries, &encoder, 6).expect("forward ok");
778 assert!(
779 out.iter().all(|v| v.is_finite()),
780 "non-finite in multi-layer output"
781 );
782 }
783
784 #[test]
787 fn layer_norm_constant_row_is_zero() {
788 let x = vec![5.0f32; 32];
789 let w = vec![1.0f32; 32];
790 let b = vec![0.0f32; 32];
791 let out = layer_norm(&x, &w, &b, 1, 32, 1e-5);
792 for v in &out {
793 assert!(v.abs() < 1e-5, "expected near-zero, got {v}");
794 }
795 }
796
797 #[test]
800 fn gelu_zero() {
801 assert!((gelu_approx(0.0) - 0.0).abs() < 1e-6);
802 }
803
804 #[test]
805 fn gelu_large_pos() {
806 assert!((gelu_approx(10.0) - 10.0).abs() < 1e-3);
807 }
808
809 #[test]
810 fn gelu_large_neg() {
811 assert!(gelu_approx(-10.0).abs() < 1e-3);
812 }
813}