use baracuda_cutlass::{Error, Result};
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum RopeScaling {
Linear,
YaRN {
scale: f32,
alpha: f32,
beta: f32,
original_max_seq_len: i32,
},
LongRoPE {
per_dim_factors: Vec<f32>,
},
}
#[derive(Debug, Clone)]
pub struct RopeScaledTableBuilder {
head_dim: i32,
max_seq_len: i32,
base: f32,
scaling: RopeScaling,
}
impl RopeScaledTableBuilder {
pub fn new(head_dim: i32, max_seq_len: i32, base: f32, scaling: RopeScaling) -> Self {
Self {
head_dim,
max_seq_len,
base,
scaling,
}
}
pub fn validate(&self) -> Result<()> {
if self.head_dim <= 0 || self.head_dim % 2 != 0 {
return Err(Error::InvalidProblem(
"RopeScaledTableBuilder: head_dim must be positive + even",
));
}
if self.max_seq_len <= 0 {
return Err(Error::InvalidProblem(
"RopeScaledTableBuilder: max_seq_len must be positive",
));
}
if !self.base.is_finite() || self.base <= 0.0 {
return Err(Error::InvalidProblem(
"RopeScaledTableBuilder: base must be finite and positive",
));
}
match &self.scaling {
RopeScaling::Linear => {}
RopeScaling::YaRN {
scale,
alpha,
beta,
original_max_seq_len,
} => {
if !scale.is_finite() || *scale <= 0.0 {
return Err(Error::InvalidProblem(
"RopeScaledTableBuilder::YaRN: scale must be finite + positive",
));
}
if !alpha.is_finite() || !beta.is_finite() {
return Err(Error::InvalidProblem(
"RopeScaledTableBuilder::YaRN: alpha + beta must be finite",
));
}
if *alpha >= *beta {
return Err(Error::InvalidProblem(
"RopeScaledTableBuilder::YaRN: alpha must be < beta \
(paper convention: alpha=1, beta=32)",
));
}
if *original_max_seq_len <= 0 {
return Err(Error::InvalidProblem(
"RopeScaledTableBuilder::YaRN: original_max_seq_len must be positive",
));
}
}
RopeScaling::LongRoPE { per_dim_factors } => {
let expected = (self.head_dim / 2) as usize;
if per_dim_factors.len() != expected {
return Err(Error::InvalidProblem(
"RopeScaledTableBuilder::LongRoPE: per_dim_factors length must \
equal head_dim / 2",
));
}
for &f in per_dim_factors {
if !f.is_finite() || f <= 0.0 {
return Err(Error::InvalidProblem(
"RopeScaledTableBuilder::LongRoPE: per_dim_factors must be \
finite and positive",
));
}
}
}
}
Ok(())
}
pub fn build_host_tables(&self) -> Result<(Vec<f32>, Vec<f32>)> {
self.validate()?;
let half_d = (self.head_dim / 2) as usize;
let seq = self.max_seq_len as usize;
let total = seq * half_d;
let mut cos_tab = vec![0f32; total];
let mut sin_tab = vec![0f32; total];
match &self.scaling {
RopeScaling::Linear => self.fill_linear(&mut cos_tab, &mut sin_tab),
RopeScaling::YaRN {
scale,
alpha,
beta,
original_max_seq_len,
} => self.fill_yarn(
&mut cos_tab,
&mut sin_tab,
*scale,
*alpha,
*beta,
*original_max_seq_len,
),
RopeScaling::LongRoPE { per_dim_factors } => {
self.fill_longrope(&mut cos_tab, &mut sin_tab, per_dim_factors)
}
}
Ok((cos_tab, sin_tab))
}
#[inline]
pub fn inv_freq(&self, pair: usize) -> f32 {
let inv_d = 1.0f32 / (self.head_dim as f32);
let exponent = -((2 * pair) as f32) * inv_d;
self.base.powf(exponent)
}
fn fill_linear(&self, cos_tab: &mut [f32], sin_tab: &mut [f32]) {
let half_d = (self.head_dim / 2) as usize;
let seq = self.max_seq_len as usize;
for s in 0..seq {
for pair in 0..half_d {
let freq = self.inv_freq(pair);
let theta = (s as f32) * freq;
cos_tab[s * half_d + pair] = theta.cos();
sin_tab[s * half_d + pair] = theta.sin();
}
}
}
fn fill_yarn(
&self,
cos_tab: &mut [f32],
sin_tab: &mut [f32],
scale: f32,
alpha: f32,
beta: f32,
original_max_seq_len: i32,
) {
let half_d = (self.head_dim / 2) as usize;
let seq = self.max_seq_len as usize;
let l_orig = original_max_seq_len as f32;
let attn_temp = if scale > 1.0 {
(1.0f32 + 0.1 * scale.ln()).sqrt()
} else {
1.0
};
let inv_attn_temp = 1.0 / attn_temp;
for pair in 0..half_d {
let inv_freq = self.inv_freq(pair);
let rotations = (l_orig * inv_freq) / (2.0 * core::f32::consts::PI);
let ramp = if rotations >= beta {
0.0 } else if rotations <= alpha {
1.0 } else {
(beta - rotations) / (beta - alpha)
};
let interpolated_inv_freq =
(1.0 - ramp) * inv_freq + ramp * (inv_freq / scale);
for s in 0..seq {
let theta = (s as f32) * interpolated_inv_freq;
cos_tab[s * half_d + pair] = theta.cos() * inv_attn_temp;
sin_tab[s * half_d + pair] = theta.sin() * inv_attn_temp;
}
}
}
fn fill_longrope(
&self,
cos_tab: &mut [f32],
sin_tab: &mut [f32],
per_dim_factors: &[f32],
) {
let half_d = (self.head_dim / 2) as usize;
let seq = self.max_seq_len as usize;
for s in 0..seq {
for pair in 0..half_d {
let inv_freq = self.inv_freq(pair) * per_dim_factors[pair];
let theta = (s as f32) * inv_freq;
cos_tab[s * half_d + pair] = theta.cos();
sin_tab[s * half_d + pair] = theta.sin();
}
}
}
#[inline]
pub fn table_len(&self) -> usize {
(self.max_seq_len as usize) * ((self.head_dim / 2) as usize)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn linear_matches_default_schedule() {
let head_dim = 16i32;
let seq = 4i32;
let base = 10000.0f32;
let builder =
RopeScaledTableBuilder::new(head_dim, seq, base, RopeScaling::Linear);
let (cos, sin) = builder.build_host_tables().expect("build");
assert_eq!(cos.len(), (seq * head_dim / 2) as usize);
assert_eq!(sin.len(), (seq * head_dim / 2) as usize);
let half_d = (head_dim / 2) as usize;
let inv_d = 1.0f32 / head_dim as f32;
for s in 0..seq as usize {
for pair in 0..half_d {
let exponent = -((2 * pair) as f32) * inv_d;
let freq = base.powf(exponent);
let theta = (s as f32) * freq;
let expected_cos = theta.cos();
let expected_sin = theta.sin();
let i = s * half_d + pair;
assert!(
(cos[i] - expected_cos).abs() < 1e-6,
"linear cos mismatch @ ({s},{pair}): got {} expected {}",
cos[i], expected_cos
);
assert!(
(sin[i] - expected_sin).abs() < 1e-6,
"linear sin mismatch @ ({s},{pair}): got {} expected {}",
sin[i], expected_sin
);
}
}
}
#[test]
fn yarn_scale_one_matches_linear() {
let head_dim = 32i32;
let seq = 8i32;
let base = 10000.0f32;
let linear = RopeScaledTableBuilder::new(head_dim, seq, base, RopeScaling::Linear)
.build_host_tables()
.expect("build linear");
let yarn = RopeScaledTableBuilder::new(
head_dim,
seq,
base,
RopeScaling::YaRN {
scale: 1.0,
alpha: 1.0,
beta: 32.0,
original_max_seq_len: 2048,
},
)
.build_host_tables()
.expect("build yarn");
for i in 0..linear.0.len() {
assert!(
(linear.0[i] - yarn.0[i]).abs() < 1e-6,
"cos mismatch @ {i}: linear={} yarn={}",
linear.0[i],
yarn.0[i]
);
assert!(
(linear.1[i] - yarn.1[i]).abs() < 1e-6,
"sin mismatch @ {i}: linear={} yarn={}",
linear.1[i],
yarn.1[i]
);
}
}
#[test]
fn longrope_unit_factors_match_linear() {
let head_dim = 16i32;
let seq = 4i32;
let base = 10000.0f32;
let linear = RopeScaledTableBuilder::new(head_dim, seq, base, RopeScaling::Linear)
.build_host_tables()
.expect("build linear");
let long_rope = RopeScaledTableBuilder::new(
head_dim,
seq,
base,
RopeScaling::LongRoPE {
per_dim_factors: vec![1.0; (head_dim / 2) as usize],
},
)
.build_host_tables()
.expect("build longrope");
for i in 0..linear.0.len() {
assert!((linear.0[i] - long_rope.0[i]).abs() < 1e-6);
assert!((linear.1[i] - long_rope.1[i]).abs() < 1e-6);
}
}
#[test]
fn yarn_scaled_reduces_low_freq_angle() {
let head_dim = 32i32;
let seq = 8i32;
let base = 10000.0f32;
let linear = RopeScaledTableBuilder::new(head_dim, seq, base, RopeScaling::Linear)
.build_host_tables()
.expect("linear");
let yarn = RopeScaledTableBuilder::new(
head_dim,
seq,
base,
RopeScaling::YaRN {
scale: 4.0,
alpha: 1.0,
beta: 32.0,
original_max_seq_len: 2048,
},
)
.build_host_tables()
.expect("yarn");
let half_d = (head_dim / 2) as usize;
let last_pair = half_d - 1;
let s = 1usize;
let idx = s * half_d + last_pair;
let yarn_mag = (yarn.0[idx].powi(2) + yarn.1[idx].powi(2)).sqrt();
let linear_mag = (linear.0[idx].powi(2) + linear.1[idx].powi(2)).sqrt();
assert!(
(linear_mag - 1.0).abs() < 1e-5,
"linear (cos,sin) must have unit magnitude"
);
let expected_attn_temp = (1.0f32 + 0.1 * 4.0f32.ln()).sqrt();
let expected_yarn_mag = 1.0 / expected_attn_temp;
assert!(
(yarn_mag - expected_yarn_mag).abs() < 1e-4,
"YaRN magnitude should be 1/attn_temp ≈ {expected_yarn_mag}, got {yarn_mag}"
);
}
#[test]
fn longrope_factor_two_doubles_angle() {
let head_dim = 8i32;
let seq = 4i32;
let base = 10000.0f32;
let half_d = (head_dim / 2) as usize;
let mut factors = vec![1.0f32; half_d];
factors[0] = 2.0; let linear = RopeScaledTableBuilder::new(head_dim, seq, base, RopeScaling::Linear)
.build_host_tables()
.expect("linear");
let lr = RopeScaledTableBuilder::new(
head_dim,
seq,
base,
RopeScaling::LongRoPE {
per_dim_factors: factors,
},
)
.build_host_tables()
.expect("longrope");
let s = 1usize;
let pair = 0usize;
let idx = s * half_d + pair;
let expected_linear_theta = 1.0f32; let expected_lr_theta = 2.0f32; assert!((linear.0[idx] - expected_linear_theta.cos()).abs() < 1e-6);
assert!((lr.0[idx] - expected_lr_theta.cos()).abs() < 1e-6);
assert!((lr.1[idx] - expected_lr_theta.sin()).abs() < 1e-6);
}
#[test]
fn validate_rejects_odd_head_dim() {
let b = RopeScaledTableBuilder::new(7, 4, 10000.0, RopeScaling::Linear);
assert!(b.validate().is_err());
}
#[test]
fn validate_rejects_longrope_factor_mismatch() {
let b = RopeScaledTableBuilder::new(
16,
4,
10000.0,
RopeScaling::LongRoPE {
per_dim_factors: vec![1.0; 3], },
);
assert!(b.validate().is_err());
}
#[test]
fn validate_rejects_yarn_alpha_ge_beta() {
let b = RopeScaledTableBuilder::new(
16,
4,
10000.0,
RopeScaling::YaRN {
scale: 4.0,
alpha: 32.0,
beta: 1.0,
original_max_seq_len: 2048,
},
);
assert!(b.validate().is_err());
}
}