use anyhow::{Result, ensure};
use burn::module::Module;
use burn::nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig};
use burn::tensor::Tensor;
use burn::tensor::activation;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn::tensor::backend::AutodiffBackend;
use burn::tensor::backend::Backend;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn::tensor::{DType, Shape, TensorPrimitive};
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn_autodiff::Autodiff;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn_autodiff::checkpoint::strategy::NoCheckpointing;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn_autodiff::grads::Gradients;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn_autodiff::ops::{Backward, Ops, OpsKind};
#[cfg(feature = "cuda")]
use burn_cubecl::cubecl::cuda::CudaRuntime;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn_cubecl::cubecl::std::tensor::layout::linear::LinearView;
#[cfg(feature = "wgpu-kernel")]
use burn_cubecl::cubecl::wgpu::WgpuRuntime;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn_cubecl::cubecl::{self, calculate_cube_count_elemwise, prelude::*};
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn_cubecl::kernel::into_contiguous;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn_cubecl::ops::numeric::zeros_client;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn_cubecl::tensor::CubeTensor;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn_cubecl::{BoolElement, CubeRuntime};
use serde::{Deserialize, Serialize};
use std::any::TypeId;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use std::{any::Any, marker::PhantomData};
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
const DSA_PAIR_BACKWARD_WORKGROUP_X: u32 = 128;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
#[derive(Debug)]
pub struct DsaAttentionCachedForwardOutput<R: CubeRuntime, BT: BoolElement> {
pub output: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>,
weights: CubeTensor<R>,
topk: usize,
}
#[cfg(feature = "cuda")]
pub type DsaCudaKernelBackend = burn_cubecl::CubeBackend<CudaRuntime, f32, i32, u8>;
#[cfg(feature = "cuda")]
pub type DsaCudaF16KernelBackend =
burn_cubecl::CubeBackend<CudaRuntime, burn::tensor::f16, i32, u8>;
#[cfg(feature = "cuda")]
pub type DsaCudaFlex32KernelBackend =
burn_cubecl::CubeBackend<CudaRuntime, burn_cubecl::cubecl::flex32, i32, u8>;
#[cfg(feature = "cuda")]
pub type DsaCudaBf16KernelBackend =
burn_cubecl::CubeBackend<CudaRuntime, burn::tensor::bf16, i32, u8>;
#[cfg(feature = "cuda")]
type DsaCudaKernelTensor = burn::tensor::ops::FloatTensor<DsaCudaKernelBackend>;
#[cfg(feature = "wgpu-kernel")]
pub type DsaWgpuKernelBackend = burn_cubecl::CubeBackend<WgpuRuntime, f32, i32, u32>;
#[cfg(feature = "wgpu-kernel")]
type DsaWgpuKernelTensor = burn::tensor::ops::FloatTensor<DsaWgpuKernelBackend>;
#[cfg(feature = "cuda")]
type DsaCudaAutodiffBackend = Autodiff<DsaCudaKernelBackend>;
#[cfg(feature = "cuda")]
type DsaCudaAutodiffTensor = burn::tensor::ops::FloatTensor<DsaCudaAutodiffBackend>;
#[cfg(feature = "wgpu-kernel")]
type DsaWgpuAutodiffBackend = Autodiff<DsaWgpuKernelBackend>;
#[cfg(feature = "wgpu-kernel")]
type DsaWgpuAutodiffTensor = burn::tensor::ops::FloatTensor<DsaWgpuAutodiffBackend>;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DsaExecutor {
Reference,
CudaKernel,
WgpuKernel,
}
impl Default for DsaExecutor {
fn default() -> Self {
Self::Reference
}
}
fn should_use_custom_dsa_attention(topk: usize, context_len: usize) -> bool {
topk < context_len
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
#[cube(launch_unchecked, address_type = "dynamic")]
fn dsa_full_mask_kernel(topk_mask: &mut LinearView<f32, ReadWrite>, elements: usize) {
let idx = ABSOLUTE_POS;
if idx >= elements {
terminate!();
}
topk_mask[idx] = f32::cast_from(1u32);
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
#[cube(launch_unchecked, address_type = "dynamic")]
fn dsa_topk_mask_kernel(
selector: &LinearView<f32>,
topk_mask: &mut LinearView<f32, ReadWrite>,
elements: usize,
query_len: usize,
context_len: usize,
topk: usize,
) {
let idx = ABSOLUTE_POS;
if idx >= elements {
terminate!();
}
let qk = idx % (query_len * context_len);
let q_index = qk / context_len;
let k_index = qk % context_len;
let selector_score = selector[idx];
let mut rank = 0usize;
let mut other = 0usize;
while other < context_len {
let other_index = (idx - k_index) + other;
let other_score = selector[other_index];
if other_score > selector_score || (other_score == selector_score && other < k_index) {
rank += 1usize;
}
other += 1usize;
}
let _ = q_index;
let one = f32::cast_from(1u32);
let zero = f32::cast_from(0u32);
topk_mask[idx] = if rank < topk { one } else { zero };
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
#[cube(launch_unchecked, address_type = "dynamic")]
fn dsa_attention_forward_kernel(
query: &LinearView<f32>,
key: &LinearView<f32>,
value: &LinearView<f32>,
topk_mask: &LinearView<f32>,
output: &mut LinearView<f32, ReadWrite>,
elements: usize,
query_len: usize,
context_len: usize,
heads: usize,
qk_dim: usize,
value_dim: usize,
topk: usize,
) {
let idx = ABSOLUTE_POS;
if idx >= elements {
terminate!();
}
let value_plane = heads * value_dim;
let b = idx / (query_len * value_plane);
let local = idx % (query_len * value_plane);
let q_index = local / value_plane;
let head_value = local % value_plane;
let h = head_value / value_dim;
let d = head_value % value_dim;
let _ = topk;
let zero = f32::cast_from(0u32);
let inv_scale = f32::sqrt(f32::cast_from(qk_dim as u32));
let mut max_logit = f32::cast_from(-1000000000i32);
let mut k_index = 0usize;
while k_index < context_len {
let mask_index = (b * query_len + q_index) * context_len + k_index;
if topk_mask[mask_index] > zero {
let mut dot = zero;
let mut qd = 0usize;
while qd < qk_dim {
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + qd;
let k_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + qd;
dot += query[q_offset] * key[k_offset];
qd += 1usize;
}
let logit = dot / inv_scale;
if logit > max_logit {
max_logit = logit;
}
}
k_index += 1usize;
}
let mut denom = zero;
let mut acc = zero;
k_index = 0usize;
while k_index < context_len {
let mask_index = (b * query_len + q_index) * context_len + k_index;
if topk_mask[mask_index] > zero {
let mut dot = zero;
let mut qd = 0usize;
while qd < qk_dim {
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + qd;
let k_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + qd;
dot += query[q_offset] * key[k_offset];
qd += 1usize;
}
let weight = f32::exp((dot / inv_scale) - max_logit);
let value_offset = (((b * heads + h) * context_len + k_index) * value_dim) + d;
denom += weight;
acc += weight * value[value_offset];
}
k_index += 1usize;
}
output[idx] = if denom > zero { acc / denom } else { zero };
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
#[cube(launch_unchecked, address_type = "dynamic")]
fn dsa_attention_forward_cached_kernel(
query: &LinearView<f32>,
key: &LinearView<f32>,
value: &LinearView<f32>,
topk_mask: &LinearView<f32>,
output: &mut LinearView<f32, ReadWrite>,
weights: &mut LinearView<f32, ReadWrite>,
elements: usize,
query_len: usize,
context_len: usize,
heads: usize,
qk_dim: usize,
value_dim: usize,
topk: usize,
) {
let idx = ABSOLUTE_POS;
if idx >= elements {
terminate!();
}
let value_plane = heads * value_dim;
let b = idx / (query_len * value_plane);
let local = idx % (query_len * value_plane);
let q_index = local / value_plane;
let head_value = local % value_plane;
let h = head_value / value_dim;
let d = head_value % value_dim;
let _ = topk;
let zero = f32::cast_from(0u32);
let inv_scale = f32::sqrt(f32::cast_from(qk_dim as u32));
let mut max_logit = f32::cast_from(-1000000000i32);
let mut k_index = 0usize;
while k_index < context_len {
let mask_index = (b * query_len + q_index) * context_len + k_index;
if topk_mask[mask_index] > zero {
let mut dot = zero;
let mut qd = 0usize;
while qd < qk_dim {
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + qd;
let k_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + qd;
dot += query[q_offset] * key[k_offset];
qd += 1usize;
}
let logit = dot / inv_scale;
if logit > max_logit {
max_logit = logit;
}
}
k_index += 1usize;
}
let mut denom = zero;
let mut acc = zero;
k_index = 0usize;
while k_index < context_len {
let mask_index = (b * query_len + q_index) * context_len + k_index;
let mut unnormalized = zero;
if topk_mask[mask_index] > zero {
let mut dot = zero;
let mut qd = 0usize;
while qd < qk_dim {
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + qd;
let k_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + qd;
dot += query[q_offset] * key[k_offset];
qd += 1usize;
}
let weight = f32::exp((dot / inv_scale) - max_logit);
let value_offset = (((b * heads + h) * context_len + k_index) * value_dim) + d;
denom += weight;
acc += weight * value[value_offset];
unnormalized = weight;
}
if d == 0usize {
let weight_index = (((b * heads + h) * query_len + q_index) * context_len) + k_index;
weights[weight_index] = unnormalized;
}
k_index += 1usize;
}
if denom > zero {
output[idx] = acc / denom;
if d == 0usize {
k_index = 0usize;
while k_index < context_len {
let mask_index = (b * query_len + q_index) * context_len + k_index;
let weight_index =
(((b * heads + h) * query_len + q_index) * context_len) + k_index;
if topk_mask[mask_index] > zero {
weights[weight_index] = weights[weight_index] / denom;
} else {
weights[weight_index] = zero;
}
k_index += 1usize;
}
}
} else {
output[idx] = zero;
if d == 0usize {
k_index = 0usize;
while k_index < context_len {
let weight_index =
(((b * heads + h) * query_len + q_index) * context_len) + k_index;
weights[weight_index] = zero;
k_index += 1usize;
}
}
}
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
#[cube(launch_unchecked, address_type = "dynamic")]
fn dsa_attention_backward_kernel(
query: &LinearView<f32>,
key: &LinearView<f32>,
value: &LinearView<f32>,
topk_mask: &LinearView<f32>,
output: &LinearView<f32>,
grad_output: &LinearView<f32>,
grad_query: &mut LinearView<f32, ReadWrite>,
grad_key: &mut LinearView<f32, ReadWrite>,
grad_value: &mut LinearView<f32, ReadWrite>,
elements: usize,
query_len: usize,
context_len: usize,
heads: usize,
qk_dim: usize,
value_dim: usize,
topk: usize,
) {
let idx = ABSOLUTE_POS;
if idx >= elements {
terminate!();
}
let query_elements = heads * query_len * qk_dim;
let key_elements = heads * context_len * qk_dim;
let value_elements = heads * context_len * value_dim;
let per_batch = query_elements + key_elements + value_elements;
let b = idx / per_batch;
let local = idx % per_batch;
let _ = topk;
let zero = f32::cast_from(0u32);
let inv_scale = f32::sqrt(f32::cast_from(qk_dim as u32));
if local < query_elements {
let hqd = local;
let h = hqd / (query_len * qk_dim);
let q_local = hqd % (query_len * qk_dim);
let q_index = q_local / qk_dim;
let qd = q_local % qk_dim;
let mut acc = zero;
let mut max_logit = f32::cast_from(-1000000000i32);
let mut k_index = 0usize;
while k_index < context_len {
let mask_index = (b * query_len + q_index) * context_len + k_index;
if topk_mask[mask_index] > zero {
let mut dot = zero;
let mut d = 0usize;
while d < qk_dim {
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + d;
let k_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + d;
dot += query[q_offset] * key[k_offset];
d += 1usize;
}
let logit = dot / inv_scale;
if logit > max_logit {
max_logit = logit;
}
}
k_index += 1usize;
}
let mut denom = zero;
k_index = 0usize;
while k_index < context_len {
let mask_index = (b * query_len + q_index) * context_len + k_index;
if topk_mask[mask_index] > zero {
let mut dot = zero;
let mut d = 0usize;
while d < qk_dim {
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + d;
let k_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + d;
dot += query[q_offset] * key[k_offset];
d += 1usize;
}
denom += f32::exp((dot / inv_scale) - max_logit);
}
k_index += 1usize;
}
k_index = 0usize;
while k_index < context_len {
let mask_index = (b * query_len + q_index) * context_len + k_index;
if topk_mask[mask_index] > zero && denom > zero {
let mut dot = zero;
let mut d = 0usize;
while d < qk_dim {
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + d;
let k_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + d;
dot += query[q_offset] * key[k_offset];
d += 1usize;
}
let weight = f32::exp((dot / inv_scale) - max_logit) / denom;
let mut value_dot = zero;
let mut vd = 0usize;
while vd < value_dim {
let out_offset =
(b * query_len + q_index) * heads * value_dim + h * value_dim + vd;
let value_offset = (((b * heads + h) * context_len + k_index) * value_dim) + vd;
value_dot +=
grad_output[out_offset] * (value[value_offset] - output[out_offset]);
vd += 1usize;
}
let k_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + qd;
acc += weight * value_dot * key[k_offset] / inv_scale;
}
k_index += 1usize;
}
let grad_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + qd;
grad_query[grad_offset] = acc;
} else if local < query_elements + key_elements {
let hkd = local - query_elements;
let h = hkd / (context_len * qk_dim);
let k_local = hkd % (context_len * qk_dim);
let k_index_target = k_local / qk_dim;
let qd = k_local % qk_dim;
let mut acc = zero;
let mut q_index = 0usize;
while q_index < query_len {
let target_mask_index = (b * query_len + q_index) * context_len + k_index_target;
if topk_mask[target_mask_index] > zero {
let mut max_logit = f32::cast_from(-1000000000i32);
let mut k_index = 0usize;
while k_index < context_len {
let mask_index = (b * query_len + q_index) * context_len + k_index;
if topk_mask[mask_index] > zero {
let mut dot = zero;
let mut d = 0usize;
while d < qk_dim {
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + d;
let k_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + d;
dot += query[q_offset] * key[k_offset];
d += 1usize;
}
let logit = dot / inv_scale;
if logit > max_logit {
max_logit = logit;
}
}
k_index += 1usize;
}
let mut denom = zero;
k_index = 0usize;
while k_index < context_len {
let mask_index = (b * query_len + q_index) * context_len + k_index;
if topk_mask[mask_index] > zero {
let mut dot = zero;
let mut d = 0usize;
while d < qk_dim {
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + d;
let k_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + d;
dot += query[q_offset] * key[k_offset];
d += 1usize;
}
denom += f32::exp((dot / inv_scale) - max_logit);
}
k_index += 1usize;
}
if denom > zero {
let mut target_dot = zero;
let mut d = 0usize;
while d < qk_dim {
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + d;
let k_offset =
(((b * heads + h) * context_len + k_index_target) * qk_dim) + d;
target_dot += query[q_offset] * key[k_offset];
d += 1usize;
}
let weight = f32::exp((target_dot / inv_scale) - max_logit) / denom;
let mut value_dot = zero;
let mut vd = 0usize;
while vd < value_dim {
let out_offset =
(b * query_len + q_index) * heads * value_dim + h * value_dim + vd;
let value_offset =
(((b * heads + h) * context_len + k_index_target) * value_dim) + vd;
value_dot +=
grad_output[out_offset] * (value[value_offset] - output[out_offset]);
vd += 1usize;
}
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + qd;
acc += weight * value_dot * query[q_offset] / inv_scale;
}
}
q_index += 1usize;
}
let grad_offset = (((b * heads + h) * context_len + k_index_target) * qk_dim) + qd;
grad_key[grad_offset] = acc;
} else {
let hvd = local - query_elements - key_elements;
let h = hvd / (context_len * value_dim);
let k_local = hvd % (context_len * value_dim);
let k_index_target = k_local / value_dim;
let vd = k_local % value_dim;
let mut acc = zero;
let mut q_index = 0usize;
while q_index < query_len {
let target_mask_index = (b * query_len + q_index) * context_len + k_index_target;
if topk_mask[target_mask_index] > zero {
let mut max_logit = f32::cast_from(-1000000000i32);
let mut k_index = 0usize;
while k_index < context_len {
let mask_index = (b * query_len + q_index) * context_len + k_index;
if topk_mask[mask_index] > zero {
let mut dot = zero;
let mut d = 0usize;
while d < qk_dim {
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + d;
let k_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + d;
dot += query[q_offset] * key[k_offset];
d += 1usize;
}
let logit = dot / inv_scale;
if logit > max_logit {
max_logit = logit;
}
}
k_index += 1usize;
}
let mut denom = zero;
k_index = 0usize;
while k_index < context_len {
let mask_index = (b * query_len + q_index) * context_len + k_index;
if topk_mask[mask_index] > zero {
let mut dot = zero;
let mut d = 0usize;
while d < qk_dim {
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + d;
let k_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + d;
dot += query[q_offset] * key[k_offset];
d += 1usize;
}
denom += f32::exp((dot / inv_scale) - max_logit);
}
k_index += 1usize;
}
if denom > zero {
let mut target_dot = zero;
let mut d = 0usize;
while d < qk_dim {
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + d;
let k_offset =
(((b * heads + h) * context_len + k_index_target) * qk_dim) + d;
target_dot += query[q_offset] * key[k_offset];
d += 1usize;
}
let weight = f32::exp((target_dot / inv_scale) - max_logit) / denom;
let out_offset =
(b * query_len + q_index) * heads * value_dim + h * value_dim + vd;
acc += weight * grad_output[out_offset];
}
}
q_index += 1usize;
}
let grad_offset = (((b * heads + h) * context_len + k_index_target) * value_dim) + vd;
grad_value[grad_offset] = acc;
}
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
#[cube(launch)]
fn dsa_attention_backward_pair_kernel(
query: &cubecl::prelude::Tensor<f32>,
key: &cubecl::prelude::Tensor<f32>,
value: &cubecl::prelude::Tensor<f32>,
topk_mask: &cubecl::prelude::Tensor<f32>,
output: &cubecl::prelude::Tensor<f32>,
grad_output: &cubecl::prelude::Tensor<f32>,
grad_query: &mut cubecl::prelude::Tensor<Atomic<f32>>,
grad_key: &mut cubecl::prelude::Tensor<Atomic<f32>>,
grad_value: &mut cubecl::prelude::Tensor<Atomic<f32>>,
#[comptime] batch: usize,
#[comptime] heads: usize,
#[comptime] query_len: usize,
#[comptime] context_len: usize,
#[comptime] qk_dim: usize,
#[comptime] value_dim: usize,
) {
let idx = ABSOLUTE_POS as usize;
let pair_elements = batch * heads * query_len * context_len;
if idx >= pair_elements {
terminate!();
}
let k_index = idx % context_len;
let q_index = (idx / context_len) % query_len;
let h = (idx / (context_len * query_len)) % heads;
let b = idx / (context_len * query_len * heads);
let mask_index = (b * query_len + q_index) * context_len + k_index;
let zero = f32::cast_from(0u32);
if topk_mask[mask_index] <= zero {
terminate!();
}
let inv_scale = f32::cast_from(1u32) / f32::sqrt(f32::cast_from(qk_dim as u32));
let mut max_logit = f32::cast_from(-1000000000i32);
let mut k_scan = 0usize;
while k_scan < context_len {
let scan_mask_index = (b * query_len + q_index) * context_len + k_scan;
if topk_mask[scan_mask_index] > zero {
let mut dot = zero;
let mut d = 0usize;
while d < qk_dim {
let q_offset =
b * query.stride(0) + h * query.stride(1) + q_index * query.stride(2) + d;
let k_offset = b * key.stride(0) + h * key.stride(1) + k_scan * key.stride(2) + d;
dot += query[q_offset] * key[k_offset];
d += 1usize;
}
let logit = dot * inv_scale;
if logit > max_logit {
max_logit = logit;
}
}
k_scan += 1usize;
}
let mut denom = zero;
k_scan = 0usize;
while k_scan < context_len {
let scan_mask_index = (b * query_len + q_index) * context_len + k_scan;
if topk_mask[scan_mask_index] > zero {
let mut dot = zero;
let mut d = 0usize;
while d < qk_dim {
let q_offset =
b * query.stride(0) + h * query.stride(1) + q_index * query.stride(2) + d;
let k_offset = b * key.stride(0) + h * key.stride(1) + k_scan * key.stride(2) + d;
dot += query[q_offset] * key[k_offset];
d += 1usize;
}
denom += f32::exp(dot * inv_scale - max_logit);
}
k_scan += 1usize;
}
if denom <= zero {
terminate!();
}
let mut target_dot = zero;
let mut d = 0usize;
while d < qk_dim {
let q_offset = b * query.stride(0) + h * query.stride(1) + q_index * query.stride(2) + d;
let k_offset = b * key.stride(0) + h * key.stride(1) + k_index * key.stride(2) + d;
target_dot += query[q_offset] * key[k_offset];
d += 1usize;
}
let weight = f32::exp(target_dot * inv_scale - max_logit) / denom;
let mut value_dot = zero;
let mut vd = 0usize;
while vd < value_dim {
let out_offset = b * output.stride(0) + q_index * output.stride(1) + h * value_dim + vd;
let value_offset =
b * value.stride(0) + h * value.stride(1) + k_index * value.stride(2) + vd;
value_dot += grad_output[out_offset] * (value[value_offset] - output[out_offset]);
vd += 1usize;
}
d = 0usize;
while d < qk_dim {
let q_offset = b * query.stride(0) + h * query.stride(1) + q_index * query.stride(2) + d;
let k_offset = b * key.stride(0) + h * key.stride(1) + k_index * key.stride(2) + d;
let grad_q_offset = b * grad_query.stride(0)
+ h * grad_query.stride(1)
+ q_index * grad_query.stride(2)
+ d;
let grad_k_offset =
b * grad_key.stride(0) + h * grad_key.stride(1) + k_index * grad_key.stride(2) + d;
let scaled = weight * value_dot * inv_scale;
grad_query[grad_q_offset].fetch_add(scaled * key[k_offset]);
grad_key[grad_k_offset].fetch_add(scaled * query[q_offset]);
d += 1usize;
}
vd = 0usize;
while vd < value_dim {
let out_offset = b * output.stride(0) + q_index * output.stride(1) + h * value_dim + vd;
let grad_value_offset = b * grad_value.stride(0)
+ h * grad_value.stride(1)
+ k_index * grad_value.stride(2)
+ vd;
grad_value[grad_value_offset].fetch_add(weight * grad_output[out_offset]);
vd += 1usize;
}
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
#[cube(launch)]
fn dsa_attention_backward_cached_pair_kernel(
query: &cubecl::prelude::Tensor<f32>,
key: &cubecl::prelude::Tensor<f32>,
value: &cubecl::prelude::Tensor<f32>,
weights: &cubecl::prelude::Tensor<f32>,
output: &cubecl::prelude::Tensor<f32>,
grad_output: &cubecl::prelude::Tensor<f32>,
grad_query: &mut cubecl::prelude::Tensor<Atomic<f32>>,
grad_key: &mut cubecl::prelude::Tensor<Atomic<f32>>,
grad_value: &mut cubecl::prelude::Tensor<Atomic<f32>>,
#[comptime] batch: usize,
#[comptime] heads: usize,
#[comptime] query_len: usize,
#[comptime] context_len: usize,
#[comptime] qk_dim: usize,
#[comptime] value_dim: usize,
) {
let idx = ABSOLUTE_POS as usize;
let pair_elements = batch * heads * query_len * context_len;
if idx >= pair_elements {
terminate!();
}
let k_index = idx % context_len;
let q_index = (idx / context_len) % query_len;
let h = (idx / (context_len * query_len)) % heads;
let b = idx / (context_len * query_len * heads);
let zero = f32::cast_from(0u32);
let weight = weights[idx];
if weight <= zero {
terminate!();
}
let inv_scale = f32::cast_from(1u32) / f32::sqrt(f32::cast_from(qk_dim as u32));
let mut value_dot = zero;
let mut vd = 0usize;
while vd < value_dim {
let out_offset = b * output.stride(0) + q_index * output.stride(1) + h * value_dim + vd;
let value_offset =
b * value.stride(0) + h * value.stride(1) + k_index * value.stride(2) + vd;
value_dot += grad_output[out_offset] * (value[value_offset] - output[out_offset]);
vd += 1usize;
}
let mut d = 0usize;
while d < qk_dim {
let q_offset = b * query.stride(0) + h * query.stride(1) + q_index * query.stride(2) + d;
let k_offset = b * key.stride(0) + h * key.stride(1) + k_index * key.stride(2) + d;
let grad_q_offset = b * grad_query.stride(0)
+ h * grad_query.stride(1)
+ q_index * grad_query.stride(2)
+ d;
let grad_k_offset =
b * grad_key.stride(0) + h * grad_key.stride(1) + k_index * grad_key.stride(2) + d;
let scaled = weight * value_dot * inv_scale;
grad_query[grad_q_offset].fetch_add(scaled * key[k_offset]);
grad_key[grad_k_offset].fetch_add(scaled * query[q_offset]);
d += 1usize;
}
vd = 0usize;
while vd < value_dim {
let out_offset = b * output.stride(0) + q_index * output.stride(1) + h * value_dim + vd;
let grad_value_offset = b * grad_value.stride(0)
+ h * grad_value.stride(1)
+ k_index * grad_value.stride(2)
+ vd;
grad_value[grad_value_offset].fetch_add(weight * grad_output[out_offset]);
vd += 1usize;
}
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
#[cube(launch_unchecked, address_type = "dynamic")]
fn dsa_attention_backward_cached_element_kernel(
query: &LinearView<f32>,
key: &LinearView<f32>,
value: &LinearView<f32>,
weights: &LinearView<f32>,
output: &LinearView<f32>,
grad_output: &LinearView<f32>,
grad_query: &mut LinearView<f32, ReadWrite>,
grad_key: &mut LinearView<f32, ReadWrite>,
grad_value: &mut LinearView<f32, ReadWrite>,
elements: usize,
query_len: usize,
context_len: usize,
heads: usize,
qk_dim: usize,
value_dim: usize,
) {
let idx = ABSOLUTE_POS;
if idx >= elements {
terminate!();
}
let query_elements = heads * query_len * qk_dim;
let key_elements = heads * context_len * qk_dim;
let value_elements = heads * context_len * value_dim;
let per_batch = query_elements + key_elements + value_elements;
let b = idx / per_batch;
let local = idx % per_batch;
let zero = f32::cast_from(0u32);
let inv_scale = f32::cast_from(1u32) / f32::sqrt(f32::cast_from(qk_dim as u32));
if local < query_elements {
let hqd = local;
let h = hqd / (query_len * qk_dim);
let q_local = hqd % (query_len * qk_dim);
let q_index = q_local / qk_dim;
let qd = q_local % qk_dim;
let mut acc = zero;
let mut k_index = 0usize;
while k_index < context_len {
let weight_index = (((b * heads + h) * query_len + q_index) * context_len) + k_index;
let weight = weights[weight_index];
if weight > zero {
let mut value_dot = zero;
let mut vd = 0usize;
while vd < value_dim {
let out_offset =
(b * query_len + q_index) * heads * value_dim + h * value_dim + vd;
let value_offset = (((b * heads + h) * context_len + k_index) * value_dim) + vd;
value_dot +=
grad_output[out_offset] * (value[value_offset] - output[out_offset]);
vd += 1usize;
}
let k_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + qd;
acc += weight * value_dot * key[k_offset] * inv_scale;
}
k_index += 1usize;
}
let grad_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + qd;
grad_query[grad_offset] = acc;
} else if local < query_elements + key_elements {
let hkd = local - query_elements;
let h = hkd / (context_len * qk_dim);
let k_local = hkd % (context_len * qk_dim);
let k_index = k_local / qk_dim;
let qd = k_local % qk_dim;
let mut acc = zero;
let mut q_index = 0usize;
while q_index < query_len {
let weight_index = (((b * heads + h) * query_len + q_index) * context_len) + k_index;
let weight = weights[weight_index];
if weight > zero {
let mut value_dot = zero;
let mut vd = 0usize;
while vd < value_dim {
let out_offset =
(b * query_len + q_index) * heads * value_dim + h * value_dim + vd;
let value_offset = (((b * heads + h) * context_len + k_index) * value_dim) + vd;
value_dot +=
grad_output[out_offset] * (value[value_offset] - output[out_offset]);
vd += 1usize;
}
let q_offset = (((b * heads + h) * query_len + q_index) * qk_dim) + qd;
acc += weight * value_dot * query[q_offset] * inv_scale;
}
q_index += 1usize;
}
let grad_offset = (((b * heads + h) * context_len + k_index) * qk_dim) + qd;
grad_key[grad_offset] = acc;
} else {
let hvd = local - query_elements - key_elements;
let h = hvd / (context_len * value_dim);
let k_local = hvd % (context_len * value_dim);
let k_index = k_local / value_dim;
let vd = k_local % value_dim;
let mut acc = zero;
let mut q_index = 0usize;
while q_index < query_len {
let weight_index = (((b * heads + h) * query_len + q_index) * context_len) + k_index;
let weight = weights[weight_index];
if weight > zero {
let out_offset = (b * query_len + q_index) * heads * value_dim + h * value_dim + vd;
acc += weight * grad_output[out_offset];
}
q_index += 1usize;
}
let grad_offset = (((b * heads + h) * context_len + k_index) * value_dim) + vd;
grad_value[grad_offset] = acc;
}
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
fn cached_element_backward_enabled() -> bool {
matches!(
std::env::var("BURN_DSA_CACHED_ELEMENT_BACKWARD")
.ok()
.as_deref(),
Some("1") | Some("true") | Some("TRUE") | Some("on") | Some("ON")
)
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
fn should_use_cached_element_backward(topk: usize, context_len: usize) -> bool {
cached_element_backward_enabled() || topk.saturating_mul(4) >= context_len.saturating_mul(3)
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
pub fn dsa_attention_forward_runtime<R: CubeRuntime, BT: BoolElement>(
query: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>,
key: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>,
value: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>,
selector: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>,
topk: usize,
) -> Result<Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>> {
let [batch, heads, query_len, qk_dim] = query.shape().dims::<4>();
let key_shape = key.shape().dims::<4>();
let value_shape = value.shape().dims::<4>();
let selector_shape = selector.shape().dims::<3>();
ensure!(
key_shape[0] == batch && key_shape[1] == heads && key_shape[3] == qk_dim,
"DSA fused key shape mismatch: query={:?} key={:?}",
[batch, heads, query_len, qk_dim],
key_shape
);
ensure!(
value_shape[0] == batch && value_shape[1] == heads && value_shape[2] == key_shape[2],
"DSA fused value shape mismatch: key={key_shape:?} value={value_shape:?}",
);
ensure!(
selector_shape == [batch, query_len, key_shape[2]],
"DSA fused selector shape mismatch: selector={selector_shape:?} expected={:?}",
[batch, query_len, key_shape[2]]
);
ensure!(
query_len > 0 && key_shape[2] > 0 && heads > 0 && qk_dim > 0 && value_shape[3] > 0,
"DSA fused attention requires non-empty tensors"
);
ensure!(topk > 0, "DSA fused attention topk must be nonzero");
let query = into_contiguous(query.into_primitive().tensor());
let key = into_contiguous(key.into_primitive().tensor());
let value = into_contiguous(value.into_primitive().tensor());
let selector = into_contiguous(selector.into_primitive().tensor());
let context_len = key_shape[2];
let value_dim = value_shape[3];
let topk_mask = dsa_topk_mask_runtime::<R>(
selector,
batch,
query_len,
context_len,
topk.min(context_len),
)?;
let output_shape = [batch, query_len, heads * value_dim];
let elements = output_shape
.iter()
.try_fold(1usize, |acc, dim| acc.checked_mul(*dim))
.ok_or_else(|| anyhow::anyhow!("DSA fused output element count overflow"))?;
let bytes = elements
.checked_mul(core::mem::size_of::<f32>())
.ok_or_else(|| anyhow::anyhow!("DSA fused output byte size overflow"))?;
let output = CubeTensor::new_contiguous(
query.client.clone(),
query.device.clone(),
Shape::new(output_shape),
query.client.empty(bytes),
DType::F32,
);
let cube_dim = CubeDim::new(&query.client, elements);
let cube_count = calculate_cube_count_elemwise(&query.client, elements, cube_dim);
let address_type = [
query.required_address_type(),
key.required_address_type(),
value.required_address_type(),
topk_mask.required_address_type(),
output.required_address_type(),
]
.into_iter()
.max()
.unwrap_or_default();
let client = query.client.clone();
unsafe {
dsa_attention_forward_kernel::launch_unchecked::<R>(
&client,
cube_count,
cube_dim,
address_type,
query.into_linear_view(),
key.into_linear_view(),
value.into_linear_view(),
topk_mask.into_linear_view(),
output.clone().into_linear_view(),
elements,
query_len,
context_len,
heads,
qk_dim,
value_dim,
topk.min(context_len),
);
}
Ok(Tensor::from_primitive(TensorPrimitive::Float(output)))
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
pub fn dsa_attention_forward_cached_runtime<R: CubeRuntime, BT: BoolElement>(
query: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>,
key: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>,
value: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>,
selector: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>,
topk: usize,
) -> Result<DsaAttentionCachedForwardOutput<R, BT>> {
let [batch, heads, query_len, qk_dim] = query.shape().dims::<4>();
let key_shape = key.shape().dims::<4>();
let value_shape = value.shape().dims::<4>();
let selector_shape = selector.shape().dims::<3>();
ensure!(
key_shape[0] == batch && key_shape[1] == heads && key_shape[3] == qk_dim,
"DSA fused key shape mismatch: query={:?} key={:?}",
[batch, heads, query_len, qk_dim],
key_shape
);
ensure!(
value_shape[0] == batch && value_shape[1] == heads && value_shape[2] == key_shape[2],
"DSA fused value shape mismatch: key={key_shape:?} value={value_shape:?}",
);
ensure!(
selector_shape == [batch, query_len, key_shape[2]],
"DSA fused selector shape mismatch: selector={selector_shape:?} expected={:?}",
[batch, query_len, key_shape[2]]
);
ensure!(
query_len > 0 && key_shape[2] > 0 && heads > 0 && qk_dim > 0 && value_shape[3] > 0,
"DSA fused attention requires non-empty tensors"
);
ensure!(topk > 0, "DSA fused attention topk must be nonzero");
let query = into_contiguous(query.into_primitive().tensor());
let key = into_contiguous(key.into_primitive().tensor());
let value = into_contiguous(value.into_primitive().tensor());
let selector = into_contiguous(selector.into_primitive().tensor());
let context_len = key_shape[2];
let value_dim = value_shape[3];
let topk_mask = dsa_topk_mask_runtime::<R>(
selector,
batch,
query_len,
context_len,
topk.min(context_len),
)?;
let output_shape = [batch, query_len, heads * value_dim];
let elements = output_shape
.iter()
.try_fold(1usize, |acc, dim| acc.checked_mul(*dim))
.ok_or_else(|| anyhow::anyhow!("DSA fused output element count overflow"))?;
let output_bytes = elements
.checked_mul(core::mem::size_of::<f32>())
.ok_or_else(|| anyhow::anyhow!("DSA fused output byte size overflow"))?;
let pair_elements = batch
.checked_mul(heads)
.and_then(|v| v.checked_mul(query_len))
.and_then(|v| v.checked_mul(context_len))
.ok_or_else(|| anyhow::anyhow!("DSA fused weight element count overflow"))?;
let weight_bytes = pair_elements
.checked_mul(core::mem::size_of::<f32>())
.ok_or_else(|| anyhow::anyhow!("DSA fused weight byte size overflow"))?;
let output = CubeTensor::new_contiguous(
query.client.clone(),
query.device.clone(),
Shape::new(output_shape),
query.client.empty(output_bytes),
DType::F32,
);
let weights = CubeTensor::new_contiguous(
query.client.clone(),
query.device.clone(),
Shape::new([batch, heads, query_len, context_len]),
query.client.empty(weight_bytes),
DType::F32,
);
let cube_dim = CubeDim::new(&query.client, elements);
let cube_count = calculate_cube_count_elemwise(&query.client, elements, cube_dim);
let address_type = [
query.required_address_type(),
key.required_address_type(),
value.required_address_type(),
topk_mask.required_address_type(),
output.required_address_type(),
weights.required_address_type(),
]
.into_iter()
.max()
.unwrap_or_default();
let client = query.client.clone();
unsafe {
dsa_attention_forward_cached_kernel::launch_unchecked::<R>(
&client,
cube_count,
cube_dim,
address_type,
query.into_linear_view(),
key.into_linear_view(),
value.into_linear_view(),
topk_mask.into_linear_view(),
output.clone().into_linear_view(),
weights.clone().into_linear_view(),
elements,
query_len,
context_len,
heads,
qk_dim,
value_dim,
topk.min(context_len),
);
}
Ok(DsaAttentionCachedForwardOutput {
output: Tensor::from_primitive(TensorPrimitive::Float(output)),
weights,
topk: topk.min(context_len),
})
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
fn dsa_topk_mask_runtime<R: CubeRuntime>(
selector: CubeTensor<R>,
batch: usize,
query_len: usize,
context_len: usize,
topk: usize,
) -> Result<CubeTensor<R>> {
let elements = batch
.checked_mul(query_len)
.and_then(|v| v.checked_mul(context_len))
.ok_or_else(|| anyhow::anyhow!("DSA top-k mask element count overflow"))?;
let bytes = elements
.checked_mul(core::mem::size_of::<f32>())
.ok_or_else(|| anyhow::anyhow!("DSA top-k mask byte size overflow"))?;
let topk_mask = CubeTensor::new_contiguous(
selector.client.clone(),
selector.device.clone(),
Shape::new([batch, query_len, context_len]),
selector.client.empty(bytes),
DType::F32,
);
let cube_dim = CubeDim::new(&selector.client, elements);
let cube_count = calculate_cube_count_elemwise(&selector.client, elements, cube_dim);
if topk >= context_len {
let address_type = topk_mask.required_address_type();
let client = selector.client.clone();
unsafe {
dsa_full_mask_kernel::launch_unchecked::<R>(
&client,
cube_count,
cube_dim,
address_type,
topk_mask.clone().into_linear_view(),
elements,
);
}
return Ok(topk_mask);
}
let address_type = [
selector.required_address_type(),
topk_mask.required_address_type(),
]
.into_iter()
.max()
.unwrap_or_default();
let client = selector.client.clone();
unsafe {
dsa_topk_mask_kernel::launch_unchecked::<R>(
&client,
cube_count,
cube_dim,
address_type,
selector.into_linear_view(),
topk_mask.clone().into_linear_view(),
elements,
query_len,
context_len,
topk.min(context_len),
);
}
Ok(topk_mask)
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
#[derive(Debug)]
pub struct DsaAttentionBackwardOutput<B: Backend> {
pub grad_query: Tensor<B, 4>,
pub grad_key: Tensor<B, 4>,
pub grad_value: Tensor<B, 4>,
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
pub fn dsa_attention_backward_runtime<R: CubeRuntime, BT: BoolElement>(
query: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>,
key: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>,
value: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>,
selector: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>,
output: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>,
grad_output: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>,
topk: usize,
) -> Result<DsaAttentionBackwardOutput<burn_cubecl::CubeBackend<R, f32, i32, BT>>> {
let [batch, heads, query_len, qk_dim] = query.shape().dims::<4>();
let key_shape = key.shape().dims::<4>();
let value_shape = value.shape().dims::<4>();
let selector_shape = selector.shape().dims::<3>();
let output_shape = output.shape().dims::<3>();
ensure!(
key_shape[0] == batch && key_shape[1] == heads && key_shape[3] == qk_dim,
"DSA backward key shape mismatch: query={:?} key={:?}",
[batch, heads, query_len, qk_dim],
key_shape
);
ensure!(
value_shape[0] == batch && value_shape[1] == heads && value_shape[2] == key_shape[2],
"DSA backward value shape mismatch: key={key_shape:?} value={value_shape:?}",
);
ensure!(
selector_shape == [batch, query_len, key_shape[2]],
"DSA backward selector shape mismatch: selector={selector_shape:?} expected={:?}",
[batch, query_len, key_shape[2]]
);
ensure!(
output_shape == [batch, query_len, heads * value_shape[3]]
&& grad_output.shape().dims::<3>() == output_shape,
"DSA backward output shape mismatch: output={output_shape:?} grad={:?}",
grad_output.shape().dims::<3>()
);
ensure!(topk > 0, "DSA backward topk must be nonzero");
let query = into_contiguous(query.into_primitive().tensor());
let key = into_contiguous(key.into_primitive().tensor());
let value = into_contiguous(value.into_primitive().tensor());
let selector = into_contiguous(selector.into_primitive().tensor());
let output = into_contiguous(output.into_primitive().tensor());
let grad_output = into_contiguous(grad_output.into_primitive().tensor());
let context_len = key_shape[2];
let value_dim = value_shape[3];
let topk_mask = dsa_topk_mask_runtime::<R>(
selector,
batch,
query_len,
context_len,
topk.min(context_len),
)?;
let query_elements = batch
.checked_mul(heads)
.and_then(|v| v.checked_mul(query_len))
.and_then(|v| v.checked_mul(qk_dim))
.ok_or_else(|| anyhow::anyhow!("DSA backward query element count overflow"))?;
let key_elements = batch
.checked_mul(heads)
.and_then(|v| v.checked_mul(context_len))
.and_then(|v| v.checked_mul(qk_dim))
.ok_or_else(|| anyhow::anyhow!("DSA backward key element count overflow"))?;
let value_elements = batch
.checked_mul(heads)
.and_then(|v| v.checked_mul(context_len))
.and_then(|v| v.checked_mul(value_dim))
.ok_or_else(|| anyhow::anyhow!("DSA backward value element count overflow"))?;
let client = query.client.clone();
let device = query.device.clone();
let grad_query = zeros_client::<R>(
client.clone(),
device.clone(),
Shape::new([batch, heads, query_len, qk_dim]),
DType::F32,
);
let grad_key = zeros_client::<R>(
client.clone(),
device.clone(),
Shape::new([batch, heads, context_len, qk_dim]),
DType::F32,
);
let grad_value = zeros_client::<R>(
client.clone(),
device,
Shape::new([batch, heads, context_len, value_dim]),
DType::F32,
);
let pair_elements = batch
.checked_mul(heads)
.and_then(|v| v.checked_mul(query_len))
.and_then(|v| v.checked_mul(context_len))
.ok_or_else(|| anyhow::anyhow!("DSA backward pair element count overflow"))?;
let pair_cube_count = pair_elements.div_ceil(DSA_PAIR_BACKWARD_WORKGROUP_X as usize);
let _ = (query_elements, key_elements, value_elements);
dsa_attention_backward_pair_kernel::launch::<R>(
&client,
CubeCount::Static(pair_cube_count as u32, 1, 1),
CubeDim::new_1d(DSA_PAIR_BACKWARD_WORKGROUP_X),
query.into_tensor_arg(),
key.into_tensor_arg(),
value.into_tensor_arg(),
topk_mask.into_tensor_arg(),
output.into_tensor_arg(),
grad_output.into_tensor_arg(),
grad_query.clone().into_tensor_arg(),
grad_key.clone().into_tensor_arg(),
grad_value.clone().into_tensor_arg(),
batch,
heads,
query_len,
context_len,
qk_dim,
value_dim,
);
Ok(DsaAttentionBackwardOutput {
grad_query: Tensor::from_primitive(TensorPrimitive::Float(grad_query)),
grad_key: Tensor::from_primitive(TensorPrimitive::Float(grad_key)),
grad_value: Tensor::from_primitive(TensorPrimitive::Float(grad_value)),
})
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
pub fn dsa_attention_backward_cached_runtime<R: CubeRuntime, BT: BoolElement>(
query: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>,
key: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>,
value: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>,
forward: DsaAttentionCachedForwardOutput<R, BT>,
grad_output: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>,
) -> Result<DsaAttentionBackwardOutput<burn_cubecl::CubeBackend<R, f32, i32, BT>>> {
let [batch, heads, query_len, qk_dim] = query.shape().dims::<4>();
let key_shape = key.shape().dims::<4>();
let value_shape = value.shape().dims::<4>();
let output_shape = forward.output.shape().dims::<3>();
ensure!(
key_shape[0] == batch && key_shape[1] == heads && key_shape[3] == qk_dim,
"DSA cached backward key shape mismatch: query={:?} key={:?}",
[batch, heads, query_len, qk_dim],
key_shape
);
ensure!(
value_shape[0] == batch && value_shape[1] == heads && value_shape[2] == key_shape[2],
"DSA cached backward value shape mismatch: key={key_shape:?} value={value_shape:?}",
);
ensure!(
output_shape == [batch, query_len, heads * value_shape[3]]
&& grad_output.shape().dims::<3>() == output_shape,
"DSA cached backward output shape mismatch: output={output_shape:?} grad={:?}",
grad_output.shape().dims::<3>()
);
ensure!(
forward.weights.meta.shape.dims::<4>() == [batch, heads, query_len, key_shape[2]],
"DSA cached backward weight shape mismatch: weights={:?} expected={:?}",
forward.weights.meta.shape.dims::<4>(),
[batch, heads, query_len, key_shape[2]]
);
let query = into_contiguous(query.into_primitive().tensor());
let key = into_contiguous(key.into_primitive().tensor());
let value = into_contiguous(value.into_primitive().tensor());
let output = into_contiguous(forward.output.into_primitive().tensor());
let grad_output = into_contiguous(grad_output.into_primitive().tensor());
let context_len = key_shape[2];
let value_dim = value_shape[3];
let use_element_backward = should_use_cached_element_backward(forward.topk, context_len);
let weights = into_contiguous(forward.weights);
let client = query.client.clone();
let device = query.device.clone();
let grad_query = zeros_client::<R>(
client.clone(),
device.clone(),
Shape::new([batch, heads, query_len, qk_dim]),
DType::F32,
);
let grad_key = zeros_client::<R>(
client.clone(),
device.clone(),
Shape::new([batch, heads, context_len, qk_dim]),
DType::F32,
);
let grad_value = zeros_client::<R>(
client.clone(),
device,
Shape::new([batch, heads, context_len, value_dim]),
DType::F32,
);
let pair_elements = batch
.checked_mul(heads)
.and_then(|v| v.checked_mul(query_len))
.and_then(|v| v.checked_mul(context_len))
.ok_or_else(|| anyhow::anyhow!("DSA cached backward pair element count overflow"))?;
if use_element_backward {
let query_elements = batch
.checked_mul(heads)
.and_then(|v| v.checked_mul(query_len))
.and_then(|v| v.checked_mul(qk_dim))
.ok_or_else(|| anyhow::anyhow!("DSA cached backward query element count overflow"))?;
let key_elements = batch
.checked_mul(heads)
.and_then(|v| v.checked_mul(context_len))
.and_then(|v| v.checked_mul(qk_dim))
.ok_or_else(|| anyhow::anyhow!("DSA cached backward key element count overflow"))?;
let value_elements = batch
.checked_mul(heads)
.and_then(|v| v.checked_mul(context_len))
.and_then(|v| v.checked_mul(value_dim))
.ok_or_else(|| anyhow::anyhow!("DSA cached backward value element count overflow"))?;
let elements = query_elements
.checked_add(key_elements)
.and_then(|v| v.checked_add(value_elements))
.ok_or_else(|| anyhow::anyhow!("DSA cached backward element count overflow"))?;
let cube_dim = CubeDim::new(&client, elements);
let cube_count = calculate_cube_count_elemwise(&client, elements, cube_dim);
let address_type = [
query.required_address_type(),
key.required_address_type(),
value.required_address_type(),
weights.required_address_type(),
output.required_address_type(),
grad_output.required_address_type(),
grad_query.required_address_type(),
grad_key.required_address_type(),
grad_value.required_address_type(),
]
.into_iter()
.max()
.unwrap_or_default();
unsafe {
dsa_attention_backward_cached_element_kernel::launch_unchecked::<R>(
&client,
cube_count,
cube_dim,
address_type,
query.into_linear_view(),
key.into_linear_view(),
value.into_linear_view(),
weights.into_linear_view(),
output.into_linear_view(),
grad_output.into_linear_view(),
grad_query.clone().into_linear_view(),
grad_key.clone().into_linear_view(),
grad_value.clone().into_linear_view(),
elements,
query_len,
context_len,
heads,
qk_dim,
value_dim,
);
}
return Ok(DsaAttentionBackwardOutput {
grad_query: Tensor::from_primitive(TensorPrimitive::Float(grad_query)),
grad_key: Tensor::from_primitive(TensorPrimitive::Float(grad_key)),
grad_value: Tensor::from_primitive(TensorPrimitive::Float(grad_value)),
});
}
let pair_cube_count = pair_elements.div_ceil(DSA_PAIR_BACKWARD_WORKGROUP_X as usize);
dsa_attention_backward_cached_pair_kernel::launch::<R>(
&client,
CubeCount::Static(pair_cube_count as u32, 1, 1),
CubeDim::new_1d(DSA_PAIR_BACKWARD_WORKGROUP_X),
query.into_tensor_arg(),
key.into_tensor_arg(),
value.into_tensor_arg(),
weights.into_tensor_arg(),
output.into_tensor_arg(),
grad_output.into_tensor_arg(),
grad_query.clone().into_tensor_arg(),
grad_key.clone().into_tensor_arg(),
grad_value.clone().into_tensor_arg(),
batch,
heads,
query_len,
context_len,
qk_dim,
value_dim,
);
Ok(DsaAttentionBackwardOutput {
grad_query: Tensor::from_primitive(TensorPrimitive::Float(grad_query)),
grad_key: Tensor::from_primitive(TensorPrimitive::Float(grad_key)),
grad_value: Tensor::from_primitive(TensorPrimitive::Float(grad_value)),
})
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
fn try_cast_primitive<B: Backend, T: 'static>(value: B::FloatTensorPrimitive) -> Option<T>
where
B::FloatTensorPrimitive: 'static,
{
let boxed: Box<dyn Any> = Box::new(value);
boxed.downcast::<T>().ok().map(|boxed| *boxed)
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
fn try_cast_backend<B: Backend, T: 'static>(value: T) -> Option<B::FloatTensorPrimitive>
where
B::FloatTensorPrimitive: 'static,
{
let boxed: Box<dyn Any> = Box::new(value);
boxed
.downcast::<B::FloatTensorPrimitive>()
.ok()
.map(|boxed| *boxed)
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
#[derive(Debug, Clone)]
struct DsaAttentionBackwardState<FT> {
query: FT,
key: FT,
value: FT,
output: FT,
weights: FT,
topk: usize,
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
#[derive(Debug)]
struct DsaAttentionBackward<B>(PhantomData<B>);
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
fn dsa_attention_backward_impl<R: CubeRuntime, BT: BoolElement>(
ops: Ops<DsaAttentionBackwardState<CubeTensor<R>>, 3>,
grads: &mut Gradients,
) {
let grad_output = Tensor::<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>::from_primitive(
TensorPrimitive::Float(
grads.consume::<burn_cubecl::CubeBackend<R, f32, i32, BT>>(&ops.node),
),
);
let state = ops.state;
let parents = ops.parents;
let query = Tensor::<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>::from_primitive(
TensorPrimitive::Float(state.query),
);
let key = Tensor::<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>::from_primitive(
TensorPrimitive::Float(state.key),
);
let value = Tensor::<burn_cubecl::CubeBackend<R, f32, i32, BT>, 4>::from_primitive(
TensorPrimitive::Float(state.value),
);
let output = Tensor::<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>::from_primitive(
TensorPrimitive::Float(state.output),
);
let forward = DsaAttentionCachedForwardOutput {
output,
weights: state.weights,
topk: state.topk,
};
let grads_out =
dsa_attention_backward_cached_runtime::<R, BT>(query, key, value, forward, grad_output)
.expect("DSA fused attention backward failed");
if let Some(parent) = &parents[0] {
grads.register::<burn_cubecl::CubeBackend<R, f32, i32, BT>>(
parent.id,
grads_out.grad_query.into_primitive().tensor(),
);
}
if let Some(parent) = &parents[1] {
grads.register::<burn_cubecl::CubeBackend<R, f32, i32, BT>>(
parent.id,
grads_out.grad_key.into_primitive().tensor(),
);
}
if let Some(parent) = &parents[2] {
grads.register::<burn_cubecl::CubeBackend<R, f32, i32, BT>>(
parent.id,
grads_out.grad_value.into_primitive().tensor(),
);
}
}
#[cfg(feature = "wgpu-kernel")]
impl Backward<DsaWgpuKernelBackend, 3> for DsaAttentionBackward<DsaWgpuKernelBackend> {
type State = DsaAttentionBackwardState<CubeTensor<WgpuRuntime>>;
fn backward(
self,
ops: Ops<Self::State, 3>,
grads: &mut Gradients,
_checkpointer: &mut burn_autodiff::checkpoint::base::Checkpointer,
) {
dsa_attention_backward_impl::<WgpuRuntime, u32>(ops, grads);
}
}
#[cfg(feature = "cuda")]
impl Backward<DsaCudaKernelBackend, 3> for DsaAttentionBackward<DsaCudaKernelBackend> {
type State = DsaAttentionBackwardState<CubeTensor<CudaRuntime>>;
fn backward(
self,
ops: Ops<Self::State, 3>,
grads: &mut Gradients,
_checkpointer: &mut burn_autodiff::checkpoint::base::Checkpointer,
) {
dsa_attention_backward_impl::<CudaRuntime, u8>(ops, grads);
}
}
#[cfg(feature = "wgpu-kernel")]
fn try_dsa_attention_custom_backward_wgpu<B: Backend>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
selector: Tensor<B, 3>,
topk: usize,
) -> Option<Tensor<B, 3>>
where
B::FloatTensorPrimitive: 'static,
{
let query_ad: DsaWgpuAutodiffTensor =
try_cast_primitive::<B, _>(query.into_primitive().tensor())?;
let key_ad: DsaWgpuAutodiffTensor = try_cast_primitive::<B, _>(key.into_primitive().tensor())?;
let value_ad: DsaWgpuAutodiffTensor =
try_cast_primitive::<B, _>(value.into_primitive().tensor())?;
let selector_ad: DsaWgpuAutodiffTensor =
try_cast_primitive::<B, _>(selector.into_primitive().tensor())?;
let query_inner = <DsaWgpuAutodiffBackend as AutodiffBackend>::inner(query_ad.clone());
let key_inner = <DsaWgpuAutodiffBackend as AutodiffBackend>::inner(key_ad.clone());
let value_inner = <DsaWgpuAutodiffBackend as AutodiffBackend>::inner(value_ad.clone());
let selector_inner = <DsaWgpuAutodiffBackend as AutodiffBackend>::inner(selector_ad);
let forward = dsa_attention_forward_cached_runtime::<WgpuRuntime, u32>(
Tensor::<DsaWgpuKernelBackend, 4>::from_primitive(TensorPrimitive::Float(
query_inner.clone(),
)),
Tensor::<DsaWgpuKernelBackend, 4>::from_primitive(TensorPrimitive::Float(
key_inner.clone(),
)),
Tensor::<DsaWgpuKernelBackend, 4>::from_primitive(TensorPrimitive::Float(
value_inner.clone(),
)),
Tensor::<DsaWgpuKernelBackend, 3>::from_primitive(TensorPrimitive::Float(
selector_inner.clone(),
)),
topk,
)
.ok()?;
let forward_topk = forward.topk;
let output = forward.output.into_primitive().tensor();
let output_ad = match DsaAttentionBackward::<DsaWgpuKernelBackend>(PhantomData)
.prepare::<NoCheckpointing>([
query_ad.node.clone(),
key_ad.node.clone(),
value_ad.node.clone(),
])
.compute_bound()
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
DsaAttentionBackwardState {
query: query_inner,
key: key_inner,
value: value_inner,
output: output.clone(),
weights: forward.weights,
topk: forward_topk,
},
output,
),
OpsKind::UnTracked(prep) => prep.finish(output),
};
Some(Tensor::<B, 3>::from_primitive(TensorPrimitive::Float(
try_cast_backend::<B, _>(output_ad)?,
)))
}
#[cfg(feature = "cuda")]
fn try_dsa_attention_custom_backward_cuda<B: Backend>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
selector: Tensor<B, 3>,
topk: usize,
) -> Option<Tensor<B, 3>>
where
B::FloatTensorPrimitive: 'static,
{
let query_ad: DsaCudaAutodiffTensor =
try_cast_primitive::<B, _>(query.into_primitive().tensor())?;
let key_ad: DsaCudaAutodiffTensor = try_cast_primitive::<B, _>(key.into_primitive().tensor())?;
let value_ad: DsaCudaAutodiffTensor =
try_cast_primitive::<B, _>(value.into_primitive().tensor())?;
let selector_ad: DsaCudaAutodiffTensor =
try_cast_primitive::<B, _>(selector.into_primitive().tensor())?;
let query_inner = <DsaCudaAutodiffBackend as AutodiffBackend>::inner(query_ad.clone());
let key_inner = <DsaCudaAutodiffBackend as AutodiffBackend>::inner(key_ad.clone());
let value_inner = <DsaCudaAutodiffBackend as AutodiffBackend>::inner(value_ad.clone());
let selector_inner = <DsaCudaAutodiffBackend as AutodiffBackend>::inner(selector_ad);
let forward = dsa_attention_forward_cached_runtime::<CudaRuntime, u8>(
Tensor::<DsaCudaKernelBackend, 4>::from_primitive(TensorPrimitive::Float(
query_inner.clone(),
)),
Tensor::<DsaCudaKernelBackend, 4>::from_primitive(TensorPrimitive::Float(
key_inner.clone(),
)),
Tensor::<DsaCudaKernelBackend, 4>::from_primitive(TensorPrimitive::Float(
value_inner.clone(),
)),
Tensor::<DsaCudaKernelBackend, 3>::from_primitive(TensorPrimitive::Float(
selector_inner.clone(),
)),
topk,
)
.ok()?;
let forward_topk = forward.topk;
let output = forward.output.into_primitive().tensor();
let output_ad = match DsaAttentionBackward::<DsaCudaKernelBackend>(PhantomData)
.prepare::<NoCheckpointing>([
query_ad.node.clone(),
key_ad.node.clone(),
value_ad.node.clone(),
])
.compute_bound()
.stateful()
{
OpsKind::Tracked(prep) => prep.finish(
DsaAttentionBackwardState {
query: query_inner,
key: key_inner,
value: value_inner,
output: output.clone(),
weights: forward.weights,
topk: forward_topk,
},
output,
),
OpsKind::UnTracked(prep) => prep.finish(output),
};
Some(Tensor::<B, 3>::from_primitive(TensorPrimitive::Float(
try_cast_backend::<B, _>(output_ad)?,
)))
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct DsaConfig {
pub n_heads: usize,
pub n_kv_heads: usize,
pub dsa_dim: usize,
pub dsa_tail_dim: usize,
pub score_dim: usize,
pub topk: usize,
pub block_i: usize,
pub mlp_ratio: f32,
pub layer_norm_eps: f64,
pub executor: DsaExecutor,
}
impl Default for DsaConfig {
fn default() -> Self {
Self {
n_heads: 12,
n_kv_heads: 1,
dsa_dim: 128,
dsa_tail_dim: 32,
score_dim: 32,
topk: 64,
block_i: 16,
mlp_ratio: 4.0,
layer_norm_eps: 1.0e-6,
executor: DsaExecutor::Reference,
}
}
}
impl DsaConfig {
pub fn validate(&self, hidden_dim: usize) -> Result<()> {
ensure!(hidden_dim > 0, "DSA hidden_dim must be nonzero");
ensure!(self.n_heads > 0, "dsa.n_heads must be nonzero");
ensure!(self.n_kv_heads > 0, "dsa.n_kv_heads must be nonzero");
ensure!(
self.n_heads.is_multiple_of(self.n_kv_heads),
"dsa.n_heads must be divisible by dsa.n_kv_heads"
);
ensure!(self.dsa_dim > 0, "dsa.dsa_dim must be nonzero");
ensure!(
self.dsa_dim.is_power_of_two(),
"dsa.dsa_dim must be a power of two for the fused-kernel profile"
);
ensure!(
self.dsa_tail_dim > 0 && self.dsa_tail_dim.is_power_of_two(),
"dsa.dsa_tail_dim must be a nonzero power of two"
);
ensure!(self.score_dim > 0, "dsa.score_dim must be nonzero");
ensure!(self.topk > 0, "dsa.topk must be nonzero");
ensure!(self.block_i > 0, "dsa.block_i must be nonzero");
ensure!(
self.topk.is_multiple_of(self.block_i),
"dsa.topk must be divisible by dsa.block_i"
);
ensure!(
self.mlp_ratio.is_finite() && self.mlp_ratio >= 1.0,
"dsa.mlp_ratio must be finite and >= 1"
);
ensure!(
self.layer_norm_eps.is_finite() && self.layer_norm_eps > 0.0,
"dsa.layer_norm_eps must be finite and positive"
);
Ok(())
}
pub fn qk_dim(&self) -> usize {
self.dsa_dim + self.dsa_tail_dim
}
}
#[derive(Module, Debug)]
pub struct DsaMlp<B: Backend> {
pub fc1: Linear<B>,
pub fc2: Linear<B>,
}
impl<B: Backend> DsaMlp<B> {
pub fn new(hidden_dim: usize, mlp_ratio: f32, device: &B::Device) -> Self {
let hidden = ((hidden_dim as f32) * mlp_ratio.max(1.0)).round() as usize;
Self {
fc1: LinearConfig::new(hidden_dim.max(1), hidden.max(1)).init(device),
fc2: LinearConfig::new(hidden.max(1), hidden_dim.max(1)).init(device),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
self.fc2.forward(activation::gelu(self.fc1.forward(x)))
}
}
#[derive(Module, Debug)]
pub struct DsaBlock<B: Backend> {
pub attention_norm: LayerNorm<B>,
pub q_score: Linear<B>,
pub k_score: Linear<B>,
pub q_proj: Linear<B>,
pub k_proj: Linear<B>,
pub v_proj: Linear<B>,
pub o_proj: Linear<B>,
pub ffn_norm: LayerNorm<B>,
pub ffn: DsaMlp<B>,
#[module(skip)]
pub config: DsaConfig,
}
impl<B: Backend> DsaBlock<B>
where
B::FloatTensorPrimitive: 'static,
B::FloatElem: 'static,
{
pub fn new(hidden_dim: usize, mut config: DsaConfig, device: &B::Device) -> Result<Self> {
if config.n_heads == 0 {
config.n_heads = 1;
}
config.validate(hidden_dim)?;
let qk_dim = config.qk_dim();
Ok(Self {
attention_norm: LayerNormConfig::new(hidden_dim)
.with_epsilon(config.layer_norm_eps)
.init(device),
q_score: LinearConfig::new(hidden_dim, config.score_dim).init(device),
k_score: LinearConfig::new(hidden_dim, config.score_dim).init(device),
q_proj: LinearConfig::new(hidden_dim, config.n_heads * qk_dim).init(device),
k_proj: LinearConfig::new(hidden_dim, config.n_kv_heads * qk_dim).init(device),
v_proj: LinearConfig::new(hidden_dim, config.n_kv_heads * config.dsa_dim).init(device),
o_proj: LinearConfig::new(config.n_heads * config.dsa_dim, hidden_dim).init(device),
ffn_norm: LayerNormConfig::new(hidden_dim)
.with_epsilon(config.layer_norm_eps)
.init(device),
ffn: DsaMlp::new(hidden_dim, config.mlp_ratio, device),
config,
})
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
self.forward_with_context(x.clone(), x)
}
pub fn forward_with_context(
&self,
query_tokens: Tensor<B, 3>,
context_tokens: Tensor<B, 3>,
) -> Tensor<B, 3> {
let query_norm = self.attention_norm.forward(query_tokens.clone());
let context_norm = self.attention_norm.forward(context_tokens);
let attention = match self.config.executor {
DsaExecutor::Reference => self.reference_sparse_attention(query_norm, context_norm),
DsaExecutor::CudaKernel | DsaExecutor::WgpuKernel => {
try_dsa_fused_attention(self, query_norm.clone(), context_norm.clone())
.unwrap_or_else(|| self.reference_sparse_attention(query_norm, context_norm))
}
};
let h = query_tokens + self.o_proj.forward(attention);
h.clone() + self.ffn.forward(self.ffn_norm.forward(h))
}
fn reference_sparse_attention(
&self,
query_tokens: Tensor<B, 3>,
context_tokens: Tensor<B, 3>,
) -> Tensor<B, 3> {
let [batch, query_len, _hidden] = query_tokens.shape().dims::<3>();
let [context_batch, context_len, _] = context_tokens.shape().dims::<3>();
debug_assert_eq!(batch, context_batch);
let qk_dim = self.config.qk_dim();
let heads = self.config.n_heads;
let kv_heads = self.config.n_kv_heads;
let q = self
.q_proj
.forward(query_tokens.clone())
.reshape([batch, query_len, heads, qk_dim])
.permute([0, 2, 1, 3]);
let mut k = self
.k_proj
.forward(context_tokens.clone())
.reshape([batch, context_len, kv_heads, qk_dim])
.permute([0, 2, 1, 3]);
let mut v = self
.v_proj
.forward(context_tokens.clone())
.reshape([batch, context_len, kv_heads, self.config.dsa_dim])
.permute([0, 2, 1, 3]);
if kv_heads != heads {
let repeat = heads / kv_heads;
k = k.repeat_dim(1, repeat);
v = v.repeat_dim(1, repeat);
}
let mut logits = q
.matmul(k.swap_dims(2, 3))
.mul_scalar((qk_dim as f64).powf(-0.5));
let topk = self.config.topk.min(context_len);
if topk < context_len {
let selector = self.selection_scores(query_tokens, context_tokens);
let (top_values, _top_indices) = selector.clone().topk_with_indices(topk, 2);
let threshold =
top_values
.slice_dim(2, topk - 1..topk)
.expand([batch, query_len, context_len]);
let mask = selector.lower(threshold);
logits = logits.mask_fill(mask.unsqueeze_dim::<4>(1).repeat_dim(1, heads), -1.0e9);
}
let weights = activation::softmax(logits, 3);
weights.matmul(v).permute([0, 2, 1, 3]).reshape([
batch,
query_len,
heads * self.config.dsa_dim,
])
}
fn selection_scores(
&self,
query_tokens: Tensor<B, 3>,
context_tokens: Tensor<B, 3>,
) -> Tensor<B, 3> {
let score_dim = self.config.score_dim;
self.q_score
.forward(query_tokens)
.matmul(self.k_score.forward(context_tokens).swap_dims(1, 2))
.mul_scalar((score_dim as f64).powf(-0.5))
}
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
fn dsa_fused_sparse_attention_for_block<R: CubeRuntime, BT: BoolElement>(
block: &DsaBlock<burn_cubecl::CubeBackend<R, f32, i32, BT>>,
query_tokens: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>,
context_tokens: Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>,
) -> Option<Tensor<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>> {
let [batch, query_len, _hidden] = query_tokens.shape().dims::<3>();
let [context_batch, context_len, _] = context_tokens.shape().dims::<3>();
if batch != context_batch || context_len == 0 || query_len == 0 {
return None;
}
if !should_use_custom_dsa_attention(block.config.topk, context_len) {
return None;
}
let qk_dim = block.config.qk_dim();
let heads = block.config.n_heads;
let kv_heads = block.config.n_kv_heads;
let q = block
.q_proj
.forward(query_tokens.clone())
.reshape([batch, query_len, heads, qk_dim])
.permute([0, 2, 1, 3]);
let mut k = block
.k_proj
.forward(context_tokens.clone())
.reshape([batch, context_len, kv_heads, qk_dim])
.permute([0, 2, 1, 3]);
let mut v = block
.v_proj
.forward(context_tokens.clone())
.reshape([batch, context_len, kv_heads, block.config.dsa_dim])
.permute([0, 2, 1, 3]);
if kv_heads != heads {
let repeat = heads / kv_heads;
k = k.repeat_dim(1, repeat);
v = v.repeat_dim(1, repeat);
}
let selector = if block.config.topk >= context_len {
Tensor::<burn_cubecl::CubeBackend<R, f32, i32, BT>, 3>::zeros(
[batch, query_len, context_len],
&query_tokens.device(),
)
} else {
block.selection_scores(query_tokens, context_tokens)
};
dsa_attention_forward_runtime::<R, BT>(q, k, v, selector, block.config.topk).ok()
}
#[cfg(feature = "wgpu-kernel")]
impl DsaBlock<DsaWgpuKernelBackend> {
pub fn fused_sparse_attention(
&self,
query_tokens: Tensor<DsaWgpuKernelBackend, 3>,
context_tokens: Tensor<DsaWgpuKernelBackend, 3>,
) -> Option<Tensor<DsaWgpuKernelBackend, 3>> {
dsa_fused_sparse_attention_for_block::<WgpuRuntime, u32>(self, query_tokens, context_tokens)
}
pub fn forward_with_context_fused_kernel(
&self,
query_tokens: Tensor<DsaWgpuKernelBackend, 3>,
context_tokens: Tensor<DsaWgpuKernelBackend, 3>,
) -> Tensor<DsaWgpuKernelBackend, 3> {
let query_norm = self.attention_norm.forward(query_tokens.clone());
let context_norm = self.attention_norm.forward(context_tokens.clone());
let attention = self
.fused_sparse_attention(query_norm.clone(), context_norm.clone())
.unwrap_or_else(|| self.reference_sparse_attention(query_norm, context_norm));
let h = query_tokens + self.o_proj.forward(attention);
h.clone() + self.ffn.forward(self.ffn_norm.forward(h))
}
}
#[cfg(feature = "cuda")]
impl DsaBlock<DsaCudaKernelBackend> {
pub fn fused_sparse_attention(
&self,
query_tokens: Tensor<DsaCudaKernelBackend, 3>,
context_tokens: Tensor<DsaCudaKernelBackend, 3>,
) -> Option<Tensor<DsaCudaKernelBackend, 3>> {
dsa_fused_sparse_attention_for_block::<CudaRuntime, u8>(self, query_tokens, context_tokens)
}
pub fn forward_with_context_fused_kernel(
&self,
query_tokens: Tensor<DsaCudaKernelBackend, 3>,
context_tokens: Tensor<DsaCudaKernelBackend, 3>,
) -> Tensor<DsaCudaKernelBackend, 3> {
let query_norm = self.attention_norm.forward(query_tokens.clone());
let context_norm = self.attention_norm.forward(context_tokens.clone());
let attention = self
.fused_sparse_attention(query_norm.clone(), context_norm.clone())
.unwrap_or_else(|| self.reference_sparse_attention(query_norm, context_norm));
let h = query_tokens + self.o_proj.forward(attention);
h.clone() + self.ffn.forward(self.ffn_norm.forward(h))
}
}
pub fn supports_dsa_fused_backend<B: Backend>(executor: DsaExecutor) -> bool
where
B::FloatTensorPrimitive: 'static,
B::FloatElem: 'static,
{
if TypeId::of::<B::FloatElem>() != TypeId::of::<f32>() {
return executor == DsaExecutor::Reference;
}
match executor {
DsaExecutor::Reference => true,
DsaExecutor::CudaKernel => {
#[cfg(feature = "cuda")]
{
let tensor_type = TypeId::of::<B::FloatTensorPrimitive>();
tensor_type == TypeId::of::<DsaCudaKernelTensor>()
|| tensor_type == TypeId::of::<DsaCudaAutodiffTensor>()
}
#[cfg(not(feature = "cuda"))]
{
false
}
}
DsaExecutor::WgpuKernel => {
#[cfg(feature = "wgpu-kernel")]
{
let tensor_type = TypeId::of::<B::FloatTensorPrimitive>();
tensor_type == TypeId::of::<DsaWgpuKernelTensor>()
|| tensor_type == TypeId::of::<DsaWgpuAutodiffTensor>()
}
#[cfg(not(feature = "wgpu-kernel"))]
{
false
}
}
}
}
pub fn supports_dsa_fused_executor(executor: DsaExecutor) -> bool {
match executor {
DsaExecutor::Reference => true,
DsaExecutor::CudaKernel => cfg!(feature = "cuda"),
DsaExecutor::WgpuKernel => cfg!(feature = "wgpu-kernel"),
}
}
pub fn dsa_kernel_status<B: Backend>(executor: DsaExecutor) -> &'static str
where
B::FloatTensorPrimitive: 'static,
B::FloatElem: 'static,
{
if supports_dsa_fused_backend::<B>(executor) {
"available"
} else {
"reference_fallback"
}
}
fn try_dsa_fused_attention<B: Backend>(
block: &DsaBlock<B>,
query_tokens: Tensor<B, 3>,
context_tokens: Tensor<B, 3>,
) -> Option<Tensor<B, 3>> {
if !supports_dsa_fused_backend::<B>(block.config.executor) {
return None;
}
let [batch, query_len, _hidden] = query_tokens.shape().dims::<3>();
let [context_batch, context_len, _] = context_tokens.shape().dims::<3>();
if batch != context_batch || query_len == 0 || context_len == 0 {
return None;
}
if !should_use_custom_dsa_attention(block.config.topk, context_len) {
return None;
}
let qk_dim = block.config.qk_dim();
let heads = block.config.n_heads;
let kv_heads = block.config.n_kv_heads;
let q = block
.q_proj
.forward(query_tokens.clone())
.reshape([batch, query_len, heads, qk_dim])
.permute([0, 2, 1, 3]);
let mut k = block
.k_proj
.forward(context_tokens.clone())
.reshape([batch, context_len, kv_heads, qk_dim])
.permute([0, 2, 1, 3]);
let mut v = block
.v_proj
.forward(context_tokens.clone())
.reshape([batch, context_len, kv_heads, block.config.dsa_dim])
.permute([0, 2, 1, 3]);
if kv_heads != heads {
let repeat = heads / kv_heads;
k = k.repeat_dim(1, repeat);
v = v.repeat_dim(1, repeat);
}
let selector = if block.config.topk >= context_len {
Tensor::<B, 3>::zeros([batch, query_len, context_len], &query_tokens.device())
} else {
block.selection_scores(query_tokens, context_tokens)
};
try_dsa_attention_forward(q, k, v, selector, block.config.topk)
}
#[allow(unused_variables)]
fn try_dsa_attention_forward<B: Backend>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
selector: Tensor<B, 3>,
topk: usize,
) -> Option<Tensor<B, 3>>
where
B::FloatTensorPrimitive: 'static,
{
let selector = selector.detach();
#[cfg(feature = "cuda")]
if let Some(output) = try_dsa_attention_custom_backward_cuda::<B>(
query.clone(),
key.clone(),
value.clone(),
selector.clone(),
topk,
) {
return Some(output);
}
#[cfg(feature = "wgpu-kernel")]
if let Some(output) = try_dsa_attention_custom_backward_wgpu::<B>(
query.clone(),
key.clone(),
value.clone(),
selector.clone(),
topk,
) {
return Some(output);
}
#[cfg(feature = "cuda")]
if let Some(output) = try_dsa_attention_raw_cuda::<B>(
query.clone(),
key.clone(),
value.clone(),
selector.clone(),
topk,
) {
return Some(output);
}
#[cfg(feature = "wgpu-kernel")]
if let Some(output) = try_dsa_attention_raw_wgpu::<B>(query, key, value, selector, topk) {
return Some(output);
}
None
}
#[cfg(feature = "cuda")]
fn try_dsa_attention_raw_cuda<B: Backend>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
selector: Tensor<B, 3>,
topk: usize,
) -> Option<Tensor<B, 3>>
where
B::FloatTensorPrimitive: 'static,
{
let query = try_cast_primitive::<B, _>(query.into_primitive().tensor())?;
let key = try_cast_primitive::<B, _>(key.into_primitive().tensor())?;
let value = try_cast_primitive::<B, _>(value.into_primitive().tensor())?;
let selector = try_cast_primitive::<B, _>(selector.into_primitive().tensor())?;
let output = dsa_attention_forward_runtime::<CudaRuntime, u8>(
Tensor::<DsaCudaKernelBackend, 4>::from_primitive(TensorPrimitive::Float(query)),
Tensor::<DsaCudaKernelBackend, 4>::from_primitive(TensorPrimitive::Float(key)),
Tensor::<DsaCudaKernelBackend, 4>::from_primitive(TensorPrimitive::Float(value)),
Tensor::<DsaCudaKernelBackend, 3>::from_primitive(TensorPrimitive::Float(selector)),
topk,
)
.ok()?
.into_primitive()
.tensor();
Some(Tensor::<B, 3>::from_primitive(TensorPrimitive::Float(
try_cast_backend::<B, _>(output)?,
)))
}
#[cfg(feature = "wgpu-kernel")]
fn try_dsa_attention_raw_wgpu<B: Backend>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
selector: Tensor<B, 3>,
topk: usize,
) -> Option<Tensor<B, 3>>
where
B::FloatTensorPrimitive: 'static,
{
let query = try_cast_primitive::<B, _>(query.into_primitive().tensor())?;
let key = try_cast_primitive::<B, _>(key.into_primitive().tensor())?;
let value = try_cast_primitive::<B, _>(value.into_primitive().tensor())?;
let selector = try_cast_primitive::<B, _>(selector.into_primitive().tensor())?;
let output = dsa_attention_forward_runtime::<WgpuRuntime, u32>(
Tensor::<DsaWgpuKernelBackend, 4>::from_primitive(TensorPrimitive::Float(query)),
Tensor::<DsaWgpuKernelBackend, 4>::from_primitive(TensorPrimitive::Float(key)),
Tensor::<DsaWgpuKernelBackend, 4>::from_primitive(TensorPrimitive::Float(value)),
Tensor::<DsaWgpuKernelBackend, 3>::from_primitive(TensorPrimitive::Float(selector)),
topk,
)
.ok()?
.into_primitive()
.tensor();
Some(Tensor::<B, 3>::from_primitive(TensorPrimitive::Float(
try_cast_backend::<B, _>(output)?,
)))
}
pub fn dsa_parameter_count(hidden_dim: usize, config: &DsaConfig) -> usize {
let mlp_hidden = ((hidden_dim as f32) * config.mlp_ratio.max(1.0)).round() as usize;
let qk_dim = config.qk_dim();
linear_params(hidden_dim, config.score_dim)
+ linear_params(hidden_dim, config.score_dim)
+ linear_params(hidden_dim, config.n_heads * qk_dim)
+ linear_params(hidden_dim, config.n_kv_heads * qk_dim)
+ linear_params(hidden_dim, config.n_kv_heads * config.dsa_dim)
+ linear_params(config.n_heads * config.dsa_dim, hidden_dim)
+ layer_norm_params(hidden_dim)
+ linear_params(hidden_dim, mlp_hidden)
+ linear_params(mlp_hidden, hidden_dim)
+ layer_norm_params(hidden_dim)
}
pub fn dsa_macs_per_token(hidden_dim: usize, context_tokens: usize, config: &DsaConfig) -> u128 {
let qk_dim = config.qk_dim();
let heads = config.n_heads.max(1);
let kv_heads = config.n_kv_heads.max(1);
let context_tokens = context_tokens.max(1);
let topk = config.topk.min(context_tokens);
let mlp_hidden = ((hidden_dim as f32) * config.mlp_ratio.max(1.0)).round() as usize;
let selector_macs = if topk < context_tokens {
linear_macs(hidden_dim, config.score_dim)
+ linear_macs(hidden_dim, config.score_dim)
+ (context_tokens as u128) * (config.score_dim as u128)
} else {
0
};
selector_macs
+ linear_macs(hidden_dim, heads * qk_dim)
+ linear_macs(hidden_dim, kv_heads * qk_dim)
+ linear_macs(hidden_dim, kv_heads * config.dsa_dim)
+ (topk.max(1) as u128) * (heads as u128) * (qk_dim as u128 + config.dsa_dim as u128)
+ linear_macs(heads * config.dsa_dim, hidden_dim)
+ linear_macs(hidden_dim, mlp_hidden)
+ linear_macs(mlp_hidden, hidden_dim)
}
fn linear_params(input: usize, output: usize) -> usize {
input.saturating_mul(output).saturating_add(output)
}
fn layer_norm_params(hidden_dim: usize) -> usize {
hidden_dim.saturating_mul(2)
}
fn linear_macs(input: usize, output: usize) -> u128 {
(input as u128).saturating_mul(output as u128)
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn::tensor::Distribution;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use burn::tensor::TensorData;
type TestBackend = NdArray<f32>;
fn tiny_config(executor: DsaExecutor) -> DsaConfig {
DsaConfig {
n_heads: 2,
n_kv_heads: 1,
dsa_dim: 4,
dsa_tail_dim: 4,
score_dim: 4,
topk: 2,
block_i: 1,
mlp_ratio: 2.0,
executor,
..DsaConfig::default()
}
}
#[test]
fn dsa_block_preserves_shape() {
let device = Default::default();
let block = DsaBlock::<TestBackend>::new(8, tiny_config(DsaExecutor::Reference), &device)
.expect("DSA block");
let input = Tensor::<TestBackend, 3>::zeros([2, 5, 8], &device);
let output = block.forward(input);
assert_eq!(output.shape().dims::<3>(), [2, 5, 8]);
}
#[test]
fn dsa_kernel_executors_fallback_to_reference() {
let device = Default::default();
let reference =
DsaBlock::<TestBackend>::new(8, tiny_config(DsaExecutor::Reference), &device)
.expect("reference DSA");
let kernel = DsaBlock::<TestBackend>::new(8, tiny_config(DsaExecutor::CudaKernel), &device)
.expect("kernel DSA");
let input = Tensor::<TestBackend, 3>::zeros([1, 4, 8], &device);
let reference_output = reference.forward(input.clone());
let kernel_output = kernel.forward(input);
assert_eq!(reference_output.shape(), kernel_output.shape());
assert_eq!(
dsa_kernel_status::<TestBackend>(DsaExecutor::CudaKernel),
"reference_fallback"
);
}
#[test]
fn dsa_estimates_are_nonzero() {
let config = tiny_config(DsaExecutor::Reference);
assert!(dsa_parameter_count(8, &config) > 0);
assert!(dsa_macs_per_token(8, 4, &config) > 0);
}
#[test]
fn custom_dsa_attention_is_reserved_for_sparse_topk() {
assert!(should_use_custom_dsa_attention(2, 4));
assert!(!should_use_custom_dsa_attention(4, 4));
assert!(!should_use_custom_dsa_attention(8, 4));
}
#[cfg(feature = "cuda")]
#[test]
fn cuda_bf16_backend_reports_reference_fallback_for_f32_kernel() {
assert_eq!(
dsa_kernel_status::<DsaCudaBf16KernelBackend>(DsaExecutor::CudaKernel),
"reference_fallback"
);
}
#[test]
fn dsa_full_topk_macs_skip_selector_projection() {
let mut narrow_score = tiny_config(DsaExecutor::Reference);
narrow_score.topk = 4;
narrow_score.score_dim = 4;
let mut wide_score = narrow_score.clone();
wide_score.score_dim = 64;
assert_eq!(
dsa_macs_per_token(8, 4, &narrow_score),
dsa_macs_per_token(8, 4, &wide_score)
);
narrow_score.topk = 2;
wide_score.topk = 2;
assert!(dsa_macs_per_token(8, 4, &wide_score) > dsa_macs_per_token(8, 4, &narrow_score));
}
#[cfg(feature = "cuda")]
#[test]
fn cuda_fused_kernel_matches_reference_block() {
let device = Default::default();
let config = tiny_config(DsaExecutor::Reference);
assert_eq!(
dsa_kernel_status::<DsaCudaKernelBackend>(DsaExecutor::CudaKernel),
"available"
);
let block = DsaBlock::<DsaCudaKernelBackend>::new(8, config, &device).expect("DSA block");
let query = Tensor::<DsaCudaKernelBackend, 3>::random(
[1, 3, 8],
Distribution::Normal(0.0, 1.0),
&device,
);
let context = Tensor::<DsaCudaKernelBackend, 3>::random(
[1, 5, 8],
Distribution::Normal(0.0, 1.0),
&device,
);
let reference = block.forward_with_context(query.clone(), context.clone());
let fused = block.forward_with_context_fused_kernel(query, context);
let max_diff = reference.sub(fused).abs().max().into_scalar();
assert!(max_diff < 1.0e-3, "max diff {max_diff}");
}
#[cfg(feature = "cuda")]
#[test]
fn cuda_fused_full_topk_kernel_matches_reference_block() {
let device = Default::default();
let mut config = tiny_config(DsaExecutor::Reference);
config.topk = 4;
config.block_i = 1;
let block = DsaBlock::<DsaCudaKernelBackend>::new(8, config, &device).expect("DSA block");
let query = Tensor::<DsaCudaKernelBackend, 3>::random(
[1, 4, 8],
Distribution::Normal(0.0, 1.0),
&device,
);
let context = Tensor::<DsaCudaKernelBackend, 3>::random(
[1, 4, 8],
Distribution::Normal(0.0, 1.0),
&device,
);
let reference = block.forward_with_context(query.clone(), context.clone());
let fused = block.forward_with_context_fused_kernel(query, context);
let max_diff = reference.sub(fused).abs().max().into_scalar();
assert!(max_diff < 1.0e-3, "max diff {max_diff}");
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
fn deterministic_data<const D: usize>(
shape: [usize; D],
stride: usize,
offset: f32,
) -> TensorData {
let len = shape.iter().product::<usize>();
let values = (0..len)
.map(|index| {
let value = ((index * stride) % 29) as f32 / 29.0;
value + offset
})
.collect::<Vec<_>>();
TensorData::new(values, shape)
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
fn projected_reference_attention<B: Backend>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
selector: Tensor<B, 3>,
topk: usize,
) -> Tensor<B, 3>
where
B::FloatTensorPrimitive: 'static,
{
let [batch, heads, query_len, qk_dim] = query.shape().dims::<4>();
let context_len = key.shape().dims::<4>()[2];
let value_dim = value.shape().dims::<4>()[3];
let mut logits = query
.matmul(key.swap_dims(2, 3))
.mul_scalar((qk_dim as f64).powf(-0.5));
let topk = topk.min(context_len);
if topk < context_len {
let (top_values, _top_indices) = selector.clone().topk_with_indices(topk, 2);
let threshold =
top_values
.slice_dim(2, topk - 1..topk)
.expand([batch, query_len, context_len]);
let mask = selector.lower(threshold);
logits = logits.mask_fill(mask.unsqueeze_dim::<4>(1).repeat_dim(1, heads), -1.0e9);
}
activation::softmax(logits, 3)
.matmul(value)
.permute([0, 2, 1, 3])
.reshape([batch, query_len, heads * value_dim])
}
#[cfg(feature = "cuda")]
#[test]
fn cuda_custom_backward_matches_reference_graph_projected_attention() {
type AdBackend = Autodiff<DsaCudaKernelBackend>;
let device = Default::default();
assert_eq!(
dsa_kernel_status::<AdBackend>(DsaExecutor::CudaKernel),
"available"
);
let query_data = deterministic_data([1, 2, 3, 4], 3, -0.35);
let key_data = deterministic_data([1, 2, 5, 4], 5, -0.25);
let value_data = deterministic_data([1, 2, 5, 4], 7, -0.15);
let selector_data = deterministic_data([1, 3, 5], 11, -0.1);
let topk = 2;
let query_custom =
Tensor::<AdBackend, 4>::from_data(query_data.clone(), &device).require_grad();
let key_custom =
Tensor::<AdBackend, 4>::from_data(key_data.clone(), &device).require_grad();
let value_custom =
Tensor::<AdBackend, 4>::from_data(value_data.clone(), &device).require_grad();
let selector_custom = Tensor::<AdBackend, 3>::from_data(selector_data.clone(), &device);
let custom = try_dsa_attention_forward(
query_custom.clone(),
key_custom.clone(),
value_custom.clone(),
selector_custom,
topk,
)
.expect("CUDA DSA custom backward path");
let custom_grads = custom.sum().backward();
let custom_query_grad = query_custom.grad(&custom_grads).expect("custom query grad");
let custom_key_grad = key_custom.grad(&custom_grads).expect("custom key grad");
let custom_value_grad = value_custom.grad(&custom_grads).expect("custom value grad");
let query_reference = Tensor::<AdBackend, 4>::from_data(query_data, &device).require_grad();
let key_reference = Tensor::<AdBackend, 4>::from_data(key_data, &device).require_grad();
let value_reference = Tensor::<AdBackend, 4>::from_data(value_data, &device).require_grad();
let selector_reference = Tensor::<AdBackend, 3>::from_data(selector_data, &device);
let reference = projected_reference_attention(
query_reference.clone(),
key_reference.clone(),
value_reference.clone(),
selector_reference,
topk,
);
let reference_grads = reference.sum().backward();
let reference_query_grad = query_reference
.grad(&reference_grads)
.expect("reference query grad");
let reference_key_grad = key_reference
.grad(&reference_grads)
.expect("reference key grad");
let reference_value_grad = value_reference
.grad(&reference_grads)
.expect("reference value grad");
let query_diff = custom_query_grad
.sub(reference_query_grad)
.abs()
.max()
.into_scalar();
let key_diff = custom_key_grad
.sub(reference_key_grad)
.abs()
.max()
.into_scalar();
let value_diff = custom_value_grad
.sub(reference_value_grad)
.abs()
.max()
.into_scalar();
assert!(query_diff < 2.0e-3, "query grad diff {query_diff}");
assert!(key_diff < 2.0e-3, "key grad diff {key_diff}");
assert!(value_diff < 2.0e-3, "value grad diff {value_diff}");
}
#[cfg(feature = "cuda")]
#[test]
fn cuda_custom_backward_detaches_selector_graph() {
type AdBackend = Autodiff<DsaCudaKernelBackend>;
let device = Default::default();
let query =
Tensor::<AdBackend, 4>::from_data(deterministic_data([1, 2, 3, 4], 3, -0.35), &device)
.require_grad();
let key =
Tensor::<AdBackend, 4>::from_data(deterministic_data([1, 2, 5, 4], 5, -0.25), &device)
.require_grad();
let value =
Tensor::<AdBackend, 4>::from_data(deterministic_data([1, 2, 5, 4], 7, -0.15), &device)
.require_grad();
let selector =
Tensor::<AdBackend, 3>::from_data(deterministic_data([1, 3, 5], 11, -0.1), &device)
.require_grad();
let custom = try_dsa_attention_forward(query, key, value, selector.clone(), 2)
.expect("CUDA DSA custom backward path");
let grads = custom.sum().backward();
assert!(selector.grad(&grads).is_none());
}
#[cfg(feature = "wgpu-kernel")]
#[test]
fn wgpu_fused_kernel_matches_reference_block() {
let device = Default::default();
let config = tiny_config(DsaExecutor::Reference);
assert_eq!(
dsa_kernel_status::<DsaWgpuKernelBackend>(DsaExecutor::WgpuKernel),
"available"
);
let block = DsaBlock::<DsaWgpuKernelBackend>::new(8, config, &device).expect("DSA block");
let query = Tensor::<DsaWgpuKernelBackend, 3>::random(
[1, 3, 8],
Distribution::Normal(0.0, 1.0),
&device,
);
let context = Tensor::<DsaWgpuKernelBackend, 3>::random(
[1, 5, 8],
Distribution::Normal(0.0, 1.0),
&device,
);
let reference = block.forward_with_context(query.clone(), context.clone());
let fused = block.forward_with_context_fused_kernel(query, context);
let max_diff = reference.sub(fused).abs().max().into_scalar();
assert!(max_diff < 1.0e-3, "max diff {max_diff}");
}
#[cfg(feature = "wgpu-kernel")]
#[test]
fn wgpu_custom_backward_matches_reference_graph_projected_attention() {
type AdBackend = Autodiff<DsaWgpuKernelBackend>;
let device = Default::default();
assert_eq!(
dsa_kernel_status::<AdBackend>(DsaExecutor::WgpuKernel),
"available"
);
let query_data = deterministic_data([1, 2, 3, 4], 3, -0.35);
let key_data = deterministic_data([1, 2, 5, 4], 5, -0.25);
let value_data = deterministic_data([1, 2, 5, 4], 7, -0.15);
let selector_data = deterministic_data([1, 3, 5], 11, -0.1);
let topk = 2;
let query_custom =
Tensor::<AdBackend, 4>::from_data(query_data.clone(), &device).require_grad();
let key_custom =
Tensor::<AdBackend, 4>::from_data(key_data.clone(), &device).require_grad();
let value_custom =
Tensor::<AdBackend, 4>::from_data(value_data.clone(), &device).require_grad();
let selector_custom = Tensor::<AdBackend, 3>::from_data(selector_data.clone(), &device);
let custom = try_dsa_attention_forward(
query_custom.clone(),
key_custom.clone(),
value_custom.clone(),
selector_custom,
topk,
)
.expect("WGPU DSA custom backward path");
let custom_grads = custom.sum().backward();
let custom_query_grad = query_custom.grad(&custom_grads).expect("custom query grad");
let custom_key_grad = key_custom.grad(&custom_grads).expect("custom key grad");
let custom_value_grad = value_custom.grad(&custom_grads).expect("custom value grad");
let query_reference = Tensor::<AdBackend, 4>::from_data(query_data, &device).require_grad();
let key_reference = Tensor::<AdBackend, 4>::from_data(key_data, &device).require_grad();
let value_reference = Tensor::<AdBackend, 4>::from_data(value_data, &device).require_grad();
let selector_reference = Tensor::<AdBackend, 3>::from_data(selector_data, &device);
let reference = projected_reference_attention(
query_reference.clone(),
key_reference.clone(),
value_reference.clone(),
selector_reference,
topk,
);
let reference_grads = reference.sum().backward();
let reference_query_grad = query_reference
.grad(&reference_grads)
.expect("reference query grad");
let reference_key_grad = key_reference
.grad(&reference_grads)
.expect("reference key grad");
let reference_value_grad = value_reference
.grad(&reference_grads)
.expect("reference value grad");
let query_diff = custom_query_grad
.sub(reference_query_grad)
.abs()
.max()
.into_scalar();
let key_diff = custom_key_grad
.sub(reference_key_grad)
.abs()
.max()
.into_scalar();
let value_diff = custom_value_grad
.sub(reference_value_grad)
.abs()
.max()
.into_scalar();
assert!(query_diff < 2.0e-3, "query grad diff {query_diff}");
assert!(key_diff < 2.0e-3, "key grad diff {key_diff}");
assert!(value_diff < 2.0e-3, "value grad diff {value_diff}");
}
}