burn_core/nn/
pos_encoding.rs

1use alloc::vec::Vec;
2
3use crate as burn;
4use crate::config::Config;
5use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
6
7use crate::tensor::Tensor;
8use crate::tensor::TensorData;
9use crate::tensor::backend::Backend;
10
11#[cfg(not(feature = "std"))]
12use num_traits::Float;
13
14/// Configuration to create a [PositionalEncoding](PositionalEncoding) layer using the [init function](PositionalEncodingConfig::init).
15#[derive(Config)]
16pub struct PositionalEncodingConfig {
17    /// Maximum sequence size to use.
18    #[config(default = "5_000")]
19    pub max_sequence_size: usize,
20
21    /// The size of each vector.
22    pub d_model: usize,
23
24    /// Max time scale to use.
25    #[config(default = "10_000")]
26    pub max_timescale: usize,
27}
28
29/// Positional encoding layer for transformer models.
30///
31/// This layer adds positional information to the input embeddings, allowing the transformer model
32/// to take into account the order of the sequence. The positional encoding is added to the input
33/// embeddings by computing a set of sinusoidal functions with different frequencies and phases.
34///
35/// Sinusoids are used for positional embedding introduced in
36/// [Attention is all you need](https://arxiv.org/abs/1706.03762).
37///
38/// The reference implementation can be found here:
39/// [LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT
40/// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)
41///
42/// Should be created using [PositionalEncodingConfig]
43#[derive(Module, Debug)]
44#[module(custom_display)]
45pub struct PositionalEncoding<B: Backend> {
46    /// The sinusoids used to add positional information to the input embeddings.
47    pub sinusoids: Tensor<B, 3>,
48    /// The maximum sequence size to use.
49    pub max_sequence_size: usize,
50    /// Max time scale to use.
51    pub max_timescale: usize,
52}
53
54impl<B: Backend> ModuleDisplay for PositionalEncoding<B> {
55    fn custom_settings(&self) -> Option<DisplaySettings> {
56        DisplaySettings::new()
57            .with_new_line_after_attribute(false)
58            .optional()
59    }
60
61    fn custom_content(&self, content: Content) -> Option<Content> {
62        let [_, _, d_model] = self.sinusoids.shape().dims();
63        content
64            .add("d_model", &d_model)
65            .add("max_sequence_size", &self.max_sequence_size)
66            .add("max_timescale", &self.max_timescale)
67            .optional()
68    }
69}
70
71impl PositionalEncodingConfig {
72    /// Initialize a new [PositionalEncoding](PositionalEncoding) module.
73    pub fn init<B: Backend>(&self, device: &B::Device) -> PositionalEncoding<B> {
74        let sinusoids = generate_sinusoids::<B>(
75            self.max_sequence_size,
76            self.d_model,
77            self.max_timescale,
78            device,
79        )
80        .unsqueeze::<3>();
81
82        PositionalEncoding {
83            sinusoids,
84            max_sequence_size: self.max_sequence_size,
85            max_timescale: self.max_timescale,
86        }
87    }
88}
89
90impl<B: Backend> PositionalEncoding<B> {
91    /// Applies the forward pass on the input tensor by adding the sinusoids to the input.
92    ///
93    /// # Shapes
94    ///
95    /// * input: [batch_size, seq_length, d_model]
96    /// * output: [batch_size, seq_length, d_model]
97    ///
98    ///
99    /// # Panics
100    ///
101    /// * Panics if the input sequence length is greater than the maximum sequence size.
102    /// * Panics if the input d_model is not equal to the d_model of the sinusoids.
103    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
104        let [_, seq_length, d_model_input] = input.dims();
105
106        let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims();
107
108        assert!(
109            max_sequence_size >= seq_length,
110            "max_sequence_size({}) must be greater or equal than length({seq_length})",
111            max_sequence_size,
112        );
113
114        assert!(
115            d_model_input == d_model,
116            "d_model({}) of the input must be equal to d_model of encoding({})",
117            d_model_input,
118            d_model,
119        );
120
121        let slices = [0..batch_size, 0..seq_length, 0..d_model];
122
123        input.add(self.sinusoids.clone().slice(slices))
124    }
125}
126
127/// Returns sinusoids for positional embedding introduced in
128/// [Attention is all you need](https://arxiv.org/abs/1706.03762).
129///
130/// The reference implementation can be found here:
131/// [LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT
132/// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)
133///
134/// # Arguments
135///
136/// * `length` - The length of the sequence.
137/// * `d_model` - The size of each vector.
138/// * `max_timescale` - The maximum time scale to use.
139///
140/// # Returns
141///
142/// A tensor of shape [length, d_model] containing the sinusoids.
143pub fn generate_sinusoids<B: Backend>(
144    length: usize,
145    d_model: usize,
146    max_timescale: usize,
147    device: &B::Device,
148) -> Tensor<B, 2> {
149    assert!(d_model % 2 == 0, "d_model must be even");
150    assert!(
151        max_timescale >= length,
152        "max_timescale must be greater than length"
153    );
154
155    // Calculate the increment for the logarithmic timescale
156    let log_timescale_increment = -(max_timescale as f32).ln() / d_model as f32;
157
158    // Create a vector to hold the sinusoids
159    let mut scaled_time_sin_cos = Vec::with_capacity(length);
160
161    // Loop over each position in the sequence
162    for i in 0..length {
163        // Create a vector to hold the sinusoids for this position
164        let mut row = Vec::with_capacity(d_model / 2);
165        // Loop over each dimension of the sinusoids
166        for k in (0..d_model).step_by(2) {
167            // Calculate the division term for this dimension
168            let div_term = (k as f32 * log_timescale_increment).exp();
169            // Calculate the sine and cosine values for this dimension and position
170            row.push((div_term * i as f32).sin());
171            row.push((div_term * i as f32).cos());
172        }
173
174        // Add the sinusoids for this position to the vector
175        scaled_time_sin_cos.push(row);
176    }
177
178    // Convert the sinusoids to a tensor and return it
179    let data = TensorData::new(
180        scaled_time_sin_cos.into_iter().flatten().collect(),
181        [length, d_model],
182    );
183
184    Tensor::<B, 2>::from_data(data, device)
185}
186
187#[cfg(test)]
188mod tests {
189
190    use super::*;
191    use crate::TestBackend;
192    use burn_tensor::{Tolerance, ops::FloatElem};
193    type FT = FloatElem<TestBackend>;
194
195    #[test]
196    fn test_module() {
197        let d_model = 6;
198        let length = 3;
199
200        // expected to broadcast
201        let batch_size = 2;
202
203        let device = Default::default();
204        let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
205
206        // Use a tensor of zeros as input for easy verification of the output
207        // The output should be the sinusoids broadcasted to the input shape
208        let tensor = Tensor::zeros([batch_size, length, d_model], &device);
209
210        let output = pe.forward(tensor);
211
212        assert_eq!(output.shape().dims, [batch_size, length, d_model]);
213
214        let expected = Tensor::<TestBackend, 3>::from_floats(
215            [
216                [
217                    [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
218                    [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
219                    [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
220                ],
221                [
222                    [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
223                    [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
224                    [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
225                ],
226            ],
227            &device,
228        );
229
230        output
231            .to_data()
232            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::rel_abs(1e-4, 1e-4));
233    }
234
235    #[test]
236    fn test_generate_sinusoids() {
237        let device = Default::default();
238        let sinusoids = generate_sinusoids::<TestBackend>(12, 6, 10_000, &device);
239
240        // The values are taken from the pytorch reference implementation
241        let expected = Tensor::<TestBackend, 2>::from_floats(
242            [
243                [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
244                [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
245                [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
246                [0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998],
247                [-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996],
248                [-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994],
249                [-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992],
250                [0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989],
251                [0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985],
252                [0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981],
253                [-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977],
254                [-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972],
255            ],
256            &device,
257        );
258        sinusoids
259            .to_data()
260            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::rel_abs(1e-4, 1e-4));
261    }
262
263    #[test]
264    #[should_panic]
265    fn d_model_input_should_match() {
266        let d_model = 8;
267        let device = Default::default();
268        let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
269        let input = Tensor::zeros([1, 5, 10], &device);
270        let _output = pe.forward(input);
271    }
272
273    #[test]
274    #[should_panic]
275    fn input_length_should_be_less_than_max_len() {
276        let d_model = 8;
277        let device = Default::default();
278        let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
279        let input = Tensor::zeros([1, 6_000, d_model], &device);
280        let _output = pe.forward(input);
281    }
282
283    #[test]
284    fn display() {
285        let config = PositionalEncodingConfig::new(4);
286        let pe = config.init::<TestBackend>(&Default::default());
287
288        assert_eq!(
289            alloc::format!("{}", pe),
290            "PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}"
291        );
292    }
293}