1use alloc::vec::Vec;
2
3use crate as burn;
4use crate::config::Config;
5use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
6
7use crate::tensor::Tensor;
8use crate::tensor::TensorData;
9use crate::tensor::backend::Backend;
10
11#[cfg(not(feature = "std"))]
12use num_traits::Float;
13
14#[derive(Config)]
16pub struct PositionalEncodingConfig {
17 #[config(default = "5_000")]
19 pub max_sequence_size: usize,
20
21 pub d_model: usize,
23
24 #[config(default = "10_000")]
26 pub max_timescale: usize,
27}
28
29#[derive(Module, Debug)]
44#[module(custom_display)]
45pub struct PositionalEncoding<B: Backend> {
46 pub sinusoids: Tensor<B, 3>,
48 pub max_sequence_size: usize,
50 pub max_timescale: usize,
52}
53
54impl<B: Backend> ModuleDisplay for PositionalEncoding<B> {
55 fn custom_settings(&self) -> Option<DisplaySettings> {
56 DisplaySettings::new()
57 .with_new_line_after_attribute(false)
58 .optional()
59 }
60
61 fn custom_content(&self, content: Content) -> Option<Content> {
62 let [_, _, d_model] = self.sinusoids.shape().dims();
63 content
64 .add("d_model", &d_model)
65 .add("max_sequence_size", &self.max_sequence_size)
66 .add("max_timescale", &self.max_timescale)
67 .optional()
68 }
69}
70
71impl PositionalEncodingConfig {
72 pub fn init<B: Backend>(&self, device: &B::Device) -> PositionalEncoding<B> {
74 let sinusoids = generate_sinusoids::<B>(
75 self.max_sequence_size,
76 self.d_model,
77 self.max_timescale,
78 device,
79 )
80 .unsqueeze::<3>();
81
82 PositionalEncoding {
83 sinusoids,
84 max_sequence_size: self.max_sequence_size,
85 max_timescale: self.max_timescale,
86 }
87 }
88}
89
90impl<B: Backend> PositionalEncoding<B> {
91 pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
104 let [_, seq_length, d_model_input] = input.dims();
105
106 let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims();
107
108 assert!(
109 max_sequence_size >= seq_length,
110 "max_sequence_size({}) must be greater or equal than length({seq_length})",
111 max_sequence_size,
112 );
113
114 assert!(
115 d_model_input == d_model,
116 "d_model({}) of the input must be equal to d_model of encoding({})",
117 d_model_input,
118 d_model,
119 );
120
121 let slices = [0..batch_size, 0..seq_length, 0..d_model];
122
123 input.add(self.sinusoids.clone().slice(slices))
124 }
125}
126
127pub fn generate_sinusoids<B: Backend>(
144 length: usize,
145 d_model: usize,
146 max_timescale: usize,
147 device: &B::Device,
148) -> Tensor<B, 2> {
149 assert!(d_model % 2 == 0, "d_model must be even");
150 assert!(
151 max_timescale >= length,
152 "max_timescale must be greater than length"
153 );
154
155 let log_timescale_increment = -(max_timescale as f32).ln() / d_model as f32;
157
158 let mut scaled_time_sin_cos = Vec::with_capacity(length);
160
161 for i in 0..length {
163 let mut row = Vec::with_capacity(d_model / 2);
165 for k in (0..d_model).step_by(2) {
167 let div_term = (k as f32 * log_timescale_increment).exp();
169 row.push((div_term * i as f32).sin());
171 row.push((div_term * i as f32).cos());
172 }
173
174 scaled_time_sin_cos.push(row);
176 }
177
178 let data = TensorData::new(
180 scaled_time_sin_cos.into_iter().flatten().collect(),
181 [length, d_model],
182 );
183
184 Tensor::<B, 2>::from_data(data, device)
185}
186
187#[cfg(test)]
188mod tests {
189
190 use super::*;
191 use crate::TestBackend;
192 use burn_tensor::{Tolerance, ops::FloatElem};
193 type FT = FloatElem<TestBackend>;
194
195 #[test]
196 fn test_module() {
197 let d_model = 6;
198 let length = 3;
199
200 let batch_size = 2;
202
203 let device = Default::default();
204 let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
205
206 let tensor = Tensor::zeros([batch_size, length, d_model], &device);
209
210 let output = pe.forward(tensor);
211
212 assert_eq!(output.shape().dims, [batch_size, length, d_model]);
213
214 let expected = Tensor::<TestBackend, 3>::from_floats(
215 [
216 [
217 [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
218 [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
219 [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
220 ],
221 [
222 [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
223 [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
224 [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
225 ],
226 ],
227 &device,
228 );
229
230 output
231 .to_data()
232 .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::rel_abs(1e-4, 1e-4));
233 }
234
235 #[test]
236 fn test_generate_sinusoids() {
237 let device = Default::default();
238 let sinusoids = generate_sinusoids::<TestBackend>(12, 6, 10_000, &device);
239
240 let expected = Tensor::<TestBackend, 2>::from_floats(
242 [
243 [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
244 [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
245 [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
246 [0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998],
247 [-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996],
248 [-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994],
249 [-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992],
250 [0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989],
251 [0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985],
252 [0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981],
253 [-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977],
254 [-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972],
255 ],
256 &device,
257 );
258 sinusoids
259 .to_data()
260 .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::rel_abs(1e-4, 1e-4));
261 }
262
263 #[test]
264 #[should_panic]
265 fn d_model_input_should_match() {
266 let d_model = 8;
267 let device = Default::default();
268 let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
269 let input = Tensor::zeros([1, 5, 10], &device);
270 let _output = pe.forward(input);
271 }
272
273 #[test]
274 #[should_panic]
275 fn input_length_should_be_less_than_max_len() {
276 let d_model = 8;
277 let device = Default::default();
278 let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
279 let input = Tensor::zeros([1, 6_000, d_model], &device);
280 let _output = pe.forward(input);
281 }
282
283 #[test]
284 fn display() {
285 let config = PositionalEncodingConfig::new(4);
286 let pe = config.init::<TestBackend>(&Default::default());
287
288 assert_eq!(
289 alloc::format!("{}", pe),
290 "PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}"
291 );
292 }
293}