use std::f64::consts::PI;
pub fn default_inv_freq(rope_theta: f64, head_dim: usize) -> Vec<f64> {
(0..head_dim)
.step_by(2)
.map(|i| 1.0 / rope_theta.powf(i as f64 / head_dim as f64))
.collect()
}
pub fn inv_freq_with_factors(base: &[f64], factors: &[f32]) -> Vec<f64> {
assert_eq!(
base.len(),
factors.len(),
"rope_freqs.weight length must match head_dim/2"
);
base.iter()
.zip(factors.iter())
.map(|(f, ff)| f / *ff as f64)
.collect()
}
#[derive(Debug, Clone, Copy)]
pub struct Llama3Scaling {
pub factor: f32,
pub low_freq_factor: f32,
pub high_freq_factor: f32,
pub original_max_position_embeddings: u32,
}
pub fn llama3_scaled_inv_freq(base: &[f64], s: &Llama3Scaling) -> Vec<f64> {
let low_freq_wavelen = s.original_max_position_embeddings as f64 / s.low_freq_factor as f64;
let high_freq_wavelen = s.original_max_position_embeddings as f64 / s.high_freq_factor as f64;
base.iter()
.map(|&freq| {
let wavelen = 2.0 * PI / freq;
if wavelen < high_freq_wavelen {
freq
} else if wavelen > low_freq_wavelen {
freq / s.factor as f64
} else {
let smooth = (s.original_max_position_embeddings as f64 / wavelen
- s.low_freq_factor as f64)
/ (s.high_freq_factor as f64 - s.low_freq_factor as f64);
(1.0 - smooth) * freq / s.factor as f64 + smooth * freq
}
})
.collect()
}
pub fn ntk_scaled_inv_freq(rope_theta: f64, head_dim: usize, factor: f32) -> Vec<f64> {
let alpha = (factor as f64).powf(head_dim as f64 / (head_dim as f64 - 2.0));
default_inv_freq(rope_theta * alpha, head_dim)
}
#[derive(Debug, Clone, Copy)]
pub struct YarnScaling {
pub factor: f32,
pub beta_fast: f32,
pub beta_slow: f32,
pub original_max_position_embeddings: u32,
}
pub fn yarn_scaled_inv_freq(base_theta: f64, head_dim: usize, s: &YarnScaling) -> Vec<f64> {
let base = default_inv_freq(base_theta, head_dim);
let ntk = ntk_scaled_inv_freq(base_theta, head_dim, s.factor);
let low = yarn_correction_dim(
s.beta_fast,
head_dim,
base_theta,
s.original_max_position_embeddings as f64,
);
let high = yarn_correction_dim(
s.beta_slow,
head_dim,
base_theta,
s.original_max_position_embeddings as f64,
);
let (low, high) = if low > high { (high, low) } else { (low, high) };
base.iter()
.zip(ntk.iter())
.enumerate()
.map(|(i, (b, n))| {
let mask = yarn_linear_ramp_mask(low, high, i as f64);
*n * (1.0 - mask) + *b * mask
})
.collect()
}
fn yarn_correction_dim(num_rot: f32, dim: usize, base: f64, max_pos: f64) -> f64 {
let num = (max_pos / (num_rot as f64 * 2.0 * PI)).ln();
let den = (base.ln()) * 2.0;
(dim as f64 * num / den)
.floor()
.max(0.0)
.min(dim as f64 / 2.0 - 1.0)
}
fn yarn_linear_ramp_mask(low: f64, high: f64, i: f64) -> f64 {
if (high - low).abs() < f64::EPSILON {
return 1.0;
}
((i - low) / (high - low)).clamp(0.0, 1.0)
}
pub fn build_tables(inv_freq: &[f64], max_pos: usize) -> (Vec<f32>, Vec<f32>) {
let half = inv_freq.len();
let mut cos = vec![0f32; max_pos * half];
let mut sin = vec![0f32; max_pos * half];
for pos in 0..max_pos {
for (i, &freq) in inv_freq.iter().enumerate() {
let angle = pos as f64 * freq;
cos[pos * half + i] = angle.cos() as f32;
sin[pos * half + i] = angle.sin() as f32;
}
}
(cos, sin)
}
pub fn build_default_tables(
rope_theta: f64,
head_dim: usize,
max_pos: usize,
) -> (Vec<f32>, Vec<f32>) {
build_tables(&default_inv_freq(rope_theta, head_dim), max_pos)
}
pub fn mrope_sections4(sections: &[usize]) -> [usize; 4] {
let mut out = [0usize; 4];
for (i, &v) in sections.iter().take(4).enumerate() {
out[i] = v;
}
out
}
pub fn mrope_section_for_pair(global_pair_j: usize, sections: [usize; 4]) -> usize {
let mut acc = 0usize;
for (sec_i, &sec_dim) in sections.iter().enumerate() {
if sec_dim == 0 {
continue;
}
if global_pair_j < acc + sec_dim {
return sec_i;
}
acc += sec_dim;
}
3
}
pub fn mrope_row_for_sections(
rope_theta: f64,
n_rot: usize,
sections: [usize; 4],
section_pos: [usize; 4],
head_half: usize,
) -> (Vec<f32>, Vec<f32>) {
let half_rot = n_rot / 2;
let mut cos = vec![0f32; head_half];
let mut sin = vec![0f32; head_half];
for global_j in 0..half_rot.min(head_half) {
let sec_i = mrope_section_for_pair(global_j, sections);
let p = section_pos[sec_i] as f64;
let freq = 1.0 / rope_theta.powf((2 * global_j) as f64 / n_rot as f64);
let angle = p * freq;
let (s, c) = angle.sin_cos();
cos[global_j] = c as f32;
sin[global_j] = s as f32;
}
for j in half_rot.min(head_half)..head_half {
cos[j] = 1.0;
sin[j] = 0.0;
}
(cos, sin)
}
pub fn build_mrope_text_tables(
rope_theta: f64,
head_dim: usize,
sections: [usize; 4],
max_pos: usize,
) -> (Vec<f32>, Vec<f32>) {
let half = head_dim / 2;
let mut cos = vec![1f32; max_pos * half];
let mut sin = vec![0f32; max_pos * half];
let inv = default_inv_freq(rope_theta, head_dim);
let mut offset = 0usize;
for §ion in sections.iter() {
if section == 0 {
continue;
}
let section_half = section / 2;
for pos in 0..max_pos {
for i in 0..section_half {
let idx = offset + i;
if idx >= half || i >= inv.len() {
break;
}
let angle = pos as f64 * inv[idx];
cos[pos * half + idx] = angle.cos() as f32;
sin[pos * half + idx] = angle.sin() as f32;
}
}
offset += section_half;
}
(cos, sin)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_freq_lengths() {
let f = default_inv_freq(10_000.0, 64);
assert_eq!(f.len(), 32);
assert!((f[0] - 1.0).abs() < 1e-9);
}
#[test]
fn tables_shape() {
let (cos, sin) = build_default_tables(10_000.0, 64, 16);
assert_eq!(cos.len(), 16 * 32);
assert_eq!(sin.len(), 16 * 32);
for i in 0..32 {
assert!((cos[i] - 1.0).abs() < 1e-6);
assert!(sin[i].abs() < 1e-6);
}
}
#[test]
fn llama3_scaling_high_freq_passthrough() {
let base = default_inv_freq(500_000.0, 128);
let scaling = Llama3Scaling {
factor: 8.0,
low_freq_factor: 1.0,
high_freq_factor: 4.0,
original_max_position_embeddings: 8192,
};
let scaled = llama3_scaled_inv_freq(&base, &scaling);
assert!((scaled[0] - base[0]).abs() < 1e-12);
}
#[test]
fn mrope_sections_clamp() {
assert_eq!(mrope_sections4(&[24, 20, 20, 0, 5]), [24, 20, 20, 0]);
assert_eq!(mrope_sections4(&[8]), [8, 0, 0, 0]);
}
}