Skip to main content

Module rope_train

Module rope_train 

Source
Expand description

Differentiable Rotary Position Embedding — forward + backward.

Used by hf2q’s DWQ training tape (ADR-022 / flash_attn_train Phase 1a). This module provides a standalone RoPE op that is its own backward:

Forward: Q’ = RoPE(Q, pos) Backward: dQ = RoPE(dQ’, -pos) (rotation matrix is orthogonal; R(θ)^T = R(-θ))

§Implementation

Both the forward and the backward dispatch the SAME Metal kernel — rope_multi_bf16 / rope_multi_f32 from super::rope_multi. The backward simply passes a negated copy of the positions buffer. No new Metal shader is needed.

§IMROPE convention (Qwen3.5 / Qwen3.6)

sections = [11, 11, 10, 0] with freq_base = 1e7 and mode = IMROPE (40). Positions layout: int32[4 * seq_len] — first seq_len entries are the time-axis positions, next seq_len the height-axis, then width, then extra. For text-only inputs all four axes equal the token’s 1-D position.

Pair indexing is NeoX-style: each thread rotates (x[p], x[p + head_dim/2]) for pair p ∈ [0, rope_dim/2). Pairs p ≥ rope_dim/2 pass through unchanged (partial-rotary tail).

§References

  • src/ops/rope_multi.rs — the underlying dispatch + buffer construction
  • src/shaders/rope_multi.metal — IMROPE / MROPE / VISION Metal kernel
  • tests/test_rope_multi.rs — parity oracle (cpu_rope_multi)
  • /opt/hf2q/src/inference/models/qwen35/full_attn.rs:18-19 — production call-sites
  • /opt/hf2q/src/inference/models/qwen35/mod.rs:235 — mrope_section=[11,11,10,0]

Structs§

RopeTrainParams
Shape + frequency parameters for a differentiable RoPE dispatch.

Functions§

dispatch_rope_backward_bf16
Apply the RoPE backward pass: dQ = RoPE(dQ', -pos).
dispatch_rope_backward_f32
f32 backward variant. Same contract as the bf16 backward.
dispatch_rope_forward_bf16
Apply RoPE (IMROPE mode) to in_buf and write result to out_buf.
dispatch_rope_forward_f32
f32 forward variant. Same contract as the bf16 version; operates on f32 in_buf / out_buf.