burn_core/nn/
rope_encoding.rs

1use crate as burn;
2use crate::config::Config;
3use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
4use crate::tensor::backend::Backend;
5use crate::tensor::Int;
6use crate::tensor::Tensor;
7use alloc::vec;
8
9#[cfg(not(feature = "std"))]
10use num_traits::Float;
11
12/// Configuration to create a [RotaryEncoding](RotaryEncoding) layer using the [init function](RotaryEncodingConfig::init).
13#[derive(Config, Debug)]
14pub struct RotaryEncodingConfig {
15    /// Maximum sequence length of input
16    pub max_sequence_length: usize,
17
18    /// Size of the input embedding or hidden dimension
19    pub d_model: usize,
20
21    /// Scaling factor for frequency computation. Defaults to 10000.0
22    #[config(default = "10000.0")]
23    pub theta: f32,
24}
25
26impl RotaryEncodingConfig {
27    /// Initialize a new [RotaryEncoding](RotaryEncoding) module.
28    ///
29    /// # Panics
30    ///
31    /// Panics if the size of input embedding dimension is not even.
32    /// Panics if the theta parameter is not positive.
33    pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
34        self.initialize(|x| x, device)
35    }
36
37    /// Initialize a new [RotaryEncoding](RotaryEncoding) module with a custom frequency scaling function.
38    /// This is useful to apply different RoPE extensions.
39    ///
40    /// # Panics
41    ///
42    /// Panics if the size of input embedding dimension is not even.
43    /// Panics if the theta parameter is not positive.
44    pub fn init_with_frequency_scaling<B: Backend>(
45        &self,
46        scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
47        device: &B::Device,
48    ) -> RotaryEncoding<B> {
49        self.initialize(scaling, device)
50    }
51
52    /// Initialize a new [RotaryEncoding](RotaryEncoding) module.
53    ///
54    /// # Panics
55    ///
56    /// Panics if the size of input embedding dimension is not even.
57    /// Panics if the theta parameter is not positive.
58    fn initialize<B: Backend>(
59        &self,
60        scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
61        device: &B::Device,
62    ) -> RotaryEncoding<B> {
63        assert_eq!(
64            self.d_model % 2,
65            0,
66            "The input embedding dimension must be even"
67        );
68        assert!(
69            self.theta > 0.0,
70            "Theta parameter must be positive (default: 10000)."
71        );
72
73        // Calculate the rotation frequencies for positional embeddings based on the formula
74        // `theta_i = 1 / (theta ^ (2i / d_model)) for i in [0..d_model/2]`
75        let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
76            .float()
77            .div_scalar(self.d_model as f32);
78
79        // Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`
80        // This is done since burn doesn't support exponentiation of scalar to tensor
81        let theta_i = exponent.mul_scalar(self.theta.ln()).exp();
82        let theta_i = theta_i.powf_scalar(-1.0);
83
84        let theta_i = scaling(theta_i);
85
86        // Generate frequency values for positional embeddings
87        let frequencies: Tensor<B, 2> =
88            Tensor::<B, 1, Int>::arange(0..self.max_sequence_length as i64, device)
89                .float()
90                .unsqueeze()
91                .transpose()
92                .repeat_dim(1, self.d_model / 2)
93                * theta_i.unsqueeze();
94
95        // Convert frequency values to complex numbers (polar form)
96        let p_cos = frequencies.clone().cos();
97        let p_sin = frequencies.sin();
98
99        // Create the frequency tensor of shape (max_sequence_length, d_model, 2) with the real(cos)
100        // and imaginary(sin) components along last dimension
101        let freq_complex: Tensor<B, 3> = Tensor::cat(vec![p_cos, p_sin], 1)
102            .reshape([self.max_sequence_length, 2, self.d_model / 2])
103            .transpose()
104            .unsqueeze_dim::<4>(2)
105            .repeat_dim(2, 2)
106            .reshape([self.max_sequence_length, self.d_model, 2]);
107
108        RotaryEncoding {
109            freq_complex,
110            max_sequence_length: self.max_sequence_length,
111            theta: self.theta,
112        }
113    }
114}
115
116/// A module that applies rotary positional encoding to a tensor.
117/// Rotary Position Encoding or Embedding (RoPE), is a type of position embedding which encodes
118/// absolute positional information with rotation matrix and naturally incorporates
119/// explicit relative position dependency in self-attention formulation.
120///
121/// Introduced in the paper: [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864)
122///
123/// Should be created using [RotaryEncodingConfig].
124#[derive(Module, Debug)]
125#[module(custom_display)]
126pub struct RotaryEncoding<B: Backend> {
127    /// Frequency Tensor of shape (max_sequence_length, d_model, 2) with real and imaginary components
128    pub freq_complex: Tensor<B, 3>,
129    /// Maximum sequence length of input
130    pub max_sequence_length: usize,
131    /// Scaling factor for frequency computation.
132    pub theta: f32,
133}
134
135impl<B: Backend> ModuleDisplay for RotaryEncoding<B> {
136    fn custom_settings(&self) -> Option<DisplaySettings> {
137        DisplaySettings::new()
138            .with_new_line_after_attribute(false)
139            .optional()
140    }
141
142    fn custom_content(&self, content: Content) -> Option<Content> {
143        let [_, _, d_model] = self.freq_complex.shape().dims();
144        content
145            .add("d_model", &d_model)
146            .add("max_sequence_length", &self.max_sequence_length)
147            .add("theta", &self.theta)
148            .optional()
149    }
150}
151
152#[allow(clippy::single_range_in_vec_init)]
153impl<B: Backend> RotaryEncoding<B> {
154    /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
155    ///
156    /// Arguments:
157    /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
158    ///    for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
159    ///    respectively.
160    ///
161    /// Returns:
162    /// * Output tensor with the same shape as input tensor after applying rotary encoding.
163    ///
164    /// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
165    pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
166        self.apply(x, 0)
167    }
168
169    /// Applies rotary positional encoding to a tensor of dimensions (..., seq_len, d_model)
170    ///
171    /// Arguments:
172    /// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
173    ///    for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
174    ///    respectively.
175    /// * `start` - Sequence start position index.
176    ///
177    /// Returns:
178    /// * Output tensor with the same shape as input tensor after applying rotary encoding.
179    ///
180    /// Panics if the input tensor does not have at least 2 dimensions for sequence length and hidden dimension.
181    pub fn apply<const D: usize>(&self, x: Tensor<B, D>, start: usize) -> Tensor<B, D> {
182        assert!(
183            D >= 2,
184            "Input tensor must have at least 2 dimensions for sequence length and hidden dimension"
185        );
186
187        let device = x.device();
188        let input_shape = x.shape();
189
190        // Extract the sequence length and embedding dimension, other dimensions are kept generic
191        // to allow both 3D and 4D tensors i.e. batch_size or (batch_size, num_heads)
192        let (seq_len, d_model) = (x.dims()[D - 2], x.dims()[D - 1]);
193        let dummy_dim_size = input_shape.num_elements() / (seq_len * d_model);
194
195        // Create a dummy tensor with signed ones based on the 2D rotation matrix
196        // [[cos, -sin], [sin, cos]]
197        let sign_tensor =
198            Tensor::<B, 2>::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]], &device);
199
200        // Rotate input using the frequency tensor. Slice the frequencies till input sequence length
201        let out: Tensor<B, 4> = x
202            .reshape([dummy_dim_size, seq_len, d_model / 2, 2])
203            .matmul(sign_tensor.unsqueeze())
204            .reshape([dummy_dim_size, seq_len, d_model, 2])
205            * self
206                .freq_complex
207                .clone()
208                .slice([start..start + seq_len])
209                .unsqueeze();
210
211        // Sum the real and imaginary components to get output tensor and reshape to original shape
212        out.sum_dim(D - 1).reshape(input_shape)
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::TestBackend;
220
221    #[test]
222    fn test_rotary_encoding_forward() {
223        let device = Default::default();
224        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
225
226        let input = Tensor::<TestBackend, 3>::from_floats(
227            [
228                [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
229                [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
230            ],
231            &device,
232        );
233
234        // Input = [Batch size, Num of heads, Seq_len, d_model]
235        let input = input.unsqueeze::<4>();
236
237        let output = rotary_encoding.forward(input);
238        let expected_output = Tensor::<TestBackend, 3>::from_floats(
239            [
240                [
241                    [1.0000, 2.0000, 3.0000, 4.0000],
242                    [-2.3473, 7.4492, 6.9197, 8.0696],
243                ],
244                [
245                    [9.0000, 10.0000, 11.0000, 12.0000],
246                    [-4.7567, 18.5034, 14.8393, 16.1492],
247                ],
248            ],
249            &device,
250        );
251
252        output
253            .squeeze::<3>(0)
254            .to_data()
255            .assert_approx_eq(&expected_output.to_data(), 4);
256    }
257
258    #[test]
259    fn test_zero_input_rotary_encoding_forward() {
260        let device = Default::default();
261        let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
262
263        // Use a tensor of exact zeros as input. The output rotary embedding should be zeros as well
264        let input = Tensor::<TestBackend, 4>::zeros([1, 2, 2, 4], &device);
265
266        let output = rotary_encoding.forward(input);
267        let expected_output = Tensor::<TestBackend, 3>::from_floats(
268            [
269                [
270                    [0.0000, 0.0000, 0.0000, 0.0000],
271                    [0.0000, 0.0000, 0.0000, 0.0000],
272                ],
273                [
274                    [0.0000, 0.0000, 0.0000, 0.0000],
275                    [0.0000, 0.0000, 0.0000, 0.0000],
276                ],
277            ],
278            &device,
279        );
280
281        output
282            .squeeze::<3>(0)
283            .to_data()
284            .assert_approx_eq(&expected_output.to_data(), 4);
285    }
286
287    #[test]
288    #[should_panic]
289    fn test_valid_input_hidden_dim() {
290        // Hidden dimension must be even to be able to split into real and imaginary components
291        // for rotation
292        let d_model = 15;
293        let device = Default::default();
294        let pe = RotaryEncodingConfig::new(10, d_model).init::<TestBackend>(&device);
295        let input = Tensor::<TestBackend, 3>::zeros([1, 5, d_model], &device);
296        let _output = pe.forward(input);
297    }
298
299    #[test]
300    fn test_rotary_encoding_frequencies() {
301        let device = Default::default();
302        let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::<TestBackend>(&device);
303
304        let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
305            [
306                [
307                    [1.0000, 0.0000],
308                    [1.0000, 0.0000],
309                    [1.0000, 0.0000],
310                    [1.0000, 0.0000],
311                ],
312                [
313                    [5.4030e-01, 8.4147e-01],
314                    [9.9500e-01, 9.9833e-02],
315                    [9.9995e-01, 9.9998e-03],
316                    [9.9999e-01, 9.9999e-04],
317                ],
318            ],
319            &device,
320        )
321        .unsqueeze_dim::<4>(2)
322        .repeat_dim(2, 2)
323        .reshape([2, 8, 2]);
324
325        rotary_encoding
326            .freq_complex
327            .to_data()
328            .assert_approx_eq(&expected_freqs.to_data(), 4);
329    }
330
331    fn apply_freq_scaling_by_parts<B: Backend>(freqs: Tensor<B, 1>) -> Tensor<B, 1> {
332        // Adapted from: https://github.com/meta-llama/llama-models/blob/main/models/llama3/reference_impl/model.py#L45
333        let scale_factor = 8.;
334        let low_freq_factor = 1.;
335        let high_freq_factor = 4.;
336        let old_context_len = 8192.;
337
338        let low_freq_wavelen = old_context_len / low_freq_factor;
339        let high_freq_wavelen = old_context_len / high_freq_factor;
340
341        let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI);
342
343        // if wavelen >= high_freq_wavelen
344        let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen);
345        let smooth = wavelen
346            .clone()
347            .recip()
348            .mul_scalar(old_context_len)
349            .sub_scalar(low_freq_factor)
350            .div_scalar(high_freq_factor - low_freq_factor);
351        // (1 - smooth) * freq / scale_factor + smooth * freq
352        let new_freqs = smooth
353            .clone()
354            .neg()
355            .add_scalar(1.)
356            .mul(freqs.clone().div_scalar(scale_factor))
357            .add(smooth.clone().mul(freqs.clone()));
358        let new_freqs = freqs.clone().mask_where(cond, new_freqs);
359
360        // if wavelen > low_freq_wavelen
361        let cond = wavelen.clone().greater_elem(low_freq_wavelen);
362        let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor));
363
364        // if wavelen < high_freq_wavelen
365        let cond = wavelen.lower_elem(high_freq_wavelen);
366        new_freqs.mask_where(cond, freqs)
367    }
368
369    #[test]
370    fn test_rotary_encoding_with_frequency_scaling() {
371        let device = Default::default();
372        let rotary_encoding = RotaryEncodingConfig::new(2, 8)
373            .init_with_frequency_scaling::<TestBackend>(apply_freq_scaling_by_parts, &device);
374
375        let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
376            [
377                [
378                    [1.0000, 0.0000],
379                    [1.0000, 0.0000],
380                    [1.0000, 0.0000],
381                    [1.0000, 0.0000],
382                ],
383                [
384                    [5.4030e-01, 8.4148e-01],
385                    [9.9500e-01, 9.9833e-02],
386                    [9.9995e-01, 9.9998e-03],
387                    [1.0000, 2.1361e-04],
388                ],
389            ],
390            &device,
391        )
392        .unsqueeze_dim::<4>(2)
393        .repeat_dim(2, 2)
394        .reshape([2, 8, 2]);
395
396        rotary_encoding
397            .freq_complex
398            .to_data()
399            .assert_approx_eq(&expected_freqs.to_data(), 4);
400    }
401
402    #[test]
403    fn display() {
404        let config = RotaryEncodingConfig::new(10, 4);
405        let pe = config.init::<TestBackend>(&Default::default());
406
407        assert_eq!(
408            alloc::format!("{}", pe),
409            "RotaryEncoding {d_model: 2, max_sequence_length: 10, theta: 10000}"
410        );
411    }
412}