use crate::config::LLaDA2MoeConfig;
pub fn inv_freq(cfg: &LLaDA2MoeConfig) -> Vec<f32> {
let dim = cfg.rope_dim();
let theta = cfg.rope_theta as f32;
(0..dim)
.step_by(2)
.map(|i| 1.0 / theta.powf(i as f32 / dim as f32))
.collect()
}
pub fn build_rope_tables(
cfg: &LLaDA2MoeConfig,
inv_freq: &[f32],
max_seq: usize,
) -> (Vec<f32>, Vec<f32>) {
let head_dim = cfg.head_dim();
let rope_dim = cfg.rope_dim();
let tab_half = head_dim / 2;
let rot_half = rope_dim / 2;
let mut cos = vec![0f32; max_seq * tab_half];
let mut sin = vec![0f32; max_seq * tab_half];
for pos in 0..max_seq {
let base = pos * tab_half;
for j in 0..rot_half {
let angle = pos as f32 * inv_freq[j];
let c = angle.cos();
let s = angle.sin();
cos[base + j] = c;
sin[base + j] = s;
if rope_dim > 2 && base + rot_half + j < cos.len() {
cos[base + rot_half + j] = c;
sin[base + rot_half + j] = s;
}
}
}
(cos, sin)
}