use crate::{
error::{VisionError, VisionResult},
handle::LcgRng,
};
#[derive(Debug, Clone)]
pub struct DetrConfig {
pub n_queries: usize,
pub embed_dim: usize,
pub n_heads: usize,
pub depth: usize,
pub mlp_ratio: usize,
}
impl DetrConfig {
pub fn new(
n_queries: usize,
embed_dim: usize,
n_heads: usize,
depth: usize,
mlp_ratio: usize,
) -> VisionResult<Self> {
if embed_dim == 0 {
return Err(VisionError::InvalidEmbedDim(embed_dim));
}
if n_heads == 0 {
return Err(VisionError::InvalidNumHeads(n_heads));
}
if embed_dim % n_heads != 0 {
return Err(VisionError::HeadDimMismatch { n_heads, embed_dim });
}
if n_queries == 0 {
return Err(VisionError::DimensionMismatch {
expected: 1,
got: 0,
});
}
if depth == 0 {
return Err(VisionError::DimensionMismatch {
expected: 1,
got: 0,
});
}
if mlp_ratio == 0 {
return Err(VisionError::DimensionMismatch {
expected: 1,
got: 0,
});
}
Ok(Self {
n_queries,
embed_dim,
n_heads,
depth,
mlp_ratio,
})
}
pub fn tiny() -> Self {
Self {
n_queries: 4,
embed_dim: 32,
n_heads: 4,
depth: 1,
mlp_ratio: 4,
}
}
#[inline]
pub fn mlp_dim(&self) -> usize {
self.mlp_ratio * self.embed_dim
}
#[inline]
pub fn head_dim(&self) -> usize {
self.embed_dim / self.n_heads
}
}
pub struct DetrDecoderLayerWeights {
pub self_qkv_weight: Vec<f32>,
pub self_qkv_bias: Vec<f32>,
pub self_out_weight: Vec<f32>,
pub self_out_bias: Vec<f32>,
pub cross_q_weight: Vec<f32>,
pub cross_q_bias: Vec<f32>,
pub cross_kv_weight: Vec<f32>,
pub cross_kv_bias: Vec<f32>,
pub cross_out_weight: Vec<f32>,
pub cross_out_bias: Vec<f32>,
pub ffn1_weight: Vec<f32>,
pub ffn1_bias: Vec<f32>,
pub ffn2_weight: Vec<f32>,
pub ffn2_bias: Vec<f32>,
pub ln1_weight: Vec<f32>,
pub ln1_bias: Vec<f32>,
pub ln2_weight: Vec<f32>,
pub ln2_bias: Vec<f32>,
pub ln3_weight: Vec<f32>,
pub ln3_bias: Vec<f32>,
}
impl DetrDecoderLayerWeights {
pub fn default_init(cfg: &DetrConfig, rng: &mut LcgRng) -> Self {
let e = cfg.embed_dim;
let mlp = cfg.mlp_dim();
let scale = 1.0_f32 / (e as f32).sqrt();
let fill_scaled = |rng: &mut LcgRng, n: usize| -> Vec<f32> {
let mut v = vec![0.0f32; n];
rng.fill_normal(&mut v);
for x in &mut v {
*x *= scale;
}
v
};
let self_qkv_weight = fill_scaled(rng, 3 * e * e);
let self_qkv_bias = vec![0.0f32; 3 * e];
let self_out_weight = fill_scaled(rng, e * e);
let self_out_bias = vec![0.0f32; e];
let cross_q_weight = fill_scaled(rng, e * e);
let cross_q_bias = vec![0.0f32; e];
let cross_kv_weight = fill_scaled(rng, 2 * e * e);
let cross_kv_bias = vec![0.0f32; 2 * e];
let cross_out_weight = fill_scaled(rng, e * e);
let cross_out_bias = vec![0.0f32; e];
let ffn1_weight = fill_scaled(rng, mlp * e);
let ffn1_bias = vec![0.0f32; mlp];
let ffn2_weight = fill_scaled(rng, e * mlp);
let ffn2_bias = vec![0.0f32; e];
let ln1_weight = vec![1.0f32; e];
let ln1_bias = vec![0.0f32; e];
let ln2_weight = vec![1.0f32; e];
let ln2_bias = vec![0.0f32; e];
let ln3_weight = vec![1.0f32; e];
let ln3_bias = vec![0.0f32; e];
Self {
self_qkv_weight,
self_qkv_bias,
self_out_weight,
self_out_bias,
cross_q_weight,
cross_q_bias,
cross_kv_weight,
cross_kv_bias,
cross_out_weight,
cross_out_bias,
ffn1_weight,
ffn1_bias,
ffn2_weight,
ffn2_bias,
ln1_weight,
ln1_bias,
ln2_weight,
ln2_bias,
ln3_weight,
ln3_bias,
}
}
}
pub struct DetrDecoderLayer {
pub config: DetrConfig,
pub weights: DetrDecoderLayerWeights,
}
impl DetrDecoderLayer {
pub fn new(cfg: DetrConfig, rng: &mut LcgRng) -> Self {
let weights = DetrDecoderLayerWeights::default_init(&cfg, rng);
Self {
config: cfg,
weights,
}
}
pub fn forward(
&self,
queries: &[f32],
encoder_feats: &[f32],
n_enc_tokens: usize,
) -> VisionResult<Vec<f32>> {
let e = self.config.embed_dim;
let nq = self.config.n_queries;
let nh = self.config.n_heads;
let w = &self.weights;
let expected_q = nq * e;
if queries.len() != expected_q {
return Err(VisionError::DimensionMismatch {
expected: expected_q,
got: queries.len(),
});
}
let expected_enc = n_enc_tokens * e;
if encoder_feats.len() != expected_enc {
return Err(VisionError::DimensionMismatch {
expected: expected_enc,
got: encoder_feats.len(),
});
}
if n_enc_tokens == 0 {
return Err(VisionError::EmptyInput("encoder features"));
}
let queries_normed = layer_norm(queries, &w.ln1_weight, &w.ln1_bias, nq, e, 1e-5);
let sa_out = mhsa_self(
&queries_normed,
nq,
e,
nh,
&w.self_qkv_weight,
&w.self_qkv_bias,
&w.self_out_weight,
&w.self_out_bias,
)?;
let q1: Vec<f32> = queries
.iter()
.zip(sa_out.iter())
.map(|(a, b)| a + b)
.collect();
let q1_normed = layer_norm(&q1, &w.ln2_weight, &w.ln2_bias, nq, e, 1e-5);
let ca_out = mhsa_cross(
&q1_normed,
nq,
encoder_feats,
n_enc_tokens,
e,
nh,
&w.cross_q_weight,
&w.cross_q_bias,
&w.cross_kv_weight,
&w.cross_kv_bias,
&w.cross_out_weight,
&w.cross_out_bias,
)?;
let q2: Vec<f32> = q1.iter().zip(ca_out.iter()).map(|(a, b)| a + b).collect();
let q2_normed = layer_norm(&q2, &w.ln3_weight, &w.ln3_bias, nq, e, 1e-5);
let mlp_dim = self.config.mlp_dim();
let ffn_mid = linear(&q2_normed, &w.ffn1_weight, &w.ffn1_bias, e, mlp_dim);
let ffn_mid: Vec<f32> = ffn_mid.iter().map(|&v| gelu_approx(v)).collect();
let ffn_out = linear(&ffn_mid, &w.ffn2_weight, &w.ffn2_bias, mlp_dim, e);
let out: Vec<f32> = q2.iter().zip(ffn_out.iter()).map(|(a, b)| a + b).collect();
Ok(out)
}
}
pub struct DetrDecoder {
pub layers: Vec<DetrDecoderLayer>,
}
impl DetrDecoder {
pub fn new(cfg: DetrConfig, rng: &mut LcgRng) -> VisionResult<Self> {
if cfg.depth == 0 {
return Err(VisionError::DimensionMismatch {
expected: 1,
got: 0,
});
}
let depth = cfg.depth;
let mut layers = Vec::with_capacity(depth);
for _ in 0..depth {
layers.push(DetrDecoderLayer::new(cfg.clone(), rng));
}
Ok(Self { layers })
}
pub fn forward(
&self,
queries: &[f32],
encoder_feats: &[f32],
n_enc_tokens: usize,
) -> VisionResult<Vec<f32>> {
let mut current = queries.to_vec();
for layer in &self.layers {
current = layer.forward(¤t, encoder_feats, n_enc_tokens)?;
}
Ok(current)
}
}
fn layer_norm(x: &[f32], weight: &[f32], bias: &[f32], n: usize, d: usize, eps: f32) -> Vec<f32> {
let mut out = vec![0.0f32; n * d];
for i in 0..n {
let row = &x[i * d..(i + 1) * d];
let mean: f32 = row.iter().sum::<f32>() / d as f32;
let var: f32 = row.iter().map(|&v| (v - mean) * (v - mean)).sum::<f32>() / d as f32;
let inv_std = 1.0 / (var + eps).sqrt();
let o = &mut out[i * d..(i + 1) * d];
for j in 0..d {
o[j] = (row[j] - mean) * inv_std * weight[j] + bias[j];
}
}
out
}
fn linear(x: &[f32], w: &[f32], b: &[f32], n_in: usize, n_out: usize) -> Vec<f32> {
let batch = x.len() / n_in;
let mut out = vec![0.0f32; batch * n_out];
for bi in 0..batch {
let xrow = &x[bi * n_in..(bi + 1) * n_in];
let orow = &mut out[bi * n_out..(bi + 1) * n_out];
for oi in 0..n_out {
let wrow = &w[oi * n_in..(oi + 1) * n_in];
let mut acc = b[oi];
for k in 0..n_in {
acc += xrow[k] * wrow[k];
}
orow[oi] = acc;
}
}
out
}
#[inline]
fn gelu_approx(x: f32) -> f32 {
const SQRT_2_OVER_PI: f32 = 0.797_884_6;
const COEFF: f32 = 0.044_715;
let inner = SQRT_2_OVER_PI * (x + COEFF * x * x * x);
x * 0.5 * (1.0 + inner.tanh())
}
fn softmax_rows(logits: &mut [f32], n_rows: usize, n_cols: usize) {
for i in 0..n_rows {
let row = &mut logits[i * n_cols..(i + 1) * n_cols];
let mx = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for v in row.iter_mut() {
*v = (*v - mx).exp();
sum += *v;
}
let inv = if sum > 0.0 { 1.0 / sum } else { 1.0 };
for v in row.iter_mut() {
*v *= inv;
}
}
}
#[allow(clippy::too_many_arguments)]
fn mhsa_self(
tokens: &[f32],
n_tokens: usize,
embed_dim: usize,
n_heads: usize,
qkv_weight: &[f32],
qkv_bias: &[f32],
out_weight: &[f32],
out_bias: &[f32],
) -> VisionResult<Vec<f32>> {
let head_dim = embed_dim / n_heads;
let qkv = linear(tokens, qkv_weight, qkv_bias, embed_dim, 3 * embed_dim);
let mut q = vec![0.0f32; n_tokens * embed_dim];
let mut k = vec![0.0f32; n_tokens * embed_dim];
let mut v = vec![0.0f32; n_tokens * embed_dim];
for t in 0..n_tokens {
let src = &qkv[t * 3 * embed_dim..(t + 1) * 3 * embed_dim];
q[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[..embed_dim]);
k[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[embed_dim..2 * embed_dim]);
v[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[2 * embed_dim..]);
}
compute_attention(
&q, n_tokens, &k, n_tokens, &v, embed_dim, n_heads, head_dim, out_weight, out_bias,
)
}
#[allow(clippy::too_many_arguments)]
fn mhsa_cross(
queries: &[f32],
n_queries: usize,
encoder: &[f32],
n_enc: usize,
embed_dim: usize,
n_heads: usize,
q_weight: &[f32],
q_bias: &[f32],
kv_weight: &[f32],
kv_bias: &[f32],
out_weight: &[f32],
out_bias: &[f32],
) -> VisionResult<Vec<f32>> {
let head_dim = embed_dim / n_heads;
let q = linear(queries, q_weight, q_bias, embed_dim, embed_dim);
let kv = linear(encoder, kv_weight, kv_bias, embed_dim, 2 * embed_dim);
let mut k = vec![0.0f32; n_enc * embed_dim];
let mut v = vec![0.0f32; n_enc * embed_dim];
for t in 0..n_enc {
let src = &kv[t * 2 * embed_dim..(t + 1) * 2 * embed_dim];
k[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[..embed_dim]);
v[t * embed_dim..(t + 1) * embed_dim].copy_from_slice(&src[embed_dim..]);
}
compute_attention(
&q, n_queries, &k, n_enc, &v, embed_dim, n_heads, head_dim, out_weight, out_bias,
)
}
#[allow(clippy::too_many_arguments)]
fn compute_attention(
q: &[f32],
n_q: usize,
k: &[f32],
n_k: usize,
v: &[f32],
embed_dim: usize,
n_heads: usize,
head_dim: usize,
out_weight: &[f32],
out_bias: &[f32],
) -> VisionResult<Vec<f32>> {
let scale = 1.0_f32 / (head_dim as f32).sqrt();
let mut concat = vec![0.0f32; n_q * embed_dim];
let mut scores = vec![0.0f32; n_q * n_k];
for h in 0..n_heads {
let hd_off = h * head_dim;
for i in 0..n_q {
for j in 0..n_k {
let mut dot = 0.0f32;
for d in 0..head_dim {
dot += q[i * embed_dim + hd_off + d] * k[j * embed_dim + hd_off + d];
}
scores[i * n_k + j] = dot * scale;
}
}
softmax_rows(&mut scores, n_q, n_k);
for i in 0..n_q {
for d in 0..head_dim {
let mut acc = 0.0f32;
for j in 0..n_k {
acc += scores[i * n_k + j] * v[j * embed_dim + hd_off + d];
}
concat[i * embed_dim + hd_off + d] = acc;
}
}
}
let out = linear(&concat, out_weight, out_bias, embed_dim, embed_dim);
if out.iter().any(|v| !v.is_finite()) {
return Err(VisionError::NonFinite("DETR decoder attention output"));
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_rng() -> LcgRng {
LcgRng::new(42)
}
#[test]
fn detr_config_tiny() {
let cfg = DetrConfig::tiny();
assert_eq!(cfg.n_queries, 4);
assert_eq!(cfg.embed_dim, 32);
assert_eq!(cfg.n_heads, 4);
assert_eq!(cfg.depth, 1);
assert_eq!(cfg.mlp_ratio, 4);
assert_eq!(cfg.mlp_dim(), 128);
assert_eq!(cfg.head_dim(), 8);
}
#[test]
fn detr_config_invalid_embed_dim_zero() {
let r = DetrConfig::new(4, 0, 4, 1, 4);
assert!(matches!(r, Err(VisionError::InvalidEmbedDim(0))));
}
#[test]
fn detr_config_invalid_heads_zero() {
let r = DetrConfig::new(4, 32, 0, 1, 4);
assert!(matches!(r, Err(VisionError::InvalidNumHeads(0))));
}
#[test]
fn detr_config_head_dim_mismatch() {
let r = DetrConfig::new(4, 32, 3, 1, 4); assert!(matches!(r, Err(VisionError::HeadDimMismatch { .. })));
}
#[test]
fn detr_config_zero_queries_errors() {
let r = DetrConfig::new(0, 32, 4, 1, 4);
assert!(r.is_err());
}
#[test]
fn single_layer_forward_shape() {
let mut rng = make_rng();
let cfg = DetrConfig::tiny();
let nq = cfg.n_queries;
let e = cfg.embed_dim;
let layer = DetrDecoderLayer::new(cfg, &mut rng);
let queries = vec![0.1f32; nq * e];
let encoder = vec![0.2f32; 8 * e]; let out = layer.forward(&queries, &encoder, 8).expect("forward ok");
assert_eq!(out.len(), nq * e, "output shape [n_queries × embed_dim]");
}
#[test]
fn single_layer_forward_finite() {
let mut rng = make_rng();
let cfg = DetrConfig::tiny();
let nq = cfg.n_queries;
let e = cfg.embed_dim;
let layer = DetrDecoderLayer::new(cfg, &mut rng);
let mut queries = vec![0.0f32; nq * e];
rng.fill_normal(&mut queries);
let mut encoder = vec![0.0f32; 16 * e];
rng.fill_normal(&mut encoder);
let out = layer.forward(&queries, &encoder, 16).expect("forward ok");
assert!(out.iter().all(|v| v.is_finite()), "non-finite in output");
}
#[test]
fn single_layer_forward_wrong_query_size_errors() {
let mut rng = make_rng();
let cfg = DetrConfig::tiny();
let e = cfg.embed_dim;
let layer = DetrDecoderLayer::new(cfg, &mut rng);
let queries = vec![0.0f32; 3 * e]; let encoder = vec![0.0f32; 8 * e];
let r = layer.forward(&queries, &encoder, 8);
assert!(
matches!(r, Err(VisionError::DimensionMismatch { .. })),
"expected DimensionMismatch"
);
}
#[test]
fn single_layer_forward_empty_encoder_errors() {
let mut rng = make_rng();
let cfg = DetrConfig::tiny();
let nq = cfg.n_queries;
let e = cfg.embed_dim;
let layer = DetrDecoderLayer::new(cfg, &mut rng);
let queries = vec![0.0f32; nq * e];
let r = layer.forward(&queries, &[], 0);
assert!(r.is_err(), "expected error for empty encoder");
}
#[test]
fn multi_layer_decoder_forward_shape() {
let mut rng = make_rng();
let cfg = DetrConfig::new(4, 32, 4, 3, 4).expect("valid config");
let nq = cfg.n_queries;
let e = cfg.embed_dim;
let decoder = DetrDecoder::new(cfg, &mut rng).expect("valid decoder");
let queries = vec![0.1f32; nq * e];
let encoder = vec![0.2f32; 12 * e];
let out = decoder
.forward(&queries, &encoder, 12)
.expect("multi-layer ok");
assert_eq!(out.len(), nq * e, "multi-layer output shape preserved");
}
#[test]
fn multi_layer_decoder_forward_finite() {
let mut rng = make_rng();
let cfg = DetrConfig::new(8, 32, 4, 2, 4).expect("valid config");
let nq = cfg.n_queries;
let e = cfg.embed_dim;
let decoder = DetrDecoder::new(cfg, &mut rng).expect("valid decoder");
let mut queries = vec![0.0f32; nq * e];
rng.fill_normal(&mut queries);
let mut encoder = vec![0.0f32; 6 * e];
rng.fill_normal(&mut encoder);
let out = decoder.forward(&queries, &encoder, 6).expect("forward ok");
assert!(
out.iter().all(|v| v.is_finite()),
"non-finite in multi-layer output"
);
}
#[test]
fn layer_norm_constant_row_is_zero() {
let x = vec![5.0f32; 32];
let w = vec![1.0f32; 32];
let b = vec![0.0f32; 32];
let out = layer_norm(&x, &w, &b, 1, 32, 1e-5);
for v in &out {
assert!(v.abs() < 1e-5, "expected near-zero, got {v}");
}
}
#[test]
fn gelu_zero() {
assert!((gelu_approx(0.0) - 0.0).abs() < 1e-6);
}
#[test]
fn gelu_large_pos() {
assert!((gelu_approx(10.0) - 10.0).abs() < 1e-3);
}
#[test]
fn gelu_large_neg() {
assert!(gelu_approx(-10.0).abs() < 1e-3);
}
}