#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
use mlx_native::ops::rope_multi::{
build_rope_multi_buffers, RopeMultiMode, RopeMultiParams,
};
use mlx_native::{DType, KernelRegistry, MlxBuffer, MlxDevice};
fn setup() -> (MlxDevice, KernelRegistry) {
let device = MlxDevice::new().expect("MlxDevice::new");
let registry = KernelRegistry::new();
(device, registry)
}
fn upload_f32(device: &MlxDevice, data: &[f32]) -> MlxBuffer {
let mut buf = device
.alloc_buffer(data.len() * 4, DType::F32, vec![data.len()])
.expect("alloc");
buf.as_mut_slice::<f32>()
.expect("mut")
.copy_from_slice(data);
buf
}
fn upload_i32(device: &MlxDevice, data: &[i32]) -> MlxBuffer {
let mut buf = device
.alloc_buffer(data.len() * 4, DType::I32, vec![data.len()])
.expect("alloc");
buf.as_mut_slice::<i32>()
.expect("mut")
.copy_from_slice(data);
buf
}
fn pick_axis_cpu(sector: u32, mode: RopeMultiMode, s: [u32; 4]) -> u32 {
match mode {
RopeMultiMode::Imrope => {
if sector % 3 == 0 && sector < 3 * s[0] {
0
} else if sector % 3 == 1 && sector < 3 * s[1] {
1
} else if sector % 3 == 2 && sector < 3 * s[2] {
2
} else {
3
}
}
RopeMultiMode::Mrope => {
if sector < s[0] {
0
} else if sector < s[0] + s[1] {
1
} else if sector < s[0] + s[1] + s[2] {
2
} else {
3
}
}
}
}
fn cpu_rope_multi(
input: &[f32],
positions: &[i32],
p: RopeMultiParams,
) -> Vec<f32> {
let n_rows = (p.seq_len * p.n_heads) as usize;
let head_dim = p.head_dim as usize;
let half_dim = head_dim / 2;
let rope_dim = p.rope_dim as usize;
let half_rope = rope_dim / 2;
let sect_dims = p.sections.iter().sum::<u32>().max(1);
let mut out = input.to_vec();
for row in 0..n_rows {
let base = row * head_dim;
let seq_idx = (row as u32) / p.n_heads;
for pair in 0..half_dim {
if (pair as usize) < half_rope {
let sector = (pair as u32) % sect_dims;
let axis = pick_axis_cpu(sector, p.mode, p.sections);
let pos = positions[(axis * p.seq_len + seq_idx) as usize];
let dim_ratio = 2.0 * (pair as f32) / (rope_dim as f32);
let freq = 1.0 / (p.freq_base.powf(dim_ratio));
let theta = (pos as f32) * freq;
let (cos_a, sin_a) = (theta.cos(), theta.sin());
let x0 = input[base + pair];
let x1 = input[base + pair + half_dim];
out[base + pair] = x0 * cos_a - x1 * sin_a;
out[base + pair + half_dim] = x0 * sin_a + x1 * cos_a;
} else {
out[base + pair] = input[base + pair];
out[base + pair + half_dim] = input[base + pair + half_dim];
}
}
}
out
}
fn run_rope_multi(
device: &MlxDevice,
registry: &mut KernelRegistry,
input_data: &[f32],
positions_data: &[i32],
p: RopeMultiParams,
) -> Vec<f32> {
let input_buf = upload_f32(device, input_data);
let output_buf = device
.alloc_buffer(input_data.len() * 4, DType::F32, vec![input_data.len()])
.expect("output");
let positions_buf = upload_i32(device, positions_data);
let (params_buf, rope_params_buf, sections_buf) =
build_rope_multi_buffers(device, p).expect("build bufs");
let mut enc = device.command_encoder().expect("enc");
mlx_native::ops::rope_multi::dispatch_rope_multi(
&mut enc,
registry,
device.metal_device(),
&input_buf,
&output_buf,
&positions_buf,
¶ms_buf,
&rope_params_buf,
§ions_buf,
p,
)
.expect("dispatch");
enc.commit_and_wait().expect("commit");
output_buf.as_slice::<f32>().expect("read").to_vec()
}
#[test]
fn test_rope_multi_imrope_spec_driven_n8_sections_2_2_1_0() {
let (device, mut registry) = setup();
let p = RopeMultiParams {
head_dim: 16,
rope_dim: 8,
n_heads: 1,
seq_len: 1,
freq_base: 2.0,
mode: RopeMultiMode::Imrope,
sections: [2, 2, 1, 0],
};
let input: Vec<f32> = (0..16).map(|i| (i as f32 + 1.0) * 0.1).collect();
let positions = [5i32, 7, 9, 0];
let got = run_rope_multi(&device, &mut registry, &input, &positions, p);
let want = cpu_rope_multi(&input, &positions, p);
for (i, (&g, &w)) in got.iter().zip(want.iter()).enumerate() {
let d = (g - w).abs();
assert!(
d < 1e-5,
"spec-driven mismatch at {}: got {}, want {}, diff {}",
i, g, w, d
);
}
let pair0 = (got[0], got[8]);
let expected0 = (
0.1 * 5.0_f32.cos() - 0.9 * 5.0_f32.sin(),
0.1 * 5.0_f32.sin() + 0.9 * 5.0_f32.cos(),
);
assert!((pair0.0 - expected0.0).abs() < 1e-5, "pair0.x mismatch");
assert!((pair0.1 - expected0.1).abs() < 1e-5, "pair0.y mismatch");
assert!((got[4] - input[4]).abs() < 1e-7, "pass-through x0 broken");
assert!((got[12] - input[12]).abs() < 1e-7, "pass-through x1 broken");
}
#[test]
fn test_rope_multi_mrope_vs_imrope_differ_for_distinct_positions() {
let (device, mut registry) = setup();
let p_imrope = RopeMultiParams {
head_dim: 16,
rope_dim: 8,
n_heads: 1,
seq_len: 1,
freq_base: 2.0,
mode: RopeMultiMode::Imrope,
sections: [2, 2, 1, 0],
};
let mut p_mrope = p_imrope;
p_mrope.mode = RopeMultiMode::Mrope;
let input: Vec<f32> = (0..16).map(|i| (i as f32 + 1.0) * 0.1).collect();
let positions = [5i32, 7, 9, 0];
let got_imrope = run_rope_multi(&device, &mut registry, &input, &positions, p_imrope);
let got_mrope = run_rope_multi(&device, &mut registry, &input, &positions, p_mrope);
let d1 = (got_imrope[1] - got_mrope[1]).abs();
let d9 = (got_imrope[9] - got_mrope[9]).abs();
assert!(
d1 > 1e-3 || d9 > 1e-3,
"MROPE and IMROPE produced identical pair 1 — axis-mapping may be wrong"
);
}
#[test]
fn test_rope_multi_random_cpu_parity() {
let (device, mut registry) = setup();
let p = RopeMultiParams {
head_dim: 32,
rope_dim: 16,
n_heads: 4,
seq_len: 6,
freq_base: 10000.0,
mode: RopeMultiMode::Imrope,
sections: [3, 3, 2, 0],
};
let n_rows = (p.seq_len * p.n_heads) as usize;
let n_elem = n_rows * (p.head_dim as usize);
let mut seed = 0xdeadbeefu32;
let mut rand = || -> f32 {
seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
(seed as i32 as f32) / (i32::MAX as f32) * 1.5
};
let input: Vec<f32> = (0..n_elem).map(|_| rand()).collect();
let positions: Vec<i32> = (0..p.seq_len as i32)
.cycle()
.take(4 * p.seq_len as usize)
.collect();
let got = run_rope_multi(&device, &mut registry, &input, &positions, p);
let want = cpu_rope_multi(&input, &positions, p);
for (i, (&g, &w)) in got.iter().zip(want.iter()).enumerate() {
let d = (g - w).abs();
assert!(
d < 1e-5,
"random parity mismatch at {}: got {}, want {}",
i, g, w
);
}
}
#[test]
fn test_rope_multi_qwen35_shape_determinism() {
let (device, mut registry) = setup();
let p = RopeMultiParams {
head_dim: 256,
rope_dim: 64,
n_heads: 4, seq_len: 3,
freq_base: 1e7,
mode: RopeMultiMode::Imrope,
sections: [11, 11, 10, 0],
};
let n_rows = (p.seq_len * p.n_heads) as usize;
let n_elem = n_rows * (p.head_dim as usize);
let mut seed = 0xfeed1234u32;
let mut rand = || -> f32 {
seed = seed.wrapping_mul(1103515245).wrapping_add(12345);
(seed as i32 as f32) / (i32::MAX as f32)
};
let input: Vec<f32> = (0..n_elem).map(|_| rand()).collect();
let positions: Vec<i32> = (0..p.seq_len as i32)
.cycle()
.take(4 * p.seq_len as usize)
.collect();
let got1 = run_rope_multi(&device, &mut registry, &input, &positions, p);
let got2 = run_rope_multi(&device, &mut registry, &input, &positions, p);
let got3 = run_rope_multi(&device, &mut registry, &input, &positions, p);
for i in 0..n_elem {
assert_eq!(
got1[i].to_bits(),
got2[i].to_bits(),
"non-deterministic at {}: {} vs {}",
i, got1[i], got2[i]
);
assert_eq!(got1[i].to_bits(), got3[i].to_bits(), "non-deterministic run3");
}
let want = cpu_rope_multi(&input, &positions, p);
for (i, (&g, &w)) in got1.iter().zip(want.iter()).enumerate() {
let d = (g - w).abs();
assert!(
d < 1e-4,
"qwen35-shape parity mismatch at {}: got {}, want {}",
i, g, w
);
}
for row in 0..n_rows {
let base = row * (p.head_dim as usize);
for pair in 32..128 {
let ix0 = base + pair;
let ix1 = base + pair + (p.head_dim as usize / 2);
assert!(
(got1[ix0] - input[ix0]).abs() < 1e-7,
"partial-rotary tail modified at row={}, pair={}",
row, pair
);
assert!(
(got1[ix1] - input[ix1]).abs() < 1e-7,
"partial-rotary tail modified at row={}, pair+half",
row
);
}
}
}
#[test]
fn test_rope_multi_imrope_text_equals_neox_rope() {
let (device, mut registry) = setup();
let p = RopeMultiParams {
head_dim: 32,
rope_dim: 16,
n_heads: 2,
seq_len: 4,
freq_base: 1e4,
mode: RopeMultiMode::Imrope,
sections: [3, 3, 2, 0],
};
let n_rows = (p.seq_len * p.n_heads) as usize;
let n_elem = n_rows * (p.head_dim as usize);
let input: Vec<f32> = (0..n_elem).map(|i| (i as f32) * 0.01).collect();
let positions: Vec<i32> = (0..p.seq_len as i32)
.cycle()
.take(4 * p.seq_len as usize)
.collect();
let got = run_rope_multi(&device, &mut registry, &input, &positions, p);
let mut want = input.clone();
let half_dim = (p.head_dim / 2) as usize;
let half_rope = (p.rope_dim / 2) as usize;
for row in 0..n_rows {
let base = row * p.head_dim as usize;
let seq_idx = row as u32 / p.n_heads;
let pos = seq_idx as f32;
for pair in 0..half_rope {
let dim_ratio = 2.0 * pair as f32 / p.rope_dim as f32;
let freq = 1.0 / p.freq_base.powf(dim_ratio);
let theta = pos * freq;
let (ca, sa) = (theta.cos(), theta.sin());
let x0 = input[base + pair];
let x1 = input[base + pair + half_dim];
want[base + pair] = x0 * ca - x1 * sa;
want[base + pair + half_dim] = x0 * sa + x1 * ca;
}
}
for (i, (&g, &w)) in got.iter().zip(want.iter()).enumerate() {
let d = (g - w).abs();
assert!(
d < 1e-5,
"text IMROPE != NeoX at {}: got {}, want {}",
i, g, w
);
}
}
#[test]
fn test_rope_multi_bf16_matches_f32_within_tolerance() {
use half::bf16;
let (device, mut registry) = setup();
let p = RopeMultiParams {
head_dim: 16,
rope_dim: 8,
n_heads: 2,
seq_len: 3,
freq_base: 1e4,
mode: RopeMultiMode::Imrope,
sections: [2, 2, 1, 0],
};
let n_rows = (p.seq_len * p.n_heads) as usize;
let n_elem = n_rows * (p.head_dim as usize);
let input_f32: Vec<f32> = (0..n_elem).map(|i| (i as f32) * 0.05 - 1.0).collect();
let positions: Vec<i32> = (0..p.seq_len as i32)
.cycle()
.take(4 * p.seq_len as usize)
.collect();
let f32_out = run_rope_multi(&device, &mut registry, &input_f32, &positions, p);
let input_bf: Vec<bf16> = input_f32.iter().map(|&v| bf16::from_f32(v)).collect();
let mut in_buf = device
.alloc_buffer(n_elem * 2, DType::BF16, vec![n_elem])
.expect("input bf16");
in_buf.as_mut_slice::<bf16>().expect("mut").copy_from_slice(&input_bf);
let out_buf = device
.alloc_buffer(n_elem * 2, DType::BF16, vec![n_elem])
.expect("output bf16");
let positions_buf = upload_i32(&device, &positions);
let (params_buf, rope_params_buf, sections_buf) =
build_rope_multi_buffers(&device, p).expect("bufs");
let mut enc = device.command_encoder().expect("enc");
mlx_native::ops::rope_multi::dispatch_rope_multi(
&mut enc,
&mut registry,
device.metal_device(),
&in_buf,
&out_buf,
&positions_buf,
¶ms_buf,
&rope_params_buf,
§ions_buf,
p,
)
.expect("dispatch bf16");
enc.commit_and_wait().expect("commit");
let bf_out: Vec<bf16> = out_buf.as_slice::<bf16>().expect("read").to_vec();
for (i, (bf, f)) in bf_out.iter().zip(f32_out.iter()).enumerate() {
let diff = (bf.to_f32() - f).abs();
assert!(
diff < 5e-2,
"bf16 drift at {}: bf={}, f32={}, diff={}",
i, bf.to_f32(), f, diff
);
}
}
#[test]
fn test_rope_multi_rejects_odd_head_dim() {
let (device, mut registry) = setup();
let p = RopeMultiParams {
head_dim: 15,
rope_dim: 4,
n_heads: 1,
seq_len: 1,
freq_base: 1e4,
mode: RopeMultiMode::Imrope,
sections: [1, 1, 0, 0],
};
let dummy = device.alloc_buffer(4, DType::F32, vec![1]).expect("d");
let pos = device.alloc_buffer(16, DType::I32, vec![4]).expect("p");
let (params, rope_params, sections) = build_rope_multi_buffers(&device, p).expect("b");
let mut enc = device.command_encoder().expect("enc");
let res = mlx_native::ops::rope_multi::dispatch_rope_multi(
&mut enc,
&mut registry,
device.metal_device(),
&dummy,
&dummy,
&pos,
¶ms,
&rope_params,
§ions,
p,
);
assert!(res.is_err(), "odd head_dim should error");
}
#[test]
fn test_rope_multi_rejects_rope_dim_gt_head_dim() {
let (device, mut registry) = setup();
let p = RopeMultiParams {
head_dim: 8,
rope_dim: 16,
n_heads: 1,
seq_len: 1,
freq_base: 1e4,
mode: RopeMultiMode::Imrope,
sections: [1, 1, 0, 0],
};
let dummy = device.alloc_buffer(4, DType::F32, vec![1]).expect("d");
let pos = device.alloc_buffer(16, DType::I32, vec![4]).expect("p");
let (params, rope_params, sections) = build_rope_multi_buffers(&device, p).expect("b");
let mut enc = device.command_encoder().expect("enc");
let res = mlx_native::ops::rope_multi::dispatch_rope_multi(
&mut enc,
&mut registry,
device.metal_device(),
&dummy,
&dummy,
&pos,
¶ms,
&rope_params,
§ions,
p,
);
assert!(res.is_err(), "rope_dim > head_dim should error");
}