1use alloc::vec::Vec;
2
3use crate as burn;
4use crate::config::Config;
5use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
6
7use crate::tensor::Tensor;
8use crate::tensor::TensorData;
9use crate::tensor::backend::Backend;
10
11#[cfg(not(feature = "std"))]
12use num_traits::Float;
13
14#[derive(Config)]
16pub struct PositionalEncodingConfig {
17 #[config(default = "5_000")]
19 pub max_sequence_size: usize,
20
21 pub d_model: usize,
23
24 #[config(default = "10_000")]
26 pub max_timescale: usize,
27}
28
29#[derive(Module, Debug)]
44#[module(custom_display)]
45pub struct PositionalEncoding<B: Backend> {
46 pub sinusoids: Tensor<B, 3>,
48 pub max_sequence_size: usize,
50 pub max_timescale: usize,
52}
53
54impl<B: Backend> ModuleDisplay for PositionalEncoding<B> {
55 fn custom_settings(&self) -> Option<DisplaySettings> {
56 DisplaySettings::new()
57 .with_new_line_after_attribute(false)
58 .optional()
59 }
60
61 fn custom_content(&self, content: Content) -> Option<Content> {
62 let [_, _, d_model] = self.sinusoids.shape().dims();
63 content
64 .add("d_model", &d_model)
65 .add("max_sequence_size", &self.max_sequence_size)
66 .add("max_timescale", &self.max_timescale)
67 .optional()
68 }
69}
70
71impl PositionalEncodingConfig {
72 pub fn init<B: Backend>(&self, device: &B::Device) -> PositionalEncoding<B> {
74 let sinusoids = generate_sinusoids::<B>(
75 self.max_sequence_size,
76 self.d_model,
77 self.max_timescale,
78 device,
79 )
80 .unsqueeze::<3>();
81
82 PositionalEncoding {
83 sinusoids,
84 max_sequence_size: self.max_sequence_size,
85 max_timescale: self.max_timescale,
86 }
87 }
88}
89
90impl<B: Backend> PositionalEncoding<B> {
91 pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
104 let [_, seq_length, d_model_input] = input.dims();
105
106 let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims();
107
108 assert!(
109 max_sequence_size >= seq_length,
110 "max_sequence_size({max_sequence_size}) must be greater or equal than length({seq_length})"
111 );
112
113 assert!(
114 d_model_input == d_model,
115 "d_model({d_model_input}) of the input must be equal to d_model of encoding({d_model})"
116 );
117
118 let slices = [0..batch_size, 0..seq_length, 0..d_model];
119
120 input.add(self.sinusoids.clone().slice(slices))
121 }
122}
123
124pub fn generate_sinusoids<B: Backend>(
141 length: usize,
142 d_model: usize,
143 max_timescale: usize,
144 device: &B::Device,
145) -> Tensor<B, 2> {
146 assert!(d_model % 2 == 0, "d_model must be even");
147 assert!(
148 max_timescale >= length,
149 "max_timescale must be greater than length"
150 );
151
152 let log_timescale_increment = -(max_timescale as f32).ln() / d_model as f32;
154
155 let mut scaled_time_sin_cos = Vec::with_capacity(length);
157
158 for i in 0..length {
160 let mut row = Vec::with_capacity(d_model / 2);
162 for k in (0..d_model).step_by(2) {
164 let div_term = (k as f32 * log_timescale_increment).exp();
166 row.push((div_term * i as f32).sin());
168 row.push((div_term * i as f32).cos());
169 }
170
171 scaled_time_sin_cos.push(row);
173 }
174
175 let data = TensorData::new(
177 scaled_time_sin_cos.into_iter().flatten().collect(),
178 [length, d_model],
179 );
180
181 Tensor::<B, 2>::from_data(data, device)
182}
183
184#[cfg(test)]
185mod tests {
186
187 use super::*;
188 use crate::TestBackend;
189 use burn_tensor::{Tolerance, ops::FloatElem};
190 type FT = FloatElem<TestBackend>;
191
192 #[test]
193 fn test_module() {
194 let d_model = 6;
195 let length = 3;
196
197 let batch_size = 2;
199
200 let device = Default::default();
201 let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
202
203 let tensor = Tensor::zeros([batch_size, length, d_model], &device);
206
207 let output = pe.forward(tensor);
208
209 assert_eq!(output.shape().dims, [batch_size, length, d_model]);
210
211 let expected = Tensor::<TestBackend, 3>::from_floats(
212 [
213 [
214 [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
215 [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
216 [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
217 ],
218 [
219 [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
220 [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
221 [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
222 ],
223 ],
224 &device,
225 );
226
227 output
228 .to_data()
229 .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
230 }
231
232 #[test]
233 fn test_generate_sinusoids() {
234 let device = Default::default();
235 let sinusoids = generate_sinusoids::<TestBackend>(12, 6, 10_000, &device);
236
237 let expected = Tensor::<TestBackend, 2>::from_floats(
239 [
240 [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
241 [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
242 [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
243 [0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998],
244 [-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996],
245 [-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994],
246 [-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992],
247 [0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989],
248 [0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985],
249 [0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981],
250 [-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977],
251 [-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972],
252 ],
253 &device,
254 );
255 sinusoids
256 .to_data()
257 .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
258 }
259
260 #[test]
261 #[should_panic]
262 fn d_model_input_should_match() {
263 let d_model = 8;
264 let device = Default::default();
265 let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
266 let input = Tensor::zeros([1, 5, 10], &device);
267 let _output = pe.forward(input);
268 }
269
270 #[test]
271 #[should_panic]
272 fn input_length_should_be_less_than_max_len() {
273 let d_model = 8;
274 let device = Default::default();
275 let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
276 let input = Tensor::zeros([1, 6_000, d_model], &device);
277 let _output = pe.forward(input);
278 }
279
280 #[test]
281 fn display() {
282 let config = PositionalEncodingConfig::new(4);
283 let pe = config.init::<TestBackend>(&Default::default());
284
285 assert_eq!(
286 alloc::format!("{pe}"),
287 "PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}"
288 );
289 }
290}