use crate::riir::variants::{ROPE_THETA, VARIANT};
unsafe extern "C" {
fn cosf(x: f32) -> f32;
fn sinf(x: f32) -> f32;
fn powf(base: f32, exp: f32) -> f32;
}
#[derive(Debug, thiserror::Error)]
pub enum RopeError {
#[error("position must be non-negative (got {pos})")]
NegativePos { pos: i32 },
#[error(
"Q buffer length {got} != num_attn_heads * head_dim ({expected})"
)]
QLen { got: usize, expected: usize },
#[error(
"K buffer length {got} != num_kv_heads * head_dim ({expected})"
)]
KLen { got: usize, expected: usize },
}
pub fn apply_rotary_emb(
pos: i32,
q: &mut [f32],
k: &mut [f32],
) -> Result<(), RopeError> {
if pos < 0 {
return Err(RopeError::NegativePos { pos });
}
let head_dim = VARIANT.head_dim;
let num_heads = VARIANT.num_attn_heads;
let num_kv_heads = VARIANT.num_kv_heads;
let rotary_dim = VARIANT.rotary_dim();
let q_expected = num_heads * head_dim;
if q.len() != q_expected {
return Err(RopeError::QLen {
got: q.len(),
expected: q_expected,
});
}
let k_expected = num_kv_heads * head_dim;
if k.len() != k_expected {
return Err(RopeError::KLen {
got: k.len(),
expected: k_expected,
});
}
let half = rotary_dim / 2;
let pos_f = pos as f32;
let rdim_f = rotary_dim as f32;
for h in 0..num_heads {
let qh = &mut q[h * head_dim..h * head_dim + head_dim];
for i in 0..half {
let freq = unsafe {
1.0f32 / powf(ROPE_THETA, (2 * i) as f32 / rdim_f)
};
let angle = pos_f * freq;
let cos_a = unsafe { cosf(angle) };
let sin_a = unsafe { sinf(angle) };
let q0 = qh[i];
let q1 = qh[i + half];
qh[i] = q0 * cos_a - q1 * sin_a;
qh[i + half] = q0 * sin_a + q1 * cos_a;
}
}
for h in 0..num_kv_heads {
let kh = &mut k[h * head_dim..h * head_dim + head_dim];
for i in 0..half {
let freq = unsafe {
1.0f32 / powf(ROPE_THETA, (2 * i) as f32 / rdim_f)
};
let angle = pos_f * freq;
let cos_a = unsafe { cosf(angle) };
let sin_a = unsafe { sinf(angle) };
let k0 = kh[i];
let k1 = kh[i + half];
kh[i] = k0 * cos_a - k1 * sin_a;
kh[i + half] = k0 * sin_a + k1 * cos_a;
}
}
Ok(())
}
pub fn yarn_find_correction_dim(
num_rotations: f32,
dim: usize,
base: f32,
max_position_embeddings: f32,
) -> f32 {
let two_pi = 2.0f32 * std::f32::consts::PI;
let ln_base = base.ln();
(dim as f32 * (max_position_embeddings / (num_rotations * two_pi)).ln())
/ (2.0 * ln_base)
}
pub fn yarn_find_correction_range(
beta_fast: f32,
beta_slow: f32,
dim: usize,
base: f32,
max_position_embeddings: f32,
) -> (f32, f32) {
let low =
yarn_find_correction_dim(beta_fast, dim, base, max_position_embeddings)
.floor();
let high =
yarn_find_correction_dim(beta_slow, dim, base, max_position_embeddings)
.ceil();
let dim_max = (dim - 1) as f32;
(low.clamp(0.0, dim_max), high.clamp(0.0, dim_max))
}
pub fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
if scale <= 1.0 {
1.0
} else {
0.1 * mscale * scale.ln() + 1.0
}
}
pub fn yarn_get_mscale_full(
scale: f32,
mscale: f32,
mscale_all_dim: f32,
) -> f32 {
yarn_get_mscale(scale, mscale) / yarn_get_mscale(scale, mscale_all_dim)
}
pub fn yarn_linear_ramp_mask(min: f32, max: f32, dim: usize) -> Vec<f32> {
let max_eff = if max == min { max + 0.001 } else { max };
let mut out = Vec::with_capacity(dim);
for i in 0..dim {
let v = (i as f32 - min) / (max_eff - min);
out.push(v.clamp(0.0, 1.0));
}
out
}
pub fn compute_yarn_inv_freq(
dim: usize,
base: f32,
factor: f32,
original_max_position: f32,
beta_fast: f32,
beta_slow: f32,
) -> Vec<f32> {
let half = dim / 2;
let mut freq_extra = Vec::with_capacity(half);
let mut freq_inter = Vec::with_capacity(half);
for i in 0..half {
let exp_i = (2 * i) as f32 / dim as f32;
let extra = 1.0 / unsafe { powf(base, exp_i) };
let inter = 1.0 / unsafe { powf(factor * base, exp_i) };
freq_extra.push(extra);
freq_inter.push(inter);
}
let (low, high) = yarn_find_correction_range(
beta_fast,
beta_slow,
dim,
base,
original_max_position,
);
let ramp = yarn_linear_ramp_mask(low, high, half);
let mut inv_freq = Vec::with_capacity(half);
for i in 0..half {
let mask_extra = 1.0 - ramp[i];
inv_freq
.push(freq_inter[i] * ramp[i] + freq_extra[i] * mask_extra);
}
inv_freq
}
#[derive(Debug, thiserror::Error)]
pub enum YarnError {
#[error("position must be non-negative (got {pos})")]
NegativePos { pos: i32 },
#[error("buffer length {got} != num_heads * rotary_dim ({expected})")]
BufLen { got: usize, expected: usize },
#[error("inv_freq length {got} != rotary_dim/2 ({expected})")]
InvFreqLen { got: usize, expected: usize },
}
pub fn apply_rotary_emb_yarn(
pos: i32,
x: &mut [f32],
rotary_dim: usize,
inv_freq: &[f32],
mscale: f32,
) -> Result<(), YarnError> {
if pos < 0 {
return Err(YarnError::NegativePos { pos });
}
let half = rotary_dim / 2;
if inv_freq.len() != half {
return Err(YarnError::InvFreqLen {
got: inv_freq.len(),
expected: half,
});
}
if x.len() % rotary_dim != 0 {
return Err(YarnError::BufLen {
got: x.len(),
expected: rotary_dim,
});
}
let num_heads = x.len() / rotary_dim;
let pos_f = pos as f32;
for h in 0..num_heads {
let xh = &mut x[h * rotary_dim..(h + 1) * rotary_dim];
for i in 0..half {
let angle = pos_f * inv_freq[i];
let cos_a = unsafe { cosf(angle) } * mscale;
let sin_a = unsafe { sinf(angle) } * mscale;
let x0 = xh[i];
let x1 = xh[i + half];
xh[i] = x0 * cos_a - x1 * sin_a;
xh[i + half] = x0 * sin_a + x1 * cos_a;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rope_at_pos_zero_is_identity_on_rotated_channels() {
let head_dim = VARIANT.head_dim;
let num_heads = VARIANT.num_attn_heads;
let num_kv_heads = VARIANT.num_kv_heads;
let rotary_dim = VARIANT.rotary_dim();
let mut q: Vec<f32> = (0..num_heads * head_dim)
.map(|i| i as f32 * 0.001)
.collect();
let mut k: Vec<f32> = (0..num_kv_heads * head_dim)
.map(|i| i as f32 * 0.001)
.collect();
let q_orig = q.clone();
let k_orig = k.clone();
apply_rotary_emb(0, &mut q, &mut k).unwrap();
for h in 0..num_heads {
for i in 0..rotary_dim {
assert_eq!(q[h * head_dim + i], q_orig[h * head_dim + i]);
}
}
for h in 0..num_kv_heads {
for i in 0..rotary_dim {
assert_eq!(k[h * head_dim + i], k_orig[h * head_dim + i]);
}
}
}
#[test]
fn yarn_mscale_is_identity_at_scale_one() {
assert_eq!(yarn_get_mscale(1.0, 1.0), 1.0);
assert_eq!(yarn_get_mscale(1.0, 5.0), 1.0);
assert_eq!(yarn_get_mscale(0.5, 5.0), 1.0);
}
#[test]
fn yarn_mscale_full_is_identity_at_scale_one() {
assert_eq!(yarn_get_mscale_full(1.0, 1.0, 1.0), 1.0);
}
#[test]
fn yarn_inv_freq_at_factor_one_collapses_to_vanilla() {
let dim = 64;
let base = 10_000.0f32;
let factor = 1.0f32;
let original_max = 4096.0f32;
let beta_fast = 32.0f32;
let beta_slow = 1.0f32;
let inv = compute_yarn_inv_freq(
dim,
base,
factor,
original_max,
beta_fast,
beta_slow,
);
assert_eq!(inv.len(), dim / 2);
for i in 0..dim / 2 {
let expected =
1.0 / unsafe { powf(base, (2 * i) as f32 / dim as f32) };
let diff = (inv[i] - expected).abs();
assert!(
diff < 1e-7 * expected.max(1e-30),
"inv[{i}] = {} vs vanilla {}",
inv[i],
expected,
);
}
}
#[test]
fn yarn_inv_freq_is_monotone_decreasing() {
let inv = compute_yarn_inv_freq(
64, 10_000.0,
40.0, 4096.0,
32.0, 1.0, );
for i in 1..inv.len() {
assert!(
inv[i] < inv[i - 1] || (inv[i] - inv[i - 1]).abs() < 1e-12,
"inv_freq not monotone at i={i}: {} vs {}",
inv[i - 1],
inv[i]
);
}
}
#[test]
fn yarn_rope_at_pos_zero_mscale_one_is_identity() {
let rotary_dim = 64;
let num_heads = 4;
let inv_freq = compute_yarn_inv_freq(
rotary_dim,
10_000.0,
40.0,
4096.0,
32.0,
1.0,
);
let mut x: Vec<f32> = (0..num_heads * rotary_dim)
.map(|i| i as f32 * 0.01)
.collect();
let x_orig = x.clone();
apply_rotary_emb_yarn(0, &mut x, rotary_dim, &inv_freq, 1.0).unwrap();
for i in 0..x.len() {
assert!(
(x[i] - x_orig[i]).abs() < 1e-6,
"x[{i}] changed: {} → {}",
x_orig[i],
x[i]
);
}
}
}