#[cfg(any(feature = "wgpu", feature = "wgpu-f16"))]
pub mod gpu {
use burn::prelude::*;
use burn::tensor::TensorPrimitive;
use burn_cubecl::tensor::CubeTensor;
use burn_cubecl::CubeRuntime;
pub fn fused_attention<B, R>(
q: Tensor<B, 4>,
k: Tensor<B, 4>,
v: Tensor<B, 4>,
tile_size: usize,
) -> Tensor<B, 4>
where
B: Backend<FloatTensorPrimitive = CubeTensor<R>>,
R: CubeRuntime,
{
let [b, h, n, dh] = q.dims();
let k_t = k.transpose();
let mut output_tiles: Vec<Tensor<B, 4>> = Vec::with_capacity((n + tile_size - 1) / tile_size);
let mut offset = 0;
while offset < n {
let tile_len = (n - offset).min(tile_size);
let q_tile = q.clone().narrow(2, offset, tile_len);
let scores = q_tile.matmul(k_t.clone());
let attn = burn::tensor::activation::softmax(scores, 3);
let out_tile = attn.matmul(v.clone());
output_tiles.push(out_tile);
offset += tile_len;
}
Tensor::cat(output_tiles, 2)
}
}