use crate::error::{LmError, LmResult};
use core::f32::consts::PI;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum RopeScalingKind {
None,
Linear {
scale: f32,
},
NtkAware {
scale: f32,
},
Yarn {
scale: f32,
original_max_pos: usize,
beta_fast: f32,
beta_slow: f32,
},
}
#[derive(Debug, Clone)]
pub struct RopeScaling {
dim: usize,
base: f32,
kind: RopeScalingKind,
inv_freqs: Vec<f32>,
mscale: f32,
}
impl RopeScaling {
pub fn new(dim: usize, base: f32, kind: RopeScalingKind) -> LmResult<Self> {
if dim == 0 || dim % 2 != 0 {
return Err(LmError::InvalidConfig {
msg: format!("RopeScaling: dim={dim} must be even and > 0"),
});
}
if base <= 0.0 {
return Err(LmError::InvalidConfig {
msg: "RopeScaling: base must be > 0".into(),
});
}
let half = dim / 2;
let base_inv_freqs: Vec<f32> = (0..half)
.map(|d| base.powf(-((2 * d) as f32) / dim as f32))
.collect();
let (inv_freqs, mscale) = match kind {
RopeScalingKind::None => (base_inv_freqs, 1.0_f32),
RopeScalingKind::Linear { scale } => {
Self::check_scale(scale)?;
let f = base_inv_freqs.iter().map(|&w| w / scale).collect();
(f, 1.0_f32)
}
RopeScalingKind::NtkAware { scale } => {
Self::check_scale(scale)?;
let exponent = dim as f32 / (dim as f32 - 2.0);
let new_base = base * scale.powf(exponent);
let f = (0..half)
.map(|d| new_base.powf(-((2 * d) as f32) / dim as f32))
.collect();
(f, 1.0_f32)
}
RopeScalingKind::Yarn {
scale,
original_max_pos,
beta_fast,
beta_slow,
} => {
Self::check_scale(scale)?;
if original_max_pos == 0 {
return Err(LmError::InvalidConfig {
msg: "RopeScaling::Yarn: original_max_pos must be > 0".into(),
});
}
if beta_fast <= beta_slow {
return Err(LmError::InvalidConfig {
msg: format!(
"RopeScaling::Yarn: beta_fast={beta_fast} must be > beta_slow={beta_slow}"
),
});
}
let orig = original_max_pos as f32;
let denom = beta_fast - beta_slow;
let f = base_inv_freqs
.iter()
.map(|&w| {
let wavelength = 2.0 * PI / w;
let rotations = orig / wavelength;
let ramp = ((rotations - beta_slow) / denom).clamp(0.0, 1.0);
(1.0 - ramp) * (w / scale) + ramp * w
})
.collect();
let mscale = 0.1 * scale.ln() + 1.0;
(f, mscale)
}
};
Ok(Self {
dim,
base,
kind,
inv_freqs,
mscale,
})
}
fn check_scale(scale: f32) -> LmResult<()> {
if scale <= 0.0 {
return Err(LmError::InvalidConfig {
msg: "RopeScaling: scale must be > 0".into(),
});
}
Ok(())
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn base(&self) -> f32 {
self.base
}
pub fn kind(&self) -> RopeScalingKind {
self.kind
}
pub fn inv_freqs(&self) -> &[f32] {
&self.inv_freqs
}
pub fn mscale(&self) -> f32 {
self.mscale
}
pub fn cos_sin(&self, n_positions: usize) -> (Vec<f32>, Vec<f32>) {
let half = self.dim / 2;
let n = n_positions * half;
let mut cos_table = Vec::with_capacity(n);
let mut sin_table = Vec::with_capacity(n);
for pos in 0..n_positions {
for &freq in &self.inv_freqs {
let angle = pos as f32 * freq;
cos_table.push(self.mscale * angle.cos());
sin_table.push(self.mscale * angle.sin());
}
}
(cos_table, sin_table)
}
pub fn apply(&self, x: &mut [f32], n_heads: usize, positions: &[usize]) -> LmResult<()> {
if positions.is_empty() {
return Err(LmError::EmptyInput {
context: "RopeScaling::apply positions",
});
}
if n_heads == 0 {
return Err(LmError::InvalidConfig {
msg: "RopeScaling::apply: n_heads must be > 0".into(),
});
}
let n_tokens = positions.len();
let expected = n_tokens * n_heads * self.dim;
if x.len() != expected {
return Err(LmError::DimensionMismatch {
expected,
got: x.len(),
});
}
let half = self.dim / 2;
for (t, &abs_pos) in positions.iter().enumerate() {
for h in 0..n_heads {
let base = (t * n_heads + h) * self.dim;
for i in 0..half {
let freq = *self.inv_freqs.get(i).ok_or_else(|| LmError::Internal {
msg: "RopeScaling::apply: inv_freqs index out of range".into(),
})?;
let angle = abs_pos as f32 * freq;
let cos = self.mscale * angle.cos();
let sin = self.mscale * angle.sin();
let i0 = base + 2 * i;
let i1 = base + 2 * i + 1;
let x0 = *x.get(i0).ok_or_else(|| LmError::Internal {
msg: "RopeScaling::apply: x index out of range".into(),
})?;
let x1 = *x.get(i1).ok_or_else(|| LmError::Internal {
msg: "RopeScaling::apply: x index out of range".into(),
})?;
*x.get_mut(i0).ok_or_else(|| LmError::Internal {
msg: "RopeScaling::apply: x index out of range".into(),
})? = x0 * cos - x1 * sin;
*x.get_mut(i1).ok_or_else(|| LmError::Internal {
msg: "RopeScaling::apply: x index out of range".into(),
})? = x0 * sin + x1 * cos;
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::layer::embedding::RotaryEmbedding;
const DIM: usize = 8;
const BASE: f32 = 10_000.0;
fn reference_inv_freqs(dim: usize, base: f32) -> Vec<f32> {
(0..dim / 2)
.map(|d| base.powf(-((2 * d) as f32) / dim as f32))
.collect()
}
#[test]
fn none_reproduces_rotary_embedding_inv_freqs_exactly() {
let rs = RopeScaling::new(DIM, BASE, RopeScalingKind::None)
.expect("None with even dim and positive base is valid");
let want = reference_inv_freqs(DIM, BASE);
assert_eq!(rs.inv_freqs().len(), want.len());
for (got, exp) in rs.inv_freqs().iter().zip(want.iter()) {
assert_eq!(
*got, *exp,
"inv_freq must match RotaryEmbedding bit-for-bit"
);
}
}
#[test]
fn none_reproduces_rotary_embedding_cos_sin_exactly() {
let rs = RopeScaling::new(DIM, BASE, RopeScalingKind::None).expect("None config is valid");
let rope = RotaryEmbedding::new(DIM, 16, BASE).expect("RotaryEmbedding is valid");
let (cos, sin) = rs.cos_sin(16);
let half = DIM / 2;
for pos in 0..16 {
for i in 0..half {
let idx = pos * half + i;
assert_eq!(cos[idx], rope.cos_at(pos, i), "cos at ({pos},{i})");
assert_eq!(sin[idx], rope.sin_at(pos, i), "sin at ({pos},{i})");
}
}
}
#[test]
fn linear_scale_one_equals_none() {
let none = RopeScaling::new(DIM, BASE, RopeScalingKind::None).expect("none");
let lin = RopeScaling::new(DIM, BASE, RopeScalingKind::Linear { scale: 1.0 })
.expect("linear scale=1");
for (a, b) in none.inv_freqs().iter().zip(lin.inv_freqs().iter()) {
assert_eq!(*a, *b, "Linear(1.0) must equal None");
}
assert_eq!(lin.mscale(), 1.0);
}
#[test]
fn ntk_scale_one_equals_none() {
let none = RopeScaling::new(DIM, BASE, RopeScalingKind::None).expect("none");
let ntk = RopeScaling::new(DIM, BASE, RopeScalingKind::NtkAware { scale: 1.0 })
.expect("ntk scale=1");
for (a, b) in none.inv_freqs().iter().zip(ntk.inv_freqs().iter()) {
assert!(
(a - b).abs() < 1e-6,
"NtkAware(1.0) must equal None: {a} vs {b}"
);
}
}
#[test]
fn linear_scale_shifts_position() {
let scale = 4.0_f32;
let lin = RopeScaling::new(DIM, BASE, RopeScalingKind::Linear { scale }).expect("linear");
let base_inv = reference_inv_freqs(DIM, BASE);
let half = DIM / 2;
let (cos, sin) = lin.cos_sin(16);
let p = 12usize;
for (i, &w) in base_inv.iter().enumerate() {
let want_angle = (p as f32 / scale) * w;
let idx = p * half + i;
assert!((cos[idx] - want_angle.cos()).abs() < 1e-5);
assert!((sin[idx] - want_angle.sin()).abs() < 1e-5);
}
}
#[test]
fn ntk_base_matches_hand_computation() {
let scale = 8.0_f32;
let ntk = RopeScaling::new(DIM, BASE, RopeScalingKind::NtkAware { scale }).expect("ntk");
let exponent = DIM as f32 / (DIM as f32 - 2.0);
let new_base = BASE * scale.powf(exponent);
let want: Vec<f32> = (0..DIM / 2)
.map(|d| new_base.powf(-((2 * d) as f32) / DIM as f32))
.collect();
for (got, exp) in ntk.inv_freqs().iter().zip(want.iter()) {
assert!((got - exp).abs() < 1e-9, "{got} vs {exp}");
}
}
#[test]
fn yarn_ramp_within_unit_interval() {
let scale = 16.0_f32;
let yarn = RopeScaling::new(
DIM,
BASE,
RopeScalingKind::Yarn {
scale,
original_max_pos: 2048,
beta_fast: 32.0,
beta_slow: 1.0,
},
)
.expect("yarn");
let base_inv = reference_inv_freqs(DIM, BASE);
for (got, &w) in yarn.inv_freqs().iter().zip(base_inv.iter()) {
let lo = (w / scale).min(w);
let hi = (w / scale).max(w);
assert!(
*got >= lo - 1e-6 && *got <= hi + 1e-6,
"{got} not in [{lo},{hi}]"
);
}
}
#[test]
fn yarn_scale_one_is_near_identity() {
let none = RopeScaling::new(DIM, BASE, RopeScalingKind::None).expect("none");
let yarn = RopeScaling::new(
DIM,
BASE,
RopeScalingKind::Yarn {
scale: 1.0,
original_max_pos: 2048,
beta_fast: 32.0,
beta_slow: 1.0,
},
)
.expect("yarn scale=1");
for (a, b) in none.inv_freqs().iter().zip(yarn.inv_freqs().iter()) {
assert!((a - b).abs() < 1e-6, "{a} vs {b}");
}
assert!((yarn.mscale() - 1.0).abs() < 1e-6);
}
#[test]
fn mscale_formula() {
let scale = 32.0_f32;
let yarn = RopeScaling::new(
DIM,
BASE,
RopeScalingKind::Yarn {
scale,
original_max_pos: 4096,
beta_fast: 32.0,
beta_slow: 1.0,
},
)
.expect("yarn");
let want = 0.1 * scale.ln() + 1.0;
assert!(
(yarn.mscale() - want).abs() < 1e-6,
"{} vs {want}",
yarn.mscale()
);
}
#[test]
fn non_yarn_mscale_is_one() {
for kind in [
RopeScalingKind::None,
RopeScalingKind::Linear { scale: 4.0 },
RopeScalingKind::NtkAware { scale: 4.0 },
] {
let rs = RopeScaling::new(DIM, BASE, kind).expect("valid kind");
assert_eq!(rs.mscale(), 1.0, "non-YaRN mscale must be 1.0");
}
}
#[test]
fn large_position_no_nan_inf() {
for kind in [
RopeScalingKind::None,
RopeScalingKind::Linear { scale: 8.0 },
RopeScalingKind::NtkAware { scale: 8.0 },
RopeScalingKind::Yarn {
scale: 8.0,
original_max_pos: 2048,
beta_fast: 32.0,
beta_slow: 1.0,
},
] {
let rs = RopeScaling::new(DIM, BASE, kind).expect("valid");
let (cos, sin) = rs.cos_sin(8193);
assert!(cos.iter().all(|v| v.is_finite()), "cos finite for {kind:?}");
assert!(sin.iter().all(|v| v.is_finite()), "sin finite for {kind:?}");
}
}
#[test]
fn inv_freqs_length_is_half_dim() {
for d in [2usize, 4, 8, 16, 64, 128] {
let rs = RopeScaling::new(d, BASE, RopeScalingKind::None).expect("valid");
assert_eq!(rs.inv_freqs().len(), d / 2);
}
}
#[test]
fn freqs_monotone_decreasing() {
for kind in [
RopeScalingKind::None,
RopeScalingKind::Linear { scale: 4.0 },
RopeScalingKind::NtkAware { scale: 4.0 },
RopeScalingKind::Yarn {
scale: 4.0,
original_max_pos: 2048,
beta_fast: 32.0,
beta_slow: 1.0,
},
] {
let rs = RopeScaling::new(64, BASE, kind).expect("valid");
let f = rs.inv_freqs();
for w in f.windows(2) {
assert!(w[0] >= w[1], "freqs must be non-increasing for {kind:?}");
}
}
}
#[test]
fn apply_then_inverse_recovers_input() {
let rs =
RopeScaling::new(DIM, BASE, RopeScalingKind::Linear { scale: 2.0 }).expect("linear");
let original = vec![0.3_f32, -1.2, 0.7, 2.1, -0.5, 1.1, 0.9, -0.2];
let mut x = original.clone();
let positions = [5usize];
rs.apply(&mut x, 1, &positions).expect("apply forward");
let half = DIM / 2;
let inv = rs.inv_freqs();
for i in 0..half {
let angle = positions[0] as f32 * inv[i];
let cos = angle.cos();
let sin = angle.sin();
let x0 = x[2 * i];
let x1 = x[2 * i + 1];
x[2 * i] = x0 * cos + x1 * sin;
x[2 * i + 1] = -x0 * sin + x1 * cos;
}
for (a, b) in x.iter().zip(original.iter()) {
assert!((a - b).abs() < 1e-5, "{a} vs {b}");
}
}
#[test]
fn apply_matches_cos_sin_tables() {
let rs = RopeScaling::new(DIM, BASE, RopeScalingKind::None).expect("none");
let original = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut x = original.clone();
let positions = [7usize];
rs.apply(&mut x, 1, &positions).expect("apply");
let (cos, sin) = rs.cos_sin(8);
let half = DIM / 2;
for i in 0..half {
let idx = positions[0] * half + i;
let x0 = original[2 * i];
let x1 = original[2 * i + 1];
let want0 = x0 * cos[idx] - x1 * sin[idx];
let want1 = x0 * sin[idx] + x1 * cos[idx];
assert!((x[2 * i] - want0).abs() < 1e-5);
assert!((x[2 * i + 1] - want1).abs() < 1e-5);
}
}
#[test]
fn apply_matches_rotary_embedding_for_none() {
let rs = RopeScaling::new(DIM, BASE, RopeScalingKind::None).expect("none");
let rope = RotaryEmbedding::new(DIM, 32, BASE).expect("rope");
let original = vec![0.5_f32, -0.3, 1.7, 0.2, -2.1, 0.8, 1.0, -1.0];
let mut a = original.clone();
let mut b = original.clone();
rs.apply(&mut a, 1, &[9]).expect("rope_scaling apply");
rope.apply(&mut b, 1, 1, 9).expect("rotary apply");
for (x, y) in a.iter().zip(b.iter()) {
assert!((x - y).abs() < 1e-6, "{x} vs {y}");
}
}
#[test]
fn apply_multi_token_multi_head() {
let rs =
RopeScaling::new(DIM, BASE, RopeScalingKind::NtkAware { scale: 2.0 }).expect("ntk");
let mut x = vec![0.5_f32; 3 * 2 * DIM]; rs.apply(&mut x, 2, &[0, 1, 2]).expect("apply");
assert_eq!(x.len(), 3 * 2 * DIM);
assert!(x.iter().all(|v| v.is_finite()));
}
#[test]
fn err_dim_odd() {
assert!(RopeScaling::new(7, BASE, RopeScalingKind::None).is_err());
}
#[test]
fn err_dim_zero() {
assert!(RopeScaling::new(0, BASE, RopeScalingKind::None).is_err());
}
#[test]
fn err_base_non_positive() {
assert!(RopeScaling::new(DIM, 0.0, RopeScalingKind::None).is_err());
assert!(RopeScaling::new(DIM, -10.0, RopeScalingKind::None).is_err());
}
#[test]
fn err_scale_non_positive() {
assert!(RopeScaling::new(DIM, BASE, RopeScalingKind::Linear { scale: 0.0 }).is_err());
assert!(RopeScaling::new(DIM, BASE, RopeScalingKind::NtkAware { scale: -1.0 }).is_err());
assert!(
RopeScaling::new(
DIM,
BASE,
RopeScalingKind::Yarn {
scale: 0.0,
original_max_pos: 2048,
beta_fast: 32.0,
beta_slow: 1.0,
}
)
.is_err()
);
}
#[test]
fn err_yarn_beta_order() {
assert!(
RopeScaling::new(
DIM,
BASE,
RopeScalingKind::Yarn {
scale: 4.0,
original_max_pos: 2048,
beta_fast: 1.0,
beta_slow: 32.0,
}
)
.is_err()
);
assert!(
RopeScaling::new(
DIM,
BASE,
RopeScalingKind::Yarn {
scale: 4.0,
original_max_pos: 2048,
beta_fast: 8.0,
beta_slow: 8.0,
}
)
.is_err()
);
}
#[test]
fn err_yarn_original_max_pos_zero() {
assert!(
RopeScaling::new(
DIM,
BASE,
RopeScalingKind::Yarn {
scale: 4.0,
original_max_pos: 0,
beta_fast: 32.0,
beta_slow: 1.0,
}
)
.is_err()
);
}
#[test]
fn err_apply_empty_positions() {
let rs = RopeScaling::new(DIM, BASE, RopeScalingKind::None).expect("none");
let mut x = vec![0.0_f32; DIM];
assert!(matches!(
rs.apply(&mut x, 1, &[]),
Err(LmError::EmptyInput { .. })
));
}
#[test]
fn err_apply_dim_mismatch() {
let rs = RopeScaling::new(DIM, BASE, RopeScalingKind::None).expect("none");
let mut x = vec![0.0_f32; DIM + 1];
assert!(matches!(
rs.apply(&mut x, 1, &[0]),
Err(LmError::DimensionMismatch { .. })
));
}
#[test]
fn err_apply_zero_heads() {
let rs = RopeScaling::new(DIM, BASE, RopeScalingKind::None).expect("none");
let mut x = vec![0.0_f32; DIM];
assert!(rs.apply(&mut x, 0, &[0]).is_err());
}
}