Skip to main content

attention_spirv

Function attention_spirv 

Source
pub fn attention_spirv(
    batch_heads: u32,
    seq_q: u32,
    seq_kv: u32,
    head_dim: u32,
    scale: f32,
    causal: bool,
) -> Vec<u32>
Expand description

Generate an OpenCL SPIR-V compute kernel for scaled dot-product attention.

Each work-item handles one (batch_head, query_position) pair.

Kernel parameters:

(CrossWorkgroup float* Q,
 CrossWorkgroup float* K,
 CrossWorkgroup float* V,
 CrossWorkgroup float* O)

Dimension constants are baked in as OpConstant.