1use burn::prelude::*;
28use burn::tensor::backend::Backend;
29
30#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
47pub struct RotaryPositionEncoding2DConfig {
48 pub embed_dim: usize,
50 pub max_height: usize,
52 pub max_width: usize,
54 pub base_freq: f64,
56}
57
58impl RotaryPositionEncoding2DConfig {
59 pub fn new(embed_dim: usize, max_height: usize, max_width: usize) -> Self {
61 Self {
62 embed_dim,
63 max_height,
64 max_width,
65 base_freq: 10000.0,
66 }
67 }
68
69 pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryPositionEncoding2D<B> {
71 let half_dim = self.embed_dim / 2;
72 let quarter_dim = half_dim / 2;
73 let max_seq = self.max_height * self.max_width;
74
75 let mut freqs_data = Vec::with_capacity(quarter_dim);
77 for i in 0..quarter_dim {
78 let freq = 1.0 / self.base_freq.powf(2.0 * i as f64 / half_dim as f64);
79 freqs_data.push(freq as f32);
80 }
81
82 let mut cos_data = vec![0.0f32; max_seq * half_dim];
84 let mut sin_data = vec![0.0f32; max_seq * half_dim];
85
86 for row in 0..self.max_height {
87 for col in 0..self.max_width {
88 let pos = row * self.max_width + col;
89 for (i, &freq) in freqs_data.iter().enumerate() {
91 let angle = row as f64 * freq as f64;
92 cos_data[pos * half_dim + i] = angle.cos() as f32;
93 sin_data[pos * half_dim + i] = angle.sin() as f32;
94 }
95 for (i, &freq) in freqs_data.iter().enumerate() {
97 let angle = col as f64 * freq as f64;
98 cos_data[pos * half_dim + quarter_dim + i] = angle.cos() as f32;
99 sin_data[pos * half_dim + quarter_dim + i] = angle.sin() as f32;
100 }
101 }
102 }
103
104 let cos_table = Tensor::from_floats(
105 burn::tensor::TensorData::new(cos_data, [max_seq, half_dim]),
106 device,
107 );
108 let sin_table = Tensor::from_floats(
109 burn::tensor::TensorData::new(sin_data, [max_seq, half_dim]),
110 device,
111 );
112
113 RotaryPositionEncoding2D {
114 cos_table,
115 sin_table,
116 embed_dim: self.embed_dim,
117 }
118 }
119}
120
121#[derive(Module, Debug)]
126pub struct RotaryPositionEncoding2D<B: Backend> {
127 cos_table: Tensor<B, 2>,
129 sin_table: Tensor<B, 2>,
131 embed_dim: usize,
133}
134
135impl<B: Backend> RotaryPositionEncoding2D<B> {
136 pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
144 let [batch, seq_len, _dim] = x.dims();
145 let half_dim = self.embed_dim / 2;
146
147 let cos = self.cos_table.clone().slice([0..seq_len, 0..half_dim]); let sin = self.sin_table.clone().slice([0..seq_len, 0..half_dim]); let cos = cos.unsqueeze::<3>().expand([batch, seq_len, half_dim]);
153 let sin = sin.unsqueeze::<3>().expand([batch, seq_len, half_dim]);
154
155 let x1 = x.clone().slice([0..batch, 0..seq_len, 0..half_dim]);
157 let x2 = x
158 .clone()
159 .slice([0..batch, 0..seq_len, half_dim..self.embed_dim]);
160
161 let out1 = x1.clone() * cos.clone() - x2.clone() * sin.clone();
163 let out2 = x1 * sin + x2 * cos;
164
165 Tensor::cat(vec![out1, out2], 2)
166 }
167
168 pub fn embed_dim(&self) -> usize {
170 self.embed_dim
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use burn::tensor::ElementConversion;
178 use burn_ndarray::NdArray;
179
180 type TestBackend = NdArray<f32>;
181
182 fn device() -> burn_ndarray::NdArrayDevice {
183 burn_ndarray::NdArrayDevice::Cpu
184 }
185
186 #[test]
187 fn test_rope_output_shape() {
188 let config = RotaryPositionEncoding2DConfig::new(64, 14, 14);
189 let rope = config.init::<TestBackend>(&device());
190
191 let x: Tensor<TestBackend, 3> = Tensor::ones([2, 196, 64], &device());
192 let out = rope.forward(x);
193 assert_eq!(out.dims(), [2, 196, 64]);
194 }
195
196 #[test]
197 fn test_rope_preserves_norm_approximately() {
198 let config = RotaryPositionEncoding2DConfig::new(32, 4, 4);
200 let rope = config.init::<TestBackend>(&device());
201
202 let x: Tensor<TestBackend, 3> = Tensor::random(
203 [1, 16, 32],
204 burn::tensor::Distribution::Normal(0.0, 1.0),
205 &device(),
206 );
207
208 let x_norm: f32 = (x.clone() * x.clone()).sum().into_scalar().elem();
209
210 let out = rope.forward(x);
211 let out_norm: f32 = (out.clone() * out.clone()).sum().into_scalar().elem();
212
213 let ratio = out_norm / x_norm;
214 assert!(
215 (ratio - 1.0).abs() < 0.01,
216 "RoPE should approximately preserve norm, ratio: {ratio}"
217 );
218 }
219
220 #[test]
221 fn test_rope_different_positions_give_different_outputs() {
222 let config = RotaryPositionEncoding2DConfig::new(16, 4, 4);
223 let rope = config.init::<TestBackend>(&device());
224
225 let x: Tensor<TestBackend, 3> = Tensor::ones([1, 16, 16], &device());
227 let out = rope.forward(x);
228
229 let pos0 = out.clone().slice([0..1, 0..1, 0..16]);
231 let pos1 = out.clone().slice([0..1, 1..2, 0..16]);
232
233 let diff: f32 = (pos0 - pos1).abs().sum().into_scalar().elem();
235 assert!(
236 diff > 1e-6,
237 "different positions should produce different outputs"
238 );
239 }
240
241 #[test]
242 fn test_rope_small_grid() {
243 let config = RotaryPositionEncoding2DConfig::new(8, 2, 2);
244 let rope = config.init::<TestBackend>(&device());
245
246 let x: Tensor<TestBackend, 3> = Tensor::ones([1, 4, 8], &device());
247 let out = rope.forward(x);
248 assert_eq!(out.dims(), [1, 4, 8]);
249 }
250
251 use proptest::prelude::*;
252
253 proptest! {
254 #[test]
255 fn prop_rope_preserves_shape(
256 grid_h in 2usize..5,
257 grid_w in 2usize..5,
258 embed_dim in proptest::sample::select(vec![8usize, 16, 32]),
259 ) {
260 let config = RotaryPositionEncoding2DConfig::new(embed_dim, grid_h, grid_w);
261 let rope = config.init::<TestBackend>(&device());
262 let seq_len = grid_h * grid_w;
263 let x: Tensor<TestBackend, 3> = Tensor::ones([1, seq_len, embed_dim], &device());
264 let out = rope.forward(x);
265 prop_assert_eq!(out.dims(), [1, seq_len, embed_dim]);
266 }
267
268 #[test]
269 fn prop_rope_preserves_norm(
270 grid_h in 2usize..4,
271 grid_w in 2usize..4,
272 ) {
273 let embed_dim = 16;
274 let config = RotaryPositionEncoding2DConfig::new(embed_dim, grid_h, grid_w);
275 let rope = config.init::<TestBackend>(&device());
276 let seq_len = grid_h * grid_w;
277 let x: Tensor<TestBackend, 3> = Tensor::random(
278 [1, seq_len, embed_dim],
279 burn::tensor::Distribution::Normal(0.0, 1.0),
280 &device(),
281 );
282 let x_norm: f32 = (x.clone() * x.clone()).sum().into_scalar().elem();
283 let out = rope.forward(x);
284 let out_norm: f32 = (out.clone() * out.clone()).sum().into_scalar().elem();
285 let ratio = out_norm / x_norm;
286 prop_assert!((ratio - 1.0).abs() < 0.01, "RoPE norm ratio: {}", ratio);
287 }
288 }
289}