use thiserror::Error;
#[derive(Debug, Error)]
pub enum CrossAttnError {
#[error("decoder_hidden length {got} != decoder_seq_len * num_heads * head_dim = {expected}")]
DecoderDimMismatch { expected: usize, got: usize },
#[error("encoder_hidden length {got} != encoder_seq_len * num_heads * head_dim = {expected}")]
EncoderDimMismatch { expected: usize, got: usize },
#[error("encoder mask length {got} != encoder_seq_len {expected}")]
MaskLengthMismatch { expected: usize, got: usize },
#[error("head_dim must be > 0")]
InvalidHeadDim,
#[error("num_heads must be > 0")]
InvalidNumHeads,
}
#[derive(Debug, Clone)]
pub struct CrossAttentionConfig {
pub num_heads: usize,
pub head_dim: usize,
pub dropout_rate: f32,
pub scale: f32,
}
impl CrossAttentionConfig {
pub fn new(num_heads: usize, head_dim: usize) -> Self {
let scale = if head_dim > 0 {
1.0 / (head_dim as f32).sqrt()
} else {
1.0
};
Self {
num_heads,
head_dim,
dropout_rate: 0.0,
scale,
}
}
pub fn hidden_dim(&self) -> usize {
self.num_heads * self.head_dim
}
}
pub fn compute_attention_weights(
queries: &[f32],
keys: &[f32],
decoder_seq: usize,
encoder_seq: usize,
head_dim: usize,
scale: f32,
) -> Result<Vec<f32>, CrossAttnError> {
if head_dim == 0 {
return Err(CrossAttnError::InvalidHeadDim);
}
let mut weights = vec![0.0f32; decoder_seq * encoder_seq];
for dq in 0..decoder_seq {
let q_slice = &queries[dq * head_dim..(dq + 1) * head_dim];
let row_start = dq * encoder_seq;
let mut max_score = f32::NEG_INFINITY;
for ek in 0..encoder_seq {
let k_slice = &keys[ek * head_dim..(ek + 1) * head_dim];
let score = dot_scaled(q_slice, k_slice, scale);
weights[row_start + ek] = score;
if score > max_score {
max_score = score;
}
}
let mut sum_exp = 0.0f32;
for ek in 0..encoder_seq {
let e = (weights[row_start + ek] - max_score).exp();
weights[row_start + ek] = e;
sum_exp += e;
}
if sum_exp > 0.0 {
let inv = 1.0 / sum_exp;
for ek in 0..encoder_seq {
weights[row_start + ek] *= inv;
}
}
}
Ok(weights)
}
#[allow(clippy::too_many_arguments)]
pub fn single_head_cross_attention(
queries: &[f32],
keys: &[f32],
values: &[f32],
decoder_seq: usize,
encoder_seq: usize,
head_dim: usize,
scale: f32,
mask: Option<&[bool]>,
) -> Result<Vec<f32>, CrossAttnError> {
if head_dim == 0 {
return Err(CrossAttnError::InvalidHeadDim);
}
if let Some(m) = mask {
if m.len() != encoder_seq {
return Err(CrossAttnError::MaskLengthMismatch {
expected: encoder_seq,
got: m.len(),
});
}
}
let mut output = vec![0.0f32; decoder_seq * head_dim];
for dq in 0..decoder_seq {
let q_slice = &queries[dq * head_dim..(dq + 1) * head_dim];
let mut scores = vec![0.0f32; encoder_seq];
let mut max_score = f32::NEG_INFINITY;
for ek in 0..encoder_seq {
let score = if mask.is_none_or(|m| m[ek]) {
dot_scaled(q_slice, &keys[ek * head_dim..(ek + 1) * head_dim], scale)
} else {
f32::NEG_INFINITY
};
scores[ek] = score;
if score > max_score {
max_score = score;
}
}
if max_score == f32::NEG_INFINITY {
max_score = 0.0;
}
let mut sum_exp = 0.0f32;
for s in scores.iter_mut() {
let e = (*s - max_score).exp();
*s = e;
sum_exp += e;
}
if sum_exp > 0.0 {
let inv = 1.0 / sum_exp;
for s in scores.iter_mut() {
*s *= inv;
}
}
let out_slice = &mut output[dq * head_dim..(dq + 1) * head_dim];
for ek in 0..encoder_seq {
let w = scores[ek];
let v_slice = &values[ek * head_dim..(ek + 1) * head_dim];
for d in 0..head_dim {
out_slice[d] += w * v_slice[d];
}
}
}
Ok(output)
}
pub fn cross_attention_forward(
decoder_hidden: &[f32],
encoder_hidden: &[f32],
decoder_seq_len: usize,
encoder_seq_len: usize,
config: &CrossAttentionConfig,
encoder_mask: Option<&[bool]>,
) -> Result<Vec<f32>, CrossAttnError> {
let num_heads = config.num_heads;
let head_dim = config.head_dim;
if num_heads == 0 {
return Err(CrossAttnError::InvalidNumHeads);
}
if head_dim == 0 {
return Err(CrossAttnError::InvalidHeadDim);
}
let dec_expected = decoder_seq_len * num_heads * head_dim;
if decoder_hidden.len() != dec_expected {
return Err(CrossAttnError::DecoderDimMismatch {
expected: dec_expected,
got: decoder_hidden.len(),
});
}
let enc_expected = encoder_seq_len * num_heads * head_dim;
if encoder_hidden.len() != enc_expected {
return Err(CrossAttnError::EncoderDimMismatch {
expected: enc_expected,
got: encoder_hidden.len(),
});
}
if let Some(m) = encoder_mask {
if m.len() != encoder_seq_len {
return Err(CrossAttnError::MaskLengthMismatch {
expected: encoder_seq_len,
got: m.len(),
});
}
}
let mut output = vec![0.0f32; decoder_seq_len * num_heads * head_dim];
for h in 0..num_heads {
let dec_queries = extract_head(decoder_hidden, decoder_seq_len, num_heads, head_dim, h);
let enc_keys = extract_head(encoder_hidden, encoder_seq_len, num_heads, head_dim, h);
let enc_values = enc_keys.clone();
let head_out = single_head_cross_attention(
&dec_queries,
&enc_keys,
&enc_values,
decoder_seq_len,
encoder_seq_len,
head_dim,
config.scale,
encoder_mask,
)?;
scatter_head(
&mut output,
&head_out,
decoder_seq_len,
num_heads,
head_dim,
h,
);
}
Ok(output)
}
pub fn causal_cross_attention(
decoder_hidden: &[f32],
encoder_hidden: &[f32],
decoder_seq_len: usize,
encoder_seq_len: usize,
config: &CrossAttentionConfig,
) -> Result<Vec<f32>, CrossAttnError> {
let num_heads = config.num_heads;
let head_dim = config.head_dim;
if num_heads == 0 {
return Err(CrossAttnError::InvalidNumHeads);
}
if head_dim == 0 {
return Err(CrossAttnError::InvalidHeadDim);
}
let dec_expected = decoder_seq_len * num_heads * head_dim;
if decoder_hidden.len() != dec_expected {
return Err(CrossAttnError::DecoderDimMismatch {
expected: dec_expected,
got: decoder_hidden.len(),
});
}
let enc_expected = encoder_seq_len * num_heads * head_dim;
if encoder_hidden.len() != enc_expected {
return Err(CrossAttnError::EncoderDimMismatch {
expected: enc_expected,
got: encoder_hidden.len(),
});
}
let mut output = vec![0.0f32; decoder_seq_len * num_heads * head_dim];
for h in 0..num_heads {
let dec_queries = extract_head(decoder_hidden, decoder_seq_len, num_heads, head_dim, h);
let enc_keys = extract_head(encoder_hidden, encoder_seq_len, num_heads, head_dim, h);
let mut head_out = vec![0.0f32; decoder_seq_len * head_dim];
for dq in 0..decoder_seq_len {
let allowed = (dq + 1).min(encoder_seq_len);
let q_slice = &dec_queries[dq * head_dim..(dq + 1) * head_dim];
let mut scores = vec![0.0f32; encoder_seq_len];
let mut max_score = f32::NEG_INFINITY;
for ek in 0..encoder_seq_len {
let score = if ek < allowed {
dot_scaled(
q_slice,
&enc_keys[ek * head_dim..(ek + 1) * head_dim],
config.scale,
)
} else {
f32::NEG_INFINITY
};
scores[ek] = score;
if score > max_score {
max_score = score;
}
}
if max_score == f32::NEG_INFINITY {
max_score = 0.0;
}
let mut sum_exp = 0.0f32;
for s in scores.iter_mut() {
let e = (*s - max_score).exp();
*s = e;
sum_exp += e;
}
if sum_exp > 0.0 {
let inv = 1.0 / sum_exp;
for s in scores.iter_mut() {
*s *= inv;
}
}
let out_slice = &mut head_out[dq * head_dim..(dq + 1) * head_dim];
for ek in 0..encoder_seq_len {
let w = scores[ek];
let v_slice = &enc_keys[ek * head_dim..(ek + 1) * head_dim];
for d in 0..head_dim {
out_slice[d] += w * v_slice[d];
}
}
}
scatter_head(
&mut output,
&head_out,
decoder_seq_len,
num_heads,
head_dim,
h,
);
}
Ok(output)
}
fn extract_head(
hidden: &[f32],
seq_len: usize,
num_heads: usize,
head_dim: usize,
head: usize,
) -> Vec<f32> {
let mut out = vec![0.0f32; seq_len * head_dim];
for pos in 0..seq_len {
let src_start = pos * num_heads * head_dim + head * head_dim;
let dst_start = pos * head_dim;
out[dst_start..dst_start + head_dim]
.copy_from_slice(&hidden[src_start..src_start + head_dim]);
}
out
}
fn scatter_head(
output: &mut [f32],
head_data: &[f32],
seq_len: usize,
num_heads: usize,
head_dim: usize,
head: usize,
) {
for pos in 0..seq_len {
let dst_start = pos * num_heads * head_dim + head * head_dim;
let src_start = pos * head_dim;
output[dst_start..dst_start + head_dim]
.copy_from_slice(&head_data[src_start..src_start + head_dim]);
}
}
#[inline]
fn dot_scaled(a: &[f32], b: &[f32], scale: f32) -> f32 {
let len = a.len().min(b.len());
let mut acc = 0.0f32;
for i in 0..len {
acc += a[i] * b[i];
}
acc * scale
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f32 = 1e-5;
fn make_hidden(seq: usize, num_heads: usize, head_dim: usize, fill: f32) -> Vec<f32> {
vec![fill; seq * num_heads * head_dim]
}
#[test]
fn cross_attn_config_hidden_dim() {
let cfg = CrossAttentionConfig::new(4, 8);
assert_eq!(
cfg.hidden_dim(),
32,
"hidden_dim should be num_heads * head_dim"
);
}
#[test]
fn cross_attention_output_shape() {
let num_heads = 2;
let head_dim = 4;
let dec_seq = 3;
let enc_seq = 5;
let cfg = CrossAttentionConfig::new(num_heads, head_dim);
let dec = make_hidden(dec_seq, num_heads, head_dim, 0.1);
let enc = make_hidden(enc_seq, num_heads, head_dim, 0.2);
let out = cross_attention_forward(&dec, &enc, dec_seq, enc_seq, &cfg, None)
.expect("cross_attention_forward should succeed");
assert_eq!(out.len(), dec_seq * num_heads * head_dim);
}
#[test]
fn cross_attention_identity_query() {
let head_dim = 4;
let cfg = CrossAttentionConfig::new(1, head_dim);
let dec = vec![1.0f32, 0.0, 0.0, 0.0]; let enc = vec![1.0f32, 0.0, 0.0, 0.0]; let out = cross_attention_forward(&dec, &enc, 1, 1, &cfg, None).expect("should succeed");
for i in 0..head_dim {
assert!(
(out[i] - enc[i]).abs() < EPS,
"output[{i}] = {} expected {}",
out[i],
enc[i]
);
}
}
#[test]
fn cross_attention_with_mask() {
let num_heads = 1;
let head_dim = 2;
let dec_seq = 1;
let enc_seq = 3;
let cfg = CrossAttentionConfig::new(num_heads, head_dim);
let dec = vec![1.0f32, 0.0];
let enc = vec![1.0f32, 0.0, 0.0, 1.0, 0.0, 1.0];
let mask = vec![true, false, false];
let out = cross_attention_forward(&dec, &enc, dec_seq, enc_seq, &cfg, Some(&mask))
.expect("masked cross attention should succeed");
assert!(
(out[0] - 1.0).abs() < EPS,
"output[0] = {} expected 1.0",
out[0]
);
assert!(
(out[1] - 0.0).abs() < EPS,
"output[1] = {} expected 0.0",
out[1]
);
}
#[test]
fn cross_attention_uniform_encoder() {
let num_heads = 1;
let head_dim = 2;
let dec_seq = 1;
let enc_seq = 4;
let cfg = CrossAttentionConfig::new(num_heads, head_dim);
let dec = vec![0.0f32; dec_seq * num_heads * head_dim];
let enc: Vec<f32> = (0..enc_seq * num_heads * head_dim)
.map(|i| if i % 2 == 0 { 1.0 } else { 2.0 })
.collect();
let out = cross_attention_forward(&dec, &enc, dec_seq, enc_seq, &cfg, None)
.expect("uniform encoder cross attention should succeed");
assert!((out[0] - 1.0).abs() < EPS, "expected 1.0 got {}", out[0]);
assert!((out[1] - 2.0).abs() < EPS, "expected 2.0 got {}", out[1]);
}
#[test]
fn single_head_output_shape() {
let dec_seq = 3;
let enc_seq = 5;
let head_dim = 4;
let q = vec![0.1f32; dec_seq * head_dim];
let k = vec![0.2f32; enc_seq * head_dim];
let v = vec![0.3f32; enc_seq * head_dim];
let out = single_head_cross_attention(&q, &k, &v, dec_seq, enc_seq, head_dim, 0.5, None)
.expect("single head should succeed");
assert_eq!(out.len(), dec_seq * head_dim);
}
#[test]
fn single_head_deterministic() {
let dec_seq = 2;
let enc_seq = 3;
let head_dim = 4;
let q: Vec<f32> = (0..dec_seq * head_dim).map(|i| i as f32 * 0.1).collect();
let k: Vec<f32> = (0..enc_seq * head_dim).map(|i| i as f32 * 0.05).collect();
let v: Vec<f32> = (0..enc_seq * head_dim).map(|i| (i as f32).sin()).collect();
let out1 = single_head_cross_attention(&q, &k, &v, dec_seq, enc_seq, head_dim, 0.5, None)
.expect("first call should succeed");
let out2 = single_head_cross_attention(&q, &k, &v, dec_seq, enc_seq, head_dim, 0.5, None)
.expect("second call should succeed");
assert_eq!(out1, out2, "single_head must be deterministic");
}
#[test]
fn single_head_scale_effect() {
let dec_seq = 1;
let enc_seq = 2;
let head_dim = 2;
let q = vec![1.0f32, 0.0];
let k = vec![1.0f32, 0.0, 0.0, 1.0];
let v = vec![1.0f32, 0.0, 0.0, 1.0];
let out1 = single_head_cross_attention(&q, &k, &v, dec_seq, enc_seq, head_dim, 1.0, None)
.expect("scale=1.0 should succeed");
let out2 = single_head_cross_attention(&q, &k, &v, dec_seq, enc_seq, head_dim, 0.01, None)
.expect("scale=0.01 should succeed");
assert_ne!(out1, out2, "different scale must produce different output");
}
#[test]
fn causal_cross_attention_shape() {
let num_heads = 2;
let head_dim = 4;
let dec_seq = 3;
let enc_seq = 5;
let cfg = CrossAttentionConfig::new(num_heads, head_dim);
let dec = make_hidden(dec_seq, num_heads, head_dim, 0.1);
let enc = make_hidden(enc_seq, num_heads, head_dim, 0.2);
let out = causal_cross_attention(&dec, &enc, dec_seq, enc_seq, &cfg)
.expect("causal cross attention should succeed");
assert_eq!(out.len(), dec_seq * num_heads * head_dim);
}
#[test]
fn attention_weights_shape() {
let dec_seq = 3;
let enc_seq = 5;
let head_dim = 4;
let q = vec![0.1f32; dec_seq * head_dim];
let k = vec![0.2f32; enc_seq * head_dim];
let weights = compute_attention_weights(&q, &k, dec_seq, enc_seq, head_dim, 0.5)
.expect("compute_attention_weights should succeed");
assert_eq!(weights.len(), dec_seq * enc_seq);
}
#[test]
fn attention_weights_sum_to_one() {
let dec_seq = 4;
let enc_seq = 6;
let head_dim = 8;
let q: Vec<f32> = (0..dec_seq * head_dim)
.map(|i| (i as f32) * 0.1 - 1.0)
.collect();
let k: Vec<f32> = (0..enc_seq * head_dim).map(|i| (i as f32) * 0.05).collect();
let weights = compute_attention_weights(&q, &k, dec_seq, enc_seq, head_dim, 0.5)
.expect("compute_attention_weights should succeed");
for dq in 0..dec_seq {
let row_sum: f32 = weights[dq * enc_seq..(dq + 1) * enc_seq].iter().sum();
assert!(
(row_sum - 1.0).abs() < 1e-5,
"row {dq} sums to {row_sum}, expected 1.0"
);
}
}
#[test]
fn cross_attn_invalid_head_dim_error() {
let cfg = CrossAttentionConfig::new(2, 0);
let dec = vec![0.0f32; 0]; let enc = vec![0.0f32; 0];
let result = cross_attention_forward(&dec, &enc, 1, 1, &cfg, None);
assert!(
matches!(result, Err(CrossAttnError::InvalidHeadDim)),
"head_dim=0 should return InvalidHeadDim"
);
}
#[test]
fn cross_attn_dim_mismatch_error() {
let cfg = CrossAttentionConfig::new(2, 4);
let dec = vec![0.0f32; 3]; let enc = vec![0.0f32; 8];
let result = cross_attention_forward(&dec, &enc, 1, 1, &cfg, None);
assert!(
matches!(result, Err(CrossAttnError::DecoderDimMismatch { .. })),
"wrong decoder_hidden size should return DecoderDimMismatch"
);
}
}