#![allow(
clippy::expect_used,
clippy::unwrap_used,
clippy::panic,
clippy::cast_precision_loss,
clippy::cast_possible_truncation
)]
use half::bf16;
use mlx_native::ops::rope_train::{
dispatch_rope_backward_bf16, dispatch_rope_backward_f32, dispatch_rope_forward_bf16,
dispatch_rope_forward_f32, RopeTrainParams,
};
use mlx_native::ops::rope_multi::register as register_rope_multi;
use mlx_native::{DType, KernelRegistry, MlxBuffer, MlxDevice};
fn device_and_registry() -> (MlxDevice, KernelRegistry) {
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
register_rope_multi(&mut registry);
(device, registry)
}
fn upload_bf16(device: &MlxDevice, data: &[f32]) -> MlxBuffer {
let bf: Vec<bf16> = data.iter().map(|&v| bf16::from_f32(v)).collect();
let mut buf = device
.alloc_buffer(bf.len() * 2, DType::BF16, vec![bf.len()])
.expect("alloc bf16");
buf.as_mut_slice::<bf16>()
.expect("mut")
.copy_from_slice(&bf);
buf
}
fn upload_f32(device: &MlxDevice, data: &[f32]) -> MlxBuffer {
let mut buf = device
.alloc_buffer(data.len() * 4, DType::F32, vec![data.len()])
.expect("alloc f32");
buf.as_mut_slice::<f32>()
.expect("mut")
.copy_from_slice(data);
buf
}
fn upload_i32(device: &MlxDevice, data: &[i32]) -> MlxBuffer {
let mut buf = device
.alloc_buffer(data.len() * 4, DType::I32, vec![data.len()])
.expect("alloc i32");
buf.as_mut_slice::<i32>()
.expect("mut")
.copy_from_slice(data);
buf
}
fn alloc_bf16_zeros(device: &MlxDevice, n: usize) -> MlxBuffer {
device
.alloc_buffer(n * 2, DType::BF16, vec![n])
.expect("alloc bf16 zeros")
}
fn alloc_f32_zeros(device: &MlxDevice, n: usize) -> MlxBuffer {
device
.alloc_buffer(n * 4, DType::F32, vec![n])
.expect("alloc f32 zeros")
}
fn download_bf16(buf: &MlxBuffer) -> Vec<f32> {
buf.as_slice::<bf16>()
.expect("read bf16")
.iter()
.map(|v| v.to_f32())
.collect()
}
fn download_f32(buf: &MlxBuffer) -> Vec<f32> {
buf.as_slice::<f32>().expect("read f32").to_vec()
}
fn text_positions(batch: u32, seq_len: u32) -> Vec<i32> {
let n = (batch * seq_len) as usize;
let base: Vec<i32> = (0..n as i32).collect();
let mut pos = Vec::with_capacity(4 * n);
for _ in 0..4 {
pos.extend_from_slice(&base);
}
pos
}
fn pick_axis_imrope(sector: u32, sections: [u32; 4]) -> usize {
let s = sections;
if sector % 3 == 0 && sector < 3 * s[0] {
0
} else if sector % 3 == 1 && sector < 3 * s[1] {
1
} else if sector % 3 == 2 && sector < 3 * s[2] {
2
} else {
3
}
}
pub fn rope_reference_f32(input: &[f32], positions: &[i32], p: &RopeTrainParams) -> Vec<f32> {
let batch = p.batch as usize;
let n_heads = p.n_heads as usize;
let seq = p.seq_len as usize;
let hd = p.head_dim as usize;
let half_dim = hd / 2;
let rope_dim = p.rope_dim as usize;
let half_rope = rope_dim / 2;
let sections = p.sections;
let theta_base = p.theta_base;
let sect_dims = sections.iter().sum::<u32>().max(1) as usize;
let n_rows = batch * n_heads * seq;
assert_eq!(input.len(), n_rows * hd, "rope_reference_f32: input size mismatch");
assert_eq!(positions.len(), 4 * batch * seq, "rope_reference_f32: positions size mismatch");
let mut out = input.to_vec();
for row in 0..n_rows {
let base = row * hd;
let b = row / (n_heads * seq);
let tok_in_batch = (row % (n_heads * seq)) / n_heads; let flat_tok = b * seq + tok_in_batch;
for pair in 0..half_dim {
if pair < half_rope {
let sector = (pair % sect_dims) as u32;
let axis = pick_axis_imrope(sector, sections);
let pos = positions[axis * batch * seq + flat_tok] as f32;
let dim_ratio = 2.0 * pair as f32 / rope_dim as f32;
let freq = 1.0 / theta_base.powf(dim_ratio);
let theta = pos * freq;
let (cos_a, sin_a) = (theta.cos(), theta.sin());
let x0 = input[base + pair];
let x1 = input[base + pair + half_dim];
out[base + pair] = x0 * cos_a - x1 * sin_a;
out[base + pair + half_dim] = x0 * sin_a + x1 * cos_a;
} else {
out[base + pair] = input[base + pair];
out[base + pair + half_dim] = input[base + pair + half_dim];
}
}
}
out
}
fn assert_close_atol(label: &str, got: &[f32], want: &[f32], atol: f32) {
assert_eq!(got.len(), want.len(), "{label}: length mismatch");
let mut max_diff: f32 = 0.0;
let mut max_idx = 0;
for (i, (&g, &w)) in got.iter().zip(want.iter()).enumerate() {
let d = (g - w).abs();
if d > max_diff {
max_diff = d;
max_idx = i;
}
assert!(
d <= atol,
"{label}: i={max_idx}: got={g} want={w} diff={d} > atol={atol}"
);
}
}
fn run_forward_bf16(
device: &MlxDevice,
registry: &mut KernelRegistry,
input_f32: &[f32],
positions: &[i32],
p: &RopeTrainParams,
) -> Vec<f32> {
let n = input_f32.len();
let in_buf = upload_bf16(device, input_f32);
let out_buf = alloc_bf16_zeros(device, n);
let pos_buf = upload_i32(device, positions);
let mut enc = device.command_encoder().expect("enc");
dispatch_rope_forward_bf16(
&mut enc,
registry,
device.metal_device(),
device,
&in_buf,
&pos_buf,
&out_buf,
p,
)
.expect("dispatch_rope_forward_bf16");
enc.commit_and_wait().expect("commit");
download_bf16(&out_buf)
}
fn run_backward_bf16(
device: &MlxDevice,
registry: &mut KernelRegistry,
grad_out_f32: &[f32],
positions: &[i32],
p: &RopeTrainParams,
) -> Vec<f32> {
let n = grad_out_f32.len();
let grad_out_buf = upload_bf16(device, grad_out_f32);
let grad_in_buf = alloc_bf16_zeros(device, n);
let pos_buf = upload_i32(device, positions);
let mut enc = device.command_encoder().expect("enc");
dispatch_rope_backward_bf16(
&mut enc,
registry,
device.metal_device(),
device,
&grad_out_buf,
&pos_buf,
&grad_in_buf,
p,
)
.expect("dispatch_rope_backward_bf16");
enc.commit_and_wait().expect("commit");
download_bf16(&grad_in_buf)
}
fn run_forward_f32(
device: &MlxDevice,
registry: &mut KernelRegistry,
input_f32: &[f32],
positions: &[i32],
p: &RopeTrainParams,
) -> Vec<f32> {
let n = input_f32.len();
let in_buf = upload_f32(device, input_f32);
let out_buf = alloc_f32_zeros(device, n);
let pos_buf = upload_i32(device, positions);
let mut enc = device.command_encoder().expect("enc");
dispatch_rope_forward_f32(
&mut enc,
registry,
device.metal_device(),
device,
&in_buf,
&pos_buf,
&out_buf,
p,
)
.expect("dispatch_rope_forward_f32");
enc.commit_and_wait().expect("commit");
download_f32(&out_buf)
}
fn run_backward_f32(
device: &MlxDevice,
registry: &mut KernelRegistry,
grad_out_f32: &[f32],
positions: &[i32],
p: &RopeTrainParams,
) -> Vec<f32> {
let n = grad_out_f32.len();
let grad_out_buf = upload_f32(device, grad_out_f32);
let grad_in_buf = alloc_f32_zeros(device, n);
let pos_buf = upload_i32(device, positions);
let mut enc = device.command_encoder().expect("enc");
dispatch_rope_backward_f32(
&mut enc,
registry,
device.metal_device(),
device,
&grad_out_buf,
&pos_buf,
&grad_in_buf,
p,
)
.expect("dispatch_rope_backward_f32");
enc.commit_and_wait().expect("commit");
download_f32(&grad_in_buf)
}
#[test]
fn test_forward_parity_b1_h4_s64_d64() {
let (device, mut registry) = device_and_registry();
let p = RopeTrainParams {
batch: 1,
n_heads: 4,
seq_len: 64,
head_dim: 64,
rope_dim: 64,
theta_base: 1e6,
sections: [11, 11, 10, 0],
};
let n = (p.batch * p.n_heads * p.seq_len * p.head_dim) as usize;
let mut seed = 0xdeadc0deu32;
let mut rand_f32 = || -> f32 {
seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
(seed as i32 as f32) / (i32::MAX as f32)
};
let input: Vec<f32> = (0..n).map(|_| rand_f32()).collect();
let positions = text_positions(p.batch, p.seq_len);
let got = run_forward_bf16(&device, &mut registry, &input, &positions, &p);
let want = rope_reference_f32(&input, &positions, &p);
assert_close_atol("fwd B1 H4 S64 D64", &got, &want, 6e-3);
}
#[test]
fn test_forward_parity_b2_h2_s128_d128() {
let (device, mut registry) = device_and_registry();
let p = RopeTrainParams {
batch: 2,
n_heads: 2,
seq_len: 128,
head_dim: 128,
rope_dim: 64, theta_base: 1e6,
sections: [11, 11, 10, 0],
};
let n = (p.batch * p.n_heads * p.seq_len * p.head_dim) as usize;
let mut seed = 0xcafebabeu32;
let mut rand_f32 = || -> f32 {
seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
(seed as i32 as f32) / (i32::MAX as f32)
};
let input: Vec<f32> = (0..n).map(|_| rand_f32()).collect();
let positions = text_positions(p.batch, p.seq_len);
let got = run_forward_bf16(&device, &mut registry, &input, &positions, &p);
let want = rope_reference_f32(&input, &positions, &p);
assert_close_atol("fwd B2 H2 S128 D128 partial", &got, &want, 6e-3);
let hd = p.head_dim as usize;
let half_dim = hd / 2;
let half_rope = p.rope_dim as usize / 2;
let n_rows = (p.batch * p.n_heads * p.seq_len) as usize;
for row in 0..n_rows {
let base = row * hd;
for pair in half_rope..half_dim {
let d0 = (got[base + pair] - input[base + pair]).abs();
let d1 = (got[base + pair + half_dim] - input[base + pair + half_dim]).abs();
assert!(
d0 < 5e-3,
"partial-rotary tail[pair={pair}] x0 modified: got={}, input={}",
got[base + pair], input[base + pair]
);
assert!(
d1 < 5e-3,
"partial-rotary tail[pair={pair}] x1 modified: got={}, input={}",
got[base + pair + half_dim], input[base + pair + half_dim]
);
}
}
}
#[test]
fn test_backward_equals_forward_with_negated_pos_bf16() {
let (device, mut registry) = device_and_registry();
let p = RopeTrainParams {
batch: 1,
n_heads: 4,
seq_len: 32,
head_dim: 64,
rope_dim: 64,
theta_base: 1e6,
sections: [11, 11, 10, 0],
};
let n = (p.batch * p.n_heads * p.seq_len * p.head_dim) as usize;
let mut seed = 0xf00dbabeu32;
let mut rand_f32 = || -> f32 {
seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
(seed as i32 as f32) / (i32::MAX as f32)
};
let dy: Vec<f32> = (0..n).map(|_| rand_f32()).collect();
let positions = text_positions(p.batch, p.seq_len);
let neg_positions: Vec<i32> = positions.iter().map(|&v| -v).collect();
let backward_result = run_backward_bf16(&device, &mut registry, &dy, &positions, &p);
let forward_negated_result = run_forward_bf16(&device, &mut registry, &dy, &neg_positions, &p);
assert_eq!(
backward_result.len(),
forward_negated_result.len(),
"length mismatch"
);
for (i, (b, f)) in backward_result
.iter()
.zip(forward_negated_result.iter())
.enumerate()
{
assert_eq!(
bf16::from_f32(*b).to_bits(),
bf16::from_f32(*f).to_bits(),
"backward != forward(-pos) at i={}: backward={}, forward_neg={}",
i, b, f
);
}
}
#[test]
fn test_round_trip_identity_bf16() {
let (device, mut registry) = device_and_registry();
let p = RopeTrainParams {
batch: 1,
n_heads: 2,
seq_len: 16,
head_dim: 64,
rope_dim: 64,
theta_base: 1e6,
sections: [11, 11, 10, 0],
};
let n = (p.batch * p.n_heads * p.seq_len * p.head_dim) as usize;
let input_f32: Vec<f32> = (0..n)
.map(|i| ((i as f32 * 0.03 + 0.01).sin()) * 0.5)
.collect();
let positions = text_positions(p.batch, p.seq_len);
let y = run_forward_bf16(&device, &mut registry, &input_f32, &positions, &p);
let x_recovered = run_backward_bf16(&device, &mut registry, &y, &positions, &p);
assert_close_atol("round-trip identity", &x_recovered, &input_f32, 5e-3);
}
#[test]
fn test_finite_diff_falsifier_f32() {
let (device, mut registry) = device_and_registry();
let p = RopeTrainParams {
batch: 1,
n_heads: 1,
seq_len: 8,
head_dim: 16,
rope_dim: 16,
theta_base: 1e4,
sections: [3, 3, 2, 0],
};
let n = (p.batch * p.n_heads * p.seq_len * p.head_dim) as usize;
let input_f32: Vec<f32> = (0..n).map(|i| (i as f32 * 0.1 - 0.5).tanh()).collect();
let positions = text_positions(p.batch, p.seq_len);
let eps = 1e-2_f32;
let probes: &[usize] = &[0, 3, 7, 12];
let baseline = run_forward_f32(&device, &mut registry, &input_f32, &positions, &p);
for &probe_in in probes {
let mut input_perturbed = input_f32.clone();
input_perturbed[probe_in] += eps;
let perturbed = run_forward_f32(&device, &mut registry, &input_perturbed, &positions, &p);
let fd_grad = (perturbed[probe_in] - baseline[probe_in]) / eps;
let mut dy_onehot = vec![0f32; n];
dy_onehot[probe_in] = 1.0;
let dx = run_backward_f32(&device, &mut registry, &dy_onehot, &positions, &p);
let analytic_grad = dx[probe_in];
let diff = (analytic_grad - fd_grad).abs();
let scale = analytic_grad.abs().max(fd_grad.abs()).max(1.0);
assert!(
diff / scale <= 5e-2,
"finite-diff falsifier FAILED at probe_in={probe_in}: \
analytic={analytic_grad:.6}, fd={fd_grad:.6}, diff={diff:.6}, scale={scale:.6}"
);
}
}
#[test]
fn test_imrope_section_independence() {
let (device, mut registry) = device_and_registry();
let p_full = RopeTrainParams {
batch: 1,
n_heads: 1,
seq_len: 1,
head_dim: 24,
rope_dim: 24,
theta_base: 1e4,
sections: [4, 4, 4, 0],
};
let n = (p_full.head_dim) as usize; let hd = p_full.head_dim as usize;
let half_dim = hd / 2;
let mut input_axis0_only = vec![0f32; n];
let axis0_pairs: &[usize] = &[0, 3, 6, 9]; for &pair in axis0_pairs {
input_axis0_only[pair] = 1.0;
input_axis0_only[pair + half_dim] = 0.5;
}
let positions_axis0_active = [5i32, 0, 0, 0];
let positions_all_zero = [0i32, 0, 0, 0];
let got_active = run_forward_f32(
&device,
&mut registry,
&input_axis0_only,
&positions_axis0_active,
&p_full,
);
let got_zero = run_forward_f32(
&device,
&mut registry,
&input_axis0_only,
&positions_all_zero,
&p_full,
);
assert_close_atol("pos=0 identity", &got_zero, &input_axis0_only, 1e-6);
let mut found_nonzero_rotation = false;
for &pair in axis0_pairs {
let diff0 = (got_active[pair] - input_axis0_only[pair]).abs();
let diff1 = (got_active[pair + half_dim] - input_axis0_only[pair + half_dim]).abs();
if diff0 > 1e-4 || diff1 > 1e-4 {
found_nonzero_rotation = true;
}
}
assert!(
found_nonzero_rotation,
"axis-0 pairs were NOT rotated when axis-0 position=5 (expected non-zero rotation)"
);
let non_axis0_pairs: &[usize] = &[1, 2, 4, 5, 7, 8, 10, 11];
for &pair in non_axis0_pairs {
let val0 = got_active[pair].abs();
let val1 = got_active[pair + half_dim].abs();
assert!(
val0 < 1e-6,
"non-axis0 pair {pair} x0 should be 0 (input=0, pos=0): got {val0}"
);
assert!(
val1 < 1e-6,
"non-axis0 pair {pair} x1 should be 0 (input=0, pos=0): got {val1}"
);
}
}
#[test]
fn test_qwen35_production_shape_forward_parity() {
let (device, mut registry) = device_and_registry();
let p = RopeTrainParams {
batch: 1,
n_heads: 4, seq_len: 8,
head_dim: 256,
rope_dim: 64, theta_base: 1e6, sections: [11, 11, 10, 0], };
let n = (p.batch * p.n_heads * p.seq_len * p.head_dim) as usize;
let mut seed = 0xf1e2d3c4u32;
let mut rand_f32 = || -> f32 {
seed = seed.wrapping_mul(1664525).wrapping_add(1013904223);
(seed as i32 as f32) / (i32::MAX as f32) * 0.5
};
let input: Vec<f32> = (0..n).map(|_| rand_f32()).collect();
let positions = text_positions(p.batch, p.seq_len);
let got = run_forward_bf16(&device, &mut registry, &input, &positions, &p);
let want = rope_reference_f32(&input, &positions, &p);
assert_close_atol("qwen35 production shape", &got, &want, 6e-3);
let hd = p.head_dim as usize;
let half_dim = hd / 2;
let half_rope = p.rope_dim as usize / 2; let n_rows = (p.batch * p.n_heads * p.seq_len) as usize;
for row in 0..n_rows {
let base = row * hd;
for pair in half_rope..half_dim {
let d0 = (got[base + pair] - input[base + pair]).abs();
let d1 = (got[base + pair + half_dim] - input[base + pair + half_dim]).abs();
assert!(
d0 < 5e-3,
"qwen35 partial-rotary tail pair={pair} x0: got={}, input={}",
got[base + pair],
input[base + pair]
);
assert!(
d1 < 5e-3,
"qwen35 partial-rotary tail pair={pair} x1: got={}, input={}",
got[base + pair + half_dim],
input[base + pair + half_dim]
);
}
}
}
#[test]
fn test_validates_odd_head_dim() {
let (device, mut registry) = device_and_registry();
let p = RopeTrainParams {
batch: 1, n_heads: 1, seq_len: 4, head_dim: 15, rope_dim: 4,
theta_base: 1e4, sections: [1, 1, 0, 0],
};
let n = (p.batch * p.n_heads * p.seq_len * p.head_dim) as usize;
let in_buf = alloc_bf16_zeros(&device, n);
let out_buf = alloc_bf16_zeros(&device, n);
let pos_buf = upload_i32(&device, &vec![0i32; 4 * p.seq_len as usize]);
let mut enc = device.command_encoder().expect("enc");
let res = dispatch_rope_forward_bf16(
&mut enc, &mut registry, device.metal_device(), &device,
&in_buf, &pos_buf, &out_buf, &p,
);
assert!(res.is_err(), "odd head_dim must error");
}
#[test]
fn test_validates_rope_dim_gt_head_dim() {
let (device, mut registry) = device_and_registry();
let p = RopeTrainParams {
batch: 1, n_heads: 1, seq_len: 4, head_dim: 16, rope_dim: 32,
theta_base: 1e4, sections: [4, 4, 0, 0],
};
let n = (p.batch * p.n_heads * p.seq_len * p.head_dim) as usize;
let in_buf = alloc_bf16_zeros(&device, n);
let out_buf = alloc_bf16_zeros(&device, n);
let pos_buf = upload_i32(&device, &vec![0i32; 4 * p.seq_len as usize]);
let mut enc = device.command_encoder().expect("enc");
let res = dispatch_rope_forward_bf16(
&mut enc, &mut registry, device.metal_device(), &device,
&in_buf, &pos_buf, &out_buf, &p,
);
assert!(res.is_err(), "rope_dim > head_dim must error");
}