burn_core/nn/
rope_encoding.rs

1use crate as burn;
2use crate::config::Config;
3use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
4use crate::tensor::Int;
5use crate::tensor::Tensor;
6use crate::tensor::backend::Backend;
7use alloc::vec;
8
9#[cfg(not(feature = "std"))]
10use num_traits::Float;
11
12/// Configuration to create a [RotaryEncoding](RotaryEncoding) layer using the [init function](RotaryEncodingConfig::init).
13#[derive(Config, Debug)]
14pub struct RotaryEncodingConfig {
15    /// Maximum sequence length of input
16    pub max_sequence_length: usize,
17
18    /// Size of the input embedding or hidden dimension
19    pub d_model: usize,
20
21    /// Scaling factor for frequency computation. Defaults to 10000.0
22    #[config(default = "10000.0")]
23    pub theta: f32,
24}
25
26impl RotaryEncodingConfig {
27    /// Initialize a new [RotaryEncoding](RotaryEncoding) module.
28    ///
29    /// # Panics
30    ///
31    /// Panics if the size of input embedding dimension is not even.
32    /// Panics if the theta parameter is not positive.
33    pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
34        self.initialize(|x| x, device)
35    }
36
37    /// Initialize a new [RotaryEncoding](RotaryEncoding) module with a custom frequency scaling function.
38    /// This is useful to apply different RoPE extensions.
39    ///
40    /// # Panics
41    ///
42    /// Panics if the size of input embedding dimension is not even.
43    /// Panics if the theta parameter is not positive.
44    pub fn init_with_frequency_scaling<B: Backend>(
45        &self,
46        scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
47        device: &B::Device,
48    ) -> RotaryEncoding<B> {
49        self.initialize(scaling, device)
50    }
51
52    /// Initialize a new [RotaryEncoding](RotaryEncoding) module.
53    ///
54    /// # Panics
55    ///
56    /// Panics if the size of input embedding dimension is not even.
57    /// Panics if the theta parameter is not positive.
58    fn initialize<B: Backend>(
59        &self,
60        scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
61        device: &B::Device,
62    ) -> RotaryEncoding<B> {
63        assert_eq!(
64            self.d_model % 2,
65            0,
66            "The input embedding dimension must be even"
67        );
68        assert!(
69            self.theta > 0.0,
70            "Theta parameter must be positive (default: 10000)."
71        );
72
73        // Calculate the rotation frequencies for positional embeddings based on the formula
74        // `theta_i = 1 / (theta ^ (2i / d_model)) for i in [0..d_model/2]`
75        let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
76            .float()
77            .div_scalar(self.d_model as f32);
78
79        // Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`
80        // This is done since burn doesn't support exponentiation of scalar to tensor
81        let theta_i = exponent.mul_scalar(self.theta.ln()).exp();
82        let theta_i = theta_i.powf_scalar(-1.0);
83
84        let theta_i = scaling(theta_i);
85
86        // Generate frequency values for positional embeddings
87        let frequencies: Tensor<B, 2> =
88            Tensor::<B, 1, Int>::arange(0..self.max_sequence_length as i64, device)
89                .float()
90                .unsqueeze()
91                .transpose()
92                .repeat_dim(1, self.d_model / 2)
93                * theta_i.unsqueeze();
94
95        // Convert frequency values to complex numbers (polar form)
96        let p_cos = frequencies.clone().cos();
97        let p_sin = frequencies.sin();
98
99        // Create the frequency tensor of shape (max_sequence_length, d_model, 2) with the real(cos)
100        // and imaginary(sin) components along last dimension
101        let freq_complex: Tensor<B, 3> = Tensor::cat(vec![p_cos, p_sin], 1)
102            .reshape([self.max_sequence_length, 2, self.d_model / 2])
103            .transpose()
104            .unsqueeze_dim::<4>(2)
105            .repeat_dim(2, 2)
106            .reshape([self.max_sequence_length, self.d_model, 2]);
107
108        RotaryEncoding {
109            freq_complex,
110            max_sequence_length: self.max_sequence_length,
111            theta: self.theta,
112        }
113    }
114}
115
116/// A module that applies rotary positional encoding to a tensor.
117/// Rotary Position Encoding or Embedding (RoPE), is a type of position embedding which encodes
118/// absolute positional information with rotation matrix and naturally incorporates
119/// explicit relative position dependency in self-attention formulation.
120///
121/// Introduced in the paper: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
122///
123/// Should be created using [RotaryEncodingConfig].
124#[derive(Module, Debug)]
125#[module(custom_display)]
126pub struct RotaryEncoding<B: Backend> {
127    /// Frequency Tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components
128    pub freq_complex: Tensor<B, 3>,
129    /// Maximum sequence length of input
130    pub max_sequence_length: usize,
131    /// Scaling factor for frequency computation.
132    pub theta: f32,
133}
134
135impl<B: Backend> ModuleDisplay for RotaryEncoding<B> {
136    fn custom_settings(&self) -> Option<DisplaySettings> {
137        DisplaySettings::new()
138            .with_new_line_after_attribute(false)
139            .optional()
140    }
141
142    fn custom_content(&self, content: Content) -> Option<Content> {
143        let [_, _, d_model] = self.freq_complex.shape().dims();
144        content
145            .add("d_model", &d_model)
146            .add("max_sequence_length", &self.max_sequence_length)
147            .add("theta", &self.theta)
148            .optional()
149    }
150}
151
152#[allow(clippy::single_range_in_vec_init)]
153impl<B: Backend> RotaryEncoding<B> {
154    /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
155    ///
156    /// Arguments:
157    /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
158    ///   for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
159    ///   respectively.
160    ///
161    /// Returns:
162    /// * Output tensor with the same shape as input tensor after applying rotary encoding.
163    ///
164    /// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
165    pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
166        self.apply(x, 0)
167    }
168
169    /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
170    ///
171    /// Arguments:
172    /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
173    ///   for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
174    ///   respectively.
175    /// * `start` - Sequence start position index.
176    ///
177    /// Returns:
178    /// * Output tensor with the same shape as input tensor after applying rotary encoding.
179    ///
180    /// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
181    pub fn apply<const D: usize>(&self, x: Tensor<B, D>, start: usize) -> Tensor<B, D> {
182        assert!(
183            D >= 2,
184            "Input tensor must have at least 2 dimensions for sequence length and hidden dimension"
185        );
186
187        let device = x.device();
188        let input_shape = x.shape();
189
190        // Extract the sequence length and embedding dimension, other dimensions are kept generic
191        // to allow both 3D and 4D tensors i.e. batch_size or (batch_size, num_heads)
192        let (seq_len, d_model) = (x.dims()[D - 2], x.dims()[D - 1]);
193        let dummy_dim_size = input_shape.num_elements() / (seq_len * d_model);
194
195        // Create a dummy tensor with signed ones based on the 2D rotation matrix
196        // [[cos, -sin], [sin, cos]]
197        let sign_tensor =
198            Tensor::<B, 2>::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]], &device);
199
200        // Rotate input using the frequency tensor. Slice the frequencies till input sequence length
201        let out: Tensor<B, 4> = x
202            .reshape([dummy_dim_size, seq_len, d_model / 2, 2])
203            .matmul(sign_tensor.unsqueeze())
204            .reshape([dummy_dim_size, seq_len, d_model, 2])
205            * self
206                .freq_complex
207                .clone()
208                .slice([start..start + seq_len])
209                .unsqueeze();
210
211        // Sum the real and imaginary components to get output tensor and reshape to original shape
212        out.sum_dim(D - 1).reshape(input_shape)
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::TestBackend;
220    use burn_tensor::{Tolerance, ops::FloatElem};
221    type FT = FloatElem<TestBackend>;
222
223    #[test]
224    fn test_rotary_encoding_forward() {
225        let device = Default::default();
226        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
227
228        let input = Tensor::<TestBackend, 3>::from_floats(
229            [
230                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
231                [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
232            ],
233            &device,
234        );
235
236        // Input = [Batch size, Num of heads, Seq_len, d_model]
237        let input = input.unsqueeze::<4>();
238
239        let output = rotary_encoding.forward(input);
240        let expected_output = Tensor::<TestBackend, 3>::from_floats(
241            [
242                [
243                    [1.0000, 2.0000, 3.0000, 4.0000],
244                    [-2.3473, 7.4492, 6.9197, 8.0696],
245                ],
246                [
247                    [9.0000, 10.0000, 11.0000, 12.0000],
248                    [-4.7567, 18.5034, 14.8393, 16.1492],
249                ],
250            ],
251            &device,
252        );
253
254        output
255            .squeeze::<3>(0)
256            .to_data()
257            .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
258    }
259
260    #[test]
261    fn test_zero_input_rotary_encoding_forward() {
262        let device = Default::default();
263        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
264
265        // Use a tensor of exact zeros as input. The output rotary embedding should be zeros as well
266        let input = Tensor::<TestBackend, 4>::zeros([1, 2, 2, 4], &device);
267
268        let output = rotary_encoding.forward(input);
269        let expected_output = Tensor::<TestBackend, 3>::from_floats(
270            [
271                [
272                    [0.0000, 0.0000, 0.0000, 0.0000],
273                    [0.0000, 0.0000, 0.0000, 0.0000],
274                ],
275                [
276                    [0.0000, 0.0000, 0.0000, 0.0000],
277                    [0.0000, 0.0000, 0.0000, 0.0000],
278                ],
279            ],
280            &device,
281        );
282
283        output
284            .squeeze::<3>(0)
285            .to_data()
286            .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
287    }
288
289    #[test]
290    #[should_panic]
291    fn test_valid_input_hidden_dim() {
292        // Hidden dimension must be even to be able to split into real and imaginary components
293        // for rotation
294        let d_model = 15;
295        let device = Default::default();
296        let pe = RotaryEncodingConfig::new(10, d_model).init::<TestBackend>(&device);
297        let input = Tensor::<TestBackend, 3>::zeros([1, 5, d_model], &device);
298        let _output = pe.forward(input);
299    }
300
301    #[test]
302    fn test_rotary_encoding_frequencies() {
303        let device = Default::default();
304        let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::<TestBackend>(&device);
305
306        let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
307            [
308                [
309                    [1.0000, 0.0000],
310                    [1.0000, 0.0000],
311                    [1.0000, 0.0000],
312                    [1.0000, 0.0000],
313                ],
314                [
315                    [5.4030e-01, 8.4147e-01],
316                    [9.9500e-01, 9.9833e-02],
317                    [9.9995e-01, 9.9998e-03],
318                    [9.9999e-01, 9.9999e-04],
319                ],
320            ],
321            &device,
322        )
323        .unsqueeze_dim::<4>(2)
324        .repeat_dim(2, 2)
325        .reshape([2, 8, 2]);
326
327        rotary_encoding
328            .freq_complex
329            .to_data()
330            .assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::rel_abs(1e-4, 1e-5));
331    }
332
333    fn apply_freq_scaling_by_parts<B: Backend>(freqs: Tensor<B, 1>) -> Tensor<B, 1> {
334        // Adapted from: https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/model.py#L45
335        let scale_factor = 8.;
336        let low_freq_factor = 1.;
337        let high_freq_factor = 4.;
338        let old_context_len = 8192.;
339
340        let low_freq_wavelen = old_context_len / low_freq_factor;
341        let high_freq_wavelen = old_context_len / high_freq_factor;
342
343        let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI);
344
345        // if wavelen >= high_freq_wavelen
346        let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen);
347        let smooth = wavelen
348            .clone()
349            .recip()
350            .mul_scalar(old_context_len)
351            .sub_scalar(low_freq_factor)
352            .div_scalar(high_freq_factor - low_freq_factor);
353        // (1 - smooth) * freq / scale_factor + smooth * freq
354        let new_freqs = smooth
355            .clone()
356            .neg()
357            .add_scalar(1.)
358            .mul(freqs.clone().div_scalar(scale_factor))
359            .add(smooth.clone().mul(freqs.clone()));
360        let new_freqs = freqs.clone().mask_where(cond, new_freqs);
361
362        // if wavelen > low_freq_wavelen
363        let cond = wavelen.clone().greater_elem(low_freq_wavelen);
364        let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor));
365
366        // if wavelen < high_freq_wavelen
367        let cond = wavelen.lower_elem(high_freq_wavelen);
368        new_freqs.mask_where(cond, freqs)
369    }
370
371    #[test]
372    fn test_rotary_encoding_with_frequency_scaling() {
373        let device = Default::default();
374        let rotary_encoding = RotaryEncodingConfig::new(2, 8)
375            .init_with_frequency_scaling::<TestBackend>(apply_freq_scaling_by_parts, &device);
376
377        let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
378            [
379                [
380                    [1.0000, 0.0000],
381                    [1.0000, 0.0000],
382                    [1.0000, 0.0000],
383                    [1.0000, 0.0000],
384                ],
385                [
386                    [5.4030e-01, 8.4148e-01],
387                    [9.9500e-01, 9.9833e-02],
388                    [9.9995e-01, 9.9998e-03],
389                    [1.0000, 2.1361e-04],
390                ],
391            ],
392            &device,
393        )
394        .unsqueeze_dim::<4>(2)
395        .repeat_dim(2, 2)
396        .reshape([2, 8, 2]);
397
398        rotary_encoding
399            .freq_complex
400            .to_data()
401            .assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::rel_abs(1e-4, 1e-5));
402    }
403
404    #[test]
405    fn display() {
406        let config = RotaryEncodingConfig::new(10, 4);
407        let pe = config.init::<TestBackend>(&Default::default());
408
409        assert_eq!(
410            alloc::format!("{}", pe),
411            "RotaryEncoding {d_model: 2, max_sequence_length: 10, theta: 10000}"
412        );
413    }
414}