use crate::riir::variants::{ROPE_THETA, VARIANT};
unsafe extern "C" {
fn cosf(x: f32) -> f32;
fn sinf(x: f32) -> f32;
fn powf(base: f32, exp: f32) -> f32;
}
#[derive(Debug, thiserror::Error)]
pub enum RopeError {
#[error("position must be non-negative (got {pos})")]
NegativePos { pos: i32 },
#[error(
"Q buffer length {got} != num_attn_heads * head_dim ({expected})"
)]
QLen { got: usize, expected: usize },
#[error(
"K buffer length {got} != num_kv_heads * head_dim ({expected})"
)]
KLen { got: usize, expected: usize },
}
pub fn apply_rotary_emb(
pos: i32,
q: &mut [f32],
k: &mut [f32],
) -> Result<(), RopeError> {
if pos < 0 {
return Err(RopeError::NegativePos { pos });
}
let head_dim = VARIANT.head_dim;
let num_heads = VARIANT.num_attn_heads;
let num_kv_heads = VARIANT.num_kv_heads;
let rotary_dim = VARIANT.rotary_dim();
let q_expected = num_heads * head_dim;
if q.len() != q_expected {
return Err(RopeError::QLen {
got: q.len(),
expected: q_expected,
});
}
let k_expected = num_kv_heads * head_dim;
if k.len() != k_expected {
return Err(RopeError::KLen {
got: k.len(),
expected: k_expected,
});
}
let half = rotary_dim / 2;
let pos_f = pos as f32;
let rdim_f = rotary_dim as f32;
for h in 0..num_heads {
let qh = &mut q[h * head_dim..h * head_dim + head_dim];
for i in 0..half {
let freq = unsafe {
1.0f32 / powf(ROPE_THETA, (2 * i) as f32 / rdim_f)
};
let angle = pos_f * freq;
let cos_a = unsafe { cosf(angle) };
let sin_a = unsafe { sinf(angle) };
let q0 = qh[i];
let q1 = qh[i + half];
qh[i] = q0 * cos_a - q1 * sin_a;
qh[i + half] = q0 * sin_a + q1 * cos_a;
}
}
for h in 0..num_kv_heads {
let kh = &mut k[h * head_dim..h * head_dim + head_dim];
for i in 0..half {
let freq = unsafe {
1.0f32 / powf(ROPE_THETA, (2 * i) as f32 / rdim_f)
};
let angle = pos_f * freq;
let cos_a = unsafe { cosf(angle) };
let sin_a = unsafe { sinf(angle) };
let k0 = kh[i];
let k1 = kh[i + half];
kh[i] = k0 * cos_a - k1 * sin_a;
kh[i + half] = k0 * sin_a + k1 * cos_a;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rope_at_pos_zero_is_identity_on_rotated_channels() {
let head_dim = VARIANT.head_dim;
let num_heads = VARIANT.num_attn_heads;
let num_kv_heads = VARIANT.num_kv_heads;
let rotary_dim = VARIANT.rotary_dim();
let mut q: Vec<f32> = (0..num_heads * head_dim)
.map(|i| i as f32 * 0.001)
.collect();
let mut k: Vec<f32> = (0..num_kv_heads * head_dim)
.map(|i| i as f32 * 0.001)
.collect();
let q_orig = q.clone();
let k_orig = k.clone();
apply_rotary_emb(0, &mut q, &mut k).unwrap();
for h in 0..num_heads {
for i in 0..rotary_dim {
assert_eq!(q[h * head_dim + i], q_orig[h * head_dim + i]);
}
}
for h in 0..num_kv_heads {
for i in 0..rotary_dim {
assert_eq!(k[h * head_dim + i], k_orig[h * head_dim + i]);
}
}
}
}