Expand description
Differentiable Rotary Position Embedding — forward + backward.
Used by hf2q’s DWQ training tape (ADR-022 / flash_attn_train Phase 1a). This module provides a standalone RoPE op that is its own backward:
Forward: Q’ = RoPE(Q, pos) Backward: dQ = RoPE(dQ’, -pos) (rotation matrix is orthogonal; R(θ)^T = R(-θ))
§Implementation
Both the forward and the backward dispatch the SAME Metal kernel —
rope_multi_bf16 / rope_multi_f32 from super::rope_multi. The
backward simply passes a negated copy of the positions buffer.
No new Metal shader is needed.
§IMROPE convention (Qwen3.5 / Qwen3.6)
sections = [11, 11, 10, 0] with freq_base = 1e7 and mode = IMROPE (40).
Positions layout: int32[4 * seq_len] — first seq_len entries are the
time-axis positions, next seq_len the height-axis, then width, then extra.
For text-only inputs all four axes equal the token’s 1-D position.
Pair indexing is NeoX-style: each thread rotates
(x[p], x[p + head_dim/2]) for pair p ∈ [0, rope_dim/2).
Pairs p ≥ rope_dim/2 pass through unchanged (partial-rotary tail).
§References
src/ops/rope_multi.rs— the underlying dispatch + buffer constructionsrc/shaders/rope_multi.metal— IMROPE / MROPE / VISION Metal kerneltests/test_rope_multi.rs— parity oracle (cpu_rope_multi)/opt/hf2q/src/inference/models/qwen35/full_attn.rs:18-19— production call-sites/opt/hf2q/src/inference/models/qwen35/mod.rs:235— mrope_section=[11,11,10,0]
Structs§
- Rope
Train Params - Shape + frequency parameters for a differentiable RoPE dispatch.
Functions§
- dispatch_
rope_ backward_ bf16 - Apply the RoPE backward pass:
dQ = RoPE(dQ', -pos). - dispatch_
rope_ backward_ f32 - f32 backward variant. Same contract as the bf16 backward.
- dispatch_
rope_ forward_ bf16 - Apply RoPE (IMROPE mode) to
in_bufand write result toout_buf. - dispatch_
rope_ forward_ f32 - f32 forward variant. Same contract as the bf16 version; operates on f32
in_buf/out_buf.