#[cfg(feature = "cuda")]
use super::super::GpuResidentTensor;
#[cfg(feature = "cuda")]
use crate::driver::GpuBuffer;
#[cfg(feature = "cuda")]
use crate::error::Result;
#[allow(unused_imports)]
#[cfg(feature = "cuda")]
use crate::kernels::Batched4DGemmKernel;
#[cfg(feature = "cuda")]
use super::helpers::{
batched_gemm, batched_scale_all, batched_softmax_all, batched_to_interleaved_all,
batched_transpose_all, copy_head_to_output, extract_single_head, interleaved_to_batched_all,
transpose_matrix,
};
#[cfg(feature = "cuda")]
use crate::driver::CudaContext;
#[cfg(feature = "cuda")]
pub fn batched_multihead_attention(
ctx: &CudaContext,
q: &GpuResidentTensor<f32>,
k: &GpuResidentTensor<f32>,
v: &GpuResidentTensor<f32>,
n_heads: u32,
head_dim: u32,
seq_len: u32,
) -> Result<GpuResidentTensor<f32>> {
let d_model = (n_heads * head_dim) as usize;
let expected_size = (seq_len as usize) * d_model;
if q.len() != expected_size {
return Err(crate::GpuError::InvalidParameter(format!(
"Q has {} elements, expected {} (seq_len={}, d_model={})",
q.len(),
expected_size,
seq_len,
d_model
)));
}
if k.len() != expected_size || v.len() != expected_size {
return Err(crate::GpuError::InvalidParameter(
"K and V must have same size as Q".to_string(),
));
}
let scale = 1.0 / (head_dim as f32).sqrt();
let output_buffer = GpuBuffer::new(ctx, expected_size)?;
let debug_attn = std::env::var("WHISPER_DEBUG_ATTN").is_ok();
for h in 0..n_heads {
let out_h = compute_single_head_attention(
ctx, q, k, v, h, seq_len, n_heads, head_dim, scale, debug_attn,
)?;
copy_head_to_output(ctx, &output_buffer, &out_h, h, seq_len, n_heads, head_dim)?;
}
Ok(GpuResidentTensor::from_buffer_internal(output_buffer, 1))
}
#[cfg(feature = "cuda")]
fn compute_single_head_attention(
ctx: &CudaContext,
q: &GpuResidentTensor<f32>,
k: &GpuResidentTensor<f32>,
v: &GpuResidentTensor<f32>,
h: u32,
seq_len: u32,
n_heads: u32,
head_dim: u32,
scale: f32,
debug_attn: bool,
) -> Result<GpuResidentTensor<f32>> {
let q_h = extract_single_head(ctx, q, h, seq_len, n_heads, head_dim)?;
let k_h = extract_single_head(ctx, k, h, seq_len, n_heads, head_dim)?;
let v_h = extract_single_head(ctx, v, h, seq_len, n_heads, head_dim)?;
maybe_debug_head_inputs(debug_attn, h, &q_h, &k_h, &v_h)?;
let kt_h = transpose_matrix(ctx, &k_h.buffer, seq_len, head_dim)?;
let kt_tensor = GpuResidentTensor::from_buffer_internal(kt_h, 1);
let scores_h = q_h.matmul(ctx, &kt_tensor, seq_len, seq_len, head_dim)?;
maybe_debug_scores(debug_attn, h, &scores_h)?;
let scaled_h = scores_h.scale(ctx, scale)?;
let attn_h = scaled_h.softmax(ctx, seq_len)?;
maybe_debug_attention_weights(debug_attn, h, &attn_h, seq_len)?;
let out_h = attn_h.matmul(ctx, &v_h, seq_len, head_dim, seq_len)?;
maybe_debug_head_output(debug_attn, h, &out_h)?;
Ok(out_h)
}
#[cfg(feature = "cuda")]
fn maybe_debug_head_inputs(
debug_attn: bool,
h: u32,
q_h: &GpuResidentTensor<f32>,
k_h: &GpuResidentTensor<f32>,
v_h: &GpuResidentTensor<f32>,
) -> Result<()> {
if !(debug_attn && h == 0) {
return Ok(());
}
let q_host = q_h.peek_host()?;
let k_host = k_h.peek_host()?;
let v_host = v_h.peek_host()?;
eprintln!(
"[DEBUG-ATTN] head 0: Q_h mean={:.6}, K_h mean={:.6}, V_h mean={:.6}",
q_host.iter().sum::<f32>() / q_host.len() as f32,
k_host.iter().sum::<f32>() / k_host.len() as f32,
v_host.iter().sum::<f32>() / v_host.len() as f32
);
Ok(())
}
#[cfg(feature = "cuda")]
fn maybe_debug_scores(debug_attn: bool, h: u32, scores_h: &GpuResidentTensor<f32>) -> Result<()> {
if !(debug_attn && h == 0) {
return Ok(());
}
let scores_host = scores_h.peek_host()?;
eprintln!(
"[DEBUG-ATTN] head 0: scores mean={:.6}, max={:.6}",
scores_host.iter().sum::<f32>() / scores_host.len() as f32,
scores_host.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
);
Ok(())
}
#[cfg(feature = "cuda")]
fn maybe_debug_attention_weights(
debug_attn: bool,
h: u32,
attn_h: &GpuResidentTensor<f32>,
seq_len: u32,
) -> Result<()> {
if !(debug_attn && h == 0) {
return Ok(());
}
let attn_host = attn_h.peek_host()?;
let first_row_sum: f32 = attn_host[..seq_len as usize].iter().sum();
eprintln!(
"[DEBUG-ATTN] head 0: attn first_row_sum={:.6}, mean={:.6}",
first_row_sum,
attn_host.iter().sum::<f32>() / attn_host.len() as f32
);
Ok(())
}
#[cfg(feature = "cuda")]
fn maybe_debug_head_output(debug_attn: bool, h: u32, out_h: &GpuResidentTensor<f32>) -> Result<()> {
if !(debug_attn && h == 0) {
return Ok(());
}
let out_host = out_h.peek_host()?;
eprintln!(
"[DEBUG-ATTN] head 0: out mean={:.6}, std={:.6}",
out_host.iter().sum::<f32>() / out_host.len() as f32,
(out_host.iter().map(|v| v.powi(2)).sum::<f32>() / out_host.len() as f32).sqrt()
);
Ok(())
}
#[cfg(feature = "cuda")]
pub fn batched_multihead_attention_optimized(
ctx: &CudaContext,
q: &GpuResidentTensor<f32>,
k: &GpuResidentTensor<f32>,
v: &GpuResidentTensor<f32>,
n_heads: u32,
head_dim: u32,
seq_len: u32,
) -> Result<GpuResidentTensor<f32>> {
let d_model = (n_heads * head_dim) as usize;
let expected_size = (seq_len as usize) * d_model;
if q.len() != expected_size || k.len() != expected_size || v.len() != expected_size {
return Err(crate::GpuError::InvalidParameter(format!(
"Q/K/V size mismatch: expected {} (seq_len={}, d_model={})",
expected_size, seq_len, d_model
)));
}
let scale = 1.0 / (head_dim as f32).sqrt();
let batch = n_heads;
let q_batched = interleaved_to_batched_all(ctx, q, seq_len, n_heads, head_dim)?;
let k_batched = interleaved_to_batched_all(ctx, k, seq_len, n_heads, head_dim)?;
let v_batched = interleaved_to_batched_all(ctx, v, seq_len, n_heads, head_dim)?;
let kt_batched = batched_transpose_all(ctx, &k_batched, batch, seq_len, head_dim)?;
let scores = batched_gemm(ctx, &q_batched, &kt_batched, batch, seq_len, seq_len, head_dim)?;
let total_scores = batch * seq_len * seq_len;
let scaled_scores = batched_scale_all(ctx, &scores, scale, total_scores)?;
let attn = batched_softmax_all(ctx, &scaled_scores, batch * seq_len, seq_len)?;
let out_batched = batched_gemm(ctx, &attn, &v_batched, batch, seq_len, head_dim, seq_len)?;
let output = batched_to_interleaved_all(ctx, &out_batched, seq_len, n_heads, head_dim)?;
Ok(output)
}