Skip to main content

mlx_native/ops/
repeat_tiled.rs

1//! GPU-accelerated tiled-GQA broadcast: `[T, Hg, K]` → `[T, H, K]` F32.
2//!
3//! Replaces the hf2q-side CPU triple-loop tiled-replicate at
4//! `gpu_delta_net.rs:893-940` (`q_expanded` / `k_expanded` fill,
5//! ~497 ms / 10.4 ms-per-layer at PP4106 per the W-5b.17 audit).
6//!
7//! Mapping:
8//!
9//! ```text
10//! dst[t, h, k] = src[t, h % Hg, k]
11//! ```
12//!
13//! Where `Hg = n_k_heads`, `H = n_v_heads`, `K = head_dim`. The "tiled"
14//! variant matches Qwen3.6 GGUF tensor layout (per
15//! `project_qwen36_gqa_tiled_vs_block` and `gpu_delta_net.rs:834-866`),
16//! and is the same convention as llama.cpp's `ggml_repeat_4d` graph op.
17//!
18//! ADR-005 W-5b.19 (2026-04-27): single-dispatch GPU broadcast eliminates
19//! the chunk-wrapper's CPU memcpy bucket. Production caller:
20//! `hf2q::inference::models::qwen35::gpu_delta_net::apply_gated_delta_net_chunk`
21//! (chunk-prefill GQA pre-expansion).
22
23use metal::MTLSize;
24
25use crate::buffer::MlxBuffer;
26use crate::encoder::CommandEncoder;
27use crate::error::{MlxError, Result};
28use crate::kernel_registry::KernelRegistry;
29
30use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
31
32/// MSL source for the tiled-repeat kernel (embedded at compile time).
33pub static REPEAT_TILED_SHADER_SOURCE: &str =
34    include_str!("../shaders/repeat_tiled.metal");
35
36/// Register the repeat-tiled shader source with the given kernel registry.
37///
38/// Idempotent — the source is also auto-registered by `KernelRegistry::new`,
39/// but this helper exists to mirror the convention used by other op modules
40/// (`copy::register`, `flash_attn_prefill::register`, ...).
41pub fn register(registry: &mut KernelRegistry) {
42    registry.register_source("repeat_tiled_f32", REPEAT_TILED_SHADER_SOURCE);
43}
44
45/// MSL-compatible params struct. Must match `RepeatTiledParams` in
46/// `repeat_tiled.metal`.
47#[repr(C)]
48#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
49struct GpuRepeatTiledParams {
50    seq: u32,
51    hg: u32,
52    h: u32,
53    k: u32,
54}
55
56/// Parameters for a tiled-GQA broadcast operation.
57#[derive(Clone, Copy, Debug)]
58pub struct RepeatTiledParams {
59    /// Number of tokens (T).
60    pub seq: u32,
61    /// Source head count (Hg = n_k_heads).
62    pub hg: u32,
63    /// Destination head count (H = n_v_heads). Must satisfy `H % Hg == 0`.
64    pub h: u32,
65    /// Per-head element count (K = head_dim).
66    pub k: u32,
67}
68
69/// Dispatch a tiled-GQA broadcast on the GPU.
70///
71/// Expands a `[seq, hg, k]` f32 input to a `[seq, h, k]` f32 output via
72/// `dst[t, h, k] = src[t, h % hg, k]` in a single dispatch — no compute,
73/// no host round-trip.
74///
75/// # Arguments
76///
77/// * `encoder`  - Command encoder to record the dispatch into.
78/// * `registry` - Kernel registry (`repeat_tiled_f32` is auto-registered).
79/// * `device`   - Metal device for pipeline compilation.
80/// * `src`      - Input buffer, f32, contiguous, ≥ `seq*hg*k` elements.
81/// * `dst`      - Output buffer, f32, contiguous, ≥ `seq*h*k` elements.
82/// * `params`   - Shape parameters.
83///
84/// # Errors
85///
86/// Returns `MlxError::InvalidArgument` if any dimension is zero, if
87/// `h % hg != 0`, or if either buffer is too small for the declared shapes.
88pub fn dispatch_repeat_tiled_f32(
89    encoder: &mut CommandEncoder,
90    registry: &mut KernelRegistry,
91    device: &metal::DeviceRef,
92    src: &MlxBuffer,
93    dst: &MlxBuffer,
94    params: &RepeatTiledParams,
95) -> Result<()> {
96    if params.seq == 0 || params.hg == 0 || params.h == 0 || params.k == 0 {
97        return Err(MlxError::InvalidArgument(
98            "repeat_tiled_f32: seq, hg, h, k must all be > 0".into(),
99        ));
100    }
101    if params.h % params.hg != 0 {
102        return Err(MlxError::InvalidArgument(format!(
103            "repeat_tiled_f32: h ({}) must be a multiple of hg ({})",
104            params.h, params.hg
105        )));
106    }
107
108    // Buffer-size sanity checks (in bytes; f32 = 4 B).
109    let src_elems = (params.seq as usize)
110        .checked_mul(params.hg as usize)
111        .and_then(|v| v.checked_mul(params.k as usize))
112        .ok_or_else(|| {
113            MlxError::InvalidArgument(
114                "repeat_tiled_f32: seq*hg*k overflows usize".into(),
115            )
116        })?;
117    let dst_elems = (params.seq as usize)
118        .checked_mul(params.h as usize)
119        .and_then(|v| v.checked_mul(params.k as usize))
120        .ok_or_else(|| {
121            MlxError::InvalidArgument(
122                "repeat_tiled_f32: seq*h*k overflows usize".into(),
123            )
124        })?;
125
126    let src_bytes = src_elems * 4;
127    if src.byte_len() < src_bytes {
128        return Err(MlxError::InvalidArgument(format!(
129            "repeat_tiled_f32: src buffer too small: need {} bytes, have {}",
130            src_bytes,
131            src.byte_len()
132        )));
133    }
134    let dst_bytes = dst_elems * 4;
135    if dst.byte_len() < dst_bytes {
136        return Err(MlxError::InvalidArgument(format!(
137            "repeat_tiled_f32: dst buffer too small: need {} bytes, have {}",
138            dst_bytes,
139            dst.byte_len()
140        )));
141    }
142
143    let pipeline = registry.get_pipeline("repeat_tiled_f32", device)?;
144
145    let gpu_params = GpuRepeatTiledParams {
146        seq: params.seq,
147        hg: params.hg,
148        h: params.h,
149        k: params.k,
150    };
151
152    // Grid: (K, H, T) — one thread per output element. Threadgroup width
153    // along K dimension (innermost / contiguous in dst write) up to 256.
154    let grid = MTLSize::new(params.k as u64, params.h as u64, params.seq as u64);
155    let tg_x = std::cmp::min(256u64, params.k as u64);
156    let tg = MTLSize::new(tg_x, 1, 1);
157
158    encode_with_args(
159        encoder,
160        pipeline,
161        &[
162            (0, KernelArg::Buffer(src)),
163            (1, KernelArg::Buffer(dst)),
164            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
165        ],
166        grid,
167        tg,
168    );
169
170    Ok(())
171}