1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
//! Variable-length (ragged/packed) attention traits
use crateResult;
use Runtime;
use Tensor;
/// Variable-length Flash Attention — packed sequences with cu_seqlens indexing
///
/// Eliminates padding waste by packing sequences of different lengths into
/// a single 1D buffer. 30-50% memory savings for variable-length batches.
///
/// Supports both MHA (`num_kv_heads == num_heads`) and GQA
/// (`num_kv_heads < num_heads`, where `num_heads % num_kv_heads == 0`).
///
/// # Layout contract
///
/// - `q`: `[total_tokens_q, num_heads, head_dim]` — packed queries
/// - `k`: `[total_tokens_k, num_kv_heads, head_dim]` — packed keys (GQA: fewer heads)
/// - `v`: `[total_tokens_k, num_kv_heads, head_dim]` — packed values (GQA: fewer heads)
/// - `cu_seqlens_q`: `[batch_size + 1]` — cumulative query sequence lengths (I32)
/// - `cu_seqlens_k`: `[batch_size + 1]` — cumulative key sequence lengths (I32)
/// - Output: `[total_tokens_q, num_heads, head_dim]`
/// - Logsumexp: `[total_tokens_q, num_heads]` (F32)
///
/// For MHA pass `num_kv_heads == num_heads`; K/V layout is then identical to
/// the old MHA-only contract.
///
/// # GQA key/value head mapping
///
/// `kv_head_idx = q_head_idx / (num_heads / num_kv_heads)`
///
/// # Cumulative sequence lengths
///
/// `cu_seqlens[0] = 0`, `cu_seqlens[i] = sum of lengths for sequences 0..i-1`.
/// For batch `[512, 300, 128]`: `cu_seqlens = [0, 512, 812, 940]`.