#[cfg(feature = "cuda")]
use super::{batched_multihead_attention_optimized, GpuResidentTensor};
#[cfg(feature = "cuda")]
use crate::driver::CudaContext;
#[cfg(feature = "cuda")]
use crate::error::Result;
#[cfg(feature = "cuda")]
fn debug_gpu_stats(label: &str, tensor: &GpuResidentTensor<f32>) {
if let Ok(host) = tensor.peek_host() {
let n = host.len() as f32;
let mean = host.iter().sum::<f32>() / n;
let std = (host.iter().map(|v| v.powi(2)).sum::<f32>() / n).sqrt();
eprintln!("[DEBUG-GPU-INTERNAL] {label}: mean={mean:.6}, std={std:.6}");
}
}
#[cfg(feature = "cuda")]
fn debug_gpu_weight(label: &str, weight: &GpuResidentTensor<f32>, bias: &GpuResidentTensor<f32>) {
if let Ok(w_host) = weight.peek_host() {
eprintln!(
"[DEBUG-GPU-INTERNAL] {label}_w: len={}, mean={:.6}, max={:.6}",
w_host.len(),
w_host.iter().sum::<f32>() / w_host.len() as f32,
w_host.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
);
}
if let Ok(b_host) = bias.peek_host() {
eprintln!(
"[DEBUG-GPU-INTERNAL] {label}_b: len={}, mean={:.6}",
b_host.len(),
b_host.iter().sum::<f32>() / b_host.len() as f32
);
}
}
#[cfg(feature = "cuda")]
pub struct GpuEncoderBlockWeights {
pub ln1_gamma: GpuResidentTensor<f32>,
pub ln1_beta: GpuResidentTensor<f32>,
pub w_q: GpuResidentTensor<f32>,
pub b_q: GpuResidentTensor<f32>,
pub w_k: GpuResidentTensor<f32>,
pub b_k: GpuResidentTensor<f32>,
pub w_v: GpuResidentTensor<f32>,
pub b_v: GpuResidentTensor<f32>,
pub w_o: GpuResidentTensor<f32>,
pub b_o: GpuResidentTensor<f32>,
pub ln2_gamma: GpuResidentTensor<f32>,
pub ln2_beta: GpuResidentTensor<f32>,
pub ffn_up_w: GpuResidentTensor<f32>,
pub ffn_up_b: GpuResidentTensor<f32>,
pub ffn_down_w: GpuResidentTensor<f32>,
pub ffn_down_b: GpuResidentTensor<f32>,
}
#[cfg(feature = "cuda")]
pub struct GpuConvFrontendWeights {
pub conv1_weight: GpuResidentTensor<f32>,
pub conv1_bias: GpuResidentTensor<f32>,
pub conv2_weight: GpuResidentTensor<f32>,
pub conv2_bias: GpuResidentTensor<f32>,
}
#[cfg(feature = "cuda")]
pub struct GpuDecoderBlockWeights {
pub ln1_gamma: GpuResidentTensor<f32>,
pub ln1_beta: GpuResidentTensor<f32>,
pub self_w_q: GpuResidentTensor<f32>,
pub self_b_q: GpuResidentTensor<f32>,
pub self_w_k: GpuResidentTensor<f32>,
pub self_b_k: GpuResidentTensor<f32>,
pub self_w_v: GpuResidentTensor<f32>,
pub self_b_v: GpuResidentTensor<f32>,
pub self_w_o: GpuResidentTensor<f32>,
pub self_b_o: GpuResidentTensor<f32>,
pub ln2_gamma: GpuResidentTensor<f32>,
pub ln2_beta: GpuResidentTensor<f32>,
pub cross_w_q: GpuResidentTensor<f32>,
pub cross_b_q: GpuResidentTensor<f32>,
pub cross_w_k: GpuResidentTensor<f32>,
pub cross_b_k: GpuResidentTensor<f32>,
pub cross_w_v: GpuResidentTensor<f32>,
pub cross_b_v: GpuResidentTensor<f32>,
pub cross_w_o: GpuResidentTensor<f32>,
pub cross_b_o: GpuResidentTensor<f32>,
pub ln3_gamma: GpuResidentTensor<f32>,
pub ln3_beta: GpuResidentTensor<f32>,
pub ffn_up_w: GpuResidentTensor<f32>,
pub ffn_up_b: GpuResidentTensor<f32>,
pub ffn_down_w: GpuResidentTensor<f32>,
pub ffn_down_b: GpuResidentTensor<f32>,
}
#[cfg(feature = "cuda")]
pub struct GpuKvCache {
pub key: GpuResidentTensor<f32>,
pub value: GpuResidentTensor<f32>,
pub seq_len: usize,
pub max_seq_len: usize,
pub d_model: usize,
}
#[cfg(feature = "cuda")]
impl GpuKvCache {
pub fn new(ctx: &CudaContext, max_seq_len: usize, d_model: usize) -> Result<Self> {
let total_size = max_seq_len * d_model;
let zeros = vec![0.0f32; total_size];
let key = GpuResidentTensor::from_host(ctx, &zeros)?;
let value = GpuResidentTensor::from_host(ctx, &zeros)?;
Ok(Self { key, value, seq_len: 0, max_seq_len, d_model })
}
pub fn reset(&mut self) {
self.seq_len = 0;
}
pub fn len(&self) -> usize {
self.seq_len
}
pub fn is_empty(&self) -> bool {
self.seq_len == 0
}
}
#[cfg(feature = "cuda")]
#[derive(Debug, Clone, Copy)]
pub struct GpuDecoderConfig {
pub d_model: u32,
pub n_heads: u32,
pub ffn_dim: u32,
pub max_seq_len: u32,
pub n_layers: u32,
}
#[cfg(feature = "cuda")]
#[derive(Debug, Clone, Copy)]
pub struct GpuEncoderConfig {
pub d_model: u32,
pub n_heads: u32,
pub ffn_dim: u32,
}
#[cfg(feature = "cuda")]
pub fn forward_encoder_block_gpu(
ctx: &CudaContext,
x: &GpuResidentTensor<f32>,
weights: &GpuEncoderBlockWeights,
config: &GpuEncoderConfig,
) -> Result<GpuResidentTensor<f32>> {
let d_model = config.d_model;
let n_heads = config.n_heads;
let head_dim = d_model / n_heads;
let ffn_dim = config.ffn_dim;
let seq_len = (x.len() / d_model as usize) as u32;
let debug = std::env::var("WHISPER_DEBUG_GPU_INTERNALS").is_ok();
let x_norm = x.layer_norm(ctx, &weights.ln1_gamma, &weights.ln1_beta, d_model, seq_len)?;
if debug {
debug_gpu_stats("LN1 output", &x_norm);
debug_gpu_weight("q", &weights.w_q, &weights.b_q);
}
let q = x_norm.linear(ctx, &weights.w_q, Some(&weights.b_q), seq_len, d_model, d_model)?;
let k = x_norm.linear(ctx, &weights.w_k, Some(&weights.b_k), seq_len, d_model, d_model)?;
let v = x_norm.linear(ctx, &weights.w_v, Some(&weights.b_v), seq_len, d_model, d_model)?;
if debug {
debug_gpu_stats("Q", &q);
debug_gpu_stats("K", &k);
debug_gpu_stats("V", &v);
}
let attn_out =
batched_multihead_attention_optimized(ctx, &q, &k, &v, n_heads, head_dim, seq_len)?;
if debug {
debug_gpu_stats("attn_out", &attn_out);
}
let attn_proj =
attn_out.linear(ctx, &weights.w_o, Some(&weights.b_o), seq_len, d_model, d_model)?;
if debug {
debug_gpu_stats("attn_proj", &attn_proj);
}
let residual1 = x.add(ctx, &attn_proj)?;
if debug {
debug_gpu_stats("residual1", &residual1);
}
let x_norm2 =
residual1.layer_norm(ctx, &weights.ln2_gamma, &weights.ln2_beta, d_model, seq_len)?;
if debug {
debug_gpu_stats("LN2 output", &x_norm2);
}
let ffn_gelu = x_norm2.fused_linear_gelu(
ctx,
&weights.ffn_up_w,
&weights.ffn_up_b,
seq_len,
d_model,
ffn_dim,
)?;
if debug {
debug_gpu_stats("ffn_gelu (fused)", &ffn_gelu);
}
let ffn_down = ffn_gelu.linear(
ctx,
&weights.ffn_down_w,
Some(&weights.ffn_down_b),
seq_len,
ffn_dim,
d_model,
)?;
if debug {
debug_gpu_stats("ffn_down", &ffn_down);
}
let output = residual1.add(ctx, &ffn_down)?;
if debug {
debug_gpu_stats("block_output", &output);
}
Ok(output)
}