gllm-kernels
Low-level attention kernels for gllm with CUDA/ROCm support.
Features
- FlashAttention: Memory-efficient attention with O(N) memory complexity
- Hierarchical Attention: Multi-level attention for ultra-long contexts (2M+ tokens)
- CUDA Kernels: Native CUDA implementation with PTX for NVIDIA GPUs
- ROCm/HIP Kernels: AMD GPU support (experimental)
- Multiple Backends: CPU (ndarray), CUDA, WebGPU via Burn
Performance
| Implementation | Time (seq=512) | vs burn_cuda |
|---|---|---|
| cuda_kernel | 21.27ms | 37% faster |
| burn_cuda | 33.83ms | baseline |
Installation
Add to your Cargo.toml:
[]
= "0.1"
Feature Flags
| Feature | Description | Default |
|---|---|---|
cpu |
CPU backend via burn-ndarray | Yes |
cuda |
CUDA backend via burn-cuda | No |
cuda-kernel |
Native CUDA kernels (requires CUDA toolkit) | No |
wgpu |
WebGPU backend | No |
rocm-kernel |
ROCm/HIP kernels (experimental) | No |
Usage
Basic FlashAttention
use ;
// Create attention module
let attention = new;
// Forward pass
let output = attention.forward?;
CUDA Kernel (Native)
use FlashAttentionKernel;
use CudaContext;
use Arc;
let ctx = new;
let kernel = new?;
let output = kernel.forward?;
Deterministic Mode
For reproducible results in ultra-long context scenarios:
use DeterministicConfig;
let config = AttentionConfig ;
Architecture
gllm-kernels
├── ops/
│ ├── flash_attention.rs # HierarchicalFlashAttention
│ ├── flash_attention_v3.rs # Advanced attention variants
│ ├── paged_attention.rs # KV cache paging
│ ├── ring_attention.rs # Distributed attention
│ ├── sparse_attention.rs # Sparse patterns
│ ├── mla.rs # Multi-head Latent Attention
│ ├── mamba.rs # State space models
│ └── kv_compression.rs # KV cache compression
├── cuda_kernels/
│ ├── flash_attn.rs # CUDA kernel bindings
│ └── kernels/
│ ├── tiled_attention.cu # CUDA source
│ └── tiled_attention.ptx # Compiled PTX (sm_61)
├── hip_kernels/ # ROCm/HIP (experimental)
└── comm/ # Distributed communication
Building CUDA Kernels
If you need to recompile PTX for a different GPU architecture:
Replace sm_XX with your GPU's compute capability (e.g., sm_61 for GTX 1060, sm_86 for RTX 3090).
Or set the environment variable to use a custom PTX:
License
Apache-2.0