1use core::ops::Range;
2
3use crate as burn;
4use crate::config::Config;
5use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
6use crate::tensor::Int;
7use crate::tensor::Tensor;
8use crate::tensor::backend::Backend;
9use alloc::vec;
10
11#[cfg(not(feature = "std"))]
12use num_traits::Float;
13
14#[derive(Config, Debug)]
16pub struct RotaryEncodingConfig {
17 pub max_sequence_length: usize,
19
20 pub d_model: usize,
22
23 #[config(default = "10000.0")]
25 pub theta: f32,
26}
27
28impl RotaryEncodingConfig {
29 pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
36 self.initialize(|x| x, device)
37 }
38
39 pub fn init_with_frequency_scaling<B: Backend>(
47 &self,
48 scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
49 device: &B::Device,
50 ) -> RotaryEncoding<B> {
51 self.initialize(scaling, device)
52 }
53
54 fn initialize<B: Backend>(
61 &self,
62 scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
63 device: &B::Device,
64 ) -> RotaryEncoding<B> {
65 assert_eq!(
66 self.d_model % 2,
67 0,
68 "The input embedding dimension must be even"
69 );
70 assert!(
71 self.theta > 0.0,
72 "Theta parameter must be positive (default: 10000)."
73 );
74
75 let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
78 .float()
79 .div_scalar(self.d_model as f32);
80
81 let theta = exponent.mul_scalar(self.theta.ln()).exp().recip();
84
85 let theta = scaling(theta);
86
87 let freq_complex =
88 RotaryEncoding::compute_rotary_frequencies(0..self.max_sequence_length, theta.clone());
89
90 RotaryEncoding {
91 freq_complex,
92 theta,
93 start_offset: 0,
94 }
95 }
96}
97
98#[derive(Module, Debug)]
107#[module(custom_display)]
108pub struct RotaryEncoding<B: Backend> {
109 pub freq_complex: Tensor<B, 3>,
112 pub theta: Tensor<B, 1>,
114 start_offset: usize,
115}
116
117impl<B: Backend> ModuleDisplay for RotaryEncoding<B> {
118 fn custom_settings(&self) -> Option<DisplaySettings> {
119 DisplaySettings::new()
120 .with_new_line_after_attribute(false)
121 .optional()
122 }
123
124 fn custom_content(&self, content: Content) -> Option<Content> {
125 let [max_sequence_length, d_model, _] = self.freq_complex.shape().dims();
126 content
127 .add("d_model", &d_model)
128 .add("max_sequence_length", &max_sequence_length)
129 .optional()
130 }
131}
132
133#[allow(clippy::single_range_in_vec_init)]
134impl<B: Backend> RotaryEncoding<B> {
135 pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
148 self.apply(x, 0)
149 }
150
151 pub fn apply<const D: usize>(&self, x: Tensor<B, D>, start: usize) -> Tensor<B, D> {
165 assert!(
166 D >= 2,
167 "Input tensor must have at least 2 dimensions for sequence length and hidden dimension"
168 );
169
170 let device = x.device();
171 let input_shape = x.shape();
172
173 let (seq_len, d_model) = (x.dims()[D - 2], x.dims()[D - 1]);
176 let dummy_dim_size = input_shape.num_elements() / (seq_len * d_model);
177
178 let sign_tensor =
181 Tensor::<B, 2>::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]], &device);
182
183 let out: Tensor<B, 4> = x
185 .reshape([dummy_dim_size, seq_len, d_model / 2, 2])
186 .matmul(sign_tensor.unsqueeze())
187 .reshape([dummy_dim_size, seq_len, d_model, 2])
188 * self
189 .freq_complex
190 .clone()
191 .slice([start..start + seq_len])
192 .unsqueeze();
193
194 out.sum_dim(D - 1).reshape(input_shape)
196 }
197
198 pub fn shift(&mut self, start: usize) {
203 let max_seq_len = self.freq_complex.dims()[0];
204 assert!(
205 start > self.start_offset,
206 "Shift start position must be monotonically increasing"
207 );
208
209 let current_end = self.start_offset + max_seq_len;
210
211 if start >= current_end {
212 let new_freqs =
214 Self::compute_rotary_frequencies(start..start + max_seq_len, self.theta.clone());
215 self.freq_complex
216 .inplace(|freqs| freqs.slice_assign([0..max_seq_len], new_freqs));
217 } else {
218 let num_keep = current_end - start;
220 let start_rel = start - self.start_offset;
221 let tail_freqs = self.freq_complex.clone().slice([start_rel..max_seq_len]);
222 self.freq_complex
223 .inplace(|freqs| freqs.slice_assign([0..num_keep], tail_freqs));
224 let new_freqs = Self::compute_rotary_frequencies(
226 current_end..start + max_seq_len,
227 self.theta.clone(),
228 );
229 self.freq_complex
230 .inplace(|freqs| freqs.slice_assign([num_keep..max_seq_len], new_freqs));
231 }
232 self.start_offset = start;
233 }
234
235 fn compute_rotary_frequencies(range: Range<usize>, theta: Tensor<B, 1>) -> Tensor<B, 3> {
244 let d_model = theta.dims()[0] * 2;
245 let num_positions = range.end - range.start;
246
247 let frequencies: Tensor<B, 2> =
249 Tensor::<B, 1, Int>::arange(range.start as i64..range.end as i64, &theta.device())
250 .float()
251 .unsqueeze()
252 .transpose()
253 .repeat_dim(1, d_model / 2)
254 * theta.unsqueeze();
255
256 let p_cos = frequencies.clone().cos();
258 let p_sin = frequencies.sin();
259
260 Tensor::cat(vec![p_cos, p_sin], 1)
261 .reshape([num_positions, 2, d_model / 2])
262 .transpose()
263 .unsqueeze_dim::<4>(2)
264 .repeat_dim(2, 2)
265 .reshape([num_positions, d_model, 2])
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::TestBackend;
273 use burn_tensor::{Tolerance, ops::FloatElem};
274 type FT = FloatElem<TestBackend>;
275
276 #[test]
277 fn test_rotary_encoding_forward() {
278 let device = Default::default();
279 let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
280
281 let input = Tensor::<TestBackend, 3>::from_floats(
282 [
283 [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
284 [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
285 ],
286 &device,
287 );
288
289 let input = input.unsqueeze::<4>();
291
292 let output = rotary_encoding.forward(input);
293 let expected_output = Tensor::<TestBackend, 3>::from_floats(
294 [
295 [
296 [1.0000, 2.0000, 3.0000, 4.0000],
297 [-2.3473, 7.4492, 6.9197, 8.0696],
298 ],
299 [
300 [9.0000, 10.0000, 11.0000, 12.0000],
301 [-4.7567, 18.5034, 14.8393, 16.1492],
302 ],
303 ],
304 &device,
305 );
306
307 output
308 .squeeze::<3>(0)
309 .to_data()
310 .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
311 }
312
313 #[test]
314 fn test_zero_input_rotary_encoding_forward() {
315 let device = Default::default();
316 let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
317
318 let input = Tensor::<TestBackend, 4>::zeros([1, 2, 2, 4], &device);
320
321 let output = rotary_encoding.forward(input);
322 let expected_output = Tensor::<TestBackend, 3>::from_floats(
323 [
324 [
325 [0.0000, 0.0000, 0.0000, 0.0000],
326 [0.0000, 0.0000, 0.0000, 0.0000],
327 ],
328 [
329 [0.0000, 0.0000, 0.0000, 0.0000],
330 [0.0000, 0.0000, 0.0000, 0.0000],
331 ],
332 ],
333 &device,
334 );
335
336 output
337 .squeeze::<3>(0)
338 .to_data()
339 .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
340 }
341
342 #[test]
343 #[should_panic]
344 fn test_valid_input_hidden_dim() {
345 let d_model = 15;
348 let device = Default::default();
349 let pe = RotaryEncodingConfig::new(10, d_model).init::<TestBackend>(&device);
350 let input = Tensor::<TestBackend, 3>::zeros([1, 5, d_model], &device);
351 let _output = pe.forward(input);
352 }
353
354 #[test]
355 fn test_rotary_encoding_frequencies() {
356 let device = Default::default();
357 let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::<TestBackend>(&device);
358
359 let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
360 [
361 [
362 [1.0000, 0.0000],
363 [1.0000, 0.0000],
364 [1.0000, 0.0000],
365 [1.0000, 0.0000],
366 ],
367 [
368 [5.4030e-01, 8.4147e-01],
369 [9.9500e-01, 9.9833e-02],
370 [9.9995e-01, 9.9998e-03],
371 [9.9999e-01, 9.9999e-04],
372 ],
373 ],
374 &device,
375 )
376 .unsqueeze_dim::<4>(2)
377 .repeat_dim(2, 2)
378 .reshape([2, 8, 2]);
379
380 rotary_encoding
381 .freq_complex
382 .to_data()
383 .assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());
384 }
385
386 fn apply_freq_scaling_by_parts<B: Backend>(freqs: Tensor<B, 1>) -> Tensor<B, 1> {
387 let scale_factor = 8.;
389 let low_freq_factor = 1.;
390 let high_freq_factor = 4.;
391 let old_context_len = 8192.;
392
393 let low_freq_wavelen = old_context_len / low_freq_factor;
394 let high_freq_wavelen = old_context_len / high_freq_factor;
395
396 let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI);
397
398 let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen);
400 let smooth = wavelen
401 .clone()
402 .recip()
403 .mul_scalar(old_context_len)
404 .sub_scalar(low_freq_factor)
405 .div_scalar(high_freq_factor - low_freq_factor);
406 let new_freqs = smooth
408 .clone()
409 .neg()
410 .add_scalar(1.)
411 .mul(freqs.clone().div_scalar(scale_factor))
412 .add(smooth.clone().mul(freqs.clone()));
413 let new_freqs = freqs.clone().mask_where(cond, new_freqs);
414
415 let cond = wavelen.clone().greater_elem(low_freq_wavelen);
417 let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor));
418
419 let cond = wavelen.lower_elem(high_freq_wavelen);
421 new_freqs.mask_where(cond, freqs)
422 }
423
424 #[test]
425 fn test_rotary_encoding_with_frequency_scaling() {
426 let device = Default::default();
427 let rotary_encoding = RotaryEncodingConfig::new(2, 8)
428 .init_with_frequency_scaling::<TestBackend>(apply_freq_scaling_by_parts, &device);
429
430 let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
431 [
432 [
433 [1.0000, 0.0000],
434 [1.0000, 0.0000],
435 [1.0000, 0.0000],
436 [1.0000, 0.0000],
437 ],
438 [
439 [5.4030e-01, 8.4148e-01],
440 [9.9500e-01, 9.9833e-02],
441 [9.9995e-01, 9.9998e-03],
442 [1.0000, 2.1361e-04],
443 ],
444 ],
445 &device,
446 )
447 .unsqueeze_dim::<4>(2)
448 .repeat_dim(2, 2)
449 .reshape([2, 8, 2]);
450
451 rotary_encoding
452 .freq_complex
453 .to_data()
454 .assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());
455 }
456
457 #[test]
458 fn test_rotary_encoding_shift_full() {
459 let device = Default::default();
460 let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
461
462 let input = Tensor::<TestBackend, 3>::from_floats(
464 [
465 [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
466 [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
467 ],
468 &device,
469 )
470 .unsqueeze::<4>();
471
472 let expected_output = rotary_encoding.apply(input.clone(), 6);
476
477 let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
478 rotary_encoding.shift(6); let output = rotary_encoding.apply(input, 0);
481
482 output
483 .into_data()
484 .assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());
485 }
486
487 #[test]
488 fn test_rotary_encoding_shift() {
489 let device = Default::default();
490 let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
491
492 let input = Tensor::<TestBackend, 3>::from_floats(
494 [
495 [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
496 [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
497 ],
498 &device,
499 )
500 .unsqueeze::<4>();
501
502 let expected_output = rotary_encoding.apply(input.clone(), 2);
506
507 let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
508 rotary_encoding.shift(2); let output = rotary_encoding.apply(input, 0);
511
512 output
513 .into_data()
514 .assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());
515 }
516
517 #[test]
518 fn test_rotary_encoding_shift_multiple() {
519 let device = Default::default();
520 let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
521 rotary_encoding.shift(2);
522 rotary_encoding.shift(5);
523 }
524
525 #[test]
526 #[should_panic = "Shift start position must be monotonically increasing"]
527 fn test_rotary_encoding_shift_should_increase() {
528 let device = Default::default();
529 let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
530 rotary_encoding.shift(6);
531 rotary_encoding.shift(4); }
533
534 #[test]
535 fn display() {
536 let config = RotaryEncodingConfig::new(10, 4);
537 let pe = config.init::<TestBackend>(&Default::default());
538
539 assert_eq!(
540 alloc::format!("{pe}"),
541 "RotaryEncoding {d_model: 4, max_sequence_length: 10}"
542 );
543 }
544}