use std::fmt;
#[derive(Debug)]
pub enum RopeError {
InvalidHeadDim { dim: usize, reason: &'static str },
InvalidScalingFactor(f64),
SequenceLengthExceeded { seq_len: usize, max: usize },
DimensionMismatch { expected: usize, got: usize },
}
impl fmt::Display for RopeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RopeError::InvalidHeadDim { dim, reason } => {
write!(f, "Invalid head dimension {dim}: {reason}")
}
RopeError::InvalidScalingFactor(factor) => {
write!(f, "Invalid scaling factor {factor}: must be >= 1.0")
}
RopeError::SequenceLengthExceeded { seq_len, max } => {
write!(
f,
"Sequence length {seq_len} exceeds precomputed maximum {max}"
)
}
RopeError::DimensionMismatch { expected, got } => {
write!(f, "Dimension mismatch: expected {expected}, got {got}")
}
}
}
}
impl std::error::Error for RopeError {}
#[derive(Debug, Clone, PartialEq)]
pub enum RopeScalingType {
None,
Linear { factor: f64 },
Ntk { factor: f64 },
DynamicNtk {
factor: f64,
original_max_position: usize,
},
Yarn {
factor: f64,
original_max_position: usize,
beta_fast: f64,
beta_slow: f64,
},
LongRope {
short_factors: Vec<f64>,
long_factors: Vec<f64>,
original_max_position: usize,
},
}
impl fmt::Display for RopeScalingType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RopeScalingType::None => write!(f, "None"),
RopeScalingType::Linear { factor } => write!(f, "Linear(factor={factor})"),
RopeScalingType::Ntk { factor } => write!(f, "NTK(factor={factor})"),
RopeScalingType::DynamicNtk {
factor,
original_max_position,
} => write!(
f,
"DynamicNTK(factor={factor}, orig_max={original_max_position})"
),
RopeScalingType::Yarn {
factor,
original_max_position,
beta_fast,
beta_slow,
} => write!(
f,
"YaRN(factor={factor}, orig_max={original_max_position}, \
beta_fast={beta_fast}, beta_slow={beta_slow})"
),
RopeScalingType::LongRope {
short_factors,
long_factors,
original_max_position,
} => write!(
f,
"LongRoPE(short_factors=[{} dims], long_factors=[{} dims], \
orig_max={original_max_position})",
short_factors.len(),
long_factors.len()
),
}
}
}
#[derive(Debug, Clone)]
pub struct RopeConfig {
pub head_dim: usize,
pub base_theta: f64,
pub scaling: RopeScalingType,
pub max_position_embeddings: usize,
}
impl RopeConfig {
pub fn standard(head_dim: usize) -> Self {
Self {
head_dim,
base_theta: 10_000.0,
scaling: RopeScalingType::None,
max_position_embeddings: 4096,
}
}
pub fn with_linear_scaling(head_dim: usize, factor: f64) -> Self {
Self {
head_dim,
base_theta: 10_000.0,
scaling: RopeScalingType::Linear { factor },
max_position_embeddings: 4096,
}
}
pub fn with_ntk(head_dim: usize, factor: f64) -> Self {
Self {
head_dim,
base_theta: 10_000.0,
scaling: RopeScalingType::Ntk { factor },
max_position_embeddings: 4096,
}
}
pub fn with_yarn(head_dim: usize, factor: f64, max_pos: usize) -> Self {
Self {
head_dim,
base_theta: 10_000.0,
scaling: RopeScalingType::Yarn {
factor,
original_max_position: max_pos,
beta_fast: 32.0,
beta_slow: 1.0,
},
max_position_embeddings: max_pos,
}
}
}
#[derive(Debug, Clone)]
pub struct RopeFreqStats {
pub min_freq: f32,
pub max_freq: f32,
pub mean_freq: f32,
pub num_low_freq_dims: usize,
pub num_high_freq_dims: usize,
}
pub struct RopeFrequencies {
pub cos: Vec<f32>,
pub sin: Vec<f32>,
pub max_seq_len: usize,
pub head_dim: usize,
pub config: RopeConfig,
}
impl RopeFrequencies {
pub fn compute(config: RopeConfig, max_seq_len: usize) -> Result<Self, RopeError> {
if config.head_dim == 0 {
return Err(RopeError::InvalidHeadDim {
dim: config.head_dim,
reason: "head_dim must be > 0",
});
}
if config.head_dim % 2 != 0 {
return Err(RopeError::InvalidHeadDim {
dim: config.head_dim,
reason: "head_dim must be even",
});
}
let half_dim = config.head_dim / 2;
match &config.scaling {
RopeScalingType::Linear { factor }
| RopeScalingType::Ntk { factor }
| RopeScalingType::DynamicNtk { factor, .. }
| RopeScalingType::Yarn { factor, .. } => {
if *factor < 1.0 {
return Err(RopeError::InvalidScalingFactor(*factor));
}
}
_ => {}
}
let freqs: Vec<f64> = match &config.scaling {
RopeScalingType::None => standard_freqs(config.base_theta, half_dim, config.head_dim),
RopeScalingType::Linear { factor } => {
let base_freqs = standard_freqs(config.base_theta, half_dim, config.head_dim);
base_freqs.into_iter().map(|f| f / factor).collect()
}
RopeScalingType::Ntk { factor } => {
let exp = config.head_dim as f64 / (config.head_dim as f64 - 2.0);
let modified_base = config.base_theta * factor.powf(exp);
standard_freqs(modified_base, half_dim, config.head_dim)
}
RopeScalingType::DynamicNtk {
factor,
original_max_position,
} => {
let effective_factor = if max_seq_len <= *original_max_position {
1.0_f64
} else {
(factor * max_seq_len as f64 / *original_max_position as f64)
- (factor - 1.0)
};
let effective_factor = effective_factor.max(1.0);
let exp = config.head_dim as f64 / (config.head_dim as f64 - 2.0);
let modified_base = config.base_theta * effective_factor.powf(exp);
standard_freqs(modified_base, half_dim, config.head_dim)
}
RopeScalingType::Yarn {
factor,
original_max_position,
beta_fast,
beta_slow,
} => {
let base_freqs = standard_freqs(config.base_theta, half_dim, config.head_dim);
let linear_freqs: Vec<f64> = base_freqs.iter().map(|f| f / factor).collect();
let two_pi = 2.0 * std::f64::consts::PI;
let low_thresh = two_pi * factor / beta_slow;
let high_thresh = two_pi / beta_fast;
base_freqs
.iter()
.zip(linear_freqs.iter())
.enumerate()
.map(|(i, (&orig_f, &lin_f))| {
let wavelength = two_pi / orig_f; let _ = i; let _ = original_max_position;
if wavelength < high_thresh {
orig_f
} else if wavelength > low_thresh {
lin_f
} else {
let ramp = yarn_ramp(wavelength, high_thresh, low_thresh);
orig_f * (1.0 - ramp) + lin_f * ramp
}
})
.collect()
}
RopeScalingType::LongRope {
short_factors,
long_factors,
original_max_position,
} => {
let factors = if max_seq_len > *original_max_position {
long_factors
} else {
short_factors
};
let base_freqs = standard_freqs(config.base_theta, half_dim, config.head_dim);
base_freqs
.iter()
.enumerate()
.map(|(i, &base_f)| {
let scale = factors.get(i).copied().unwrap_or(1.0);
base_f / scale
})
.collect()
}
};
let capacity = max_seq_len * half_dim;
let mut cos_table = Vec::with_capacity(capacity);
let mut sin_table = Vec::with_capacity(capacity);
for pos in 0..max_seq_len {
let pos_f = pos as f64;
for i in 0..half_dim {
let angle = pos_f * freqs[i];
cos_table.push(angle.cos() as f32);
sin_table.push(angle.sin() as f32);
}
}
Ok(Self {
cos: cos_table,
sin: sin_table,
max_seq_len,
head_dim: config.head_dim,
config,
})
}
pub fn apply_rope(
&self,
q: &[f32],
seq_len: usize,
num_heads: usize,
) -> Result<Vec<f32>, RopeError> {
if seq_len > self.max_seq_len {
return Err(RopeError::SequenceLengthExceeded {
seq_len,
max: self.max_seq_len,
});
}
let head_dim = self.head_dim;
let expected_len = seq_len * num_heads * head_dim;
if q.len() != expected_len {
return Err(RopeError::DimensionMismatch {
expected: expected_len,
got: q.len(),
});
}
let half_dim = head_dim / 2;
let mut output = vec![0.0_f32; expected_len];
for pos in 0..seq_len {
let cos_row = pos * half_dim;
for head in 0..num_heads {
let base_idx = pos * num_heads * head_dim + head * head_dim;
for i in 0..half_dim {
let cos_val = self.cos[cos_row + i];
let sin_val = self.sin[cos_row + i];
let x0 = q[base_idx + 2 * i];
let x1 = q[base_idx + 2 * i + 1];
output[base_idx + 2 * i] = x0 * cos_val - x1 * sin_val;
output[base_idx + 2 * i + 1] = x0 * sin_val + x1 * cos_val;
}
}
}
Ok(output)
}
pub fn apply_rope_qk(
&self,
q: &[f32],
k: &[f32],
q_seq_len: usize,
k_seq_len: usize,
num_heads: usize,
num_kv_heads: usize,
) -> Result<(Vec<f32>, Vec<f32>), RopeError> {
let rotated_q = self.apply_rope(q, q_seq_len, num_heads)?;
let rotated_k = self.apply_rope(k, k_seq_len, num_kv_heads)?;
Ok((rotated_q, rotated_k))
}
pub fn frequency_stats(&self) -> RopeFreqStats {
if self.max_seq_len == 0 {
return RopeFreqStats {
min_freq: 0.0,
max_freq: 0.0,
mean_freq: 0.0,
num_low_freq_dims: 0,
num_high_freq_dims: 0,
};
}
let half_dim = self.head_dim / 2;
let freqs: Vec<f32> = if self.max_seq_len > 1 {
(0..half_dim)
.map(|i| {
let c = self.cos[half_dim + i]; let s = self.sin[half_dim + i];
s.atan2(c).abs()
})
.collect()
} else {
vec![0.0_f32; half_dim]
};
let min_freq = freqs.iter().cloned().fold(f32::INFINITY, f32::min);
let max_freq = freqs.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mean_freq = freqs.iter().sum::<f32>() / half_dim as f32;
let num_low_freq_dims = freqs.iter().filter(|&&f| f < 0.01).count();
let num_high_freq_dims = freqs.iter().filter(|&&f| f > 1.0).count();
RopeFreqStats {
min_freq,
max_freq,
mean_freq,
num_low_freq_dims,
num_high_freq_dims,
}
}
}
fn standard_freqs(base: f64, half_dim: usize, head_dim: usize) -> Vec<f64> {
(0..half_dim)
.map(|i| {
let exponent = -2.0 * i as f64 / head_dim as f64;
base.powf(exponent)
})
.collect()
}
fn yarn_ramp(wavelength: f64, low: f64, high: f64) -> f64 {
((wavelength - low) / (high - low)).clamp(0.0, 1.0)
}
#[derive(Debug, Clone)]
pub struct YarnConfig {
pub original_max_position_embeddings: usize,
pub scaling_factor: f32,
pub beta_fast: f32,
pub beta_slow: f32,
pub mscale: f32,
pub mscale_all_dim: f32,
pub base: f32,
pub head_dim: usize,
}
pub fn yarn_find_correction_dim(
num_rotations: f32,
dim: usize,
base: f32,
max_position_embeddings: usize,
) -> f32 {
let two_pi = 2.0 * std::f32::consts::PI;
let numerator = (max_position_embeddings as f32 / (num_rotations * two_pi)).ln();
let denominator = 2.0 * base.ln();
dim as f32 * numerator / denominator
}
pub fn yarn_linear_ramp_mask(min: f32, max: f32, dim: usize) -> Vec<f32> {
if dim == 0 {
return Vec::new();
}
if (max - min).abs() < f32::EPSILON {
return vec![0.0; dim];
}
(0..dim)
.map(|i| ((i as f32 - min) / (max - min)).clamp(0.0, 1.0))
.collect()
}
pub fn yarn_get_mscale(scale: f32, mscale: f32) -> f32 {
if scale <= 0.0 {
return mscale;
}
(0.1 * scale.ln() + 1.0).sqrt() * mscale
}
pub fn apply_yarn_rope(
query: &[f32],
key: &[f32],
positions: &[usize],
num_heads: usize,
num_kv_heads: usize,
config: &YarnConfig,
) -> Result<(Vec<f32>, Vec<f32>), RopeError> {
let head_dim = config.head_dim;
if head_dim == 0 {
return Err(RopeError::InvalidHeadDim {
dim: head_dim,
reason: "head_dim must be > 0",
});
}
if head_dim % 2 != 0 {
return Err(RopeError::InvalidHeadDim {
dim: head_dim,
reason: "head_dim must be even",
});
}
let seq_len = positions.len();
let half_dim = head_dim / 2;
let expected_q = seq_len * num_heads * head_dim;
let expected_k = seq_len * num_kv_heads * head_dim;
if query.len() != expected_q {
return Err(RopeError::DimensionMismatch {
expected: expected_q,
got: query.len(),
});
}
if key.len() != expected_k {
return Err(RopeError::DimensionMismatch {
expected: expected_k,
got: key.len(),
});
}
let base = config.base;
let scale = config.scaling_factor;
let orig_max = config.original_max_position_embeddings;
let low_corr_dim = yarn_find_correction_dim(config.beta_slow, half_dim * 2, base, orig_max);
let high_corr_dim = yarn_find_correction_dim(config.beta_fast, half_dim * 2, base, orig_max);
let ramp_mask = yarn_linear_ramp_mask(low_corr_dim, high_corr_dim, half_dim);
let inv_freqs: Vec<f32> = (0..half_dim)
.map(|i| base.powf(-2.0 * i as f32 / head_dim as f32))
.collect();
let scaled_inv_freqs: Vec<f32> = inv_freqs.iter().map(|&f| f / scale).collect();
let blended_freqs: Vec<f32> = (0..half_dim)
.map(|i| {
let ramp = ramp_mask[i];
inv_freqs[i] * (1.0 - ramp) + scaled_inv_freqs[i] * ramp
})
.collect();
let mscale = yarn_get_mscale(scale, config.mscale);
let mscale_all = if config.mscale_all_dim == 0.0 {
1.0
} else {
yarn_get_mscale(scale, config.mscale_all_dim)
};
let combined_scale = mscale * mscale_all;
let mut cos_table = vec![0.0_f32; seq_len * half_dim];
let mut sin_table = vec![0.0_f32; seq_len * half_dim];
for (si, &pos) in positions.iter().enumerate() {
for i in 0..half_dim {
let angle = pos as f32 * blended_freqs[i];
cos_table[si * half_dim + i] = angle.cos() * combined_scale;
sin_table[si * half_dim + i] = angle.sin() * combined_scale;
}
}
let mut out_q = vec![0.0_f32; expected_q];
for si in 0..seq_len {
for h in 0..num_heads {
let base_idx = si * num_heads * head_dim + h * head_dim;
for i in 0..half_dim {
let cos_v = cos_table[si * half_dim + i];
let sin_v = sin_table[si * half_dim + i];
let x0 = query[base_idx + 2 * i];
let x1 = query[base_idx + 2 * i + 1];
out_q[base_idx + 2 * i] = x0 * cos_v - x1 * sin_v;
out_q[base_idx + 2 * i + 1] = x0 * sin_v + x1 * cos_v;
}
}
}
let mut out_k = vec![0.0_f32; expected_k];
for si in 0..seq_len {
for h in 0..num_kv_heads {
let base_idx = si * num_kv_heads * head_dim + h * head_dim;
for i in 0..half_dim {
let cos_v = cos_table[si * half_dim + i];
let sin_v = sin_table[si * half_dim + i];
let x0 = key[base_idx + 2 * i];
let x1 = key[base_idx + 2 * i + 1];
out_k[base_idx + 2 * i] = x0 * cos_v - x1 * sin_v;
out_k[base_idx + 2 * i + 1] = x0 * sin_v + x1 * cos_v;
}
}
}
Ok((out_q, out_k))
}
#[derive(Debug, Clone)]
pub struct DynamicNtkConfig {
pub base_theta: f32,
pub alpha: f32,
pub max_original_length: usize,
pub head_dim: usize,
}
pub fn dynamic_ntk_theta(seq_len: usize, config: &DynamicNtkConfig) -> f32 {
if seq_len <= config.max_original_length || config.head_dim < 3 {
return config.base_theta;
}
let ratio = config.alpha * seq_len as f32 / config.max_original_length as f32
- config.alpha
+ 1.0;
let exp = config.head_dim as f32 / (config.head_dim as f32 - 2.0);
config.base_theta * ratio.max(1.0).powf(exp)
}
pub fn apply_dynamic_ntk_rope(
query: &[f32],
key: &[f32],
seq_len: usize,
num_heads: usize,
num_kv_heads: usize,
config: &DynamicNtkConfig,
) -> Result<(Vec<f32>, Vec<f32>), RopeError> {
let head_dim = config.head_dim;
if head_dim == 0 {
return Err(RopeError::InvalidHeadDim {
dim: head_dim,
reason: "head_dim must be > 0",
});
}
if head_dim % 2 != 0 {
return Err(RopeError::InvalidHeadDim {
dim: head_dim,
reason: "head_dim must be even",
});
}
let half_dim = head_dim / 2;
let expected_q = seq_len * num_heads * head_dim;
let expected_k = seq_len * num_kv_heads * head_dim;
if query.len() != expected_q {
return Err(RopeError::DimensionMismatch {
expected: expected_q,
got: query.len(),
});
}
if key.len() != expected_k {
return Err(RopeError::DimensionMismatch {
expected: expected_k,
got: key.len(),
});
}
let theta = dynamic_ntk_theta(seq_len, config);
let inv_freqs: Vec<f32> = (0..half_dim)
.map(|i| theta.powf(-2.0 * i as f32 / head_dim as f32))
.collect();
let mut cos_table = vec![0.0_f32; seq_len * half_dim];
let mut sin_table = vec![0.0_f32; seq_len * half_dim];
for pos in 0..seq_len {
for i in 0..half_dim {
let angle = pos as f32 * inv_freqs[i];
cos_table[pos * half_dim + i] = angle.cos();
sin_table[pos * half_dim + i] = angle.sin();
}
}
let rotate = |tensor: &[f32], n_heads: usize, expected: usize| -> Vec<f32> {
let mut out = vec![0.0_f32; expected];
for pos in 0..seq_len {
for h in 0..n_heads {
let base_idx = pos * n_heads * head_dim + h * head_dim;
for i in 0..half_dim {
let cos_v = cos_table[pos * half_dim + i];
let sin_v = sin_table[pos * half_dim + i];
let x0 = tensor[base_idx + 2 * i];
let x1 = tensor[base_idx + 2 * i + 1];
out[base_idx + 2 * i] = x0 * cos_v - x1 * sin_v;
out[base_idx + 2 * i + 1] = x0 * sin_v + x1 * cos_v;
}
}
}
out
};
let out_q = rotate(query, num_heads, expected_q);
let out_k = rotate(key, num_kv_heads, expected_k);
Ok((out_q, out_k))
}
#[derive(Debug, Clone)]
pub struct LongRopeScaling {
pub short_factor: Vec<f32>,
pub long_factor: Vec<f32>,
pub threshold: usize,
pub short_mscale: f32,
pub long_mscale: f32,
}
pub fn apply_longrope_scaling(
inv_freq: &[f32],
seq_len: usize,
config: &LongRopeScaling,
) -> Vec<f32> {
let factors = if seq_len <= config.threshold {
&config.short_factor
} else {
&config.long_factor
};
inv_freq
.iter()
.enumerate()
.map(|(i, &f)| {
let scale = factors.get(i).copied().unwrap_or(1.0);
if scale == 0.0 {
f
} else {
f / scale
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rope_config_standard() {
let cfg = RopeConfig::standard(64);
assert_eq!(cfg.head_dim, 64);
assert_eq!(cfg.base_theta, 10_000.0);
assert_eq!(cfg.scaling, RopeScalingType::None);
assert_eq!(cfg.max_position_embeddings, 4096);
}
#[test]
fn test_rope_config_linear() {
let cfg = RopeConfig::with_linear_scaling(128, 2.0);
assert_eq!(cfg.head_dim, 128);
assert!(matches!(cfg.scaling, RopeScalingType::Linear { factor } if factor == 2.0));
}
#[test]
fn test_rope_config_ntk() {
let cfg = RopeConfig::with_ntk(64, 4.0);
assert_eq!(cfg.head_dim, 64);
assert!(matches!(cfg.scaling, RopeScalingType::Ntk { factor } if factor == 4.0));
}
#[test]
fn test_rope_freq_standard_first_dim() {
let cfg = RopeConfig::standard(64);
let freqs = RopeFrequencies::compute(cfg, 16).expect("compute failed");
let cos_val = freqs.cos[32]; assert!(
(cos_val - 1.0_f32.cos()).abs() < 1e-5,
"Expected cos(1.0) ≈ {}, got {cos_val}",
1.0_f32.cos()
);
}
#[test]
fn test_rope_freq_standard_last_dim() {
let head_dim = 64usize;
let half_dim = head_dim / 2;
let cfg = RopeConfig::standard(head_dim);
let freqs = RopeFrequencies::compute(cfg, 16).expect("compute failed");
let last_freq_idx = half_dim - 1;
let angle = freqs.sin[half_dim + last_freq_idx].atan2(freqs.cos[half_dim + last_freq_idx]);
assert!(
angle.abs() < 0.01,
"Expected small angle for last dim, got {angle}"
);
}
#[test]
fn test_rope_freq_linear_scaling() {
let head_dim = 64;
let factor = 2.0;
let cfg_none = RopeConfig::standard(head_dim);
let cfg_lin = RopeConfig::with_linear_scaling(head_dim, factor);
let freqs_none = RopeFrequencies::compute(cfg_none, 16).expect("compute failed");
let freqs_lin = RopeFrequencies::compute(cfg_lin, 16).expect("compute failed");
let half_dim = head_dim / 2;
let cos_none = freqs_none.cos[half_dim]; let cos_lin = freqs_lin.cos[half_dim];
let angle_none = freqs_none.sin[half_dim].atan2(cos_none) as f64;
let angle_lin = freqs_lin.sin[half_dim].atan2(cos_lin) as f64;
let ratio = angle_none / angle_lin;
assert!(
(ratio - factor).abs() < 0.01,
"Expected frequency ratio {factor}, got {ratio}"
);
}
#[test]
fn test_rope_freq_ntk_vs_linear() {
let head_dim = 64;
let factor = 4.0;
let cfg_ntk = RopeConfig::with_ntk(head_dim, factor);
let cfg_lin = RopeConfig::with_linear_scaling(head_dim, factor);
let freqs_ntk = RopeFrequencies::compute(cfg_ntk, 8).expect("compute failed");
let freqs_lin = RopeFrequencies::compute(cfg_lin, 8).expect("compute failed");
let half_dim = head_dim / 2;
let angle_ntk = freqs_ntk.sin[half_dim].atan2(freqs_ntk.cos[half_dim]);
let angle_lin = freqs_lin.sin[half_dim].atan2(freqs_lin.cos[half_dim]);
assert!(
(angle_ntk - angle_lin).abs() > 1e-4,
"NTK and linear should differ: ntk={angle_ntk}, lin={angle_lin}"
);
}
#[test]
fn test_rope_freq_yarn_high_freq_unchanged() {
let head_dim = 64;
let cfg_none = RopeConfig::standard(head_dim);
let cfg_yarn = RopeConfig::with_yarn(head_dim, 2.0, 2048);
let freqs_none = RopeFrequencies::compute(cfg_none, 8).expect("compute failed");
let freqs_yarn = RopeFrequencies::compute(cfg_yarn, 8).expect("compute failed");
let half_dim = head_dim / 2;
let mut differs = false;
for i in 0..half_dim {
let angle_none = freqs_none.sin[half_dim + i].atan2(freqs_none.cos[half_dim + i]);
let angle_yarn = freqs_yarn.sin[half_dim + i].atan2(freqs_yarn.cos[half_dim + i]);
if (angle_none - angle_yarn).abs() > 1e-4 {
differs = true;
break;
}
}
assert!(differs, "YaRN should produce at least some different frequencies");
}
#[test]
fn test_rope_freq_dynamic_ntk() {
let head_dim = 64;
let cfg_dntk = RopeConfig {
head_dim,
base_theta: 10_000.0,
scaling: RopeScalingType::DynamicNtk {
factor: 4.0,
original_max_position: 4096,
},
max_position_embeddings: 4096,
};
let freqs_short = RopeFrequencies::compute(cfg_dntk.clone(), 512).expect("compute failed");
let freqs_std = RopeFrequencies::compute(RopeConfig::standard(head_dim), 512)
.expect("compute failed");
let half_dim = head_dim / 2;
let angle_dntk = freqs_short.sin[half_dim].atan2(freqs_short.cos[half_dim]);
let angle_std = freqs_std.sin[half_dim].atan2(freqs_std.cos[half_dim]);
assert!(
(angle_dntk - angle_std).abs() < 1e-4,
"Dynamic NTK with short seq should match standard: dntk={angle_dntk}, std={angle_std}"
);
let cfg_dntk_long = RopeConfig {
head_dim,
base_theta: 10_000.0,
scaling: RopeScalingType::DynamicNtk {
factor: 4.0,
original_max_position: 512,
},
max_position_embeddings: 512,
};
let freqs_long =
RopeFrequencies::compute(cfg_dntk_long, 4096).expect("compute long failed");
let mid = half_dim / 2;
let angle_long_mid = freqs_long.sin[half_dim + mid].atan2(freqs_long.cos[half_dim + mid]);
let angle_std_mid = freqs_std.sin[half_dim + mid].atan2(freqs_std.cos[half_dim + mid]);
assert!(
(angle_long_mid - angle_std_mid).abs() > 1e-6,
"Dynamic NTK with long seq should differ from standard at mid dim"
);
}
#[test]
fn test_rope_apply_identity_when_zero_pos() {
let cfg = RopeConfig::standard(8);
let freqs = RopeFrequencies::compute(cfg, 4).expect("compute failed");
let q = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; let rotated = freqs.apply_rope(&q, 1, 1).expect("apply failed");
for (orig, rot) in q.iter().zip(rotated.iter()) {
assert!(
(orig - rot).abs() < 1e-6,
"Position 0 should be identity, orig={orig}, rot={rot}"
);
}
}
#[test]
fn test_rope_apply_rotation() {
let cfg = RopeConfig::standard(4);
let freqs = RopeFrequencies::compute(cfg, 4).expect("compute failed");
let q = vec![1.0_f32, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0];
let rotated = freqs.apply_rope(&q, 2, 1).expect("apply failed");
assert!((rotated[0] - 1.0).abs() < 1e-6);
assert!((rotated[1] - 0.0).abs() < 1e-6);
let expected_cos = 1.0_f32.cos();
let expected_sin = 1.0_f32.sin();
assert!(
(rotated[4] - expected_cos).abs() < 1e-5,
"Expected {expected_cos}, got {}",
rotated[4]
);
assert!(
(rotated[5] - expected_sin).abs() < 1e-5,
"Expected {expected_sin}, got {}",
rotated[5]
);
}
#[test]
fn test_rope_apply_orthogonality() {
let cfg = RopeConfig::standard(32);
let freqs = RopeFrequencies::compute(cfg, 8).expect("compute failed");
let seq_len = 5;
let num_heads = 3;
let head_dim = 32;
let total = seq_len * num_heads * head_dim;
let q: Vec<f32> = (0..total).map(|i| (i as f32 * 0.1) % 2.0 - 1.0).collect();
let rotated = freqs.apply_rope(&q, seq_len, num_heads).expect("apply failed");
for pos in 0..seq_len {
for head in 0..num_heads {
let base = pos * num_heads * head_dim + head * head_dim;
let orig_norm: f32 = q[base..base + head_dim].iter().map(|x| x * x).sum::<f32>().sqrt();
let rot_norm: f32 = rotated[base..base + head_dim].iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(orig_norm - rot_norm).abs() < 1e-4,
"Norm changed at pos={pos}, head={head}: orig={orig_norm}, rot={rot_norm}"
);
}
}
}
#[test]
fn test_rope_apply_qk() {
let cfg = RopeConfig::standard(16);
let freqs = RopeFrequencies::compute(cfg, 8).expect("compute failed");
let seq_len = 4;
let num_heads = 2;
let num_kv_heads = 1;
let head_dim = 16;
let q: Vec<f32> = vec![0.5_f32; seq_len * num_heads * head_dim];
let k: Vec<f32> = vec![1.0_f32; seq_len * num_kv_heads * head_dim];
let (rotated_q, rotated_k) = freqs
.apply_rope_qk(&q, &k, seq_len, seq_len, num_heads, num_kv_heads)
.expect("apply_qk failed");
assert_eq!(rotated_q.len(), q.len());
assert_eq!(rotated_k.len(), k.len());
}
#[test]
fn test_rope_freq_stats() {
let cfg = RopeConfig::standard(64);
let freqs = RopeFrequencies::compute(cfg, 16).expect("compute failed");
let stats = freqs.frequency_stats();
assert!(stats.min_freq >= 0.0, "Min freq should be non-negative");
assert!(stats.max_freq >= stats.min_freq, "Max >= min");
assert!(stats.mean_freq >= stats.min_freq);
assert!(stats.mean_freq <= stats.max_freq);
assert!(stats.num_low_freq_dims > 0, "Should have some low-freq dims");
}
#[test]
fn test_rope_error_odd_head_dim() {
let cfg = RopeConfig::standard(63); let result = RopeFrequencies::compute(cfg, 8);
assert!(matches!(result, Err(RopeError::InvalidHeadDim { .. })));
}
#[test]
fn test_rope_error_seq_exceeded() {
let cfg = RopeConfig::standard(16);
let freqs = RopeFrequencies::compute(cfg, 4).expect("compute failed");
let q = vec![0.0_f32; 8 * 1 * 16]; let result = freqs.apply_rope(&q, 8, 1);
assert!(matches!(result, Err(RopeError::SequenceLengthExceeded { .. })));
}
#[test]
fn test_rope_scaling_display() {
let none = RopeScalingType::None;
assert!(none.to_string().contains("None"));
let linear = RopeScalingType::Linear { factor: 2.0 };
assert!(linear.to_string().contains("Linear"));
let ntk = RopeScalingType::Ntk { factor: 4.0 };
assert!(ntk.to_string().contains("NTK"));
let dntk = RopeScalingType::DynamicNtk {
factor: 4.0,
original_max_position: 2048,
};
assert!(dntk.to_string().contains("DynamicNTK"));
let yarn = RopeScalingType::Yarn {
factor: 2.0,
original_max_position: 4096,
beta_fast: 32.0,
beta_slow: 1.0,
};
assert!(yarn.to_string().contains("YaRN"));
let lr = RopeScalingType::LongRope {
short_factors: vec![1.0; 32],
long_factors: vec![2.0; 32],
original_max_position: 4096,
};
assert!(lr.to_string().contains("LongRoPE"));
}
#[test]
fn test_rope_long_rope_short_vs_long() {
let head_dim = 32;
let half_dim = head_dim / 2;
let short_factors = vec![1.0_f64; half_dim]; let long_factors = vec![2.0_f64; half_dim];
let cfg_short = RopeConfig {
head_dim,
base_theta: 10_000.0,
scaling: RopeScalingType::LongRope {
short_factors: short_factors.clone(),
long_factors: long_factors.clone(),
original_max_position: 4096,
},
max_position_embeddings: 4096,
};
let cfg_long = RopeConfig {
head_dim,
base_theta: 10_000.0,
scaling: RopeScalingType::LongRope {
short_factors: short_factors.clone(),
long_factors: long_factors.clone(),
original_max_position: 512,
},
max_position_embeddings: 512,
};
let freqs_short = RopeFrequencies::compute(cfg_short, 512).expect("short failed");
let freqs_std = RopeFrequencies::compute(RopeConfig::standard(head_dim), 512)
.expect("std failed");
let angle_short = freqs_short.sin[half_dim].atan2(freqs_short.cos[half_dim]);
let angle_std = freqs_std.sin[half_dim].atan2(freqs_std.cos[half_dim]);
assert!(
(angle_short - angle_std).abs() < 1e-4,
"Short LongRoPE with factor=1 should match standard"
);
let freqs_long = RopeFrequencies::compute(cfg_long, 4096).expect("long failed");
let angle_long = freqs_long.sin[half_dim].atan2(freqs_long.cos[half_dim]);
let expected_angle = angle_std / 2.0;
assert!(
(angle_long - expected_angle).abs() < 1e-4,
"Long LongRoPE with factor=2 should halve frequency: expected {expected_angle}, got {angle_long}"
);
}
#[test]
fn test_rope_error_invalid_scaling_factor() {
let cfg = RopeConfig {
head_dim: 32,
base_theta: 10_000.0,
scaling: RopeScalingType::Linear { factor: 0.5 }, max_position_embeddings: 4096,
};
let result = RopeFrequencies::compute(cfg, 8);
assert!(matches!(result, Err(RopeError::InvalidScalingFactor(_))));
}
#[test]
fn test_rope_error_zero_head_dim() {
let cfg = RopeConfig {
head_dim: 0,
base_theta: 10_000.0,
scaling: RopeScalingType::None,
max_position_embeddings: 4096,
};
let result = RopeFrequencies::compute(cfg, 8);
assert!(matches!(result, Err(RopeError::InvalidHeadDim { .. })));
}
#[test]
fn test_rope_error_dimension_mismatch() {
let cfg = RopeConfig::standard(16);
let freqs = RopeFrequencies::compute(cfg, 8).expect("compute failed");
let q = vec![0.0_f32; 10]; let result = freqs.apply_rope(&q, 2, 1);
assert!(matches!(result, Err(RopeError::DimensionMismatch { .. })));
}
#[test]
fn test_yarn_find_correction_dim_positive() {
let dim = yarn_find_correction_dim(1.0, 64, 10000.0, 4096);
assert!(dim > 0.0, "correction dim should be positive: got {dim}");
}
#[test]
fn test_yarn_find_correction_dim_scales_with_rotations() {
let dim_slow = yarn_find_correction_dim(1.0, 64, 10000.0, 4096);
let dim_fast = yarn_find_correction_dim(32.0, 64, 10000.0, 4096);
assert!(
dim_slow > dim_fast,
"slow (1 rotation) should have larger correction dim than fast (32 rotations)"
);
}
#[test]
fn test_yarn_linear_ramp_mask_boundary_values() {
let mask = yarn_linear_ramp_mask(0.0, 10.0, 11);
assert!((mask[0] - 0.0).abs() < 1e-6, "first element should be 0");
assert!((mask[10] - 1.0).abs() < 1e-6, "last element should be 1");
}
#[test]
fn test_yarn_linear_ramp_mask_monotone() {
let mask = yarn_linear_ramp_mask(2.0, 8.0, 12);
for i in 1..mask.len() {
assert!(
mask[i] >= mask[i - 1],
"ramp mask must be non-decreasing: mask[{i}]={} < mask[{}]={}",
mask[i],
i - 1,
mask[i - 1]
);
}
}
#[test]
fn test_yarn_linear_ramp_mask_zero_range() {
let mask = yarn_linear_ramp_mask(5.0, 5.0, 8);
for v in &mask {
assert!((v - 0.0).abs() < 1e-6, "should be 0 when min==max");
}
}
#[test]
fn test_yarn_get_mscale_formula() {
let result = yarn_get_mscale(std::f32::consts::E, 1.0);
let expected = (0.1_f32 * 1.0_f32 + 1.0).sqrt(); assert!((result - expected).abs() < 1e-5, "mscale formula mismatch: {result} vs {expected}");
}
#[test]
fn test_yarn_get_mscale_scale_1_returns_mscale() {
let result = yarn_get_mscale(1.0, 2.5);
assert!((result - 2.5).abs() < 1e-5, "scale=1 should return mscale unchanged");
}
#[test]
fn test_apply_yarn_rope_output_shape() {
let config = YarnConfig {
original_max_position_embeddings: 4096,
scaling_factor: 4.0,
beta_fast: 32.0,
beta_slow: 1.0,
mscale: 1.0,
mscale_all_dim: 0.0,
base: 10000.0,
head_dim: 16,
};
let seq_len = 4;
let num_heads = 2;
let num_kv_heads = 2;
let q = vec![0.1_f32; seq_len * num_heads * 16];
let k = vec![0.2_f32; seq_len * num_kv_heads * 16];
let positions: Vec<usize> = (0..seq_len).collect();
let (out_q, out_k) =
apply_yarn_rope(&q, &k, &positions, num_heads, num_kv_heads, &config)
.expect("apply_yarn_rope should succeed");
assert_eq!(out_q.len(), q.len());
assert_eq!(out_k.len(), k.len());
}
#[test]
fn test_apply_yarn_rope_invalid_head_dim() {
let config = YarnConfig {
original_max_position_embeddings: 4096,
scaling_factor: 2.0,
beta_fast: 32.0,
beta_slow: 1.0,
mscale: 1.0,
mscale_all_dim: 0.0,
base: 10000.0,
head_dim: 0,
};
let result = apply_yarn_rope(&[], &[], &[], 1, 1, &config);
assert!(matches!(result, Err(RopeError::InvalidHeadDim { .. })));
}
#[test]
fn test_apply_yarn_rope_odd_head_dim() {
let config = YarnConfig {
original_max_position_embeddings: 4096,
scaling_factor: 2.0,
beta_fast: 32.0,
beta_slow: 1.0,
mscale: 1.0,
mscale_all_dim: 0.0,
base: 10000.0,
head_dim: 7,
};
let result = apply_yarn_rope(&[], &[], &[], 1, 1, &config);
assert!(matches!(result, Err(RopeError::InvalidHeadDim { .. })));
}
#[test]
fn test_dynamic_ntk_theta_no_scaling_when_short() {
let config = DynamicNtkConfig {
base_theta: 10000.0,
alpha: 8.0,
max_original_length: 4096,
head_dim: 64,
};
let theta = dynamic_ntk_theta(2048, &config);
assert!(
(theta - 10000.0).abs() < 1e-3,
"theta should be base_theta for short sequences: got {theta}"
);
}
#[test]
fn test_dynamic_ntk_theta_increases_with_seq_len() {
let config = DynamicNtkConfig {
base_theta: 10000.0,
alpha: 8.0,
max_original_length: 4096,
head_dim: 64,
};
let theta_short = dynamic_ntk_theta(4096, &config);
let theta_long = dynamic_ntk_theta(32768, &config);
assert!(
theta_long > theta_short,
"longer sequences should produce larger theta: {theta_short} vs {theta_long}"
);
}
#[test]
fn test_apply_dynamic_ntk_rope_output_shape() {
let config = DynamicNtkConfig {
base_theta: 10000.0,
alpha: 4.0,
max_original_length: 512,
head_dim: 16,
};
let seq_len = 8;
let num_heads = 2;
let q = vec![0.0_f32; seq_len * num_heads * 16];
let k = vec![0.0_f32; seq_len * num_heads * 16];
let (out_q, out_k) =
apply_dynamic_ntk_rope(&q, &k, seq_len, num_heads, num_heads, &config)
.expect("should succeed");
assert_eq!(out_q.len(), q.len());
assert_eq!(out_k.len(), k.len());
}
#[test]
fn test_apply_dynamic_ntk_rope_dimension_mismatch() {
let config = DynamicNtkConfig {
base_theta: 10000.0,
alpha: 2.0,
max_original_length: 512,
head_dim: 16,
};
let q = vec![0.0_f32; 5]; let k = vec![0.0_f32; 16];
let result = apply_dynamic_ntk_rope(&q, &k, 1, 1, 1, &config);
assert!(matches!(result, Err(RopeError::DimensionMismatch { .. })));
}
#[test]
fn test_apply_longrope_scaling_uses_short_factor_below_threshold() {
let config = LongRopeScaling {
short_factor: vec![2.0; 4],
long_factor: vec![8.0; 4],
threshold: 512,
short_mscale: 1.0,
long_mscale: 1.0,
};
let inv_freq = vec![1.0_f32; 4];
let result = apply_longrope_scaling(&inv_freq, 256, &config); for v in &result {
assert!((v - 0.5).abs() < 1e-6, "short factor should halve freq: got {v}");
}
}
#[test]
fn test_apply_longrope_scaling_uses_long_factor_above_threshold() {
let config = LongRopeScaling {
short_factor: vec![2.0; 4],
long_factor: vec![8.0; 4],
threshold: 512,
short_mscale: 1.0,
long_mscale: 1.0,
};
let inv_freq = vec![1.0_f32; 4];
let result = apply_longrope_scaling(&inv_freq, 1024, &config); for v in &result {
assert!((v - 0.125).abs() < 1e-6, "long factor should divide freq by 8: got {v}");
}
}
#[test]
fn test_apply_longrope_scaling_preserves_length() {
let config = LongRopeScaling {
short_factor: vec![1.0; 8],
long_factor: vec![2.0; 8],
threshold: 256,
short_mscale: 1.0,
long_mscale: 1.0,
};
let inv_freq = vec![0.5_f32; 8];
let result_short = apply_longrope_scaling(&inv_freq, 128, &config);
let result_long = apply_longrope_scaling(&inv_freq, 512, &config);
assert_eq!(result_short.len(), inv_freq.len());
assert_eq!(result_long.len(), inv_freq.len());
}
#[test]
fn test_apply_longrope_scaling_at_threshold_uses_short() {
let config = LongRopeScaling {
short_factor: vec![2.0; 2],
long_factor: vec![100.0; 2],
threshold: 512,
short_mscale: 1.0,
long_mscale: 1.0,
};
let inv_freq = vec![1.0_f32; 2];
let result = apply_longrope_scaling(&inv_freq, 512, &config);
for v in &result {
assert!((v - 0.5).abs() < 1e-6, "at threshold should use short factor: got {v}");
}
}
}