use scirs2_core::ndarray::{Array3, ArrayView2, ArrayView3};
use scirs2_core::numeric::{Float, NumAssignOps, Zero};
use std::ops::{Add, Div, Mul, Sub};
use super::utils::{attention, AttentionConfig, AttentionMask};
use crate::error::{LinalgError, LinalgResult};
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn multi_head_attention<F>(
query: &ArrayView3<F>,
key: &ArrayView3<F>,
value: &ArrayView3<F>,
wq: &ArrayView2<F>,
wk: &ArrayView2<F>,
wv: &ArrayView2<F>,
wo: &ArrayView2<F>,
mask: Option<&AttentionMask>,
config: &AttentionConfig,
) -> LinalgResult<Array3<F>>
where
F: Float + Add + Mul + Div + Sub + NumAssignOps + Zero + std::fmt::Debug,
{
let (batchsize, seq_len_q, d_model) = (query.shape()[0], query.shape()[1], query.shape()[2]);
let seq_len_k = key.shape()[1];
let seq_len_v = value.shape()[1];
if key.shape()[2] != d_model || value.shape()[2] != d_model {
return Err(LinalgError::DimensionError(format!(
"Model dimensions must match: {}, {}, {}",
d_model,
key.shape()[2],
value.shape()[2]
)));
}
if wq.shape() != [d_model, d_model]
|| wk.shape() != [d_model, d_model]
|| wv.shape() != [d_model, d_model]
|| wo.shape() != [d_model, d_model]
{
return Err(LinalgError::DimensionError(
"Weight matrices must have shape [d_model, d_model]".to_string(),
));
}
let num_heads = config.num_heads;
let head_dim = config.head_dim;
let scale = match config.scale {
Some(s) => F::from(s).ok_or_else(|| {
LinalgError::ValueError("Failed to convert scale to target type".to_string())
})?,
None => {
let head_dim_f64 = head_dim as f64;
if head_dim_f64 <= 0.0 {
return Err(LinalgError::ValueError(
"Head dimension must be positive".to_string(),
));
}
let default_scale = 1.0 / head_dim_f64.sqrt();
F::from(default_scale).ok_or_else(|| {
LinalgError::ValueError(
"Failed to convert default scale to target type".to_string(),
)
})?
}
};
if d_model != num_heads * head_dim {
return Err(LinalgError::ValueError(format!(
"Model dimension ({d_model}) must equal num_heads ({num_heads}) * head_dim ({head_dim})"
)));
}
let mut q_proj = Array3::<F>::zeros((batchsize, seq_len_q, d_model));
let mut k_proj = Array3::<F>::zeros((batchsize, seq_len_k, d_model));
let mut v_proj = Array3::<F>::zeros((batchsize, seq_len_v, d_model));
for b in 0..batchsize {
for i in 0..seq_len_q {
for j in 0..d_model {
let mut sum = F::zero();
for k in 0..d_model {
sum += query[[b, i, k]] * wq[[k, j]];
}
q_proj[[b, i, j]] = sum;
}
}
for i in 0..seq_len_k {
for j in 0..d_model {
let mut sum = F::zero();
for k in 0..d_model {
sum += key[[b, i, k]] * wk[[k, j]];
}
k_proj[[b, i, j]] = sum;
}
}
for i in 0..seq_len_v {
for j in 0..d_model {
let mut sum = F::zero();
for k in 0..d_model {
sum += value[[b, i, k]] * wv[[k, j]];
}
v_proj[[b, i, j]] = sum;
}
}
}
let mut head_outputs = Vec::with_capacity(num_heads);
for h in 0..num_heads {
let start_idx = h * head_dim;
let _end_idx = start_idx + head_dim;
let q_head = q_proj.slice(scirs2_core::ndarray::s![
..,
..,
start_idx..(start_idx + head_dim)
]);
let k_head = k_proj.slice(scirs2_core::ndarray::s![
..,
..,
start_idx..(start_idx + head_dim)
]);
let v_head = v_proj.slice(scirs2_core::ndarray::s![
..,
..,
start_idx..(start_idx + head_dim)
]);
let head_output = attention(&q_head, &k_head, &v_head, mask, scale)?;
head_outputs.push(head_output);
}
let mut concat_output = Array3::<F>::zeros((batchsize, seq_len_q, d_model));
for (h, head_output) in head_outputs.iter().enumerate().take(num_heads) {
let start_idx = h * head_dim;
let _end_idx = start_idx + head_dim;
for b in 0..batchsize {
for i in 0..seq_len_q {
for j in 0..head_dim {
concat_output[[b, i, start_idx + j]] = head_output[[b, i, j]];
}
}
}
}
let mut output = Array3::<F>::zeros((batchsize, seq_len_q, d_model));
for b in 0..batchsize {
for i in 0..seq_len_q {
for j in 0..d_model {
let mut sum = F::zero();
for k in 0..d_model {
sum += concat_output[[b, i, k]] * wo[[k, j]];
}
output[[b, i, j]] = sum;
}
}
}
Ok(output)
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn grouped_query_attention<F>(
query: &ArrayView3<F>,
key: &ArrayView3<F>,
value: &ArrayView3<F>,
wq: &ArrayView2<F>,
wk: &ArrayView2<F>,
wv: &ArrayView2<F>,
wo: &ArrayView2<F>,
mask: Option<&AttentionMask>,
num_heads: usize,
num_kv_heads: usize,
scale: F,
) -> LinalgResult<Array3<F>>
where
F: Float + Add + Mul + Div + Sub + NumAssignOps + Zero + std::fmt::Debug,
{
let (batchsize, seq_len_q, d_model) = (query.shape()[0], query.shape()[1], query.shape()[2]);
let seq_len_k = key.shape()[1];
if !num_heads.is_multiple_of(num_kv_heads) {
return Err(LinalgError::ValueError(format!(
"Number of query heads ({num_heads}) must be divisible by number of KV heads ({num_kv_heads})"
)));
}
if !num_heads.is_multiple_of(num_kv_heads) {
return Err(LinalgError::ValueError(format!(
"Number of heads ({num_heads}) must be divisible by number of key-value heads ({num_kv_heads})"
)));
}
let heads_per_kv = num_heads / num_kv_heads;
let head_dim = d_model / num_heads;
let kv_dim = num_kv_heads * head_dim;
if wq.shape() != [d_model, d_model]
|| wk.shape() != [d_model, kv_dim]
|| wv.shape() != [d_model, kv_dim]
|| wo.shape() != [d_model, d_model]
{
return Err(LinalgError::DimensionError(
"Weight matrices have incorrect dimensions".to_string(),
));
}
let mut q_proj = Array3::<F>::zeros((batchsize, seq_len_q, d_model));
let mut k_proj = Array3::<F>::zeros((batchsize, seq_len_k, kv_dim));
let mut v_proj = Array3::<F>::zeros((batchsize, seq_len_k, kv_dim));
for b in 0..batchsize {
for i in 0..seq_len_q {
for j in 0..d_model {
let mut sum = F::zero();
for k in 0..d_model {
sum += query[[b, i, k]] * wq[[k, j]];
}
q_proj[[b, i, j]] = sum;
}
}
for i in 0..seq_len_k {
for j in 0..kv_dim {
let mut sum = F::zero();
for k in 0..d_model {
sum += key[[b, i, k]] * wk[[k, j]];
}
k_proj[[b, i, j]] = sum;
}
}
for i in 0..seq_len_k {
for j in 0..kv_dim {
let mut sum = F::zero();
for k in 0..d_model {
sum += value[[b, i, k]] * wv[[k, j]];
}
v_proj[[b, i, j]] = sum;
}
}
}
let mut concat_output = Array3::<F>::zeros((batchsize, seq_len_q, d_model));
for h in 0..num_heads {
let kv_head_idx = h / heads_per_kv;
let q_start = h * head_dim;
let q_end = q_start + head_dim;
let kv_start = kv_head_idx * head_dim;
let kv_end = kv_start + head_dim;
let q_head = q_proj.slice(scirs2_core::ndarray::s![.., .., q_start..q_end]);
let k_head = k_proj.slice(scirs2_core::ndarray::s![.., .., kv_start..kv_end]);
let v_head = v_proj.slice(scirs2_core::ndarray::s![.., .., kv_start..kv_end]);
let head_output = attention(&q_head, &k_head, &v_head, mask, scale)?;
for b in 0..batchsize {
for i in 0..seq_len_q {
for j in 0..head_dim {
concat_output[[b, i, q_start + j]] = head_output[[b, i, j]];
}
}
}
}
let mut output = Array3::<F>::zeros((batchsize, seq_len_q, d_model));
for b in 0..batchsize {
for i in 0..seq_len_q {
for j in 0..d_model {
let mut sum = F::zero();
for k in 0..d_model {
sum += concat_output[[b, i, k]] * wo[[k, j]];
}
output[[b, i, j]] = sum;
}
}
}
Ok(output)
}