use burn::prelude::*;
use burn::tensor::backend::Backend;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RotaryPositionEncoding2DConfig {
pub embed_dim: usize,
pub max_height: usize,
pub max_width: usize,
pub base_freq: f64,
}
impl RotaryPositionEncoding2DConfig {
pub fn new(embed_dim: usize, max_height: usize, max_width: usize) -> Self {
Self {
embed_dim,
max_height,
max_width,
base_freq: 10000.0,
}
}
pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryPositionEncoding2D<B> {
let half_dim = self.embed_dim / 2;
let quarter_dim = half_dim / 2;
let max_seq = self.max_height * self.max_width;
let mut freqs_data = Vec::with_capacity(quarter_dim);
for i in 0..quarter_dim {
let freq = 1.0 / self.base_freq.powf(2.0 * i as f64 / half_dim as f64);
freqs_data.push(freq as f32);
}
let mut cos_data = vec![0.0f32; max_seq * half_dim];
let mut sin_data = vec![0.0f32; max_seq * half_dim];
for row in 0..self.max_height {
for col in 0..self.max_width {
let pos = row * self.max_width + col;
for (i, &freq) in freqs_data.iter().enumerate() {
let angle = row as f64 * freq as f64;
cos_data[pos * half_dim + i] = angle.cos() as f32;
sin_data[pos * half_dim + i] = angle.sin() as f32;
}
for (i, &freq) in freqs_data.iter().enumerate() {
let angle = col as f64 * freq as f64;
cos_data[pos * half_dim + quarter_dim + i] = angle.cos() as f32;
sin_data[pos * half_dim + quarter_dim + i] = angle.sin() as f32;
}
}
}
let cos_table = Tensor::from_floats(
burn::tensor::TensorData::new(cos_data, [max_seq, half_dim]),
device,
);
let sin_table = Tensor::from_floats(
burn::tensor::TensorData::new(sin_data, [max_seq, half_dim]),
device,
);
RotaryPositionEncoding2D {
cos_table,
sin_table,
embed_dim: self.embed_dim,
}
}
}
#[derive(Module, Debug)]
pub struct RotaryPositionEncoding2D<B: Backend> {
cos_table: Tensor<B, 2>,
sin_table: Tensor<B, 2>,
embed_dim: usize,
}
impl<B: Backend> RotaryPositionEncoding2D<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [batch, seq_len, _dim] = x.dims();
let half_dim = self.embed_dim / 2;
let cos = self.cos_table.clone().slice([0..seq_len, 0..half_dim]); let sin = self.sin_table.clone().slice([0..seq_len, 0..half_dim]);
let cos = cos.unsqueeze::<3>().expand([batch, seq_len, half_dim]);
let sin = sin.unsqueeze::<3>().expand([batch, seq_len, half_dim]);
let x1 = x.clone().slice([0..batch, 0..seq_len, 0..half_dim]);
let x2 = x
.clone()
.slice([0..batch, 0..seq_len, half_dim..self.embed_dim]);
let out1 = x1.clone() * cos.clone() - x2.clone() * sin.clone();
let out2 = x1 * sin + x2 * cos;
Tensor::cat(vec![out1, out2], 2)
}
pub fn embed_dim(&self) -> usize {
self.embed_dim
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::ElementConversion;
use burn_ndarray::NdArray;
type TestBackend = NdArray<f32>;
fn device() -> burn_ndarray::NdArrayDevice {
burn_ndarray::NdArrayDevice::Cpu
}
#[test]
fn test_rope_output_shape() {
let config = RotaryPositionEncoding2DConfig::new(64, 14, 14);
let rope = config.init::<TestBackend>(&device());
let x: Tensor<TestBackend, 3> = Tensor::ones([2, 196, 64], &device());
let out = rope.forward(x);
assert_eq!(out.dims(), [2, 196, 64]);
}
#[test]
fn test_rope_preserves_norm_approximately() {
let config = RotaryPositionEncoding2DConfig::new(32, 4, 4);
let rope = config.init::<TestBackend>(&device());
let x: Tensor<TestBackend, 3> = Tensor::random(
[1, 16, 32],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device(),
);
let x_norm: f32 = (x.clone() * x.clone()).sum().into_scalar().elem();
let out = rope.forward(x);
let out_norm: f32 = (out.clone() * out.clone()).sum().into_scalar().elem();
let ratio = out_norm / x_norm;
assert!(
(ratio - 1.0).abs() < 0.01,
"RoPE should approximately preserve norm, ratio: {ratio}"
);
}
#[test]
fn test_rope_different_positions_give_different_outputs() {
let config = RotaryPositionEncoding2DConfig::new(16, 4, 4);
let rope = config.init::<TestBackend>(&device());
let x: Tensor<TestBackend, 3> = Tensor::ones([1, 16, 16], &device());
let out = rope.forward(x);
let pos0 = out.clone().slice([0..1, 0..1, 0..16]);
let pos1 = out.clone().slice([0..1, 1..2, 0..16]);
let diff: f32 = (pos0 - pos1).abs().sum().into_scalar().elem();
assert!(
diff > 1e-6,
"different positions should produce different outputs"
);
}
#[test]
fn test_rope_small_grid() {
let config = RotaryPositionEncoding2DConfig::new(8, 2, 2);
let rope = config.init::<TestBackend>(&device());
let x: Tensor<TestBackend, 3> = Tensor::ones([1, 4, 8], &device());
let out = rope.forward(x);
assert_eq!(out.dims(), [1, 4, 8]);
}
use proptest::prelude::*;
proptest! {
#[test]
fn prop_rope_preserves_shape(
grid_h in 2usize..5,
grid_w in 2usize..5,
embed_dim in proptest::sample::select(vec![8usize, 16, 32]),
) {
let config = RotaryPositionEncoding2DConfig::new(embed_dim, grid_h, grid_w);
let rope = config.init::<TestBackend>(&device());
let seq_len = grid_h * grid_w;
let x: Tensor<TestBackend, 3> = Tensor::ones([1, seq_len, embed_dim], &device());
let out = rope.forward(x);
prop_assert_eq!(out.dims(), [1, seq_len, embed_dim]);
}
#[test]
fn prop_rope_preserves_norm(
grid_h in 2usize..4,
grid_w in 2usize..4,
) {
let embed_dim = 16;
let config = RotaryPositionEncoding2DConfig::new(embed_dim, grid_h, grid_w);
let rope = config.init::<TestBackend>(&device());
let seq_len = grid_h * grid_w;
let x: Tensor<TestBackend, 3> = Tensor::random(
[1, seq_len, embed_dim],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device(),
);
let x_norm: f32 = (x.clone() * x.clone()).sum().into_scalar().elem();
let out = rope.forward(x);
let out_norm: f32 = (out.clone() * out.clone()).sum().into_scalar().elem();
let ratio = out_norm / x_norm;
prop_assert!((ratio - 1.0).abs() < 0.01, "RoPE norm ratio: {}", ratio);
}
}
}