1use burn_core as burn;
2
3use alloc::vec::Vec;
4use burn::config::Config;
5use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
6
7use burn::tensor::Tensor;
8use burn::tensor::TensorData;
9use burn::tensor::backend::Backend;
10
11#[cfg(not(feature = "std"))]
12#[allow(unused_imports)]
13use num_traits::Float as _;
14
15#[derive(Config, Debug)]
17pub struct PositionalEncodingConfig {
18 #[config(default = "5_000")]
20 pub max_sequence_size: usize,
21
22 pub d_model: usize,
24
25 #[config(default = "10_000")]
27 pub max_timescale: usize,
28}
29
30#[derive(Module, Debug)]
45#[module(custom_display)]
46pub struct PositionalEncoding<B: Backend> {
47 pub sinusoids: Tensor<B, 3>,
49 pub max_sequence_size: usize,
51 pub max_timescale: usize,
53}
54
55impl<B: Backend> ModuleDisplay for PositionalEncoding<B> {
56 fn custom_settings(&self) -> Option<DisplaySettings> {
57 DisplaySettings::new()
58 .with_new_line_after_attribute(false)
59 .optional()
60 }
61
62 fn custom_content(&self, content: Content) -> Option<Content> {
63 let [_, _, d_model] = self.sinusoids.shape().dims();
64 content
65 .add("d_model", &d_model)
66 .add("max_sequence_size", &self.max_sequence_size)
67 .add("max_timescale", &self.max_timescale)
68 .optional()
69 }
70}
71
72impl PositionalEncodingConfig {
73 pub fn init<B: Backend>(&self, device: &B::Device) -> PositionalEncoding<B> {
75 let sinusoids = generate_sinusoids::<B>(
76 self.max_sequence_size,
77 self.d_model,
78 self.max_timescale,
79 device,
80 )
81 .unsqueeze::<3>();
82
83 PositionalEncoding {
84 sinusoids,
85 max_sequence_size: self.max_sequence_size,
86 max_timescale: self.max_timescale,
87 }
88 }
89}
90
91impl<B: Backend> PositionalEncoding<B> {
92 pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
105 let [_, seq_length, d_model_input] = input.dims();
106
107 let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims();
108
109 assert!(
110 max_sequence_size >= seq_length,
111 "max_sequence_size({max_sequence_size}) must be greater or equal than length({seq_length})"
112 );
113
114 assert!(
115 d_model_input == d_model,
116 "d_model({d_model_input}) of the input must be equal to d_model of encoding({d_model})"
117 );
118
119 let slices = [0..batch_size, 0..seq_length, 0..d_model];
120
121 input.add(self.sinusoids.clone().slice(slices))
122 }
123}
124
125pub fn generate_sinusoids<B: Backend>(
142 length: usize,
143 d_model: usize,
144 max_timescale: usize,
145 device: &B::Device,
146) -> Tensor<B, 2> {
147 assert!(d_model.is_multiple_of(2), "d_model must be even");
148 assert!(
149 max_timescale >= length,
150 "max_timescale must be greater than length"
151 );
152
153 let log_timescale_increment = -(max_timescale as f32).ln() / d_model as f32;
155
156 let mut scaled_time_sin_cos = Vec::with_capacity(length);
158
159 for i in 0..length {
161 let mut row = Vec::with_capacity(d_model / 2);
163 for k in (0..d_model).step_by(2) {
165 let div_term = (k as f32 * log_timescale_increment).exp();
167 row.push((div_term * i as f32).sin());
169 row.push((div_term * i as f32).cos());
170 }
171
172 scaled_time_sin_cos.push(row);
174 }
175
176 let data = TensorData::new(
178 scaled_time_sin_cos.into_iter().flatten().collect(),
179 [length, d_model],
180 );
181
182 Tensor::<B, 2>::from_data(data, device)
183}
184
185#[cfg(test)]
186mod tests {
187
188 use super::*;
189 use crate::TestBackend;
190 use burn::tensor::{Tolerance, ops::FloatElem};
191 type FT = FloatElem<TestBackend>;
192
193 #[test]
194 fn test_module() {
195 let d_model = 6;
196 let length = 3;
197
198 let batch_size = 2;
200
201 let device = Default::default();
202 let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
203
204 let tensor = Tensor::zeros([batch_size, length, d_model], &device);
207
208 let output = pe.forward(tensor);
209
210 assert_eq!(output.shape().dims, [batch_size, length, d_model]);
211
212 let expected = Tensor::<TestBackend, 3>::from_floats(
213 [
214 [
215 [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
216 [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
217 [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
218 ],
219 [
220 [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
221 [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
222 [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
223 ],
224 ],
225 &device,
226 );
227
228 output
229 .to_data()
230 .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
231 }
232
233 #[test]
234 fn test_generate_sinusoids() {
235 let device = Default::default();
236 let sinusoids = generate_sinusoids::<TestBackend>(12, 6, 10_000, &device);
237
238 let expected = Tensor::<TestBackend, 2>::from_floats(
240 [
241 [0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
242 [0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
243 [0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
244 [0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998],
245 [-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996],
246 [-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994],
247 [-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992],
248 [0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989],
249 [0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985],
250 [0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981],
251 [-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977],
252 [-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972],
253 ],
254 &device,
255 );
256 sinusoids
257 .to_data()
258 .assert_approx_eq::<FT>(&expected.to_data(), Tolerance::default());
259 }
260
261 #[test]
262 #[should_panic]
263 fn d_model_input_should_match() {
264 let d_model = 8;
265 let device = Default::default();
266 let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
267 let input = Tensor::zeros([1, 5, 10], &device);
268 let _output = pe.forward(input);
269 }
270
271 #[test]
272 #[should_panic]
273 fn input_length_should_be_less_than_max_len() {
274 let d_model = 8;
275 let device = Default::default();
276 let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
277 let input = Tensor::zeros([1, 6_000, d_model], &device);
278 let _output = pe.forward(input);
279 }
280
281 #[test]
282 fn display() {
283 let config = PositionalEncodingConfig::new(4);
284 let pe = config.init::<TestBackend>(&Default::default());
285
286 assert_eq!(
287 alloc::format!("{pe}"),
288 "PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}"
289 );
290 }
291}