1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
//! Gather throughput microbenchmark dispatch.
//!
//! Provides two kernels for measuring KV cache read throughput:
//!
//! * `gather_bench_nibble` — Simulates TurboQuant SDPA: unpack 4-bit nibble
//! indices then gather from a 16-entry centroid table.
//! * `gather_bench_f16_seq` — Baseline: sequential F16 read + widen to F32.
//!
//! The throughput ratio between the two kernels determines whether nibble-gather
//! meets the ADR-007 gate of ≥ 50% of sequential F16 throughput.
use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{encode_with_args, KernelArg};
/// MSL source for the gather benchmark kernels (embedded at compile time).
pub static GATHER_BENCH_SHADER_SOURCE: &str = include_str!("../shaders/gather_bench.metal");
/// Register both gather benchmark kernels with the given registry.
///
/// Each Metal kernel function name must be registered individually so that
/// `KernelRegistry::get_pipeline` can look up the source by function name.
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("gather_bench_nibble", GATHER_BENCH_SHADER_SOURCE);
registry.register_source("gather_bench_f16_seq", GATHER_BENCH_SHADER_SOURCE);
}
/// Dispatch the nibble-gather kernel.
///
/// Reads nibble-packed 4-bit indices and gathers from a centroid table,
/// simulating the TurboQuant SDPA KV-cache read path.
///
/// # Arguments
///
/// * `encoder` — Command encoder to record the dispatch into.
/// * `registry` — Kernel registry (must have gather_bench registered).
/// * `device` — Metal device for pipeline compilation.
/// * `packed` — Nibble-packed index buffer `[capacity × head_dim/2]` (u8).
/// * `centroids` — Centroid table buffer `[16 × head_dim]` (f32).
/// * `out` — Output buffer `[capacity × head_dim]` (f32).
/// * `capacity` — Number of token positions.
/// * `head_dim` — Head dimension (must be even).
///
/// # Errors
///
/// Returns `MlxError::InvalidArgument` if `head_dim` is odd or parameters are
/// inconsistent with buffer sizes.
#[allow(clippy::too_many_arguments)]
pub fn dispatch_gather_nibble(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
packed: &MlxBuffer,
centroids: &MlxBuffer,
out: &MlxBuffer,
capacity: u32,
head_dim: u32,
) -> Result<()> {
if capacity == 0 || head_dim == 0 {
return Ok(());
}
if head_dim % 2 != 0 {
return Err(MlxError::InvalidArgument(format!(
"gather_bench_nibble: head_dim must be even, got {}",
head_dim
)));
}
let pipeline = registry.get_pipeline("gather_bench_nibble", device)?;
let capacity_bytes = capacity.to_ne_bytes();
let head_dim_bytes = head_dim.to_ne_bytes();
// Grid: x covers head_dim coordinates, y covers capacity positions.
// Threadgroup x = min(256, head_dim), y = 1.
let tg_x = std::cmp::min(256, head_dim as u64);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(packed)),
(1, KernelArg::Buffer(centroids)),
(2, KernelArg::Bytes(&capacity_bytes)),
(3, KernelArg::Bytes(&head_dim_bytes)),
(4, KernelArg::Buffer(out)),
],
MTLSize::new(head_dim as u64, capacity as u64, 1),
MTLSize::new(tg_x, 1, 1),
);
Ok(())
}
/// Dispatch the sequential F16 read kernel.
///
/// Reads every F16 element of the KV cache and widens it to F32, providing a
/// throughput baseline against which `gather_bench_nibble` is compared.
///
/// # Arguments
///
/// * `encoder` — Command encoder to record the dispatch into.
/// * `registry` — Kernel registry (must have gather_bench registered).
/// * `device` — Metal device for pipeline compilation.
/// * `cache` — F16 KV cache buffer `[capacity × head_dim]` (half / u16).
/// * `out` — Output buffer `[capacity × head_dim]` (f32).
/// * `capacity` — Number of token positions.
/// * `head_dim` — Head dimension.
///
/// # Errors
///
/// Returns `MlxError::InvalidArgument` if parameters are inconsistent.
pub fn dispatch_gather_f16_seq(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
cache: &MlxBuffer,
out: &MlxBuffer,
capacity: u32,
head_dim: u32,
) -> Result<()> {
if capacity == 0 || head_dim == 0 {
return Ok(());
}
let pipeline = registry.get_pipeline("gather_bench_f16_seq", device)?;
let capacity_bytes = capacity.to_ne_bytes();
let head_dim_bytes = head_dim.to_ne_bytes();
let tg_x = std::cmp::min(256, head_dim as u64);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(cache)),
(1, KernelArg::Bytes(&capacity_bytes)),
(2, KernelArg::Bytes(&head_dim_bytes)),
(3, KernelArg::Buffer(out)),
],
MTLSize::new(head_dim as u64, capacity as u64, 1),
MTLSize::new(tg_x, 1, 1),
);
Ok(())
}