Skip to main content

burn_nn/modules/
rope_encoding.rs

1use burn_core as burn;
2
3use alloc::vec;
4use burn::config::Config;
5use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
6use burn::tensor::Int;
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9use core::ops::Range;
10
11#[cfg(not(feature = "std"))]
12#[allow(unused_imports)]
13use num_traits::Float as _;
14
15/// Configuration to create a [RotaryEncoding](RotaryEncoding) layer using the [init function](RotaryEncodingConfig::init).
16#[derive(Config, Debug)]
17pub struct RotaryEncodingConfig {
18    /// Maximum sequence length of input
19    pub max_sequence_length: usize,
20
21    /// Size of the input embedding or hidden dimension
22    pub d_model: usize,
23
24    /// Scaling factor for frequency computation. Defaults to 10000.0
25    #[config(default = "10000.0")]
26    pub theta: f32,
27}
28
29impl RotaryEncodingConfig {
30    /// Initialize a new [RotaryEncoding](RotaryEncoding) module.
31    ///
32    /// # Panics
33    ///
34    /// Panics if the size of input embedding dimension is not even.
35    /// Panics if the theta parameter is not positive.
36    pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
37        self.initialize(|x| x, device)
38    }
39
40    /// Initialize a new [RotaryEncoding](RotaryEncoding) module with a custom frequency scaling function.
41    /// This is useful to apply different RoPE extensions.
42    ///
43    /// # Panics
44    ///
45    /// Panics if the size of input embedding dimension is not even.
46    /// Panics if the theta parameter is not positive.
47    pub fn init_with_frequency_scaling<B: Backend>(
48        &self,
49        scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
50        device: &B::Device,
51    ) -> RotaryEncoding<B> {
52        self.initialize(scaling, device)
53    }
54
55    /// Initialize a new [RotaryEncoding](RotaryEncoding) module.
56    ///
57    /// # Panics
58    ///
59    /// Panics if the size of input embedding dimension is not even.
60    /// Panics if the theta parameter is not positive.
61    fn initialize<B: Backend>(
62        &self,
63        scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
64        device: &B::Device,
65    ) -> RotaryEncoding<B> {
66        assert_eq!(
67            self.d_model % 2,
68            0,
69            "The input embedding dimension must be even"
70        );
71        assert!(
72            self.theta > 0.0,
73            "Theta parameter must be positive (default: 10000)."
74        );
75
76        // Calculate the rotation frequencies for positional embeddings based on the formula
77        // `theta = 1 / (theta ^ (2i / d_model)) for i in [0..d_model/2]`
78        let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
79            .float()
80            .div_scalar(self.d_model as f32);
81
82        // Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`
83        // This is done since burn doesn't support exponentiation of scalar to tensor
84        let theta = exponent.mul_scalar(self.theta.ln()).exp().recip();
85
86        let theta = scaling(theta);
87
88        let freq_complex =
89            RotaryEncoding::compute_rotary_frequencies(0..self.max_sequence_length, theta.clone());
90
91        RotaryEncoding {
92            freq_complex,
93            theta,
94            start_offset: 0,
95        }
96    }
97}
98
99/// A module that applies rotary positional encoding to a tensor.
100/// Rotary Position Encoding or Embedding (RoPE), is a type of position embedding which encodes
101/// absolute positional information with rotation matrix and naturally incorporates
102/// explicit relative position dependency in self-attention formulation.
103///
104/// Introduced in the paper: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
105///
106/// Should be created using [RotaryEncodingConfig].
107#[derive(Module, Debug)]
108#[module(custom_display)]
109pub struct RotaryEncoding<B: Backend> {
110    /// Complex frequency tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components
111    // Essentially a cache of pre-computed RoPE values.
112    pub freq_complex: Tensor<B, 3>,
113    /// Frequency vector used to compute/apply the complex rotations.
114    pub theta: Tensor<B, 1>,
115    start_offset: usize,
116}
117
118impl<B: Backend> ModuleDisplay for RotaryEncoding<B> {
119    fn custom_settings(&self) -> Option<DisplaySettings> {
120        DisplaySettings::new()
121            .with_new_line_after_attribute(false)
122            .optional()
123    }
124
125    fn custom_content(&self, content: Content) -> Option<Content> {
126        let [max_sequence_length, d_model, _] = self.freq_complex.shape().dims();
127        content
128            .add("d_model", &d_model)
129            .add("max_sequence_length", &max_sequence_length)
130            .optional()
131    }
132}
133
134#[allow(clippy::single_range_in_vec_init)]
135impl<B: Backend> RotaryEncoding<B> {
136    /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
137    ///
138    /// # Arguments:
139    /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
140    ///   for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
141    ///   respectively.
142    ///
143    /// # Returns:
144    /// Output tensor with the same shape as input tensor after applying rotary encoding.
145    ///
146    /// # Panics
147    /// If the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
148    pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
149        self.apply(x, 0)
150    }
151
152    /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
153    ///
154    /// # Arguments:
155    /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
156    ///   for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
157    ///   respectively.
158    /// * `start` - Sequence start position index.
159    ///
160    /// # Returns:
161    /// Output tensor with the same shape as input tensor after applying rotary encoding.
162    ///
163    /// # Panics
164    /// If the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
165    pub fn apply<const D: usize>(&self, x: Tensor<B, D>, start: usize) -> Tensor<B, D> {
166        assert!(
167            D >= 2,
168            "Input tensor must have at least 2 dimensions for sequence length and hidden dimension"
169        );
170
171        let device = x.device();
172        let input_shape = x.shape();
173
174        // Extract the sequence length and embedding dimension, other dimensions are kept generic
175        // to allow both 3D and 4D tensors i.e. batch_size or (batch_size, num_heads)
176        let (seq_len, d_model) = (x.dims()[D - 2], x.dims()[D - 1]);
177        let dummy_dim_size = input_shape.num_elements() / (seq_len * d_model);
178
179        // Create a dummy tensor with signed ones based on the 2D rotation matrix
180        // [[cos, -sin], [sin, cos]]
181        let sign_tensor =
182            Tensor::<B, 2>::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]], &device);
183
184        // Rotate input using the frequency tensor. Slice the frequencies till input sequence length
185        let out: Tensor<B, 4> = x
186            .reshape([dummy_dim_size, seq_len, d_model / 2, 2])
187            .matmul(sign_tensor.unsqueeze())
188            .reshape([dummy_dim_size, seq_len, d_model, 2])
189            * self
190                .freq_complex
191                .clone()
192                .slice([start..start + seq_len])
193                .unsqueeze();
194
195        // Sum the real and imaginary components to get output tensor and reshape to original shape
196        out.sum_dim(-1).reshape(input_shape)
197    }
198
199    /// Shifts the pre-computed rotary frequency to cover a new range of positions.
200    ///
201    /// This method updates the internal frequency tensor `freq_complex` to store
202    /// the rotary positional encodings for a new window of positions starting at `start`.
203    pub fn shift(&mut self, start: usize) {
204        let max_seq_len = self.freq_complex.dims()[0];
205        assert!(
206            start > self.start_offset,
207            "Shift start position must be monotonically increasing"
208        );
209
210        let current_end = self.start_offset + max_seq_len;
211
212        if start >= current_end {
213            // Overwrite the whole buffer
214            let new_freqs =
215                Self::compute_rotary_frequencies(start..start + max_seq_len, self.theta.clone());
216            self.freq_complex
217                .inplace(|freqs| freqs.slice_assign([0..max_seq_len], new_freqs));
218        } else {
219            // Shift the tail
220            let num_keep = current_end - start;
221            let start_rel = start - self.start_offset;
222            let tail_freqs = self.freq_complex.clone().slice([start_rel..max_seq_len]);
223            self.freq_complex
224                .inplace(|freqs| freqs.slice_assign([0..num_keep], tail_freqs));
225            // Compute the rest and assign
226            let new_freqs = Self::compute_rotary_frequencies(
227                current_end..start + max_seq_len,
228                self.theta.clone(),
229            );
230            self.freq_complex
231                .inplace(|freqs| freqs.slice_assign([num_keep..max_seq_len], new_freqs));
232        }
233        self.start_offset = start;
234    }
235
236    /// Computes the positional rotation frequencies (cosine and sine values) used in RoPE.
237    ///
238    /// # Arguments
239    /// - `range`: Range of position indices `[start, end)`.
240    /// - `theta`: 1D tensor of shape `(d_model / 2)` containing base angular frequencies.
241    ///
242    /// # Returns
243    /// Tensor of shape `(range.len(), d_model, 2)` containing `[cos, sin]` pairs for each position and frequency.
244    fn compute_rotary_frequencies(range: Range<usize>, theta: Tensor<B, 1>) -> Tensor<B, 3> {
245        let d_model = theta.dims()[0] * 2;
246        let num_positions = range.end - range.start;
247
248        // Generate frequency values for positional embeddings
249        let frequencies: Tensor<B, 2> =
250            Tensor::<B, 1, Int>::arange(range.start as i64..range.end as i64, &theta.device())
251                .float()
252                .unsqueeze()
253                .transpose()
254                .repeat_dim(1, d_model / 2)
255                * theta.unsqueeze();
256
257        // Convert frequency values to complex numbers (polar form)
258        let p_cos = frequencies.clone().cos();
259        let p_sin = frequencies.sin();
260
261        Tensor::cat(vec![p_cos, p_sin], 1)
262            .reshape([num_positions, 2, d_model / 2])
263            .transpose()
264            .unsqueeze_dim::<4>(2)
265            .repeat_dim(2, 2)
266            .reshape([num_positions, d_model, 2])
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::TestBackend;
274    use burn::tensor::{Tolerance, ops::FloatElem};
275    type FT = FloatElem<TestBackend>;
276
277    #[test]
278    fn test_rotary_encoding_forward() {
279        let device = Default::default();
280        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
281
282        let input = Tensor::<TestBackend, 3>::from_floats(
283            [
284                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
285                [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
286            ],
287            &device,
288        );
289
290        // Input = [Batch size, Num of heads, Seq_len, d_model]
291        let input = input.unsqueeze::<4>();
292
293        let output = rotary_encoding.forward(input);
294        let expected_output = Tensor::<TestBackend, 3>::from_floats(
295            [
296                [
297                    [1.0000, 2.0000, 3.0000, 4.0000],
298                    [-2.3473, 7.4492, 6.9197, 8.0696],
299                ],
300                [
301                    [9.0000, 10.0000, 11.0000, 12.0000],
302                    [-4.7567, 18.5034, 14.8393, 16.1492],
303                ],
304            ],
305            &device,
306        );
307
308        output
309            .squeeze_dim::<3>(0)
310            .to_data()
311            .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
312    }
313
314    #[test]
315    fn test_rotary_encoding_3d() {
316        let device = Default::default();
317        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
318
319        let input = Tensor::<TestBackend, 3>::from_floats(
320            [
321                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
322                [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
323            ],
324            &device,
325        );
326
327        // Input = [Batch size, Num of heads, Seq_len, d_model]
328        // let input = input.unsqueeze::<4>();
329
330        let output = rotary_encoding.forward(input);
331        let expected_output = Tensor::<TestBackend, 3>::from_floats(
332            [
333                [
334                    [1.0000, 2.0000, 3.0000, 4.0000],
335                    [-2.3473, 7.4492, 6.9197, 8.0696],
336                ],
337                [
338                    [9.0000, 10.0000, 11.0000, 12.0000],
339                    [-4.7567, 18.5034, 14.8393, 16.1492],
340                ],
341            ],
342            &device,
343        );
344
345        output
346            .to_data()
347            .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
348    }
349
350    #[test]
351    fn test_zero_input_rotary_encoding_forward() {
352        let device = Default::default();
353        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
354
355        // Use a tensor of exact zeros as input. The output rotary embedding should be zeros as well
356        let input = Tensor::<TestBackend, 4>::zeros([1, 2, 2, 4], &device);
357
358        let output = rotary_encoding.forward(input);
359        let expected_output = Tensor::<TestBackend, 3>::from_floats(
360            [
361                [
362                    [0.0000, 0.0000, 0.0000, 0.0000],
363                    [0.0000, 0.0000, 0.0000, 0.0000],
364                ],
365                [
366                    [0.0000, 0.0000, 0.0000, 0.0000],
367                    [0.0000, 0.0000, 0.0000, 0.0000],
368                ],
369            ],
370            &device,
371        );
372
373        output
374            .squeeze_dim::<3>(0)
375            .to_data()
376            .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
377    }
378
379    #[test]
380    #[should_panic]
381    fn test_valid_input_hidden_dim() {
382        // Hidden dimension must be even to be able to split into real and imaginary components
383        // for rotation
384        let d_model = 15;
385        let device = Default::default();
386        let pe = RotaryEncodingConfig::new(10, d_model).init::<TestBackend>(&device);
387        let input = Tensor::<TestBackend, 3>::zeros([1, 5, d_model], &device);
388        let _output = pe.forward(input);
389    }
390
391    #[test]
392    fn test_rotary_encoding_frequencies() {
393        let device = Default::default();
394        let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::<TestBackend>(&device);
395
396        let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
397            [
398                [
399                    [1.0000, 0.0000],
400                    [1.0000, 0.0000],
401                    [1.0000, 0.0000],
402                    [1.0000, 0.0000],
403                ],
404                [
405                    [5.4030e-01, 8.4147e-01],
406                    [9.9500e-01, 9.9833e-02],
407                    [9.9995e-01, 9.9998e-03],
408                    [9.9999e-01, 9.9999e-04],
409                ],
410            ],
411            &device,
412        )
413        .unsqueeze_dim::<4>(2)
414        .repeat_dim(2, 2)
415        .reshape([2, 8, 2]);
416
417        rotary_encoding
418            .freq_complex
419            .to_data()
420            .assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());
421    }
422
423    fn apply_freq_scaling_by_parts<B: Backend>(freqs: Tensor<B, 1>) -> Tensor<B, 1> {
424        // Adapted from: https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/model.py#L45
425        let scale_factor = 8.;
426        let low_freq_factor = 1.;
427        let high_freq_factor = 4.;
428        let old_context_len = 8192.;
429
430        let low_freq_wavelen = old_context_len / low_freq_factor;
431        let high_freq_wavelen = old_context_len / high_freq_factor;
432
433        let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI);
434
435        // if wavelen >= high_freq_wavelen
436        let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen);
437        let smooth = wavelen
438            .clone()
439            .recip()
440            .mul_scalar(old_context_len)
441            .sub_scalar(low_freq_factor)
442            .div_scalar(high_freq_factor - low_freq_factor);
443        // (1 - smooth) * freq / scale_factor + smooth * freq
444        let new_freqs = smooth
445            .clone()
446            .neg()
447            .add_scalar(1.)
448            .mul(freqs.clone().div_scalar(scale_factor))
449            .add(smooth.clone().mul(freqs.clone()));
450        let new_freqs = freqs.clone().mask_where(cond, new_freqs);
451
452        // if wavelen > low_freq_wavelen
453        let cond = wavelen.clone().greater_elem(low_freq_wavelen);
454        let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor));
455
456        // if wavelen < high_freq_wavelen
457        let cond = wavelen.lower_elem(high_freq_wavelen);
458        new_freqs.mask_where(cond, freqs)
459    }
460
461    #[test]
462    fn test_rotary_encoding_with_frequency_scaling() {
463        let device = Default::default();
464        let rotary_encoding = RotaryEncodingConfig::new(2, 8)
465            .init_with_frequency_scaling::<TestBackend>(apply_freq_scaling_by_parts, &device);
466
467        let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
468            [
469                [
470                    [1.0000, 0.0000],
471                    [1.0000, 0.0000],
472                    [1.0000, 0.0000],
473                    [1.0000, 0.0000],
474                ],
475                [
476                    [5.4030e-01, 8.4148e-01],
477                    [9.9500e-01, 9.9833e-02],
478                    [9.9995e-01, 9.9998e-03],
479                    [1.0000, 2.1361e-04],
480                ],
481            ],
482            &device,
483        )
484        .unsqueeze_dim::<4>(2)
485        .repeat_dim(2, 2)
486        .reshape([2, 8, 2]);
487
488        rotary_encoding
489            .freq_complex
490            .to_data()
491            .assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());
492    }
493
494    #[test]
495    fn test_rotary_encoding_shift_full() {
496        let device = Default::default();
497        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
498
499        // Input = [Batch size, Num of heads, Seq_len, d_model]
500        let input = Tensor::<TestBackend, 3>::from_floats(
501            [
502                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
503                [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
504            ],
505            &device,
506        )
507        .unsqueeze::<4>();
508
509        // Initializing for a bigger cache (e.g., max_seq_len = 10) should give the same result
510        // as using a smaller cache of pre-computed RoPE frequencies that are shifted to the same
511        // initial position
512        let expected_output = rotary_encoding.apply(input.clone(), 6);
513
514        let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
515        rotary_encoding.shift(6); // start > 4 will perform a full re-compute
516
517        let output = rotary_encoding.apply(input, 0);
518
519        output
520            .into_data()
521            .assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());
522    }
523
524    #[test]
525    fn test_rotary_encoding_shift() {
526        let device = Default::default();
527        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
528
529        // Input = [Batch size, Num of heads, Seq_len, d_model]
530        let input = Tensor::<TestBackend, 3>::from_floats(
531            [
532                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
533                [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
534            ],
535            &device,
536        )
537        .unsqueeze::<4>();
538
539        // Initializing for a bigger cache (e.g., max_seq_len = 10) should give the same result
540        // as using a smaller cache of pre-computed RoPE frequencies that are shifted to the same
541        // initial position
542        let expected_output = rotary_encoding.apply(input.clone(), 2);
543
544        let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
545        rotary_encoding.shift(2); // start < 4 will shift the (current_end - start) freqs and compute the rest
546
547        let output = rotary_encoding.apply(input, 0);
548
549        output
550            .into_data()
551            .assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());
552    }
553
554    #[test]
555    fn test_rotary_encoding_shift_multiple() {
556        let device = Default::default();
557        let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
558        rotary_encoding.shift(2);
559        rotary_encoding.shift(5);
560    }
561
562    #[test]
563    #[should_panic = "Shift start position must be monotonically increasing"]
564    fn test_rotary_encoding_shift_should_increase() {
565        let device = Default::default();
566        let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
567        rotary_encoding.shift(6);
568        rotary_encoding.shift(4); // should be monotonically increasing
569    }
570
571    #[test]
572    fn display() {
573        let config = RotaryEncodingConfig::new(10, 4);
574        let pe = config.init::<TestBackend>(&Default::default());
575
576        assert_eq!(
577            alloc::format!("{pe}"),
578            "RotaryEncoding {d_model: 4, max_sequence_length: 10}"
579        );
580    }
581}