Skip to main content

mlx_native/ops/
gather_bench.rs

1//! Gather throughput microbenchmark dispatch.
2//!
3//! Provides two kernels for measuring KV cache read throughput:
4//!
5//! * `gather_bench_nibble`   — Simulates TurboQuant SDPA: unpack 4-bit nibble
6//!   indices then gather from a 16-entry centroid table.
7//! * `gather_bench_f16_seq`  — Baseline: sequential F16 read + widen to F32.
8//!
9//! The throughput ratio between the two kernels determines whether nibble-gather
10//! meets the ADR-007 gate of ≥ 50% of sequential F16 throughput.
11
12use metal::MTLSize;
13
14use crate::buffer::MlxBuffer;
15use crate::encoder::CommandEncoder;
16use crate::error::{MlxError, Result};
17use crate::kernel_registry::KernelRegistry;
18
19use super::encode_helpers::{encode_with_args, KernelArg};
20
21/// MSL source for the gather benchmark kernels (embedded at compile time).
22pub static GATHER_BENCH_SHADER_SOURCE: &str = include_str!("../shaders/gather_bench.metal");
23
24/// Register both gather benchmark kernels with the given registry.
25///
26/// Each Metal kernel function name must be registered individually so that
27/// `KernelRegistry::get_pipeline` can look up the source by function name.
28pub fn register(registry: &mut KernelRegistry) {
29    registry.register_source("gather_bench_nibble", GATHER_BENCH_SHADER_SOURCE);
30    registry.register_source("gather_bench_f16_seq", GATHER_BENCH_SHADER_SOURCE);
31}
32
33/// Dispatch the nibble-gather kernel.
34///
35/// Reads nibble-packed 4-bit indices and gathers from a centroid table,
36/// simulating the TurboQuant SDPA KV-cache read path.
37///
38/// # Arguments
39///
40/// * `encoder`   — Command encoder to record the dispatch into.
41/// * `registry`  — Kernel registry (must have gather_bench registered).
42/// * `device`    — Metal device for pipeline compilation.
43/// * `packed`    — Nibble-packed index buffer `[capacity × head_dim/2]` (u8).
44/// * `centroids` — Centroid table buffer `[16 × head_dim]` (f32).
45/// * `out`       — Output buffer `[capacity × head_dim]` (f32).
46/// * `capacity`  — Number of token positions.
47/// * `head_dim`  — Head dimension (must be even).
48///
49/// # Errors
50///
51/// Returns `MlxError::InvalidArgument` if `head_dim` is odd or parameters are
52/// inconsistent with buffer sizes.
53#[allow(clippy::too_many_arguments)]
54pub fn dispatch_gather_nibble(
55    encoder: &mut CommandEncoder,
56    registry: &mut KernelRegistry,
57    device: &metal::DeviceRef,
58    packed: &MlxBuffer,
59    centroids: &MlxBuffer,
60    out: &MlxBuffer,
61    capacity: u32,
62    head_dim: u32,
63) -> Result<()> {
64    if capacity == 0 || head_dim == 0 {
65        return Ok(());
66    }
67    if head_dim % 2 != 0 {
68        return Err(MlxError::InvalidArgument(format!(
69            "gather_bench_nibble: head_dim must be even, got {}",
70            head_dim
71        )));
72    }
73
74    let pipeline = registry.get_pipeline("gather_bench_nibble", device)?;
75
76    let capacity_bytes = capacity.to_ne_bytes();
77    let head_dim_bytes = head_dim.to_ne_bytes();
78
79    // Grid: x covers head_dim coordinates, y covers capacity positions.
80    // Threadgroup x = min(256, head_dim), y = 1.
81    let tg_x = std::cmp::min(256, head_dim as u64);
82    encode_with_args(
83        encoder,
84        pipeline,
85        &[
86            (0, KernelArg::Buffer(packed)),
87            (1, KernelArg::Buffer(centroids)),
88            (2, KernelArg::Bytes(&capacity_bytes)),
89            (3, KernelArg::Bytes(&head_dim_bytes)),
90            (4, KernelArg::Buffer(out)),
91        ],
92        MTLSize::new(head_dim as u64, capacity as u64, 1),
93        MTLSize::new(tg_x, 1, 1),
94    );
95
96    Ok(())
97}
98
99/// Dispatch the sequential F16 read kernel.
100///
101/// Reads every F16 element of the KV cache and widens it to F32, providing a
102/// throughput baseline against which `gather_bench_nibble` is compared.
103///
104/// # Arguments
105///
106/// * `encoder`   — Command encoder to record the dispatch into.
107/// * `registry`  — Kernel registry (must have gather_bench registered).
108/// * `device`    — Metal device for pipeline compilation.
109/// * `cache`     — F16 KV cache buffer `[capacity × head_dim]` (half / u16).
110/// * `out`       — Output buffer `[capacity × head_dim]` (f32).
111/// * `capacity`  — Number of token positions.
112/// * `head_dim`  — Head dimension.
113///
114/// # Errors
115///
116/// Returns `MlxError::InvalidArgument` if parameters are inconsistent.
117pub fn dispatch_gather_f16_seq(
118    encoder: &mut CommandEncoder,
119    registry: &mut KernelRegistry,
120    device: &metal::DeviceRef,
121    cache: &MlxBuffer,
122    out: &MlxBuffer,
123    capacity: u32,
124    head_dim: u32,
125) -> Result<()> {
126    if capacity == 0 || head_dim == 0 {
127        return Ok(());
128    }
129
130    let pipeline = registry.get_pipeline("gather_bench_f16_seq", device)?;
131
132    let capacity_bytes = capacity.to_ne_bytes();
133    let head_dim_bytes = head_dim.to_ne_bytes();
134
135    let tg_x = std::cmp::min(256, head_dim as u64);
136    encode_with_args(
137        encoder,
138        pipeline,
139        &[
140            (0, KernelArg::Buffer(cache)),
141            (1, KernelArg::Bytes(&capacity_bytes)),
142            (2, KernelArg::Bytes(&head_dim_bytes)),
143            (3, KernelArg::Buffer(out)),
144        ],
145        MTLSize::new(head_dim as u64, capacity as u64, 1),
146        MTLSize::new(tg_x, 1, 1),
147    );
148
149    Ok(())
150}