use crate::error::ModelResult;
#[derive(Debug)]
pub struct RopeTable {
cos: Vec<f32>,
sin: Vec<f32>,
half_dim: usize,
max_seq_len: usize,
}
impl RopeTable {
pub fn new(head_dim: usize, max_seq_len: usize, freq_base: f32) -> Self {
let half_dim = head_dim / 2;
let mut cos = vec![0.0f32; max_seq_len * half_dim];
let mut sin = vec![0.0f32; max_seq_len * half_dim];
for pos in 0..max_seq_len {
for i in 0..half_dim {
let freq = 1.0 / freq_base.powf(2.0 * i as f32 / head_dim as f32);
let angle = pos as f32 * freq;
cos[pos * half_dim + i] = angle.cos();
sin[pos * half_dim + i] = angle.sin();
}
}
Self {
cos,
sin,
half_dim,
max_seq_len,
}
}
pub fn apply(&self, vec: &[f32], output: &mut [f32], pos: usize) -> ModelResult<()> {
debug_assert!(pos < self.max_seq_len);
debug_assert_eq!(vec.len(), self.half_dim * 2);
debug_assert!(output.len() >= self.half_dim * 2);
let cos_row = &self.cos[pos * self.half_dim..(pos + 1) * self.half_dim];
let sin_row = &self.sin[pos * self.half_dim..(pos + 1) * self.half_dim];
oxibonsai_kernels::rope_apply_simd(vec, output, cos_row, sin_row);
Ok(())
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn cos_at(&self, pos: usize) -> &[f32] {
&self.cos[pos * self.half_dim..(pos + 1) * self.half_dim]
}
pub fn sin_at(&self, pos: usize) -> &[f32] {
&self.sin[pos * self.half_dim..(pos + 1) * self.half_dim]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rope_at_position_zero_is_identity() {
let table = RopeTable::new(4, 16, 10000.0);
let input = vec![1.0, 2.0, 3.0, 4.0];
let mut output = vec![0.0; 4];
table
.apply(&input, &mut output, 0)
.expect("rope apply should succeed");
assert!((output[0] - 1.0).abs() < 1e-5);
assert!((output[1] - 2.0).abs() < 1e-5);
assert!((output[2] - 3.0).abs() < 1e-5);
assert!((output[3] - 4.0).abs() < 1e-5);
}
#[test]
fn rope_preserves_norm() {
let table = RopeTable::new(4, 16, 10000.0);
let input = vec![1.0, 0.0, 0.0, 1.0];
let mut output = vec![0.0; 4];
table
.apply(&input, &mut output, 5)
.expect("rope apply should succeed");
let input_norm: f32 = input.iter().map(|x| x * x).sum::<f32>().sqrt();
let output_norm: f32 = output.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(input_norm - output_norm).abs() < 1e-4,
"RoPE should preserve vector norm"
);
}
}