use ferrotorch_core::grad_fns::activation::softmax;
use ferrotorch_core::grad_fns::arithmetic::{add, mul};
use ferrotorch_core::grad_fns::linalg::{bmm_differentiable, mm_differentiable};
use ferrotorch_core::grad_fns::shape::{expand, transpose_2d};
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
use crate::init::{xavier_uniform, zeros};
use crate::module::Module;
use crate::parameter::Parameter;
#[derive(Debug)]
pub struct MultiheadAttention<T: Float> {
pub embed_dim: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub q_proj: Parameter<T>,
pub k_proj: Parameter<T>,
pub v_proj: Parameter<T>,
pub out_proj: Parameter<T>,
pub q_bias: Option<Parameter<T>>,
pub k_bias: Option<Parameter<T>>,
pub v_bias: Option<Parameter<T>>,
pub out_bias: Option<Parameter<T>>,
pub training: bool,
}
impl<T: Float> MultiheadAttention<T> {
pub fn new(embed_dim: usize, num_heads: usize, bias: bool) -> FerrotorchResult<Self> {
Self::with_gqa(embed_dim, num_heads, num_heads, bias)
}
pub fn with_gqa(
embed_dim: usize,
num_heads: usize,
num_kv_heads: usize,
bias: bool,
) -> FerrotorchResult<Self> {
if embed_dim == 0 || num_heads == 0 || num_kv_heads == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "embed_dim, num_heads, num_kv_heads must be positive".into(),
});
}
if embed_dim % num_heads != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"embed_dim ({embed_dim}) must be divisible by num_heads ({num_heads})"
),
});
}
if num_heads % num_kv_heads != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"num_heads ({num_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
),
});
}
let head_dim = embed_dim / num_heads;
let kv_dim = num_kv_heads * head_dim;
let mut q_proj = Parameter::zeros(&[embed_dim, embed_dim])?;
let mut k_proj = Parameter::zeros(&[kv_dim, embed_dim])?;
let mut v_proj = Parameter::zeros(&[kv_dim, embed_dim])?;
let mut out_proj = Parameter::zeros(&[embed_dim, embed_dim])?;
xavier_uniform(&mut q_proj)?;
xavier_uniform(&mut k_proj)?;
xavier_uniform(&mut v_proj)?;
xavier_uniform(&mut out_proj)?;
let (q_bias, k_bias, v_bias, out_bias) = if bias {
let mut qb = Parameter::zeros(&[embed_dim])?;
let mut kb = Parameter::zeros(&[kv_dim])?;
let mut vb = Parameter::zeros(&[kv_dim])?;
let mut ob = Parameter::zeros(&[embed_dim])?;
zeros(&mut qb)?;
zeros(&mut kb)?;
zeros(&mut vb)?;
zeros(&mut ob)?;
(Some(qb), Some(kb), Some(vb), Some(ob))
} else {
(None, None, None, None)
};
Ok(Self {
embed_dim,
num_heads,
num_kv_heads,
head_dim,
q_proj,
k_proj,
v_proj,
out_proj,
q_bias,
k_bias,
v_bias,
out_bias,
training: true,
})
}
pub fn forward_qkv(
&self,
query: &Tensor<T>,
key: &Tensor<T>,
value: &Tensor<T>,
causal_mask: bool,
) -> FerrotorchResult<Tensor<T>> {
if query.ndim() != 3 || key.ndim() != 3 || value.ndim() != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"MultiheadAttention expects 3-D inputs [batch, seq, embed_dim], \
got query {:?}, key {:?}, value {:?}",
query.shape(),
key.shape(),
value.shape()
),
});
}
let batch = query.shape()[0];
let seq_q = query.shape()[1];
let seq_k = key.shape()[1];
if query.shape()[2] != self.embed_dim
|| key.shape()[2] != self.embed_dim
|| value.shape()[2] != self.embed_dim
{
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"embed_dim mismatch: expected {}, got query={}, key={}, value={}",
self.embed_dim,
query.shape()[2],
key.shape()[2],
value.shape()[2]
),
});
}
if key.shape()[0] != batch || value.shape()[0] != batch {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"batch size mismatch: query batch={}, key batch={}, value batch={}",
batch,
key.shape()[0],
value.shape()[0]
),
});
}
if key.shape()[1] != value.shape()[1] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"key and value seq_len must match: key={}, value={}",
key.shape()[1],
value.shape()[1]
),
});
}
if causal_mask && seq_q != seq_k {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"causal mask requires seq_q == seq_k, got seq_q={seq_q}, seq_k={seq_k}"
),
});
}
if seq_q == 1 && seq_k == 1 && !causal_mask && self.num_kv_heads == self.num_heads {
use ferrotorch_core::grad_fns::linalg::linear_fused;
let v_2d = value.reshape_t(&[batch as isize, self.embed_dim as isize])?;
let v_proj = linear_fused(
&v_2d,
self.v_proj.tensor(),
self.v_bias.as_ref().map(|b| b.tensor()),
)?;
let output = linear_fused(
&v_proj,
self.out_proj.tensor(),
self.out_bias.as_ref().map(|b| b.tensor()),
)?;
return output.reshape_t(&[batch as isize, 1, self.embed_dim as isize]);
}
let nh = self.num_heads;
let nkv = self.num_kv_heads;
let hd = self.head_dim;
let group_size = nh / nkv;
let wq_t = transpose_2d(self.q_proj.tensor())?;
let wk_t = transpose_2d(self.k_proj.tensor())?;
let wv_t = transpose_2d(self.v_proj.tensor())?;
let wo_t = transpose_2d(self.out_proj.tensor())?;
let flat_q = query.reshape_t(&[-1, self.embed_dim as isize])?;
let flat_k = key.reshape_t(&[-1, self.embed_dim as isize])?;
let flat_v = value.reshape_t(&[-1, self.embed_dim as isize])?;
let mut q_proj = mm_differentiable(&flat_q, &wq_t)?;
let mut k_proj = mm_differentiable(&flat_k, &wk_t)?;
let mut v_proj = mm_differentiable(&flat_v, &wv_t)?;
if let Some(ref qb) = self.q_bias {
let b = expand_bias_to_2d(qb.tensor(), batch * seq_q)?;
q_proj = add(&q_proj, &b)?;
}
if let Some(ref kb) = self.k_bias {
let b = expand_bias_to_2d(kb.tensor(), batch * seq_k)?;
k_proj = add(&k_proj, &b)?;
}
if let Some(ref vb) = self.v_bias {
let b = expand_bias_to_2d(vb.tensor(), batch * seq_k)?;
v_proj = add(&v_proj, &b)?;
}
let q = q_proj
.reshape_t(&[batch as isize, seq_q as isize, nh as isize, hd as isize])?
.permute(&[0, 2, 1, 3])?
.contiguous()?
.reshape_t(&[(batch * nh) as isize, seq_q as isize, hd as isize])?;
let mut k = k_proj
.reshape_t(&[batch as isize, seq_k as isize, nkv as isize, hd as isize])?
.permute(&[0, 2, 1, 3])?
.contiguous()?;
let mut v = v_proj
.reshape_t(&[batch as isize, seq_k as isize, nkv as isize, hd as isize])?
.permute(&[0, 2, 1, 3])?
.contiguous()?;
if group_size > 1 {
k = k
.reshape_t(&[batch as isize, nkv as isize, 1, seq_k as isize, hd as isize])?;
k = expand(&k, &[batch, nkv, group_size, seq_k, hd])?;
k = k.reshape_t(&[batch as isize, nh as isize, seq_k as isize, hd as isize])?;
v = v
.reshape_t(&[batch as isize, nkv as isize, 1, seq_k as isize, hd as isize])?;
v = expand(&v, &[batch, nkv, group_size, seq_k, hd])?;
v = v.reshape_t(&[batch as isize, nh as isize, seq_k as isize, hd as isize])?;
}
let k = k.reshape_t(&[(batch * nh) as isize, seq_k as isize, hd as isize])?;
let v = v.reshape_t(&[(batch * nh) as isize, seq_k as isize, hd as isize])?;
let k_t = k.permute(&[0, 2, 1])?.contiguous()?;
let scores = bmm_differentiable(&q, &k_t)?;
let scale_val = T::from(1.0 / (hd as f64).sqrt()).unwrap();
let scale_tensor = Tensor::from_storage(
TensorStorage::on_device(vec![scale_val], scores.device())?,
vec![1],
false,
)?;
let scaled = mul(&scores, &scale_tensor)?;
let masked = if causal_mask {
let neg_inf = T::from(-1e9).unwrap();
let zero = <T as num_traits::Zero>::zero();
let mut mask_data = vec![zero; seq_q * seq_k];
for i in 0..seq_q {
for j in (i + 1)..seq_k {
mask_data[i * seq_k + j] = neg_inf;
}
}
let mask = Tensor::from_storage(
TensorStorage::cpu(mask_data),
vec![1, seq_q, seq_k],
false,
)?;
let mask = if scaled.is_cuda() { mask.to(scaled.device())? } else { mask };
add(&scaled, &mask)?
} else {
scaled
};
let weights = softmax(&masked)?;
let context = bmm_differentiable(&weights, &v)?;
let context = context
.reshape_t(&[batch as isize, nh as isize, seq_q as isize, hd as isize])?
.permute(&[0, 2, 1, 3])?
.contiguous()?
.reshape_t(&[(batch * seq_q) as isize, self.embed_dim as isize])?;
let mut output = mm_differentiable(&context, &wo_t)?;
if let Some(ref ob) = self.out_bias {
let b = expand_bias_to_2d(ob.tensor(), batch * seq_q)?;
output = add(&output, &b)?;
}
output.reshape_t(&[batch as isize, seq_q as isize, self.embed_dim as isize])
}
#[inline]
pub fn embed_dim(&self) -> usize {
self.embed_dim
}
#[inline]
pub fn num_heads(&self) -> usize {
self.num_heads
}
#[inline]
pub fn num_kv_heads(&self) -> usize {
self.num_kv_heads
}
#[inline]
pub fn head_dim(&self) -> usize {
self.head_dim
}
#[inline]
pub fn is_gqa(&self) -> bool {
self.num_kv_heads != self.num_heads
}
pub fn forward_2d(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
use ferrotorch_core::grad_fns::linalg::linear_fused;
if self.is_gqa() {
return Err(FerrotorchError::InvalidArgument {
message:
"forward_2d is MHA-only; use forward_qkv for GQA (num_kv_heads != num_heads)"
.into(),
});
}
let v_proj = linear_fused(
input,
self.v_proj.tensor(),
self.v_bias.as_ref().map(|b| b.tensor()),
)?;
linear_fused(
&v_proj,
self.out_proj.tensor(),
self.out_bias.as_ref().map(|b| b.tensor()),
)
}
}
impl<T: Float> Module<T> for MultiheadAttention<T> {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
self.forward_qkv(input, input, input, false)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
let mut params = vec![&self.q_proj, &self.k_proj, &self.v_proj, &self.out_proj];
if let Some(ref b) = self.q_bias {
params.push(b);
}
if let Some(ref b) = self.k_bias {
params.push(b);
}
if let Some(ref b) = self.v_bias {
params.push(b);
}
if let Some(ref b) = self.out_bias {
params.push(b);
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
let mut params: Vec<&mut Parameter<T>> = vec![
&mut self.q_proj,
&mut self.k_proj,
&mut self.v_proj,
&mut self.out_proj,
];
if let Some(ref mut b) = self.q_bias {
params.push(b);
}
if let Some(ref mut b) = self.k_bias {
params.push(b);
}
if let Some(ref mut b) = self.v_bias {
params.push(b);
}
if let Some(ref mut b) = self.out_bias {
params.push(b);
}
params
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
let mut params = vec![
("q_proj.weight".to_string(), &self.q_proj),
("k_proj.weight".to_string(), &self.k_proj),
("v_proj.weight".to_string(), &self.v_proj),
("out_proj.weight".to_string(), &self.out_proj),
];
if let Some(ref b) = self.q_bias {
params.push(("q_proj.bias".to_string(), b));
}
if let Some(ref b) = self.k_bias {
params.push(("k_proj.bias".to_string(), b));
}
if let Some(ref b) = self.v_bias {
params.push(("v_proj.bias".to_string(), b));
}
if let Some(ref b) = self.out_bias {
params.push(("out_proj.bias".to_string(), b));
}
params
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn is_training(&self) -> bool {
self.training
}
}
fn expand_bias_to_2d<T: Float>(bias: &Tensor<T>, rows: usize) -> FerrotorchResult<Tensor<T>> {
let dim = bias.shape()[0];
let bias_2d = bias.reshape_t(&[1, dim as isize])?;
expand(&bias_2d, &[rows, dim])
}
pub fn reshape_to_heads<T: Float>(
tensor: &Tensor<T>,
num_heads: usize,
seq_len: usize,
head_dim: usize,
) -> FerrotorchResult<Tensor<T>> {
let data = tensor.data()?;
let mut result = vec![<T as num_traits::Zero>::zero(); num_heads * seq_len * head_dim];
for s in 0..seq_len {
for h in 0..num_heads {
for d in 0..head_dim {
let src_idx = s * (num_heads * head_dim) + h * head_dim + d;
let dst_idx = h * (seq_len * head_dim) + s * head_dim + d;
result[dst_idx] = data[src_idx];
}
}
}
Tensor::from_storage(
TensorStorage::cpu(result),
vec![num_heads, seq_len, head_dim],
tensor.requires_grad(),
)
}
pub fn transpose_heads_to_2d<T: Float>(
tensor: &Tensor<T>,
num_heads: usize,
seq_len: usize,
head_dim: usize,
) -> FerrotorchResult<Tensor<T>> {
let embed_dim = num_heads * head_dim;
let data = tensor.data_vec()?;
let mut result = vec![<T as num_traits::Zero>::zero(); seq_len * embed_dim];
for h in 0..num_heads {
for s in 0..seq_len {
for d in 0..head_dim {
let src_idx = h * (seq_len * head_dim) + s * head_dim + d;
let dst_idx = s * embed_dim + h * head_dim + d;
result[dst_idx] = data[src_idx];
}
}
}
let device = tensor.device();
Tensor::from_storage(
TensorStorage::on_device(result, device)?,
vec![seq_len, embed_dim],
false,
)
}
pub fn repeat_kv<T: Float>(kv: &Tensor<T>, group_size: usize) -> FerrotorchResult<Tensor<T>> {
if group_size == 1 {
return Ok(kv.clone());
}
let shape = kv.shape();
if shape.len() != 3 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"repeat_kv expects 3-D [num_kv_heads, seq, head_dim], got {:?}",
shape
),
});
}
let num_kv_heads = shape[0];
let seq = shape[1];
let head_dim = shape[2];
let num_q_heads = num_kv_heads * group_size;
let data = kv.data_vec()?;
let head_stride = seq * head_dim;
let mut out = vec![<T as num_traits::Zero>::zero(); num_q_heads * head_stride];
for h in 0..num_q_heads {
let kv_h = h / group_size;
let src_start = kv_h * head_stride;
let dst_start = h * head_stride;
out[dst_start..dst_start + head_stride]
.copy_from_slice(&data[src_start..src_start + head_stride]);
}
let device = kv.device();
Tensor::from_storage(
TensorStorage::on_device(out, device)?,
vec![num_q_heads, seq, head_dim],
kv.requires_grad(),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_valid() {
let mha = MultiheadAttention::<f32>::new(64, 8, true);
assert!(mha.is_ok());
let mha = mha.unwrap();
assert_eq!(mha.embed_dim(), 64);
assert_eq!(mha.num_heads(), 8);
assert_eq!(mha.head_dim(), 8);
}
#[test]
fn test_new_invalid_divisibility() {
let result = MultiheadAttention::<f32>::new(65, 8, true);
assert!(result.is_err());
}
#[test]
fn test_new_zero_dims() {
assert!(MultiheadAttention::<f32>::new(0, 4, false).is_err());
assert!(MultiheadAttention::<f32>::new(64, 0, false).is_err());
}
#[test]
fn test_parameter_count_with_bias() {
let mha = MultiheadAttention::<f32>::new(16, 4, true).unwrap();
let params = mha.parameters();
let total: usize = params.iter().map(|p| p.numel()).sum();
let embed_dim = 16usize;
let expected = 4 * embed_dim * embed_dim + 4 * embed_dim;
assert_eq!(total, expected);
assert_eq!(params.len(), 8); }
#[test]
fn test_parameter_count_without_bias() {
let mha = MultiheadAttention::<f32>::new(16, 4, false).unwrap();
let params = mha.parameters();
let total: usize = params.iter().map(|p| p.numel()).sum();
let embed_dim = 16usize;
let expected = 4 * embed_dim * embed_dim;
assert_eq!(total, expected);
assert_eq!(params.len(), 4); }
#[test]
fn test_named_parameters() {
let mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
let named = mha.named_parameters();
let names: Vec<&str> = named.iter().map(|(n, _)| n.as_str()).collect();
assert!(names.contains(&"q_proj.weight"));
assert!(names.contains(&"k_proj.weight"));
assert!(names.contains(&"v_proj.weight"));
assert!(names.contains(&"out_proj.weight"));
assert!(names.contains(&"q_proj.bias"));
assert!(names.contains(&"k_proj.bias"));
assert!(names.contains(&"v_proj.bias"));
assert!(names.contains(&"out_proj.bias"));
}
#[test]
fn test_output_shape() {
let mha = MultiheadAttention::<f32>::new(16, 4, true).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
let output = mha.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 5, 16]);
}
#[test]
fn test_output_shape_no_bias() {
let mha = MultiheadAttention::<f32>::new(8, 2, false).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
let output = mha.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 3, 8]);
}
#[test]
fn test_self_attention_basic_forward() {
let mha = MultiheadAttention::<f64>::new(4, 2, true).unwrap();
let input = ferrotorch_core::ones::<f64>(&[1, 2, 4]).unwrap();
let output = mha.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 2, 4]);
let data = output.data().unwrap();
for &v in data {
assert!(v.is_finite(), "output contains non-finite value: {v}");
}
}
#[test]
fn test_cross_attention_shape() {
let mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
let query = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
let kv = ferrotorch_core::zeros::<f32>(&[1, 5, 8]).unwrap();
let output = mha.forward_qkv(&query, &kv, &kv, false).unwrap();
assert_eq!(output.shape(), &[1, 3, 8]);
}
#[test]
fn test_causal_mask_different_seq_lens_error() {
let mha = MultiheadAttention::<f32>::new(8, 2, false).unwrap();
let query = ferrotorch_core::zeros::<f32>(&[1, 3, 8]).unwrap();
let kv = ferrotorch_core::zeros::<f32>(&[1, 5, 8]).unwrap();
let result = mha.forward_qkv(&query, &kv, &kv, true);
assert!(result.is_err());
}
#[test]
fn test_train_eval_toggle() {
let mut mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
assert!(mha.is_training());
mha.eval();
assert!(!mha.is_training());
mha.train();
assert!(mha.is_training());
}
#[test]
fn test_wrong_embed_dim_input() {
let mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[1, 3, 4]).unwrap();
let result = mha.forward(&input);
assert!(result.is_err());
}
#[test]
fn test_2d_input_rejected() {
let mha = MultiheadAttention::<f32>::new(8, 2, true).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[3, 8]).unwrap();
let result = mha.forward(&input);
assert!(result.is_err());
}
#[test]
fn test_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<MultiheadAttention<f32>>();
assert_send_sync::<MultiheadAttention<f64>>();
}
#[test]
fn test_with_gqa_valid_construction() {
let mha = MultiheadAttention::<f32>::with_gqa(4096, 32, 8, false).unwrap();
assert_eq!(mha.embed_dim(), 4096);
assert_eq!(mha.num_heads(), 32);
assert_eq!(mha.num_kv_heads(), 8);
assert_eq!(mha.head_dim(), 128);
assert!(mha.is_gqa());
}
#[test]
fn test_with_gqa_kv_proj_shapes() {
let mha = MultiheadAttention::<f32>::with_gqa(64, 8, 2, true).unwrap();
let kv_dim = 2 * (64 / 8); assert_eq!(mha.q_proj.shape(), &[64, 64]);
assert_eq!(mha.k_proj.shape(), &[kv_dim, 64]);
assert_eq!(mha.v_proj.shape(), &[kv_dim, 64]);
assert_eq!(mha.out_proj.shape(), &[64, 64]);
assert_eq!(mha.q_bias.as_ref().unwrap().shape(), &[64]);
assert_eq!(mha.k_bias.as_ref().unwrap().shape(), &[kv_dim]);
assert_eq!(mha.v_bias.as_ref().unwrap().shape(), &[kv_dim]);
assert_eq!(mha.out_bias.as_ref().unwrap().shape(), &[64]);
}
#[test]
fn test_with_gqa_rejects_non_divisible_kv_heads() {
let result = MultiheadAttention::<f32>::with_gqa(64, 8, 3, false);
assert!(result.is_err());
}
#[test]
fn test_with_gqa_rejects_zero_kv_heads() {
let result = MultiheadAttention::<f32>::with_gqa(64, 8, 0, false);
assert!(result.is_err());
}
#[test]
fn test_with_gqa_equivalent_to_new_when_kv_equals_q() {
let gqa = MultiheadAttention::<f32>::with_gqa(32, 4, 4, true).unwrap();
let mha = MultiheadAttention::<f32>::new(32, 4, true).unwrap();
assert_eq!(gqa.num_kv_heads(), mha.num_kv_heads());
assert_eq!(gqa.k_proj.shape(), mha.k_proj.shape());
assert_eq!(gqa.v_proj.shape(), mha.v_proj.shape());
assert!(!gqa.is_gqa());
}
#[test]
fn test_repeat_kv_noop_on_group_size_1() {
let kv = ferrotorch_core::from_slice::<f32>(
&(0..24).map(|i| i as f32).collect::<Vec<_>>(),
&[2, 3, 4], )
.unwrap();
let out = repeat_kv(&kv, 1).unwrap();
assert_eq!(out.shape(), kv.shape());
assert_eq!(out.data_vec().unwrap(), kv.data_vec().unwrap());
}
#[test]
fn test_repeat_kv_copies_correct_heads() {
let data: Vec<f32> = vec![
10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, ];
let kv = ferrotorch_core::from_slice::<f32>(&data, &[2, 2, 3]).unwrap();
let out = repeat_kv(&kv, 3).unwrap();
assert_eq!(out.shape(), &[6, 2, 3]);
let out_data = out.data_vec().unwrap();
let head_stride = 2 * 3; for h in 0..3 {
let start = h * head_stride;
assert_eq!(&out_data[start..start + head_stride], &data[0..head_stride]);
}
for h in 3..6 {
let start = h * head_stride;
assert_eq!(
&out_data[start..start + head_stride],
&data[head_stride..2 * head_stride]
);
}
}
#[test]
fn test_repeat_kv_rejects_wrong_rank() {
let kv = ferrotorch_core::zeros::<f32>(&[4, 8]).unwrap(); assert!(repeat_kv(&kv, 2).is_err());
}
#[test]
fn test_gqa_forward_output_shape_preserved() {
let mha = MultiheadAttention::<f32>::with_gqa(16, 4, 2, true).unwrap();
let input = ferrotorch_core::zeros::<f32>(&[2, 5, 16]).unwrap();
let out = mha.forward(&input).unwrap();
assert_eq!(out.shape(), &[2, 5, 16]);
}
#[test]
fn test_gqa_forward_produces_finite_values() {
let mha = MultiheadAttention::<f64>::with_gqa(8, 4, 2, true).unwrap();
let input = ferrotorch_core::ones::<f64>(&[1, 3, 8]).unwrap();
let out = mha.forward(&input).unwrap();
let data = out.data().unwrap();
for &v in data {
assert!(v.is_finite(), "GQA output non-finite: {v}");
}
}
#[test]
fn test_gqa_forward_decoder_style_single_token() {
let mha = MultiheadAttention::<f32>::with_gqa(32, 8, 2, false).unwrap();
let input = ferrotorch_core::ones::<f32>(&[1, 1, 32]).unwrap();
let out = mha.forward(&input).unwrap();
assert_eq!(out.shape(), &[1, 1, 32]);
for &v in out.data().unwrap() {
assert!(v.is_finite());
}
}
#[test]
fn test_gqa_forward_with_causal_mask() {
let mha = MultiheadAttention::<f32>::with_gqa(16, 4, 2, false).unwrap();
let x = ferrotorch_core::ones::<f32>(&[1, 4, 16]).unwrap();
let out = mha.forward_qkv(&x, &x, &x, true).unwrap();
assert_eq!(out.shape(), &[1, 4, 16]);
for &v in out.data().unwrap() {
assert!(v.is_finite());
}
}
}