mlx_native/ops/rope_train.rs
1//! Differentiable Rotary Position Embedding — forward + backward.
2//!
3//! Used by hf2q's DWQ training tape (ADR-022 / flash_attn_train Phase 1a).
4//! This module provides a standalone RoPE op that is its own backward:
5//!
6//! Forward: Q' = RoPE(Q, pos)
7//! Backward: dQ = RoPE(dQ', -pos) (rotation matrix is orthogonal; R(θ)^T = R(-θ))
8//!
9//! # Implementation
10//!
11//! Both the forward and the backward dispatch the SAME Metal kernel —
12//! `rope_multi_bf16` / `rope_multi_f32` from [`super::rope_multi`]. The
13//! backward simply passes a negated copy of the positions buffer.
14//! No new Metal shader is needed.
15//!
16//! # IMROPE convention (Qwen3.5 / Qwen3.6)
17//!
18//! `sections = [11, 11, 10, 0]` with `freq_base = 1e7` and mode = IMROPE (40).
19//! Positions layout: `int32[4 * seq_len]` — first `seq_len` entries are the
20//! time-axis positions, next `seq_len` the height-axis, then width, then extra.
21//! For text-only inputs all four axes equal the token's 1-D position.
22//!
23//! Pair indexing is NeoX-style: each thread rotates
24//! `(x[p], x[p + head_dim/2])` for pair `p ∈ [0, rope_dim/2)`.
25//! Pairs `p ≥ rope_dim/2` pass through unchanged (partial-rotary tail).
26//!
27//! # References
28//!
29//! - `src/ops/rope_multi.rs` — the underlying dispatch + buffer construction
30//! - `src/shaders/rope_multi.metal` — IMROPE / MROPE / VISION Metal kernel
31//! - `tests/test_rope_multi.rs` — parity oracle (cpu_rope_multi)
32//! - `/opt/hf2q/src/inference/models/qwen35/full_attn.rs:18-19` — production call-sites
33//! - `/opt/hf2q/src/inference/models/qwen35/mod.rs:235` — mrope_section=[11,11,10,0]
34
35use crate::buffer::MlxBuffer;
36use crate::device::MlxDevice;
37use crate::dtypes::DType;
38use crate::encoder::CommandEncoder;
39use crate::error::{MlxError, Result};
40use crate::kernel_registry::KernelRegistry;
41use crate::ops::rope_multi::{
42 build_rope_multi_buffers, dispatch_rope_multi, RopeMultiMode, RopeMultiParams,
43};
44
45// ---------------------------------------------------------------------------
46// Public parameter struct
47// ---------------------------------------------------------------------------
48
49/// Shape + frequency parameters for a differentiable RoPE dispatch.
50///
51/// Non-IMROPE (plain NeoX RoPE) is expressed as `sections = [head_dim/2, 0, 0, 0]`
52/// with `mode = Imrope` — all pairs fall into axis 0 (text time-axis) which is
53/// the only axis used. Alternatively, callers can use `rope_multi` directly
54/// with `mode = Mrope` and `sections = [rope_dim/2, 0, 0, 0]`.
55#[derive(Debug, Clone, Copy)]
56pub struct RopeTrainParams {
57 pub batch: u32,
58 /// Number of query/key heads.
59 pub n_heads: u32,
60 pub seq_len: u32,
61 /// Full head dimension (must be even).
62 pub head_dim: u32,
63 /// Number of dimensions that participate in rotation (≤ head_dim, even).
64 /// Pairs `[rope_dim/2, head_dim/2)` pass through unchanged.
65 pub rope_dim: u32,
66 /// Base frequency (theta). Qwen3.5/3.6: `1_000_000.0 = 1e6`.
67 ///
68 /// Note: the metal shader comment in `test_rope_multi.rs` line 347 uses
69 /// `1e7`; the Qwen3.5 model config uses `rope_theta = 1_000_000` = 1e6.
70 /// The caller MUST pass the value that matches the model's GGUF
71 /// `<prefix>.rope.freq_base` key.
72 pub theta_base: f32,
73 /// Section counts `[s0, s1, s2, s3]` for IMROPE / MROPE.
74 ///
75 /// Qwen3.5 / Qwen3.6: `[11, 11, 10, 0]` (IMROPE, matches
76 /// `/opt/hf2q/src/inference/models/qwen35/mod.rs:235`).
77 ///
78 /// Sum `s0+s1+s2+s3` should equal `rope_dim / 2` for full rotary-section
79 /// coverage. The kernel tolerates sums smaller than `rope_dim/2`
80 /// (sectors wrap modulo the sum), but callers should pass the canonical
81 /// value from the model config.
82 ///
83 /// For non-IMROPE plain NeoX: `[rope_dim/2, 0, 0, 0]` with MROPE mode
84 /// puts every pair in axis-0 (time).
85 pub sections: [u32; 4],
86}
87
88// ---------------------------------------------------------------------------
89// Helpers
90// ---------------------------------------------------------------------------
91
92/// Translate `RopeTrainParams` into the `RopeMultiParams` needed by
93/// [`dispatch_rope_multi`], using IMROPE mode.
94fn to_rope_multi_params(p: &RopeTrainParams) -> RopeMultiParams {
95 RopeMultiParams {
96 head_dim: p.head_dim,
97 rope_dim: p.rope_dim,
98 // rope_multi rows = seq_len * n_heads (batch dim is NOT part of
99 // rope_multi; callers must slice or iterate per batch element).
100 // Here we fold batch into n_heads for a single dispatch covering all
101 // batch × n_heads rows.
102 n_heads: p.n_heads,
103 seq_len: p.seq_len * p.batch, // fold batch: n_rows = batch*seq_len*n_heads
104 freq_base: p.theta_base,
105 mode: RopeMultiMode::Imrope,
106 sections: p.sections,
107 }
108}
109
110/// Validate `RopeTrainParams`.
111fn validate_params(p: &RopeTrainParams) -> Result<()> {
112 if p.batch == 0 || p.n_heads == 0 || p.seq_len == 0 || p.head_dim == 0 || p.rope_dim == 0 {
113 return Err(MlxError::InvalidArgument(
114 "rope_train: batch, n_heads, seq_len, head_dim, rope_dim must all be > 0".into(),
115 ));
116 }
117 if p.head_dim % 2 != 0 || p.rope_dim % 2 != 0 {
118 return Err(MlxError::InvalidArgument(
119 "rope_train: head_dim and rope_dim must be even".into(),
120 ));
121 }
122 if p.rope_dim > p.head_dim {
123 return Err(MlxError::InvalidArgument(format!(
124 "rope_train: rope_dim ({}) must be <= head_dim ({})",
125 p.rope_dim, p.head_dim
126 )));
127 }
128 if !p.theta_base.is_finite() || p.theta_base <= 0.0 {
129 return Err(MlxError::InvalidArgument(format!(
130 "rope_train: theta_base must be finite and positive, got {}",
131 p.theta_base
132 )));
133 }
134 Ok(())
135}
136
137/// Check that `buf` has the expected element count and dtype.
138fn validate_io(label: &str, buf: &MlxBuffer, expected_elems: usize, expected_dtype: DType) -> Result<()> {
139 if buf.element_count() != expected_elems {
140 return Err(MlxError::InvalidArgument(format!(
141 "rope_train: {label} element count {} != expected {}",
142 buf.element_count(),
143 expected_elems
144 )));
145 }
146 if buf.dtype() != expected_dtype {
147 return Err(MlxError::InvalidArgument(format!(
148 "rope_train: {label} dtype {} != expected {}",
149 buf.dtype(),
150 expected_dtype
151 )));
152 }
153 Ok(())
154}
155
156/// Expected element count for `in_buf` / `out_buf`.
157fn tensor_elems(p: &RopeTrainParams) -> usize {
158 p.batch as usize * p.n_heads as usize * p.seq_len as usize * p.head_dim as usize
159}
160
161/// Expected element count for `pos_buf`.
162fn pos_elems(p: &RopeTrainParams) -> usize {
163 4 * p.seq_len as usize * p.batch as usize
164}
165
166// ---------------------------------------------------------------------------
167// Forward — bf16
168// ---------------------------------------------------------------------------
169
170/// Apply RoPE (IMROPE mode) to `in_buf` and write result to `out_buf`.
171///
172/// # Buffers
173///
174/// | Buffer | Shape | DType |
175/// |:------------ |:--------------------------------------- |:----- |
176/// | `in_buf` | `[batch, n_heads, seq_len, head_dim]` | bf16 |
177/// | `pos_buf` | `[4 * batch * seq_len]` i32 | i32 |
178/// | `out_buf` | same as `in_buf` | bf16 |
179///
180/// The buffer layout for `pos_buf` folds batch into the positions array:
181/// `[t_positions_batch0..batchN, h_positions_..., w_positions_..., e_positions_...]`
182/// where each axis block has length `batch * seq_len`.
183///
184/// # Grid mapping
185///
186/// The underlying `rope_multi` kernel maps rows as `seq_len * n_heads` where
187/// `seq_len` here is `batch * seq_len` (batch is folded in) and `n_heads` is
188/// as given in `params`. Thread `(pair_idx, row_idx)` handles one NeoX pair
189/// for one (head, token) row.
190#[allow(clippy::too_many_arguments)]
191pub fn dispatch_rope_forward_bf16(
192 encoder: &mut CommandEncoder,
193 registry: &mut KernelRegistry,
194 device: &metal::DeviceRef,
195 mlx_device: &MlxDevice,
196 in_buf: &MlxBuffer,
197 pos_buf: &MlxBuffer,
198 out_buf: &MlxBuffer,
199 params: &RopeTrainParams,
200) -> Result<()> {
201 validate_params(params)?;
202 let n_elems = tensor_elems(params);
203 validate_io("in_buf", in_buf, n_elems, DType::BF16)?;
204 validate_io("out_buf", out_buf, n_elems, DType::BF16)?;
205 // pos_buf must be i32 (signed — backward uses negative positions).
206 if pos_buf.element_count() != pos_elems(params) {
207 return Err(MlxError::InvalidArgument(format!(
208 "rope_train forward: pos_buf element count {} != 4 * batch({}) * seq_len({}) = {}",
209 pos_buf.element_count(),
210 params.batch,
211 params.seq_len,
212 pos_elems(params)
213 )));
214 }
215 match pos_buf.dtype() {
216 DType::I32 | DType::U32 => {}
217 other => {
218 return Err(MlxError::InvalidArgument(format!(
219 "rope_train forward: pos_buf dtype {other} must be i32 or u32"
220 )));
221 }
222 }
223
224 let mp = to_rope_multi_params(params);
225 let (params_buf, rope_params_buf, sections_buf) = build_rope_multi_buffers(mlx_device, mp)?;
226 dispatch_rope_multi(
227 encoder,
228 registry,
229 device,
230 in_buf,
231 out_buf,
232 pos_buf,
233 ¶ms_buf,
234 &rope_params_buf,
235 §ions_buf,
236 mp,
237 )
238}
239
240// ---------------------------------------------------------------------------
241// Backward — bf16
242// ---------------------------------------------------------------------------
243
244/// Apply the RoPE backward pass: `dQ = RoPE(dQ', -pos)`.
245///
246/// Mathematically, the rotation matrix is orthogonal: `R(θ)^T = R(-θ)`.
247/// Therefore `∂Q'/∂Q = R(pos)` and the VJP is `dQ = R(pos)^T · dQ' = R(-pos) · dQ'`.
248///
249/// This function negates every entry in `pos_buf` (on the CPU before upload
250/// to the kernel) and dispatches `dispatch_rope_forward_bf16` with the negated
251/// positions.
252///
253/// # Buffers
254///
255/// | Buffer | Shape | DType |
256/// |:-------------- |:------------------------------------- |:----- |
257/// | `grad_out_buf` | `[batch, n_heads, seq_len, head_dim]` | bf16 |
258/// | `pos_buf` | `[4 * batch * seq_len]` i32 (forward positions; NOT negated — this function negates internally) | i32 |
259/// | `grad_in_buf` | same as `grad_out_buf` | bf16 |
260#[allow(clippy::too_many_arguments)]
261pub fn dispatch_rope_backward_bf16(
262 encoder: &mut CommandEncoder,
263 registry: &mut KernelRegistry,
264 device: &metal::DeviceRef,
265 mlx_device: &MlxDevice,
266 grad_out_buf: &MlxBuffer,
267 pos_buf: &MlxBuffer,
268 grad_in_buf: &MlxBuffer,
269 params: &RopeTrainParams,
270) -> Result<()> {
271 validate_params(params)?;
272 let n_elems = tensor_elems(params);
273 validate_io("grad_out_buf", grad_out_buf, n_elems, DType::BF16)?;
274 validate_io("grad_in_buf", grad_in_buf, n_elems, DType::BF16)?;
275 if pos_buf.element_count() != pos_elems(params) {
276 return Err(MlxError::InvalidArgument(format!(
277 "rope_train backward: pos_buf element count {} != 4 * batch({}) * seq_len({}) = {}",
278 pos_buf.element_count(),
279 params.batch,
280 params.seq_len,
281 pos_elems(params)
282 )));
283 }
284 match pos_buf.dtype() {
285 DType::I32 | DType::U32 => {}
286 other => {
287 return Err(MlxError::InvalidArgument(format!(
288 "rope_train backward: pos_buf dtype {other} must be i32 or u32"
289 )));
290 }
291 }
292
293 // Build a negated positions buffer on the host. The pos values are i32
294 // (signed); negation is defined. U32-typed buffers are reinterpreted as
295 // i32 for negation (the bit pattern is the same as a two's-complement
296 // signed negate for values > 0; for pos=0, -0 = 0).
297 let neg_pos_buf = negate_pos_buf_i32(mlx_device, pos_buf)?;
298
299 let mp = to_rope_multi_params(params);
300 let (params_buf, rope_params_buf, sections_buf) = build_rope_multi_buffers(mlx_device, mp)?;
301 dispatch_rope_multi(
302 encoder,
303 registry,
304 device,
305 grad_out_buf,
306 grad_in_buf,
307 &neg_pos_buf,
308 ¶ms_buf,
309 &rope_params_buf,
310 §ions_buf,
311 mp,
312 )
313}
314
315/// Build a new i32 buffer whose values are the negation of `pos_buf`'s values.
316///
317/// Handles both `DType::I32` and `DType::U32` source buffers (the rope_multi
318/// kernel accepts both; we produce `DType::I32` for the negated output).
319fn negate_pos_buf_i32(device: &MlxDevice, pos_buf: &MlxBuffer) -> Result<MlxBuffer> {
320 let n = pos_buf.element_count();
321 let src_bytes: Vec<i32> = match pos_buf.dtype() {
322 DType::I32 => pos_buf.as_slice::<i32>()?.to_vec(),
323 DType::U32 => pos_buf
324 .as_slice::<u32>()?
325 .iter()
326 .map(|&v| v as i32)
327 .collect(),
328 other => {
329 return Err(MlxError::InvalidArgument(format!(
330 "negate_pos_buf: unsupported dtype {other}"
331 )))
332 }
333 };
334
335 let negated: Vec<i32> = src_bytes.iter().map(|&v| v.wrapping_neg()).collect();
336 let mut buf = device.alloc_buffer(n * 4, DType::I32, vec![n])?;
337 buf.as_mut_slice::<i32>()?.copy_from_slice(&negated);
338 Ok(buf)
339}
340
341// ---------------------------------------------------------------------------
342// f32 variants
343// ---------------------------------------------------------------------------
344
345/// f32 forward variant. Same contract as the bf16 version; operates on f32
346/// `in_buf` / `out_buf`.
347#[allow(clippy::too_many_arguments)]
348pub fn dispatch_rope_forward_f32(
349 encoder: &mut CommandEncoder,
350 registry: &mut KernelRegistry,
351 device: &metal::DeviceRef,
352 mlx_device: &MlxDevice,
353 in_buf: &MlxBuffer,
354 pos_buf: &MlxBuffer,
355 out_buf: &MlxBuffer,
356 params: &RopeTrainParams,
357) -> Result<()> {
358 validate_params(params)?;
359 let n_elems = tensor_elems(params);
360 validate_io("in_buf", in_buf, n_elems, DType::F32)?;
361 validate_io("out_buf", out_buf, n_elems, DType::F32)?;
362 if pos_buf.element_count() != pos_elems(params) {
363 return Err(MlxError::InvalidArgument(format!(
364 "rope_train f32 forward: pos_buf element count {} != {}",
365 pos_buf.element_count(),
366 pos_elems(params)
367 )));
368 }
369 match pos_buf.dtype() {
370 DType::I32 | DType::U32 => {}
371 other => {
372 return Err(MlxError::InvalidArgument(format!(
373 "rope_train f32 forward: pos_buf dtype {other} must be i32 or u32"
374 )));
375 }
376 }
377
378 let mp = to_rope_multi_params(params);
379 let (params_buf, rope_params_buf, sections_buf) = build_rope_multi_buffers(mlx_device, mp)?;
380 dispatch_rope_multi(
381 encoder,
382 registry,
383 device,
384 in_buf,
385 out_buf,
386 pos_buf,
387 ¶ms_buf,
388 &rope_params_buf,
389 §ions_buf,
390 mp,
391 )
392}
393
394/// f32 backward variant. Same contract as the bf16 backward.
395#[allow(clippy::too_many_arguments)]
396pub fn dispatch_rope_backward_f32(
397 encoder: &mut CommandEncoder,
398 registry: &mut KernelRegistry,
399 device: &metal::DeviceRef,
400 mlx_device: &MlxDevice,
401 grad_out_buf: &MlxBuffer,
402 pos_buf: &MlxBuffer,
403 grad_in_buf: &MlxBuffer,
404 params: &RopeTrainParams,
405) -> Result<()> {
406 validate_params(params)?;
407 let n_elems = tensor_elems(params);
408 validate_io("grad_out_buf", grad_out_buf, n_elems, DType::F32)?;
409 validate_io("grad_in_buf", grad_in_buf, n_elems, DType::F32)?;
410 if pos_buf.element_count() != pos_elems(params) {
411 return Err(MlxError::InvalidArgument(format!(
412 "rope_train f32 backward: pos_buf element count {} != {}",
413 pos_buf.element_count(),
414 pos_elems(params)
415 )));
416 }
417 match pos_buf.dtype() {
418 DType::I32 | DType::U32 => {}
419 other => {
420 return Err(MlxError::InvalidArgument(format!(
421 "rope_train f32 backward: pos_buf dtype {other} must be i32 or u32"
422 )));
423 }
424 }
425
426 let neg_pos_buf = negate_pos_buf_i32(mlx_device, pos_buf)?;
427 let mp = to_rope_multi_params(params);
428 let (params_buf, rope_params_buf, sections_buf) = build_rope_multi_buffers(mlx_device, mp)?;
429 dispatch_rope_multi(
430 encoder,
431 registry,
432 device,
433 grad_out_buf,
434 grad_in_buf,
435 &neg_pos_buf,
436 ¶ms_buf,
437 &rope_params_buf,
438 §ions_buf,
439 mp,
440 )
441}