1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
//! GPU-accelerated split of a fused QKV tensor into separate Q/K/V outputs.
//!
//! Input layout (per token, contiguous f32):
//!
//! ```text
//! qkv[t, :] = [ Q (q_sp) | K (k_sp) | V (v_sp) ] (length = qkv_ch)
//! ```
//!
//! Where `q_sp = n_k_heads * d_k`, `k_sp = n_k_heads * d_k`, and
//! `v_sp = n_v_heads * d_v`. The kernel writes each input element to exactly
//! one of `{q, k, v}` in a single dispatch — replacing the prior CPU
//! download → triple-loop split → 3× upload round-trip used by the qwen35
//! Gated DeltaNet prefill path.
//!
//! ADR-005 W-5b.18 (2026-04-27): targets the 838 ms / 17.5 ms-per-layer
//! `layer.qkv_deinterleave` bucket in `hf2q::gpu_delta_net`.
//!
//! Production caller: `hf2q::inference::models::qwen35::gpu_delta_net::
//! apply_proj` (prefill seq>1 branch).
use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
/// MSL source for the QKV-split kernel (embedded at compile time).
pub static QKV_SPLIT_SHADER_SOURCE: &str = include_str!("../shaders/qkv_split.metal");
/// Register the QKV-split shader source with the given kernel registry.
///
/// Idempotent — the source is also auto-registered by `KernelRegistry::new`,
/// but this helper exists to mirror the convention used by other op modules
/// (`copy::register`, `flash_attn_prefill::register`, ...).
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("qkv_split_f32", QKV_SPLIT_SHADER_SOURCE);
}
/// MSL-compatible params struct for the QKV split kernel.
///
/// Must match `QkvSplitParams` in `qkv_split.metal`.
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuQkvSplitParams {
seq: u32,
q_sp: u32,
k_sp: u32,
v_sp: u32,
qkv_ch: u32,
}
/// Parameters for a fused-QKV split operation.
#[derive(Clone, Copy, Debug)]
pub struct QkvSplitParams {
/// Number of tokens in the sequence dimension.
pub seq: u32,
/// Q span per token, in f32 elements (== `n_k_heads * d_k`).
pub q_sp: u32,
/// K span per token, in f32 elements (== `n_k_heads * d_k`).
pub k_sp: u32,
/// V span per token, in f32 elements (== `n_v_heads * d_v`).
pub v_sp: u32,
}
/// Dispatch a fused-QKV split on the GPU.
///
/// Splits a `[seq, q_sp + k_sp + v_sp]` f32 input into three contiguous
/// outputs — `q [seq, q_sp]`, `k [seq, k_sp]`, `v [seq, v_sp]` — in a
/// single dispatch, no compute, no host round-trip.
///
/// # Arguments
///
/// * `encoder` - Command encoder to record the dispatch into.
/// * `registry` - Kernel registry (`qkv_split_f32` is auto-registered).
/// * `device` - Metal device for pipeline compilation.
/// * `qkv` - Input fused-QKV buffer, f32, contiguous.
/// * `q` - Output Q buffer, f32, contiguous.
/// * `k` - Output K buffer, f32, contiguous.
/// * `v` - Output V buffer, f32, contiguous.
/// * `params` - Shape parameters.
///
/// # Errors
///
/// Returns `MlxError::InvalidArgument` if any dimension is zero or any
/// buffer is too small for the declared shapes.
#[allow(clippy::too_many_arguments)]
pub fn dispatch_qkv_split_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
qkv: &MlxBuffer,
q: &MlxBuffer,
k: &MlxBuffer,
v: &MlxBuffer,
params: &QkvSplitParams,
) -> Result<()> {
if params.seq == 0 || params.q_sp == 0 || params.k_sp == 0 || params.v_sp == 0 {
return Err(MlxError::InvalidArgument(
"qkv_split_f32: seq, q_sp, k_sp, v_sp must all be > 0".into(),
));
}
let qkv_ch = params
.q_sp
.checked_add(params.k_sp)
.and_then(|qk| qk.checked_add(params.v_sp))
.ok_or_else(|| {
MlxError::InvalidArgument(
"qkv_split_f32: q_sp + k_sp + v_sp overflows u32".into(),
)
})?;
// Buffer-size sanity checks (all in bytes; f32 = 4 B).
let in_bytes = (params.seq as usize) * (qkv_ch as usize) * 4;
if qkv.byte_len() < in_bytes {
return Err(MlxError::InvalidArgument(format!(
"qkv_split_f32: qkv buffer too small: need {} bytes, have {}",
in_bytes,
qkv.byte_len()
)));
}
let q_bytes = (params.seq as usize) * (params.q_sp as usize) * 4;
if q.byte_len() < q_bytes {
return Err(MlxError::InvalidArgument(format!(
"qkv_split_f32: q buffer too small: need {} bytes, have {}",
q_bytes,
q.byte_len()
)));
}
let k_bytes = (params.seq as usize) * (params.k_sp as usize) * 4;
if k.byte_len() < k_bytes {
return Err(MlxError::InvalidArgument(format!(
"qkv_split_f32: k buffer too small: need {} bytes, have {}",
k_bytes,
k.byte_len()
)));
}
let v_bytes = (params.seq as usize) * (params.v_sp as usize) * 4;
if v.byte_len() < v_bytes {
return Err(MlxError::InvalidArgument(format!(
"qkv_split_f32: v buffer too small: need {} bytes, have {}",
v_bytes,
v.byte_len()
)));
}
let pipeline = registry.get_pipeline("qkv_split_f32", device)?;
let gpu_params = GpuQkvSplitParams {
seq: params.seq,
q_sp: params.q_sp,
k_sp: params.k_sp,
v_sp: params.v_sp,
qkv_ch,
};
let grid = MTLSize::new(qkv_ch as u64, params.seq as u64, 1);
let tg_x = std::cmp::min(256u64, qkv_ch as u64);
let tg = MTLSize::new(tg_x, 1, 1);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(qkv)),
(1, KernelArg::Buffer(q)),
(2, KernelArg::Buffer(k)),
(3, KernelArg::Buffer(v)),
(4, KernelArg::Bytes(as_bytes(&gpu_params))),
],
grid,
tg,
);
Ok(())
}