use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct RotaryEmbedding {
pub inv_freq: Vec<f32>,
pub dim: usize,
}
impl RotaryEmbedding {
pub fn new(dim: usize) -> Self {
let base: f64 = 10000.0;
let half = dim / 2;
let inv_freq: Vec<f32> = (0..half)
.map(|i| {
let exp = (2 * i) as f64 / dim as f64;
(1.0 / base.powf(exp)) as f32
})
.collect();
Self { inv_freq, dim }
}
pub fn forward(&self, seq_len: usize) -> Tensor {
let half = self.inv_freq.len();
let rot_dim = half * 2;
let mut data = vec![0.0f32; seq_len * rot_dim];
for pos in 0..seq_len {
for j in 0..half {
let freq = pos as f32 * self.inv_freq[j];
data[pos * rot_dim + j] = freq;
data[pos * rot_dim + half + j] = freq;
}
}
Tensor::from_vec(data, vec![seq_len, rot_dim])
}
}