Skip to main content

jepa_vision/
rope.rs

1//! Rotary Position Embedding (RoPE) for 2D spatial positions.
2//!
3//! Implements position encoding from RFC-002 (Encoder Module).
4//!
5//! Rotary Position Embedding encodes absolute token positions by rotating
6//! query and key vectors in attention. Unlike learned positional
7//! embeddings, RoPE is parameter-free, extrapolates to unseen lengths,
8//! and makes relative distances naturally emerge from the dot product.
9//!
10//! For 2D images the embedding dimension is split into two halves:
11//! one half encodes the row position, the other half encodes the column
12//! position, giving each patch a unique spatial signature.
13//!
14//! ```text
15//! embed_dim = [── height freqs ──|── width freqs ──]
16//!              quarter_dim         quarter_dim
17//! ```
18//!
19//! Sin/cos tables are **precomputed** at init time for a fixed maximum
20//! grid size, then sliced to `seq_len` at forward time.
21//!
22//! For 3D video RoPE, see [`crate::video`].
23//!
24//! Reference: Su et al. (2021), *RoFormer: Enhanced Transformer with
25//! Rotary Position Embedding*.
26
27use burn::prelude::*;
28use burn::tensor::backend::Backend;
29
30/// Configuration for 2D Rotary Position Embedding.
31///
32/// # Example
33///
34/// ```
35/// use jepa_vision::rope::RotaryPositionEncoding2DConfig;
36/// use burn_ndarray::NdArray;
37/// use burn::prelude::*;
38///
39/// type B = NdArray<f32>;
40/// let device = burn_ndarray::NdArrayDevice::Cpu;
41///
42/// let config = RotaryPositionEncoding2DConfig::new(64, 14, 14);
43/// let rope = config.init::<B>(&device);
44/// assert_eq!(rope.embed_dim(), 64);
45/// ```
46#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
47pub struct RotaryPositionEncoding2DConfig {
48    /// Embedding dimension (must be even for rotation pairs).
49    pub embed_dim: usize,
50    /// Maximum grid height (number of patch rows).
51    pub max_height: usize,
52    /// Maximum grid width (number of patch columns).
53    pub max_width: usize,
54    /// Base frequency for the sinusoidal encoding (default: 10000.0).
55    pub base_freq: f64,
56}
57
58impl RotaryPositionEncoding2DConfig {
59    /// Create a new config.
60    pub fn new(embed_dim: usize, max_height: usize, max_width: usize) -> Self {
61        Self {
62            embed_dim,
63            max_height,
64            max_width,
65            base_freq: 10000.0,
66        }
67    }
68
69    /// Initialize the position encoding, precomputing sin/cos tables.
70    pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryPositionEncoding2D<B> {
71        let half_dim = self.embed_dim / 2;
72        let quarter_dim = half_dim / 2;
73        let max_seq = self.max_height * self.max_width;
74
75        // Compute frequency bands: freq_i = 1 / (base ^ (2i / dim))
76        let mut freqs_data = Vec::with_capacity(quarter_dim);
77        for i in 0..quarter_dim {
78            let freq = 1.0 / self.base_freq.powf(2.0 * i as f64 / half_dim as f64);
79            freqs_data.push(freq as f32);
80        }
81
82        // Build position-frequency tables for height and width
83        let mut cos_data = vec![0.0f32; max_seq * half_dim];
84        let mut sin_data = vec![0.0f32; max_seq * half_dim];
85
86        for row in 0..self.max_height {
87            for col in 0..self.max_width {
88                let pos = row * self.max_width + col;
89                // First quarter_dim: height frequencies
90                for (i, &freq) in freqs_data.iter().enumerate() {
91                    let angle = row as f64 * freq as f64;
92                    cos_data[pos * half_dim + i] = angle.cos() as f32;
93                    sin_data[pos * half_dim + i] = angle.sin() as f32;
94                }
95                // Second quarter_dim: width frequencies
96                for (i, &freq) in freqs_data.iter().enumerate() {
97                    let angle = col as f64 * freq as f64;
98                    cos_data[pos * half_dim + quarter_dim + i] = angle.cos() as f32;
99                    sin_data[pos * half_dim + quarter_dim + i] = angle.sin() as f32;
100                }
101            }
102        }
103
104        let cos_table = Tensor::from_floats(
105            burn::tensor::TensorData::new(cos_data, [max_seq, half_dim]),
106            device,
107        );
108        let sin_table = Tensor::from_floats(
109            burn::tensor::TensorData::new(sin_data, [max_seq, half_dim]),
110            device,
111        );
112
113        RotaryPositionEncoding2D {
114            cos_table,
115            sin_table,
116            embed_dim: self.embed_dim,
117        }
118    }
119}
120
121/// 2D Rotary Position Embedding.
122///
123/// Applies rotary encoding to query/key tensors by rotating pairs of
124/// dimensions using precomputed sin/cos tables derived from 2D grid positions.
125#[derive(Module, Debug)]
126pub struct RotaryPositionEncoding2D<B: Backend> {
127    /// Precomputed cosine table. Shape: `[max_seq, half_dim]`
128    cos_table: Tensor<B, 2>,
129    /// Precomputed sine table. Shape: `[max_seq, half_dim]`
130    sin_table: Tensor<B, 2>,
131    /// Full embedding dimension.
132    embed_dim: usize,
133}
134
135impl<B: Backend> RotaryPositionEncoding2D<B> {
136    /// Apply rotary encoding to a tensor.
137    ///
138    /// # Arguments
139    /// * `x` - Input tensor. Shape: `[batch, seq_len, embed_dim]`
140    ///
141    /// # Returns
142    /// Rotated tensor with position information encoded. Same shape as input.
143    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
144        let [batch, seq_len, _dim] = x.dims();
145        let half_dim = self.embed_dim / 2;
146
147        // Slice cos/sin tables to current seq_len
148        let cos = self.cos_table.clone().slice([0..seq_len, 0..half_dim]); // [seq_len, half_dim]
149        let sin = self.sin_table.clone().slice([0..seq_len, 0..half_dim]); // [seq_len, half_dim]
150
151        // Unsqueeze for broadcasting over batch: [1, seq_len, half_dim]
152        let cos = cos.unsqueeze::<3>().expand([batch, seq_len, half_dim]);
153        let sin = sin.unsqueeze::<3>().expand([batch, seq_len, half_dim]);
154
155        // Split x into two halves
156        let x1 = x.clone().slice([0..batch, 0..seq_len, 0..half_dim]);
157        let x2 = x
158            .clone()
159            .slice([0..batch, 0..seq_len, half_dim..self.embed_dim]);
160
161        // Apply rotation: [x1 * cos - x2 * sin, x1 * sin + x2 * cos]
162        let out1 = x1.clone() * cos.clone() - x2.clone() * sin.clone();
163        let out2 = x1 * sin + x2 * cos;
164
165        Tensor::cat(vec![out1, out2], 2)
166    }
167
168    /// Get the embedding dimension.
169    pub fn embed_dim(&self) -> usize {
170        self.embed_dim
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use burn::tensor::ElementConversion;
178    use burn_ndarray::NdArray;
179
180    type TestBackend = NdArray<f32>;
181
182    fn device() -> burn_ndarray::NdArrayDevice {
183        burn_ndarray::NdArrayDevice::Cpu
184    }
185
186    #[test]
187    fn test_rope_output_shape() {
188        let config = RotaryPositionEncoding2DConfig::new(64, 14, 14);
189        let rope = config.init::<TestBackend>(&device());
190
191        let x: Tensor<TestBackend, 3> = Tensor::ones([2, 196, 64], &device());
192        let out = rope.forward(x);
193        assert_eq!(out.dims(), [2, 196, 64]);
194    }
195
196    #[test]
197    fn test_rope_preserves_norm_approximately() {
198        // RoPE is a rotation, so it should approximately preserve vector norms
199        let config = RotaryPositionEncoding2DConfig::new(32, 4, 4);
200        let rope = config.init::<TestBackend>(&device());
201
202        let x: Tensor<TestBackend, 3> = Tensor::random(
203            [1, 16, 32],
204            burn::tensor::Distribution::Normal(0.0, 1.0),
205            &device(),
206        );
207
208        let x_norm: f32 = (x.clone() * x.clone()).sum().into_scalar().elem();
209
210        let out = rope.forward(x);
211        let out_norm: f32 = (out.clone() * out.clone()).sum().into_scalar().elem();
212
213        let ratio = out_norm / x_norm;
214        assert!(
215            (ratio - 1.0).abs() < 0.01,
216            "RoPE should approximately preserve norm, ratio: {ratio}"
217        );
218    }
219
220    #[test]
221    fn test_rope_different_positions_give_different_outputs() {
222        let config = RotaryPositionEncoding2DConfig::new(16, 4, 4);
223        let rope = config.init::<TestBackend>(&device());
224
225        // Same vector at all positions
226        let x: Tensor<TestBackend, 3> = Tensor::ones([1, 16, 16], &device());
227        let out = rope.forward(x);
228
229        // Extract position 0 and position 1
230        let pos0 = out.clone().slice([0..1, 0..1, 0..16]);
231        let pos1 = out.clone().slice([0..1, 1..2, 0..16]);
232
233        // They should be different because of position encoding
234        let diff: f32 = (pos0 - pos1).abs().sum().into_scalar().elem();
235        assert!(
236            diff > 1e-6,
237            "different positions should produce different outputs"
238        );
239    }
240
241    #[test]
242    fn test_rope_small_grid() {
243        let config = RotaryPositionEncoding2DConfig::new(8, 2, 2);
244        let rope = config.init::<TestBackend>(&device());
245
246        let x: Tensor<TestBackend, 3> = Tensor::ones([1, 4, 8], &device());
247        let out = rope.forward(x);
248        assert_eq!(out.dims(), [1, 4, 8]);
249    }
250
251    use proptest::prelude::*;
252
253    proptest! {
254        #[test]
255        fn prop_rope_preserves_shape(
256            grid_h in 2usize..5,
257            grid_w in 2usize..5,
258            embed_dim in proptest::sample::select(vec![8usize, 16, 32]),
259        ) {
260            let config = RotaryPositionEncoding2DConfig::new(embed_dim, grid_h, grid_w);
261            let rope = config.init::<TestBackend>(&device());
262            let seq_len = grid_h * grid_w;
263            let x: Tensor<TestBackend, 3> = Tensor::ones([1, seq_len, embed_dim], &device());
264            let out = rope.forward(x);
265            prop_assert_eq!(out.dims(), [1, seq_len, embed_dim]);
266        }
267
268        #[test]
269        fn prop_rope_preserves_norm(
270            grid_h in 2usize..4,
271            grid_w in 2usize..4,
272        ) {
273            let embed_dim = 16;
274            let config = RotaryPositionEncoding2DConfig::new(embed_dim, grid_h, grid_w);
275            let rope = config.init::<TestBackend>(&device());
276            let seq_len = grid_h * grid_w;
277            let x: Tensor<TestBackend, 3> = Tensor::random(
278                [1, seq_len, embed_dim],
279                burn::tensor::Distribution::Normal(0.0, 1.0),
280                &device(),
281            );
282            let x_norm: f32 = (x.clone() * x.clone()).sum().into_scalar().elem();
283            let out = rope.forward(x);
284            let out_norm: f32 = (out.clone() * out.clone()).sum().into_scalar().elem();
285            let ratio = out_norm / x_norm;
286            prop_assert!((ratio - 1.0).abs() < 0.01, "RoPE norm ratio: {}", ratio);
287        }
288    }
289}