use super::config::{Llama32Config, Llama32RopeScaling, Llama32RopeType};
use std::f64::consts::PI;
pub fn default_inv_freq(rope_theta: f64, head_dim: usize) -> Vec<f64> {
(0..head_dim)
.step_by(2)
.map(|i| 1.0 / rope_theta.powf(i as f64 / head_dim as f64))
.collect()
}
pub fn llama3_inv_freq(cfg: &Llama32Config) -> Vec<f64> {
let head_dim = cfg.head_dim();
let base = default_inv_freq(cfg.rope_theta, head_dim);
let Some(Llama32RopeScaling {
factor,
low_freq_factor,
high_freq_factor,
original_max_position_embeddings,
rope_type: Llama32RopeType::Llama3,
}) = cfg.rope_scaling.as_ref()
else {
return base;
};
let low_freq_wavelen = *original_max_position_embeddings as f64 / *low_freq_factor as f64;
let high_freq_wavelen = *original_max_position_embeddings as f64 / *high_freq_factor as f64;
base.into_iter()
.map(|freq| {
let wavelen = 2.0 * PI / freq;
if wavelen < high_freq_wavelen {
freq
} else if wavelen > low_freq_wavelen {
freq / *factor as f64
} else {
let smooth = (*original_max_position_embeddings as f64 / wavelen
- *low_freq_factor as f64)
/ (*high_freq_factor as f64 - *low_freq_factor as f64);
(1.0 - smooth) * freq / *factor as f64 + smooth * freq
}
})
.collect()
}
pub fn inv_freq_with_factors(base: &[f64], factors: &[f32]) -> Vec<f64> {
assert_eq!(
base.len(),
factors.len(),
"rope_freqs.weight length must match head_dim/2"
);
base.iter()
.zip(factors.iter())
.map(|(f, ff)| f / *ff as f64)
.collect()
}
pub fn resolve_inv_freq(cfg: &Llama32Config, rope_freq_factors: Option<&[f32]>) -> Vec<f64> {
let base = llama3_inv_freq(cfg);
match rope_freq_factors {
Some(f) if !f.is_empty() => inv_freq_with_factors(&base, f),
_ => base,
}
}
pub fn build_rope_tables(inv_freq: &[f64], max_pos: usize) -> (Vec<f32>, Vec<f32>) {
let half = inv_freq.len();
let mut cos = vec![0f32; max_pos * half];
let mut sin = vec![0f32; max_pos * half];
for pos in 0..max_pos {
for (i, &freq) in inv_freq.iter().enumerate() {
let angle = pos as f64 * freq;
let (s, c) = angle.sin_cos();
cos[pos * half + i] = c as f32;
sin[pos * half + i] = s as f32;
}
}
(cos, sin)
}
pub fn rope_slice(inv_freq: &[f64], pos: usize) -> (Vec<f32>, Vec<f32>) {
let half = inv_freq.len();
let mut cos = vec![0f32; half];
let mut sin = vec![0f32; half];
for (i, &freq) in inv_freq.iter().enumerate() {
let angle = pos as f64 * freq;
let (s, c) = angle.sin_cos();
cos[i] = c as f32;
sin[i] = s as f32;
}
(cos, sin)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Llama32Config;
#[test]
fn llama3_scaling_reduces_low_freq_inv_freq() {
let mut cfg = Llama32Config {
hidden_size: 64,
num_attention_heads: 4,
rope_theta: 500_000.0,
max_position_embeddings: 131_072,
rope_scaling: Some(Llama32RopeScaling {
factor: 32.0,
low_freq_factor: 1.0,
high_freq_factor: 4.0,
original_max_position_embeddings: 8192,
rope_type: Llama32RopeType::Llama3,
}),
..Llama32Config::tiny_test()
};
cfg.num_key_value_heads = 2;
let plain = default_inv_freq(cfg.rope_theta, cfg.head_dim());
let scaled = llama3_inv_freq(&cfg);
assert!(scaled.last().unwrap() < plain.last().unwrap());
}
}