mlx_native/ops/gather_bench.rs
1//! Gather throughput microbenchmark dispatch.
2//!
3//! Provides two kernels for measuring KV cache read throughput:
4//!
5//! * `gather_bench_nibble` — Simulates TurboQuant SDPA: unpack 4-bit nibble
6//! indices then gather from a 16-entry centroid table.
7//! * `gather_bench_f16_seq` — Baseline: sequential F16 read + widen to F32.
8//!
9//! The throughput ratio between the two kernels determines whether nibble-gather
10//! meets the ADR-007 gate of ≥ 50% of sequential F16 throughput.
11
12use metal::MTLSize;
13
14use crate::buffer::MlxBuffer;
15use crate::encoder::CommandEncoder;
16use crate::error::{MlxError, Result};
17use crate::kernel_registry::KernelRegistry;
18
19use super::encode_helpers::{encode_with_args, KernelArg};
20
21/// MSL source for the gather benchmark kernels (embedded at compile time).
22pub static GATHER_BENCH_SHADER_SOURCE: &str = include_str!("../shaders/gather_bench.metal");
23
24/// Register both gather benchmark kernels with the given registry.
25///
26/// Each Metal kernel function name must be registered individually so that
27/// `KernelRegistry::get_pipeline` can look up the source by function name.
28pub fn register(registry: &mut KernelRegistry) {
29 registry.register_source("gather_bench_nibble", GATHER_BENCH_SHADER_SOURCE);
30 registry.register_source("gather_bench_f16_seq", GATHER_BENCH_SHADER_SOURCE);
31}
32
33/// Dispatch the nibble-gather kernel.
34///
35/// Reads nibble-packed 4-bit indices and gathers from a centroid table,
36/// simulating the TurboQuant SDPA KV-cache read path.
37///
38/// # Arguments
39///
40/// * `encoder` — Command encoder to record the dispatch into.
41/// * `registry` — Kernel registry (must have gather_bench registered).
42/// * `device` — Metal device for pipeline compilation.
43/// * `packed` — Nibble-packed index buffer `[capacity × head_dim/2]` (u8).
44/// * `centroids` — Centroid table buffer `[16 × head_dim]` (f32).
45/// * `out` — Output buffer `[capacity × head_dim]` (f32).
46/// * `capacity` — Number of token positions.
47/// * `head_dim` — Head dimension (must be even).
48///
49/// # Errors
50///
51/// Returns `MlxError::InvalidArgument` if `head_dim` is odd or parameters are
52/// inconsistent with buffer sizes.
53#[allow(clippy::too_many_arguments)]
54pub fn dispatch_gather_nibble(
55 encoder: &mut CommandEncoder,
56 registry: &mut KernelRegistry,
57 device: &metal::DeviceRef,
58 packed: &MlxBuffer,
59 centroids: &MlxBuffer,
60 out: &MlxBuffer,
61 capacity: u32,
62 head_dim: u32,
63) -> Result<()> {
64 if capacity == 0 || head_dim == 0 {
65 return Ok(());
66 }
67 if head_dim % 2 != 0 {
68 return Err(MlxError::InvalidArgument(format!(
69 "gather_bench_nibble: head_dim must be even, got {}",
70 head_dim
71 )));
72 }
73
74 let pipeline = registry.get_pipeline("gather_bench_nibble", device)?;
75
76 let capacity_bytes = capacity.to_ne_bytes();
77 let head_dim_bytes = head_dim.to_ne_bytes();
78
79 // Grid: x covers head_dim coordinates, y covers capacity positions.
80 // Threadgroup x = min(256, head_dim), y = 1.
81 let tg_x = std::cmp::min(256, head_dim as u64);
82 encode_with_args(
83 encoder,
84 pipeline,
85 &[
86 (0, KernelArg::Buffer(packed)),
87 (1, KernelArg::Buffer(centroids)),
88 (2, KernelArg::Bytes(&capacity_bytes)),
89 (3, KernelArg::Bytes(&head_dim_bytes)),
90 (4, KernelArg::Buffer(out)),
91 ],
92 MTLSize::new(head_dim as u64, capacity as u64, 1),
93 MTLSize::new(tg_x, 1, 1),
94 );
95
96 Ok(())
97}
98
99/// Dispatch the sequential F16 read kernel.
100///
101/// Reads every F16 element of the KV cache and widens it to F32, providing a
102/// throughput baseline against which `gather_bench_nibble` is compared.
103///
104/// # Arguments
105///
106/// * `encoder` — Command encoder to record the dispatch into.
107/// * `registry` — Kernel registry (must have gather_bench registered).
108/// * `device` — Metal device for pipeline compilation.
109/// * `cache` — F16 KV cache buffer `[capacity × head_dim]` (half / u16).
110/// * `out` — Output buffer `[capacity × head_dim]` (f32).
111/// * `capacity` — Number of token positions.
112/// * `head_dim` — Head dimension.
113///
114/// # Errors
115///
116/// Returns `MlxError::InvalidArgument` if parameters are inconsistent.
117pub fn dispatch_gather_f16_seq(
118 encoder: &mut CommandEncoder,
119 registry: &mut KernelRegistry,
120 device: &metal::DeviceRef,
121 cache: &MlxBuffer,
122 out: &MlxBuffer,
123 capacity: u32,
124 head_dim: u32,
125) -> Result<()> {
126 if capacity == 0 || head_dim == 0 {
127 return Ok(());
128 }
129
130 let pipeline = registry.get_pipeline("gather_bench_f16_seq", device)?;
131
132 let capacity_bytes = capacity.to_ne_bytes();
133 let head_dim_bytes = head_dim.to_ne_bytes();
134
135 let tg_x = std::cmp::min(256, head_dim as u64);
136 encode_with_args(
137 encoder,
138 pipeline,
139 &[
140 (0, KernelArg::Buffer(cache)),
141 (1, KernelArg::Bytes(&capacity_bytes)),
142 (2, KernelArg::Bytes(&head_dim_bytes)),
143 (3, KernelArg::Buffer(out)),
144 ],
145 MTLSize::new(head_dim as u64, capacity as u64, 1),
146 MTLSize::new(tg_x, 1, 1),
147 );
148
149 Ok(())
150}