use super::config::GemmaConfig;
pub fn default_inv_freq(rope_theta: f64, head_dim: usize) -> Vec<f64> {
(0..head_dim)
.step_by(2)
.map(|i| 1.0 / rope_theta.powf(i as f64 / head_dim as f64))
.collect()
}
pub fn resolve_inv_freq(cfg: &GemmaConfig, rope_freq_factors: Option<&[f32]>) -> Vec<f64> {
let base = default_inv_freq(cfg.rope_theta, cfg.head_dim());
match rope_freq_factors {
Some(f) if !f.is_empty() => base
.iter()
.zip(f.iter())
.map(|(f, ff)| f / *ff as f64)
.collect(),
_ => base,
}
}
pub fn build_rope_tables(inv_freq: &[f64], max_pos: usize) -> (Vec<f32>, Vec<f32>) {
let half = inv_freq.len();
let mut cos = vec![0f32; max_pos * half];
let mut sin = vec![0f32; max_pos * half];
for pos in 0..max_pos {
for (i, &freq) in inv_freq.iter().enumerate() {
let angle = pos as f64 * freq;
cos[pos * half + i] = angle.cos() as f32;
sin[pos * half + i] = angle.sin() as f32;
}
}
(cos, sin)
}
pub fn rope_slice(inv_freq: &[f64], pos: usize) -> (Vec<f32>, Vec<f32>) {
let half = inv_freq.len();
let mut cos = vec![0f32; half];
let mut sin = vec![0f32; half];
for (i, &freq) in inv_freq.iter().enumerate() {
let angle = pos as f64 * freq;
let (s, c) = angle.sin_cos();
cos[i] = c as f32;
sin[i] = s as f32;
}
(cos, sin)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tables_len_matches_max_pos() {
let inv = default_inv_freq(10_000.0, 8);
let (c, s) = build_rope_tables(&inv, 4);
assert_eq!(c.len(), 4 * inv.len());
assert_eq!(s.len(), c.len());
}
}