burn_nn/modules/
pos_encoding.rs

1use burn_core as burn;
2
3use alloc::vec::Vec;
4use burn::config::Config;
5use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
6
7use burn::tensor::Tensor;
8use burn::tensor::TensorData;
9use burn::tensor::backend::Backend;
10
11#[cfg(not(feature = "std"))]
12#[allow(unused_imports)]
13use num_traits::Float as _;
14
15/// Configuration to create a [PositionalEncoding](PositionalEncoding) layer using the [init function](PositionalEncodingConfig::init).
16#[derive(Config, Debug)]
17pub struct PositionalEncodingConfig {
18    /// Maximum sequence size to use.
19    #[config(default = "5_000")]
20    pub max_sequence_size: usize,
21
22    /// The size of each vector.
23    pub d_model: usize,
24
25    /// Max time scale to use.
26    #[config(default = "10_000")]
27    pub max_timescale: usize,
28}
29
30/// Positional encoding layer for transformer models.
31///
32/// This layer adds positional information to the input embeddings, allowing the transformer model
33/// to take into account the order of the sequence. The positional encoding is added to the input
34/// embeddings by computing a set of sinusoidal functions with different frequencies and phases.
35///
36/// Sinusoids are used for positional embedding introduced in
37/// [Attention is all you need](https://arxiv.org/abs/1706.03762).
38///
39/// The reference implementation can be found here:
40/// [LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT
41/// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)
42///
43/// Should be created using [PositionalEncodingConfig]
44#[derive(Module, Debug)]
45#[module(custom_display)]
46pub struct PositionalEncoding<B: Backend> {
47    /// The sinusoids used to add positional information to the input embeddings.
48    pub sinusoids: Tensor<B, 3>,
49    /// The maximum sequence size to use.
50    pub max_sequence_size: usize,
51    /// Max time scale to use.
52    pub max_timescale: usize,
53}
54
55impl<B: Backend> ModuleDisplay for PositionalEncoding<B> {
56    fn custom_settings(&self) -> Option<DisplaySettings> {
57        DisplaySettings::new()
58            .with_new_line_after_attribute(false)
59            .optional()
60    }
61
62    fn custom_content(&self, content: Content) -> Option<Content> {
63        let [_, _, d_model] = self.sinusoids.shape().dims();
64        content
65            .add("d_model", &d_model)
66            .add("max_sequence_size", &self.max_sequence_size)
67            .add("max_timescale", &self.max_timescale)
68            .optional()
69    }
70}
71
72impl PositionalEncodingConfig {
73    /// Initialize a new [PositionalEncoding](PositionalEncoding) module.
74    pub fn init<B: Backend>(&self, device: &B::Device) -> PositionalEncoding<B> {
75        let sinusoids = generate_sinusoids::<B>(
76            self.max_sequence_size,
77            self.d_model,
78            self.max_timescale,
79            device,
80        )
81        .unsqueeze::<3>();
82
83        PositionalEncoding {
84            sinusoids,
85            max_sequence_size: self.max_sequence_size,
86            max_timescale: self.max_timescale,
87        }
88    }
89}
90
91impl<B: Backend> PositionalEncoding<B> {
92    /// Applies the forward pass on the input tensor by adding the sinusoids to the input.
93    ///
94    /// # Shapes
95    ///
96    /// * input: [batch_size, seq_length, d_model]
97    /// * output: [batch_size, seq_length, d_model]
98    ///
99    ///
100    /// # Panics
101    ///
102    /// * Panics if the input sequence length is greater than the maximum sequence size.
103    /// * Panics if the input d_model is not equal to the d_model of the sinusoids.
104    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
105        let [_, seq_length, d_model_input] = input.dims();
106
107        let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims();
108
109        assert!(
110            max_sequence_size >= seq_length,
111            "max_sequence_size({max_sequence_size}) must be greater or equal than length({seq_length})"
112        );
113
114        assert!(
115            d_model_input == d_model,
116            "d_model({d_model_input}) of the input must be equal to d_model of encoding({d_model})"
117        );
118
119        let slices = [0..batch_size, 0..seq_length, 0..d_model];
120
121        input.add(self.sinusoids.clone().slice(slices))
122    }
123}
124
125/// Returns sinusoids for positional embedding introduced in
126/// [Attention is all you need](https://arxiv.org/abs/1706.03762).
127///
128/// The reference implementation can be found here:
129/// [LANGUAGE MODELING WITH NN.TRANSFORMER AND TORCHTEXT
130/// ](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)
131///
132/// # Arguments
133///
134/// * `length` - The length of the sequence.
135/// * `d_model` - The size of each vector.
136/// * `max_timescale` - The maximum time scale to use.
137///
138/// # Returns
139///
140/// A tensor of shape [length, d_model] containing the sinusoids.
141pub fn generate_sinusoids<B: Backend>(
142    length: usize,
143    d_model: usize,
144    max_timescale: usize,
145    device: &B::Device,
146) -> Tensor<B, 2> {
147    assert!(d_model.is_multiple_of(2), "d_model must be even");
148    assert!(
149        max_timescale >= length,
150        "max_timescale must be greater than length"
151    );
152
153    // Calculate the increment for the logarithmic timescale
154    let log_timescale_increment = -(max_timescale as f32).ln() / d_model as f32;
155
156    // Create a vector to hold the sinusoids
157    let mut scaled_time_sin_cos = Vec::with_capacity(length);
158
159    // Loop over each position in the sequence
160    for i in 0..length {
161        // Create a vector to hold the sinusoids for this position
162        let mut row = Vec::with_capacity(d_model / 2);
163        // Loop over each dimension of the sinusoids
164        for k in (0..d_model).step_by(2) {
165            // Calculate the division term for this dimension
166            let div_term = (k as f32 * log_timescale_increment).exp();
167            // Calculate the sine and cosine values for this dimension and position
168            row.push((div_term * i as f32).sin());
169            row.push((div_term * i as f32).cos());
170        }
171
172        // Add the sinusoids for this position to the vector
173        scaled_time_sin_cos.push(row);
174    }
175
176    // Convert the sinusoids to a tensor and return it
177    let data = TensorData::new(
178        scaled_time_sin_cos.into_iter().flatten().collect(),
179        [length, d_model],
180    );
181
182    Tensor::<B, 2>::from_data(data, device)
183}
184
185#[cfg(test)]
186mod tests {
187
188    use super::*;
189    use crate::TestBackend;
190    use burn::tensor::{Tolerance, ops::FloatElem};
191    type FT = FloatElem<TestBackend>;
192
193    #[test]
194    fn test_module() {
195        let d_model = 6;
196        let length = 3;
197
198        // expected to broadcast
199        let batch_size = 2;
200
201        let device = Default::default();
202        let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
203
204        // Use a tensor of zeros as input for easy verification of the output
205        // The output should be the sinusoids broadcasted to the input shape
206        let tensor = Tensor::zeros([batch_size, length, d_model], &device);
207
208        let output = pe.forward(tensor);
209
210        assert_eq!(output.shape().dims, [batch_size, length, d_model]);
211
212        let expected = Tensor::<TestBackend, 3>::from_floats(
213            [
214                [
215                    [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
216                    [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
217                    [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
218                ],
219                [
220                    [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
221                    [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
222                    [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
223                ],
224            ],
225            &device,
226        );
227
228        output
229            .to_data()
230            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
231    }
232
233    #[test]
234    fn test_generate_sinusoids() {
235        let device = Default::default();
236        let sinusoids = generate_sinusoids::<TestBackend>(12, 6, 10_000, &device);
237
238        // The values are taken from the pytorch reference implementation
239        let expected = Tensor::<TestBackend, 2>::from_floats(
240            [
241                [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
242                [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
243                [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
244                [0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998],
245                [-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996],
246                [-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994],
247                [-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992],
248                [0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989],
249                [0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985],
250                [0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981],
251                [-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977],
252                [-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972],
253            ],
254            &device,
255        );
256        sinusoids
257            .to_data()
258            .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
259    }
260
261    #[test]
262    #[should_panic]
263    fn d_model_input_should_match() {
264        let d_model = 8;
265        let device = Default::default();
266        let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
267        let input = Tensor::zeros([1, 5, 10], &device);
268        let _output = pe.forward(input);
269    }
270
271    #[test]
272    #[should_panic]
273    fn input_length_should_be_less_than_max_len() {
274        let d_model = 8;
275        let device = Default::default();
276        let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
277        let input = Tensor::zeros([1, 6_000, d_model], &device);
278        let _output = pe.forward(input);
279    }
280
281    #[test]
282    fn display() {
283        let config = PositionalEncodingConfig::new(4);
284        let pe = config.init::<TestBackend>(&Default::default());
285
286        assert_eq!(
287            alloc::format!("{pe}"),
288            "PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}"
289        );
290    }
291}