use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct RopeConfig {
pub theta: f64,
pub dim: usize,
pub max_seq_len: usize,
pub ntk_factor: Option<f64>,
}
impl Default for RopeConfig {
fn default() -> Self {
Self {
theta: 10000.0,
dim: 64,
max_seq_len: 8192,
ntk_factor: None,
}
}
}
#[derive(Clone)]
pub struct RotaryPositionEmbedding<B: Backend> {
cos_cached: Tensor<B, 2>,
sin_cached: Tensor<B, 2>,
max_cached_len: usize,
dim: usize,
}
impl<B: Backend> RotaryPositionEmbedding<B> {
pub fn new(device: &B::Device, config: RopeConfig) -> Self {
let dim = config.dim;
let max_seq_len = config.max_seq_len;
let half_dim = dim / 2;
let theta = if let Some(factor) = config.ntk_factor {
config.theta * factor
} else {
config.theta
};
let inv_freq: Vec<f32> = (0..half_dim)
.map(|i| {
let exponent = -2.0 * (i as f64) / (dim as f64);
(theta.powf(exponent)) as f32
})
.collect();
let positions: Vec<f32> = (0..max_seq_len).map(|p| p as f32).collect();
let mut freqs = Vec::with_capacity(max_seq_len * half_dim);
for pos in &positions {
for inv_f in &inv_freq {
freqs.push(pos * inv_f);
}
}
let cos_vals: Vec<f32> = freqs.iter().map(|f| f.cos()).collect();
let sin_vals: Vec<f32> = freqs.iter().map(|f| f.sin()).collect();
let cos_cached = Tensor::<B, 2>::from_data(
burn::tensor::TensorData::new(cos_vals, [max_seq_len, half_dim]),
device,
);
let sin_cached = Tensor::<B, 2>::from_data(
burn::tensor::TensorData::new(sin_vals, [max_seq_len, half_dim]),
device,
);
Self {
cos_cached,
sin_cached,
max_cached_len: max_seq_len,
dim,
}
}
pub fn apply(
&self,
query: Tensor<B, 4>,
key: Tensor<B, 4>,
position_offset: usize,
) -> (Tensor<B, 4>, Tensor<B, 4>) {
let [_batch, seq_len, _num_heads, head_dim] = query.dims();
let half_dim = head_dim / 2;
debug_assert!(position_offset + seq_len <= self.max_cached_len);
let cos = self.cos_cached
.clone()
.slice([position_offset..(position_offset + seq_len), 0..half_dim])
.reshape([1, seq_len, 1, half_dim]);
let sin = self.sin_cached
.clone()
.slice([position_offset..(position_offset + seq_len), 0..half_dim])
.reshape([1, seq_len, 1, half_dim]);
let q_rotated = self.rotate_half(query, cos.clone(), sin.clone());
let k_rotated = self.rotate_half(key, cos, sin);
(q_rotated, k_rotated)
}
pub fn apply_to_hidden_states(
&self,
hidden_states: Tensor<B, 3>,
position_offset: usize,
) -> Tensor<B, 3> {
let [batch, seq_len, _hidden_size] = hidden_states.dims();
let half_dim = self.dim / 2;
debug_assert!(position_offset + seq_len <= self.max_cached_len);
let cos = self.cos_cached
.clone()
.slice([position_offset..(position_offset + seq_len), 0..half_dim])
.reshape([1, seq_len, half_dim]);
let sin = self.sin_cached
.clone()
.slice([position_offset..(position_offset + seq_len), 0..half_dim])
.reshape([1, seq_len, half_dim]);
let x1 = hidden_states.clone().slice([0..batch, 0..seq_len, 0..half_dim]);
let x2 = hidden_states.clone().slice([0..batch, 0..seq_len, half_dim..self.dim]);
let x1_rotated = x1.clone() * cos.clone() - x2.clone() * sin.clone();
let x2_rotated = x2 * cos + x1 * sin;
Tensor::cat(vec![x1_rotated, x2_rotated], 2)
}
fn rotate_half(
&self,
x: Tensor<B, 4>,
cos: Tensor<B, 4>,
sin: Tensor<B, 4>,
) -> Tensor<B, 4> {
let [batch, seq_len, num_heads, head_dim] = x.dims();
let half_dim = head_dim / 2;
let x1 = x.clone().slice([0..batch, 0..seq_len, 0..num_heads, 0..half_dim]);
let x2 = x.clone().slice([0..batch, 0..seq_len, 0..num_heads, half_dim..head_dim]);
let x1_rotated = x1.clone() * cos.clone() - x2.clone() * sin.clone();
let x2_rotated = x2 * cos + x1 * sin;
Tensor::cat(vec![x1_rotated, x2_rotated], 3)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::ndarray::NdArray;
#[test]
fn rope_creates_frequency_tables() {
let device = <NdArray<f32> as Backend>::Device::default();
let config = RopeConfig {
theta: 10000.0,
dim: 64,
max_seq_len: 512,
ntk_factor: None,
};
let rope = RotaryPositionEmbedding::<NdArray<f32>>::new(&device, config);
assert_eq!(rope.cos_cached.dims(), [512, 32]);
assert_eq!(rope.sin_cached.dims(), [512, 32]);
}
#[test]
fn rope_apply_preserves_shape() {
let device = <NdArray<f32> as Backend>::Device::default();
let config = RopeConfig {
theta: 10000.0,
dim: 64,
max_seq_len: 512,
ntk_factor: None,
};
let rope = RotaryPositionEmbedding::<NdArray<f32>>::new(&device, config);
let query = Tensor::<NdArray<f32>, 4>::zeros([2, 10, 8, 64], &device);
let key = Tensor::<NdArray<f32>, 4>::zeros([2, 10, 8, 64], &device);
let (q_rot, k_rot) = rope.apply(query, key, 0);
assert_eq!(q_rot.dims(), [2, 10, 8, 64]);
assert_eq!(k_rot.dims(), [2, 10, 8, 64]);
}
}