mlx_native/ops/repeat_tiled.rs
1//! GPU-accelerated tiled-GQA broadcast: `[T, Hg, K]` → `[T, H, K]` F32.
2//!
3//! Replaces the hf2q-side CPU triple-loop tiled-replicate at
4//! `gpu_delta_net.rs:893-940` (`q_expanded` / `k_expanded` fill,
5//! ~497 ms / 10.4 ms-per-layer at PP4106 per the W-5b.17 audit).
6//!
7//! Mapping:
8//!
9//! ```text
10//! dst[t, h, k] = src[t, h % Hg, k]
11//! ```
12//!
13//! Where `Hg = n_k_heads`, `H = n_v_heads`, `K = head_dim`. The "tiled"
14//! variant matches Qwen3.6 GGUF tensor layout (per
15//! `project_qwen36_gqa_tiled_vs_block` and `gpu_delta_net.rs:834-866`),
16//! and is the same convention as llama.cpp's `ggml_repeat_4d` graph op.
17//!
18//! ADR-005 W-5b.19 (2026-04-27): single-dispatch GPU broadcast eliminates
19//! the chunk-wrapper's CPU memcpy bucket. Production caller:
20//! `hf2q::inference::models::qwen35::gpu_delta_net::apply_gated_delta_net_chunk`
21//! (chunk-prefill GQA pre-expansion).
22
23use metal::MTLSize;
24
25use crate::buffer::MlxBuffer;
26use crate::encoder::CommandEncoder;
27use crate::error::{MlxError, Result};
28use crate::kernel_registry::KernelRegistry;
29
30use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
31
32/// MSL source for the tiled-repeat kernel (embedded at compile time).
33pub static REPEAT_TILED_SHADER_SOURCE: &str =
34 include_str!("../shaders/repeat_tiled.metal");
35
36/// Register the repeat-tiled shader source with the given kernel registry.
37///
38/// Idempotent — the source is also auto-registered by `KernelRegistry::new`,
39/// but this helper exists to mirror the convention used by other op modules
40/// (`copy::register`, `flash_attn_prefill::register`, ...).
41pub fn register(registry: &mut KernelRegistry) {
42 registry.register_source("repeat_tiled_f32", REPEAT_TILED_SHADER_SOURCE);
43}
44
45/// MSL-compatible params struct. Must match `RepeatTiledParams` in
46/// `repeat_tiled.metal`.
47#[repr(C)]
48#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
49struct GpuRepeatTiledParams {
50 seq: u32,
51 hg: u32,
52 h: u32,
53 k: u32,
54}
55
56/// Parameters for a tiled-GQA broadcast operation.
57#[derive(Clone, Copy, Debug)]
58pub struct RepeatTiledParams {
59 /// Number of tokens (T).
60 pub seq: u32,
61 /// Source head count (Hg = n_k_heads).
62 pub hg: u32,
63 /// Destination head count (H = n_v_heads). Must satisfy `H % Hg == 0`.
64 pub h: u32,
65 /// Per-head element count (K = head_dim).
66 pub k: u32,
67}
68
69/// Dispatch a tiled-GQA broadcast on the GPU.
70///
71/// Expands a `[seq, hg, k]` f32 input to a `[seq, h, k]` f32 output via
72/// `dst[t, h, k] = src[t, h % hg, k]` in a single dispatch — no compute,
73/// no host round-trip.
74///
75/// # Arguments
76///
77/// * `encoder` - Command encoder to record the dispatch into.
78/// * `registry` - Kernel registry (`repeat_tiled_f32` is auto-registered).
79/// * `device` - Metal device for pipeline compilation.
80/// * `src` - Input buffer, f32, contiguous, ≥ `seq*hg*k` elements.
81/// * `dst` - Output buffer, f32, contiguous, ≥ `seq*h*k` elements.
82/// * `params` - Shape parameters.
83///
84/// # Errors
85///
86/// Returns `MlxError::InvalidArgument` if any dimension is zero, if
87/// `h % hg != 0`, or if either buffer is too small for the declared shapes.
88pub fn dispatch_repeat_tiled_f32(
89 encoder: &mut CommandEncoder,
90 registry: &mut KernelRegistry,
91 device: &metal::DeviceRef,
92 src: &MlxBuffer,
93 dst: &MlxBuffer,
94 params: &RepeatTiledParams,
95) -> Result<()> {
96 if params.seq == 0 || params.hg == 0 || params.h == 0 || params.k == 0 {
97 return Err(MlxError::InvalidArgument(
98 "repeat_tiled_f32: seq, hg, h, k must all be > 0".into(),
99 ));
100 }
101 if params.h % params.hg != 0 {
102 return Err(MlxError::InvalidArgument(format!(
103 "repeat_tiled_f32: h ({}) must be a multiple of hg ({})",
104 params.h, params.hg
105 )));
106 }
107
108 // Buffer-size sanity checks (in bytes; f32 = 4 B).
109 let src_elems = (params.seq as usize)
110 .checked_mul(params.hg as usize)
111 .and_then(|v| v.checked_mul(params.k as usize))
112 .ok_or_else(|| {
113 MlxError::InvalidArgument(
114 "repeat_tiled_f32: seq*hg*k overflows usize".into(),
115 )
116 })?;
117 let dst_elems = (params.seq as usize)
118 .checked_mul(params.h as usize)
119 .and_then(|v| v.checked_mul(params.k as usize))
120 .ok_or_else(|| {
121 MlxError::InvalidArgument(
122 "repeat_tiled_f32: seq*h*k overflows usize".into(),
123 )
124 })?;
125
126 let src_bytes = src_elems * 4;
127 if src.byte_len() < src_bytes {
128 return Err(MlxError::InvalidArgument(format!(
129 "repeat_tiled_f32: src buffer too small: need {} bytes, have {}",
130 src_bytes,
131 src.byte_len()
132 )));
133 }
134 let dst_bytes = dst_elems * 4;
135 if dst.byte_len() < dst_bytes {
136 return Err(MlxError::InvalidArgument(format!(
137 "repeat_tiled_f32: dst buffer too small: need {} bytes, have {}",
138 dst_bytes,
139 dst.byte_len()
140 )));
141 }
142
143 let pipeline = registry.get_pipeline("repeat_tiled_f32", device)?;
144
145 let gpu_params = GpuRepeatTiledParams {
146 seq: params.seq,
147 hg: params.hg,
148 h: params.h,
149 k: params.k,
150 };
151
152 // Grid: (K, H, T) — one thread per output element. Threadgroup width
153 // along K dimension (innermost / contiguous in dst write) up to 256.
154 let grid = MTLSize::new(params.k as u64, params.h as u64, params.seq as u64);
155 let tg_x = std::cmp::min(256u64, params.k as u64);
156 let tg = MTLSize::new(tg_x, 1, 1);
157
158 encode_with_args(
159 encoder,
160 pipeline,
161 &[
162 (0, KernelArg::Buffer(src)),
163 (1, KernelArg::Buffer(dst)),
164 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
165 ],
166 grid,
167 tg,
168 );
169
170 Ok(())
171}