Skip to main content

mlx_native/ops/
qkv_split.rs

1//! GPU-accelerated split of a fused QKV tensor into separate Q/K/V outputs.
2//!
3//! Input layout (per token, contiguous f32):
4//!
5//! ```text
6//! qkv[t, :] = [ Q (q_sp) | K (k_sp) | V (v_sp) ]   (length = qkv_ch)
7//! ```
8//!
9//! Where `q_sp = n_k_heads * d_k`, `k_sp = n_k_heads * d_k`, and
10//! `v_sp = n_v_heads * d_v`. The kernel writes each input element to exactly
11//! one of `{q, k, v}` in a single dispatch — replacing the prior CPU
12//! download → triple-loop split → 3× upload round-trip used by the qwen35
13//! Gated DeltaNet prefill path.
14//!
15//! ADR-005 W-5b.18 (2026-04-27): targets the 838 ms / 17.5 ms-per-layer
16//! `layer.qkv_deinterleave` bucket in `hf2q::gpu_delta_net`.
17//!
18//! Production caller: `hf2q::inference::models::qwen35::gpu_delta_net::
19//! apply_proj` (prefill seq>1 branch).
20
21use metal::MTLSize;
22
23use crate::buffer::MlxBuffer;
24use crate::encoder::CommandEncoder;
25use crate::error::{MlxError, Result};
26use crate::kernel_registry::KernelRegistry;
27
28use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
29
30/// MSL source for the QKV-split kernel (embedded at compile time).
31pub static QKV_SPLIT_SHADER_SOURCE: &str = include_str!("../shaders/qkv_split.metal");
32
33/// Register the QKV-split shader source with the given kernel registry.
34///
35/// Idempotent — the source is also auto-registered by `KernelRegistry::new`,
36/// but this helper exists to mirror the convention used by other op modules
37/// (`copy::register`, `flash_attn_prefill::register`, ...).
38pub fn register(registry: &mut KernelRegistry) {
39    registry.register_source("qkv_split_f32", QKV_SPLIT_SHADER_SOURCE);
40}
41
42/// MSL-compatible params struct for the QKV split kernel.
43///
44/// Must match `QkvSplitParams` in `qkv_split.metal`.
45#[repr(C)]
46#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
47struct GpuQkvSplitParams {
48    seq: u32,
49    q_sp: u32,
50    k_sp: u32,
51    v_sp: u32,
52    qkv_ch: u32,
53}
54
55/// Parameters for a fused-QKV split operation.
56#[derive(Clone, Copy, Debug)]
57pub struct QkvSplitParams {
58    /// Number of tokens in the sequence dimension.
59    pub seq: u32,
60    /// Q span per token, in f32 elements (== `n_k_heads * d_k`).
61    pub q_sp: u32,
62    /// K span per token, in f32 elements (== `n_k_heads * d_k`).
63    pub k_sp: u32,
64    /// V span per token, in f32 elements (== `n_v_heads * d_v`).
65    pub v_sp: u32,
66}
67
68/// Dispatch a fused-QKV split on the GPU.
69///
70/// Splits a `[seq, q_sp + k_sp + v_sp]` f32 input into three contiguous
71/// outputs — `q [seq, q_sp]`, `k [seq, k_sp]`, `v [seq, v_sp]` — in a
72/// single dispatch, no compute, no host round-trip.
73///
74/// # Arguments
75///
76/// * `encoder`  - Command encoder to record the dispatch into.
77/// * `registry` - Kernel registry (`qkv_split_f32` is auto-registered).
78/// * `device`   - Metal device for pipeline compilation.
79/// * `qkv`      - Input fused-QKV buffer, f32, contiguous.
80/// * `q`        - Output Q buffer, f32, contiguous.
81/// * `k`        - Output K buffer, f32, contiguous.
82/// * `v`        - Output V buffer, f32, contiguous.
83/// * `params`   - Shape parameters.
84///
85/// # Errors
86///
87/// Returns `MlxError::InvalidArgument` if any dimension is zero or any
88/// buffer is too small for the declared shapes.
89#[allow(clippy::too_many_arguments)]
90pub fn dispatch_qkv_split_f32(
91    encoder: &mut CommandEncoder,
92    registry: &mut KernelRegistry,
93    device: &metal::DeviceRef,
94    qkv: &MlxBuffer,
95    q: &MlxBuffer,
96    k: &MlxBuffer,
97    v: &MlxBuffer,
98    params: &QkvSplitParams,
99) -> Result<()> {
100    if params.seq == 0 || params.q_sp == 0 || params.k_sp == 0 || params.v_sp == 0 {
101        return Err(MlxError::InvalidArgument(
102            "qkv_split_f32: seq, q_sp, k_sp, v_sp must all be > 0".into(),
103        ));
104    }
105
106    let qkv_ch = params
107        .q_sp
108        .checked_add(params.k_sp)
109        .and_then(|qk| qk.checked_add(params.v_sp))
110        .ok_or_else(|| {
111            MlxError::InvalidArgument(
112                "qkv_split_f32: q_sp + k_sp + v_sp overflows u32".into(),
113            )
114        })?;
115
116    // Buffer-size sanity checks (all in bytes; f32 = 4 B).
117    let in_bytes = (params.seq as usize) * (qkv_ch as usize) * 4;
118    if qkv.byte_len() < in_bytes {
119        return Err(MlxError::InvalidArgument(format!(
120            "qkv_split_f32: qkv buffer too small: need {} bytes, have {}",
121            in_bytes,
122            qkv.byte_len()
123        )));
124    }
125    let q_bytes = (params.seq as usize) * (params.q_sp as usize) * 4;
126    if q.byte_len() < q_bytes {
127        return Err(MlxError::InvalidArgument(format!(
128            "qkv_split_f32: q buffer too small: need {} bytes, have {}",
129            q_bytes,
130            q.byte_len()
131        )));
132    }
133    let k_bytes = (params.seq as usize) * (params.k_sp as usize) * 4;
134    if k.byte_len() < k_bytes {
135        return Err(MlxError::InvalidArgument(format!(
136            "qkv_split_f32: k buffer too small: need {} bytes, have {}",
137            k_bytes,
138            k.byte_len()
139        )));
140    }
141    let v_bytes = (params.seq as usize) * (params.v_sp as usize) * 4;
142    if v.byte_len() < v_bytes {
143        return Err(MlxError::InvalidArgument(format!(
144            "qkv_split_f32: v buffer too small: need {} bytes, have {}",
145            v_bytes,
146            v.byte_len()
147        )));
148    }
149
150    let pipeline = registry.get_pipeline("qkv_split_f32", device)?;
151
152    let gpu_params = GpuQkvSplitParams {
153        seq: params.seq,
154        q_sp: params.q_sp,
155        k_sp: params.k_sp,
156        v_sp: params.v_sp,
157        qkv_ch,
158    };
159
160    let grid = MTLSize::new(qkv_ch as u64, params.seq as u64, 1);
161    let tg_x = std::cmp::min(256u64, qkv_ch as u64);
162    let tg = MTLSize::new(tg_x, 1, 1);
163
164    encode_with_args(
165        encoder,
166        pipeline,
167        &[
168            (0, KernelArg::Buffer(qkv)),
169            (1, KernelArg::Buffer(q)),
170            (2, KernelArg::Buffer(k)),
171            (3, KernelArg::Buffer(v)),
172            (4, KernelArg::Bytes(as_bytes(&gpu_params))),
173        ],
174        grid,
175        tg,
176    );
177
178    Ok(())
179}