1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
//! Dense bf16 × f32 → f32 matmul using Apple M3+ tensor cores
//! (`mpp::tensor_ops::matmul2d`).
//!
//! Mirrors the API shape of `quantized_matmul_ggml::quantized_matmul_ggml`
//! (M-N-K with batch broadcasting via r2/r3) but operates on dense bf16
//! weights instead of GGML block-quantized weights. Used by hf2q's
//! non-flash-attention prefill path for Q@K^T and scores@V, matching
//! llama.cpp's `ggml_mul_mat` dispatch when `-fa 0`.
//!
//! Derived from llama.cpp (MIT). See `src/shaders/dense_mm_bf16_tensor.metal`.
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::{CommandEncoder, KernelArg, as_bytes};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
/// Host-side parameters for `dense_matmul_bf16_f32_tensor`.
#[derive(Debug, Clone, Copy)]
pub struct DenseMmBf16F32Params {
/// M — number of src1 rows (= output rows per batch).
pub m: u32,
/// N — number of src0 rows (= output cols per batch).
pub n: u32,
/// K — contract dim, shared between src0 and src1.
pub k: u32,
/// src0 batch count (e.g. nkv for attention GQA). Every batch slice
/// is laid out contiguously as `[n, k]` bf16 row-major.
pub src0_batch: u32,
/// src1 batch count (e.g. nh for attention). Every slice is
/// `[m, k]` f32 row-major. Must be an integer multiple of
/// `src0_batch` — the kernel broadcasts each src0 slice across
/// `src1_batch / src0_batch` consecutive src1 slices (GQA head
/// broadcast).
pub src1_batch: u32,
}
/// GPU-side params struct; matches `DenseMmBf16F32TensorParams` in
/// `shaders/dense_mm_bf16_tensor.metal` byte-for-byte.
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct DenseMmBf16F32TensorGpuParams {
ne00: i32, // K (contract dim)
ne02: i32, // src0 batch count
nb01: u64, // src0 row stride (bytes)
nb02: u64, // src0 batch stride (bytes)
nb03: u64, // unused
ne12: i32, // src1 batch count
_pad0: u32,
nb10: u64, // sizeof(float) = 4
nb11: u64, // src1 row stride (bytes)
nb12: u64, // src1 batch stride (bytes)
nb13: u64, // unused
ne0: i32, // N (output cols = src0 rows)
ne1: i32, // M (output rows = src1 rows)
r2: i16, // ne12 / ne02 (GQA head broadcast factor)
r3: i16,
_pad1: u32,
}
/// Dense bf16 × f32 → f32 matmul, tensor-API path.
///
/// Computes `output[b, m, n] = sum_k src0[b/r2, n, k] * src1[b, m, k]`
/// for every `b` in `0..src1_batch`. Implements llama.cpp's
/// `kernel_mul_mm_bf16_f32` contract on the tensor-core path.
///
/// Dtype contract:
/// - `src0`: bf16 `[src0_batch, n, k]` row-major.
/// - `src1`: f32 `[src1_batch, m, k]` row-major.
/// - `dst`: f32 `[src1_batch, m, n]` row-major (output).
///
/// # Errors
///
/// `MlxError::InvalidArgument` for any shape, buffer-size, or dtype
/// mismatch, or if `k < 32` (kernel requires at least one NK=32 tile).
pub fn dense_matmul_bf16_f32_tensor(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
src0: &MlxBuffer,
src1: &MlxBuffer,
dst: &MlxBuffer,
params: &DenseMmBf16F32Params,
) -> Result<()> {
if params.m == 0 || params.n == 0 || params.k == 0 {
return Err(MlxError::InvalidArgument(
"dense_matmul_bf16_f32_tensor: M, N, K must all be > 0".into(),
));
}
if params.k < 32 {
return Err(MlxError::InvalidArgument(format!(
"dense_matmul_bf16_f32_tensor: K ({}) must be >= 32",
params.k
)));
}
if params.src0_batch == 0 || params.src1_batch == 0 {
return Err(MlxError::InvalidArgument(
"dense_matmul_bf16_f32_tensor: batch counts must be > 0".into(),
));
}
if params.src1_batch % params.src0_batch != 0 {
return Err(MlxError::InvalidArgument(format!(
"dense_matmul_bf16_f32_tensor: src1_batch ({}) must be a \
multiple of src0_batch ({}) for GQA broadcast",
params.src1_batch, params.src0_batch
)));
}
let bf16_sz = DType::BF16.size_of();
let f32_sz = DType::F32.size_of();
let expected_src0_bytes =
(params.src0_batch as usize) * (params.n as usize) * (params.k as usize) * bf16_sz;
if src0.byte_len() < expected_src0_bytes {
return Err(MlxError::InvalidArgument(format!(
"dense_matmul_bf16_f32_tensor: src0 too small: expected {} bytes for \
[{}×{}×{}] bf16, got {}",
expected_src0_bytes, params.src0_batch, params.n, params.k, src0.byte_len()
)));
}
let expected_src1_bytes =
(params.src1_batch as usize) * (params.m as usize) * (params.k as usize) * f32_sz;
if src1.byte_len() < expected_src1_bytes {
return Err(MlxError::InvalidArgument(format!(
"dense_matmul_bf16_f32_tensor: src1 too small: expected {} bytes for \
[{}×{}×{}] f32, got {}",
expected_src1_bytes, params.src1_batch, params.m, params.k, src1.byte_len()
)));
}
let expected_dst_bytes =
(params.src1_batch as usize) * (params.m as usize) * (params.n as usize) * f32_sz;
if dst.byte_len() < expected_dst_bytes {
return Err(MlxError::InvalidArgument(format!(
"dense_matmul_bf16_f32_tensor: dst too small: expected {} bytes for \
[{}×{}×{}] f32, got {}",
expected_dst_bytes, params.src1_batch, params.m, params.n, dst.byte_len()
)));
}
// ADR-029 iter-80 H60: V2 large-tile (NRA=64, NRB=128) variant
// env-gated by HF2Q_LARGE_TILE_MM. Default OFF until coherence +
// thermal-fair bench parity proven. Treated as a fan-out shim: V1
// pipeline + grid (NR0=64, NR1=32) when off, V2 pipeline + grid
// (NRA=64, NRB=128) when on. Truthy: "1", "true", "yes" (case-
// insensitive); anything else → V1.
let use_v2_large_tile = match std::env::var("HF2Q_LARGE_TILE_MM").as_deref() {
Ok("1") | Ok("true") | Ok("True") | Ok("TRUE") | Ok("yes") | Ok("YES") => true,
_ => false,
};
let kernel_name = if use_v2_large_tile {
"hf2q_dense_mm_bf16_f32_tensor_v2"
} else {
"hf2q_dense_mm_bf16_f32_tensor"
};
let pipeline = registry.get_pipeline(kernel_name, device.metal_device())?;
let nb01 = (params.k as u64) * (bf16_sz as u64); // src0 row
let nb02 = (params.n as u64) * nb01; // src0 batch
let nb11 = (params.k as u64) * (f32_sz as u64); // src1 row
let nb12 = (params.m as u64) * nb11; // src1 batch
let r2 = (params.src1_batch / params.src0_batch) as i16;
let gpu_params = DenseMmBf16F32TensorGpuParams {
ne00: params.k as i32,
ne02: params.src0_batch as i32,
nb01,
nb02,
nb03: 0,
ne12: params.src1_batch as i32,
_pad0: 0,
nb10: f32_sz as u64,
nb11,
nb12,
nb13: 0,
ne0: params.n as i32,
ne1: params.m as i32,
r2,
r3: 1,
_pad1: 0,
};
// V1 tile: NR0=64 (M_peer axis = hf2q-N), NR1=32 (N_peer axis = hf2q-M).
// V2 tile: NRA=64 (M_peer = hf2q-N), NRB=128 (N_peer = hf2q-M).
// Note hf2q axis swap: ne0 = hf2q-N (M_peer), ne1 = hf2q-M (N_peer);
// tgpig.y covers M_peer-axis (NRA/NR0), tgpig.x covers N_peer-axis
// (NRB/NR1). Threads-per-TG = NUM_THREADS = 128 in both (4 simdgroups
// × 32 lanes). V2 shmem: A-tile only (NRA × NK = 64 × 32 × 2 B =
// 4096 B), B read direct from device → halved shmem budget vs V1.
const NR0: u64 = 64;
const NR1_V1: u64 = 32;
const NRB_V2: u64 = 128;
const THREADS_PER_TG: u64 = 128;
const SHMEM_V1: u64 = 8192;
const SHMEM_V2: u64 = 4096;
let (nr1, shmem_bytes) = if use_v2_large_tile {
(NRB_V2, SHMEM_V2)
} else {
(NR1_V1, SHMEM_V1)
};
// Grid: (ceil(M/nr1), ceil(N/NR0), src1_batch). M → tgpig.x (covers
// N_peer = hf2q-M), N → tgpig.y (covers M_peer = hf2q-N), batch → z.
let threadgroups = metal::MTLSize::new(
(params.m as u64 + nr1 - 1) / nr1,
(params.n as u64 + NR0 - 1) / NR0,
params.src1_batch as u64,
);
let threads_per_tg = metal::MTLSize::new(THREADS_PER_TG, 1, 1);
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(src0)),
(2, KernelArg::Buffer(src1)),
(3, KernelArg::Buffer(dst)),
],
&[(0, shmem_bytes)],
threadgroups,
threads_per_tg,
);
Ok(())
}