use metal::{
Buffer, CommandBufferRef, ComputePipelineState, MTLResourceOptions,
MTLSize, NSUInteger,
};
use crate::riir::backend::gpu::encoder::ComputeEncoder;
use crate::riir::backend::gpu::metal::{buffer_as_slice, MetalContext, MetalError};
#[derive(Debug, thiserror::Error)]
pub enum GpuRopeError {
#[error("position must be non-negative (got {pos})")]
NegativePos { pos: i32 },
#[error("buffer length {got} != num_heads * rotary_dim ({expected})")]
BufLen { got: usize, expected: usize },
#[error("inv_freq length {got} != rotary_dim/2 ({expected})")]
InvFreqLen { got: usize, expected: usize },
#[error("Metal backend: {0}")]
Metal(#[from] MetalError),
}
#[allow(clippy::too_many_arguments)]
pub fn encode_yarn_rope_apply(
cmdbuf: &CommandBufferRef,
pipe: &ComputePipelineState,
x_buf: &Buffer,
inv_freq_buf: &Buffer,
num_heads: u32,
rotary_dim: u32,
pos: i32,
mscale: f32,
) -> Result<(), GpuRopeError> {
if pos < 0 {
return Err(GpuRopeError::NegativePos { pos });
}
let half = rotary_dim / 2;
let pos_f = pos as f32;
ComputeEncoder::begin(cmdbuf)
.pipeline(pipe)
.buffer(0, x_buf, 0)
.buffer(1, inv_freq_buf, 0)
.bytes(2, &num_heads)
.bytes(3, &rotary_dim)
.bytes(4, &pos_f)
.bytes(5, &mscale)
.dispatch(
MTLSize::new(num_heads as NSUInteger, half as NSUInteger, 1),
MTLSize::new(1, 1, 1),
);
Ok(())
}
pub fn yarn_rope_apply_oneshot(
metal: &mut MetalContext,
x: &mut [f32],
rotary_dim: usize,
inv_freq: &[f32],
pos: i32,
mscale: f32,
) -> Result<(), GpuRopeError> {
if pos < 0 {
return Err(GpuRopeError::NegativePos { pos });
}
let half = rotary_dim / 2;
if inv_freq.len() != half {
return Err(GpuRopeError::InvFreqLen {
got: inv_freq.len(),
expected: half,
});
}
if x.len() % rotary_dim != 0 {
return Err(GpuRopeError::BufLen {
got: x.len(),
expected: rotary_dim,
});
}
let num_heads = (x.len() / rotary_dim) as u32;
let pipe = metal.pipeline("yarn_rope_apply")?.clone();
let device = metal.device();
let buf_x = device.new_buffer_with_data(
x.as_ptr().cast(),
(x.len() * std::mem::size_of::<f32>()) as NSUInteger,
MTLResourceOptions::StorageModeShared,
);
let buf_inv = device.new_buffer_with_data(
inv_freq.as_ptr().cast(),
(inv_freq.len() * std::mem::size_of::<f32>()) as NSUInteger,
MTLResourceOptions::StorageModeShared,
);
let cmdbuf = metal.queue().new_command_buffer();
encode_yarn_rope_apply(
cmdbuf,
&pipe,
&buf_x,
&buf_inv,
num_heads,
rotary_dim as u32,
pos,
mscale,
)?;
cmdbuf.commit();
cmdbuf.wait_until_completed();
unsafe {
let s = buffer_as_slice::<f32>(&buf_x, x.len());
x.copy_from_slice(s);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::riir::attn::rope::{apply_rotary_emb_yarn, compute_yarn_inv_freq};
#[test]
fn yarn_rope_gpu_pos_zero_mscale_one_is_identity() {
let mut metal = match MetalContext::new() {
Ok(m) => m,
Err(e) => {
eprintln!("[gpu_rope] skipping: Metal init failed: {e:?}");
return;
}
};
let rotary_dim: usize = 64;
let half = rotary_dim / 2;
let inv_freq: Vec<f32> = (0..half)
.map(|i| 1.0 / 10000f32.powf(2.0 * (i as f32) / rotary_dim as f32))
.collect();
let num_heads = 4;
let mut x: Vec<f32> = (0..num_heads * rotary_dim)
.map(|i| (i as f32) * 0.001)
.collect();
let x_orig = x.clone();
yarn_rope_apply_oneshot(
&mut metal, &mut x, rotary_dim, &inv_freq, 0, 1.0,
)
.unwrap();
for i in 0..x.len() {
assert!(
(x[i] - x_orig[i]).abs() < 1e-7,
"x[{i}] = {} but expected identity {}",
x[i],
x_orig[i],
);
}
}
#[test]
fn yarn_rope_gpu_matches_cpu_at_pos_4096() {
let mut metal = match MetalContext::new() {
Ok(m) => m,
Err(e) => {
eprintln!("[gpu_rope] skipping: Metal init failed: {e:?}");
return;
}
};
let rotary_dim: usize = 64;
let half = rotary_dim / 2;
let inv_freq = compute_yarn_inv_freq(
rotary_dim, 1.0e4, 40.0,
4096.0, 32.0,
1.0,
);
assert_eq!(inv_freq.len(), half);
let num_heads = 8;
let pos = 4096;
let mscale: f32 = 1.0;
let mut x_gpu: Vec<f32> = (0..num_heads * rotary_dim)
.map(|i| ((i as f32) * 0.0137).sin())
.collect();
let mut x_cpu = x_gpu.clone();
apply_rotary_emb_yarn(pos, &mut x_cpu, rotary_dim, &inv_freq, mscale)
.unwrap();
yarn_rope_apply_oneshot(
&mut metal,
&mut x_gpu,
rotary_dim,
&inv_freq,
pos,
mscale,
)
.unwrap();
let max_drift =
x_gpu.iter().zip(&x_cpu).map(|(g, c)| (g - c).abs()).fold(
0.0f32,
f32::max,
);
assert!(
max_drift < 5e-6,
"GPU/CPU drift {max_drift} exceeds 4 ULP tolerance"
);
}
}