burn_core/nn/
rope_encoding.rs

1use core::ops::Range;
2
3use crate as burn;
4use crate::config::Config;
5use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
6use crate::tensor::Int;
7use crate::tensor::Tensor;
8use crate::tensor::backend::Backend;
9use alloc::vec;
10
11#[cfg(not(feature = "std"))]
12use num_traits::Float;
13
14/// Configuration to create a [RotaryEncoding](RotaryEncoding) layer using the [init function](RotaryEncodingConfig::init).
15#[derive(Config, Debug)]
16pub struct RotaryEncodingConfig {
17    /// Maximum sequence length of input
18    pub max_sequence_length: usize,
19
20    /// Size of the input embedding or hidden dimension
21    pub d_model: usize,
22
23    /// Scaling factor for frequency computation. Defaults to 10000.0
24    #[config(default = "10000.0")]
25    pub theta: f32,
26}
27
28impl RotaryEncodingConfig {
29    /// Initialize a new [RotaryEncoding](RotaryEncoding) module.
30    ///
31    /// # Panics
32    ///
33    /// Panics if the size of input embedding dimension is not even.
34    /// Panics if the theta parameter is not positive.
35    pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
36        self.initialize(|x| x, device)
37    }
38
39    /// Initialize a new [RotaryEncoding](RotaryEncoding) module with a custom frequency scaling function.
40    /// This is useful to apply different RoPE extensions.
41    ///
42    /// # Panics
43    ///
44    /// Panics if the size of input embedding dimension is not even.
45    /// Panics if the theta parameter is not positive.
46    pub fn init_with_frequency_scaling<B: Backend>(
47        &self,
48        scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
49        device: &B::Device,
50    ) -> RotaryEncoding<B> {
51        self.initialize(scaling, device)
52    }
53
54    /// Initialize a new [RotaryEncoding](RotaryEncoding) module.
55    ///
56    /// # Panics
57    ///
58    /// Panics if the size of input embedding dimension is not even.
59    /// Panics if the theta parameter is not positive.
60    fn initialize<B: Backend>(
61        &self,
62        scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
63        device: &B::Device,
64    ) -> RotaryEncoding<B> {
65        assert_eq!(
66            self.d_model % 2,
67            0,
68            "The input embedding dimension must be even"
69        );
70        assert!(
71            self.theta > 0.0,
72            "Theta parameter must be positive (default: 10000)."
73        );
74
75        // Calculate the rotation frequencies for positional embeddings based on the formula
76        // `theta = 1 / (theta ^ (2i / d_model)) for i in [0..d_model/2]`
77        let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
78            .float()
79            .div_scalar(self.d_model as f32);
80
81        // Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`
82        // This is done since burn doesn't support exponentiation of scalar to tensor
83        let theta = exponent.mul_scalar(self.theta.ln()).exp().recip();
84
85        let theta = scaling(theta);
86
87        let freq_complex =
88            RotaryEncoding::compute_rotary_frequencies(0..self.max_sequence_length, theta.clone());
89
90        RotaryEncoding {
91            freq_complex,
92            theta,
93            start_offset: 0,
94        }
95    }
96}
97
98/// A module that applies rotary positional encoding to a tensor.
99/// Rotary Position Encoding or Embedding (RoPE), is a type of position embedding which encodes
100/// absolute positional information with rotation matrix and naturally incorporates
101/// explicit relative position dependency in self-attention formulation.
102///
103/// Introduced in the paper: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
104///
105/// Should be created using [RotaryEncodingConfig].
106#[derive(Module, Debug)]
107#[module(custom_display)]
108pub struct RotaryEncoding<B: Backend> {
109    /// Complex frequency tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components
110    // Essentially a cache of pre-computed RoPE values.
111    pub freq_complex: Tensor<B, 3>,
112    /// Frequency vector used to compute/apply the complex rotations.
113    pub theta: Tensor<B, 1>,
114    start_offset: usize,
115}
116
117impl<B: Backend> ModuleDisplay for RotaryEncoding<B> {
118    fn custom_settings(&self) -> Option<DisplaySettings> {
119        DisplaySettings::new()
120            .with_new_line_after_attribute(false)
121            .optional()
122    }
123
124    fn custom_content(&self, content: Content) -> Option<Content> {
125        let [max_sequence_length, d_model, _] = self.freq_complex.shape().dims();
126        content
127            .add("d_model", &d_model)
128            .add("max_sequence_length", &max_sequence_length)
129            .optional()
130    }
131}
132
133#[allow(clippy::single_range_in_vec_init)]
134impl<B: Backend> RotaryEncoding<B> {
135    /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
136    ///
137    /// # Arguments:
138    /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
139    ///   for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
140    ///   respectively.
141    ///
142    /// # Returns:
143    /// Output tensor with the same shape as input tensor after applying rotary encoding.
144    ///
145    /// # Panics
146    /// If the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
147    pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
148        self.apply(x, 0)
149    }
150
151    /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
152    ///
153    /// # Arguments:
154    /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
155    ///   for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
156    ///   respectively.
157    /// * `start` - Sequence start position index.
158    ///
159    /// # Returns:
160    /// Output tensor with the same shape as input tensor after applying rotary encoding.
161    ///
162    /// # Panics
163    /// If the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
164    pub fn apply<const D: usize>(&self, x: Tensor<B, D>, start: usize) -> Tensor<B, D> {
165        assert!(
166            D >= 2,
167            "Input tensor must have at least 2 dimensions for sequence length and hidden dimension"
168        );
169
170        let device = x.device();
171        let input_shape = x.shape();
172
173        // Extract the sequence length and embedding dimension, other dimensions are kept generic
174        // to allow both 3D and 4D tensors i.e. batch_size or (batch_size, num_heads)
175        let (seq_len, d_model) = (x.dims()[D - 2], x.dims()[D - 1]);
176        let dummy_dim_size = input_shape.num_elements() / (seq_len * d_model);
177
178        // Create a dummy tensor with signed ones based on the 2D rotation matrix
179        // [[cos, -sin], [sin, cos]]
180        let sign_tensor =
181            Tensor::<B, 2>::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]], &device);
182
183        // Rotate input using the frequency tensor. Slice the frequencies till input sequence length
184        let out: Tensor<B, 4> = x
185            .reshape([dummy_dim_size, seq_len, d_model / 2, 2])
186            .matmul(sign_tensor.unsqueeze())
187            .reshape([dummy_dim_size, seq_len, d_model, 2])
188            * self
189                .freq_complex
190                .clone()
191                .slice([start..start + seq_len])
192                .unsqueeze();
193
194        // Sum the real and imaginary components to get output tensor and reshape to original shape
195        out.sum_dim(D - 1).reshape(input_shape)
196    }
197
198    /// Shifts the pre-computed rotary frequency to cover a new range of positions.
199    ///
200    /// This method updates the internal frequency tensor `freq_complex` to store
201    /// the rotary positional encodings for a new window of positions starting at `start`.
202    pub fn shift(&mut self, start: usize) {
203        let max_seq_len = self.freq_complex.dims()[0];
204        assert!(
205            start > self.start_offset,
206            "Shift start position must be monotonically increasing"
207        );
208
209        let current_end = self.start_offset + max_seq_len;
210
211        if start >= current_end {
212            // Overwrite the whole buffer
213            let new_freqs =
214                Self::compute_rotary_frequencies(start..start + max_seq_len, self.theta.clone());
215            self.freq_complex
216                .inplace(|freqs| freqs.slice_assign([0..max_seq_len], new_freqs));
217        } else {
218            // Shift the tail
219            let num_keep = current_end - start;
220            let start_rel = start - self.start_offset;
221            let tail_freqs = self.freq_complex.clone().slice([start_rel..max_seq_len]);
222            self.freq_complex
223                .inplace(|freqs| freqs.slice_assign([0..num_keep], tail_freqs));
224            // Compute the rest and assign
225            let new_freqs = Self::compute_rotary_frequencies(
226                current_end..start + max_seq_len,
227                self.theta.clone(),
228            );
229            self.freq_complex
230                .inplace(|freqs| freqs.slice_assign([num_keep..max_seq_len], new_freqs));
231        }
232        self.start_offset = start;
233    }
234
235    /// Computes the positional rotation frequencies (cosine and sine values) used in RoPE.
236    ///
237    /// # Arguments
238    /// - `range`: Range of position indices `[start, end)`.
239    /// - `theta`: 1D tensor of shape `(d_model / 2)` containing base angular frequencies.
240    ///
241    /// # Returns
242    /// Tensor of shape `(range.len(), d_model, 2)` containing `[cos, sin]` pairs for each position and frequency.
243    fn compute_rotary_frequencies(range: Range<usize>, theta: Tensor<B, 1>) -> Tensor<B, 3> {
244        let d_model = theta.dims()[0] * 2;
245        let num_positions = range.end - range.start;
246
247        // Generate frequency values for positional embeddings
248        let frequencies: Tensor<B, 2> =
249            Tensor::<B, 1, Int>::arange(range.start as i64..range.end as i64, &theta.device())
250                .float()
251                .unsqueeze()
252                .transpose()
253                .repeat_dim(1, d_model / 2)
254                * theta.unsqueeze();
255
256        // Convert frequency values to complex numbers (polar form)
257        let p_cos = frequencies.clone().cos();
258        let p_sin = frequencies.sin();
259
260        Tensor::cat(vec![p_cos, p_sin], 1)
261            .reshape([num_positions, 2, d_model / 2])
262            .transpose()
263            .unsqueeze_dim::<4>(2)
264            .repeat_dim(2, 2)
265            .reshape([num_positions, d_model, 2])
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::TestBackend;
273    use burn_tensor::{Tolerance, ops::FloatElem};
274    type FT = FloatElem<TestBackend>;
275
276    #[test]
277    fn test_rotary_encoding_forward() {
278        let device = Default::default();
279        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
280
281        let input = Tensor::<TestBackend, 3>::from_floats(
282            [
283                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
284                [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
285            ],
286            &device,
287        );
288
289        // Input = [Batch size, Num of heads, Seq_len, d_model]
290        let input = input.unsqueeze::<4>();
291
292        let output = rotary_encoding.forward(input);
293        let expected_output = Tensor::<TestBackend, 3>::from_floats(
294            [
295                [
296                    [1.0000, 2.0000, 3.0000, 4.0000],
297                    [-2.3473, 7.4492, 6.9197, 8.0696],
298                ],
299                [
300                    [9.0000, 10.0000, 11.0000, 12.0000],
301                    [-4.7567, 18.5034, 14.8393, 16.1492],
302                ],
303            ],
304            &device,
305        );
306
307        output
308            .squeeze::<3>(0)
309            .to_data()
310            .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
311    }
312
313    #[test]
314    fn test_zero_input_rotary_encoding_forward() {
315        let device = Default::default();
316        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
317
318        // Use a tensor of exact zeros as input. The output rotary embedding should be zeros as well
319        let input = Tensor::<TestBackend, 4>::zeros([1, 2, 2, 4], &device);
320
321        let output = rotary_encoding.forward(input);
322        let expected_output = Tensor::<TestBackend, 3>::from_floats(
323            [
324                [
325                    [0.0000, 0.0000, 0.0000, 0.0000],
326                    [0.0000, 0.0000, 0.0000, 0.0000],
327                ],
328                [
329                    [0.0000, 0.0000, 0.0000, 0.0000],
330                    [0.0000, 0.0000, 0.0000, 0.0000],
331                ],
332            ],
333            &device,
334        );
335
336        output
337            .squeeze::<3>(0)
338            .to_data()
339            .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
340    }
341
342    #[test]
343    #[should_panic]
344    fn test_valid_input_hidden_dim() {
345        // Hidden dimension must be even to be able to split into real and imaginary components
346        // for rotation
347        let d_model = 15;
348        let device = Default::default();
349        let pe = RotaryEncodingConfig::new(10, d_model).init::<TestBackend>(&device);
350        let input = Tensor::<TestBackend, 3>::zeros([1, 5, d_model], &device);
351        let _output = pe.forward(input);
352    }
353
354    #[test]
355    fn test_rotary_encoding_frequencies() {
356        let device = Default::default();
357        let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::<TestBackend>(&device);
358
359        let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
360            [
361                [
362                    [1.0000, 0.0000],
363                    [1.0000, 0.0000],
364                    [1.0000, 0.0000],
365                    [1.0000, 0.0000],
366                ],
367                [
368                    [5.4030e-01, 8.4147e-01],
369                    [9.9500e-01, 9.9833e-02],
370                    [9.9995e-01, 9.9998e-03],
371                    [9.9999e-01, 9.9999e-04],
372                ],
373            ],
374            &device,
375        )
376        .unsqueeze_dim::<4>(2)
377        .repeat_dim(2, 2)
378        .reshape([2, 8, 2]);
379
380        rotary_encoding
381            .freq_complex
382            .to_data()
383            .assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());
384    }
385
386    fn apply_freq_scaling_by_parts<B: Backend>(freqs: Tensor<B, 1>) -> Tensor<B, 1> {
387        // Adapted from: https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/model.py#L45
388        let scale_factor = 8.;
389        let low_freq_factor = 1.;
390        let high_freq_factor = 4.;
391        let old_context_len = 8192.;
392
393        let low_freq_wavelen = old_context_len / low_freq_factor;
394        let high_freq_wavelen = old_context_len / high_freq_factor;
395
396        let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI);
397
398        // if wavelen >= high_freq_wavelen
399        let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen);
400        let smooth = wavelen
401            .clone()
402            .recip()
403            .mul_scalar(old_context_len)
404            .sub_scalar(low_freq_factor)
405            .div_scalar(high_freq_factor - low_freq_factor);
406        // (1 - smooth) * freq / scale_factor + smooth * freq
407        let new_freqs = smooth
408            .clone()
409            .neg()
410            .add_scalar(1.)
411            .mul(freqs.clone().div_scalar(scale_factor))
412            .add(smooth.clone().mul(freqs.clone()));
413        let new_freqs = freqs.clone().mask_where(cond, new_freqs);
414
415        // if wavelen > low_freq_wavelen
416        let cond = wavelen.clone().greater_elem(low_freq_wavelen);
417        let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor));
418
419        // if wavelen < high_freq_wavelen
420        let cond = wavelen.lower_elem(high_freq_wavelen);
421        new_freqs.mask_where(cond, freqs)
422    }
423
424    #[test]
425    fn test_rotary_encoding_with_frequency_scaling() {
426        let device = Default::default();
427        let rotary_encoding = RotaryEncodingConfig::new(2, 8)
428            .init_with_frequency_scaling::<TestBackend>(apply_freq_scaling_by_parts, &device);
429
430        let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
431            [
432                [
433                    [1.0000, 0.0000],
434                    [1.0000, 0.0000],
435                    [1.0000, 0.0000],
436                    [1.0000, 0.0000],
437                ],
438                [
439                    [5.4030e-01, 8.4148e-01],
440                    [9.9500e-01, 9.9833e-02],
441                    [9.9995e-01, 9.9998e-03],
442                    [1.0000, 2.1361e-04],
443                ],
444            ],
445            &device,
446        )
447        .unsqueeze_dim::<4>(2)
448        .repeat_dim(2, 2)
449        .reshape([2, 8, 2]);
450
451        rotary_encoding
452            .freq_complex
453            .to_data()
454            .assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());
455    }
456
457    #[test]
458    fn test_rotary_encoding_shift_full() {
459        let device = Default::default();
460        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
461
462        // Input = [Batch size, Num of heads, Seq_len, d_model]
463        let input = Tensor::<TestBackend, 3>::from_floats(
464            [
465                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
466                [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
467            ],
468            &device,
469        )
470        .unsqueeze::<4>();
471
472        // Initializing for a bigger cache (e.g., max_seq_len = 10) should give the same result
473        // as using a smaller cache of pre-computed RoPE frequencies that are shifted to the same
474        // initial position
475        let expected_output = rotary_encoding.apply(input.clone(), 6);
476
477        let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
478        rotary_encoding.shift(6); // start > 4 will perform a full re-compute
479
480        let output = rotary_encoding.apply(input, 0);
481
482        output
483            .into_data()
484            .assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());
485    }
486
487    #[test]
488    fn test_rotary_encoding_shift() {
489        let device = Default::default();
490        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
491
492        // Input = [Batch size, Num of heads, Seq_len, d_model]
493        let input = Tensor::<TestBackend, 3>::from_floats(
494            [
495                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
496                [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
497            ],
498            &device,
499        )
500        .unsqueeze::<4>();
501
502        // Initializing for a bigger cache (e.g., max_seq_len = 10) should give the same result
503        // as using a smaller cache of pre-computed RoPE frequencies that are shifted to the same
504        // initial position
505        let expected_output = rotary_encoding.apply(input.clone(), 2);
506
507        let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
508        rotary_encoding.shift(2); // start < 4 will shift the (current_end - start) freqs and compute the rest
509
510        let output = rotary_encoding.apply(input, 0);
511
512        output
513            .into_data()
514            .assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());
515    }
516
517    #[test]
518    fn test_rotary_encoding_shift_multiple() {
519        let device = Default::default();
520        let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
521        rotary_encoding.shift(2);
522        rotary_encoding.shift(5);
523    }
524
525    #[test]
526    #[should_panic = "Shift start position must be monotonically increasing"]
527    fn test_rotary_encoding_shift_should_increase() {
528        let device = Default::default();
529        let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
530        rotary_encoding.shift(6);
531        rotary_encoding.shift(4); // should be monotonically increasing
532    }
533
534    #[test]
535    fn display() {
536        let config = RotaryEncodingConfig::new(10, 4);
537        let pe = config.init::<TestBackend>(&Default::default());
538
539        assert_eq!(
540            alloc::format!("{pe}"),
541            "RotaryEncoding {d_model: 4, max_sequence_length: 10}"
542        );
543    }
544}