1use burn_core as burn;
2
3use alloc::vec;
4use burn::config::Config;
5use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
6use burn::tensor::Int;
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9use core::ops::Range;
10
11#[cfg(not(feature = "std"))]
12#[allow(unused_imports)]
13use num_traits::Float as _;
14
15#[derive(Config, Debug)]
17pub struct RotaryEncodingConfig {
18 pub max_sequence_length: usize,
20
21 pub d_model: usize,
23
24 #[config(default = "10000.0")]
26 pub theta: f32,
27}
28
29impl RotaryEncodingConfig {
30 pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
37 self.initialize(|x| x, device)
38 }
39
40 pub fn init_with_frequency_scaling<B: Backend>(
48 &self,
49 scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
50 device: &B::Device,
51 ) -> RotaryEncoding<B> {
52 self.initialize(scaling, device)
53 }
54
55 fn initialize<B: Backend>(
62 &self,
63 scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
64 device: &B::Device,
65 ) -> RotaryEncoding<B> {
66 assert_eq!(
67 self.d_model % 2,
68 0,
69 "The input embedding dimension must be even"
70 );
71 assert!(
72 self.theta > 0.0,
73 "Theta parameter must be positive (default: 10000)."
74 );
75
76 let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
79 .float()
80 .div_scalar(self.d_model as f32);
81
82 let theta = exponent.mul_scalar(self.theta.ln()).exp().recip();
85
86 let theta = scaling(theta);
87
88 let freq_complex =
89 RotaryEncoding::compute_rotary_frequencies(0..self.max_sequence_length, theta.clone());
90
91 RotaryEncoding {
92 freq_complex,
93 theta,
94 start_offset: 0,
95 }
96 }
97}
98
99#[derive(Module, Debug)]
108#[module(custom_display)]
109pub struct RotaryEncoding<B: Backend> {
110 pub freq_complex: Tensor<B, 3>,
113 pub theta: Tensor<B, 1>,
115 start_offset: usize,
116}
117
118impl<B: Backend> ModuleDisplay for RotaryEncoding<B> {
119 fn custom_settings(&self) -> Option<DisplaySettings> {
120 DisplaySettings::new()
121 .with_new_line_after_attribute(false)
122 .optional()
123 }
124
125 fn custom_content(&self, content: Content) -> Option<Content> {
126 let [max_sequence_length, d_model, _] = self.freq_complex.shape().dims();
127 content
128 .add("d_model", &d_model)
129 .add("max_sequence_length", &max_sequence_length)
130 .optional()
131 }
132}
133
134#[allow(clippy::single_range_in_vec_init)]
135impl<B: Backend> RotaryEncoding<B> {
136 pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
149 self.apply(x, 0)
150 }
151
152 pub fn apply<const D: usize>(&self, x: Tensor<B, D>, start: usize) -> Tensor<B, D> {
166 assert!(
167 D >= 2,
168 "Input tensor must have at least 2 dimensions for sequence length and hidden dimension"
169 );
170
171 let device = x.device();
172 let input_shape = x.shape();
173
174 let (seq_len, d_model) = (x.dims()[D - 2], x.dims()[D - 1]);
177 let dummy_dim_size = input_shape.num_elements() / (seq_len * d_model);
178
179 let sign_tensor =
182 Tensor::<B, 2>::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]], &device);
183
184 let out: Tensor<B, 4> = x
186 .reshape([dummy_dim_size, seq_len, d_model / 2, 2])
187 .matmul(sign_tensor.unsqueeze())
188 .reshape([dummy_dim_size, seq_len, d_model, 2])
189 * self
190 .freq_complex
191 .clone()
192 .slice([start..start + seq_len])
193 .unsqueeze();
194
195 out.sum_dim(-1).reshape(input_shape)
197 }
198
199 pub fn shift(&mut self, start: usize) {
204 let max_seq_len = self.freq_complex.dims()[0];
205 assert!(
206 start > self.start_offset,
207 "Shift start position must be monotonically increasing"
208 );
209
210 let current_end = self.start_offset + max_seq_len;
211
212 if start >= current_end {
213 let new_freqs =
215 Self::compute_rotary_frequencies(start..start + max_seq_len, self.theta.clone());
216 self.freq_complex
217 .inplace(|freqs| freqs.slice_assign([0..max_seq_len], new_freqs));
218 } else {
219 let num_keep = current_end - start;
221 let start_rel = start - self.start_offset;
222 let tail_freqs = self.freq_complex.clone().slice([start_rel..max_seq_len]);
223 self.freq_complex
224 .inplace(|freqs| freqs.slice_assign([0..num_keep], tail_freqs));
225 let new_freqs = Self::compute_rotary_frequencies(
227 current_end..start + max_seq_len,
228 self.theta.clone(),
229 );
230 self.freq_complex
231 .inplace(|freqs| freqs.slice_assign([num_keep..max_seq_len], new_freqs));
232 }
233 self.start_offset = start;
234 }
235
236 fn compute_rotary_frequencies(range: Range<usize>, theta: Tensor<B, 1>) -> Tensor<B, 3> {
245 let d_model = theta.dims()[0] * 2;
246 let num_positions = range.end - range.start;
247
248 let frequencies: Tensor<B, 2> =
250 Tensor::<B, 1, Int>::arange(range.start as i64..range.end as i64, &theta.device())
251 .float()
252 .unsqueeze()
253 .transpose()
254 .repeat_dim(1, d_model / 2)
255 * theta.unsqueeze();
256
257 let p_cos = frequencies.clone().cos();
259 let p_sin = frequencies.sin();
260
261 Tensor::cat(vec![p_cos, p_sin], 1)
262 .reshape([num_positions, 2, d_model / 2])
263 .transpose()
264 .unsqueeze_dim::<4>(2)
265 .repeat_dim(2, 2)
266 .reshape([num_positions, d_model, 2])
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::TestBackend;
274 use burn::tensor::{Tolerance, ops::FloatElem};
275 type FT = FloatElem<TestBackend>;
276
277 #[test]
278 fn test_rotary_encoding_forward() {
279 let device = Default::default();
280 let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
281
282 let input = Tensor::<TestBackend, 3>::from_floats(
283 [
284 [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
285 [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
286 ],
287 &device,
288 );
289
290 let input = input.unsqueeze::<4>();
292
293 let output = rotary_encoding.forward(input);
294 let expected_output = Tensor::<TestBackend, 3>::from_floats(
295 [
296 [
297 [1.0000, 2.0000, 3.0000, 4.0000],
298 [-2.3473, 7.4492, 6.9197, 8.0696],
299 ],
300 [
301 [9.0000, 10.0000, 11.0000, 12.0000],
302 [-4.7567, 18.5034, 14.8393, 16.1492],
303 ],
304 ],
305 &device,
306 );
307
308 output
309 .squeeze_dim::<3>(0)
310 .to_data()
311 .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
312 }
313
314 #[test]
315 fn test_rotary_encoding_3d() {
316 let device = Default::default();
317 let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
318
319 let input = Tensor::<TestBackend, 3>::from_floats(
320 [
321 [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
322 [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
323 ],
324 &device,
325 );
326
327 let output = rotary_encoding.forward(input);
331 let expected_output = Tensor::<TestBackend, 3>::from_floats(
332 [
333 [
334 [1.0000, 2.0000, 3.0000, 4.0000],
335 [-2.3473, 7.4492, 6.9197, 8.0696],
336 ],
337 [
338 [9.0000, 10.0000, 11.0000, 12.0000],
339 [-4.7567, 18.5034, 14.8393, 16.1492],
340 ],
341 ],
342 &device,
343 );
344
345 output
346 .to_data()
347 .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
348 }
349
350 #[test]
351 fn test_zero_input_rotary_encoding_forward() {
352 let device = Default::default();
353 let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
354
355 let input = Tensor::<TestBackend, 4>::zeros([1, 2, 2, 4], &device);
357
358 let output = rotary_encoding.forward(input);
359 let expected_output = Tensor::<TestBackend, 3>::from_floats(
360 [
361 [
362 [0.0000, 0.0000, 0.0000, 0.0000],
363 [0.0000, 0.0000, 0.0000, 0.0000],
364 ],
365 [
366 [0.0000, 0.0000, 0.0000, 0.0000],
367 [0.0000, 0.0000, 0.0000, 0.0000],
368 ],
369 ],
370 &device,
371 );
372
373 output
374 .squeeze_dim::<3>(0)
375 .to_data()
376 .assert_approx_eq::<FT>(&expected_output.to_data(), Tolerance::default());
377 }
378
379 #[test]
380 #[should_panic]
381 fn test_valid_input_hidden_dim() {
382 let d_model = 15;
385 let device = Default::default();
386 let pe = RotaryEncodingConfig::new(10, d_model).init::<TestBackend>(&device);
387 let input = Tensor::<TestBackend, 3>::zeros([1, 5, d_model], &device);
388 let _output = pe.forward(input);
389 }
390
391 #[test]
392 fn test_rotary_encoding_frequencies() {
393 let device = Default::default();
394 let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::<TestBackend>(&device);
395
396 let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
397 [
398 [
399 [1.0000, 0.0000],
400 [1.0000, 0.0000],
401 [1.0000, 0.0000],
402 [1.0000, 0.0000],
403 ],
404 [
405 [5.4030e-01, 8.4147e-01],
406 [9.9500e-01, 9.9833e-02],
407 [9.9995e-01, 9.9998e-03],
408 [9.9999e-01, 9.9999e-04],
409 ],
410 ],
411 &device,
412 )
413 .unsqueeze_dim::<4>(2)
414 .repeat_dim(2, 2)
415 .reshape([2, 8, 2]);
416
417 rotary_encoding
418 .freq_complex
419 .to_data()
420 .assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());
421 }
422
423 fn apply_freq_scaling_by_parts<B: Backend>(freqs: Tensor<B, 1>) -> Tensor<B, 1> {
424 let scale_factor = 8.;
426 let low_freq_factor = 1.;
427 let high_freq_factor = 4.;
428 let old_context_len = 8192.;
429
430 let low_freq_wavelen = old_context_len / low_freq_factor;
431 let high_freq_wavelen = old_context_len / high_freq_factor;
432
433 let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI);
434
435 let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen);
437 let smooth = wavelen
438 .clone()
439 .recip()
440 .mul_scalar(old_context_len)
441 .sub_scalar(low_freq_factor)
442 .div_scalar(high_freq_factor - low_freq_factor);
443 let new_freqs = smooth
445 .clone()
446 .neg()
447 .add_scalar(1.)
448 .mul(freqs.clone().div_scalar(scale_factor))
449 .add(smooth.clone().mul(freqs.clone()));
450 let new_freqs = freqs.clone().mask_where(cond, new_freqs);
451
452 let cond = wavelen.clone().greater_elem(low_freq_wavelen);
454 let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor));
455
456 let cond = wavelen.lower_elem(high_freq_wavelen);
458 new_freqs.mask_where(cond, freqs)
459 }
460
461 #[test]
462 fn test_rotary_encoding_with_frequency_scaling() {
463 let device = Default::default();
464 let rotary_encoding = RotaryEncodingConfig::new(2, 8)
465 .init_with_frequency_scaling::<TestBackend>(apply_freq_scaling_by_parts, &device);
466
467 let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
468 [
469 [
470 [1.0000, 0.0000],
471 [1.0000, 0.0000],
472 [1.0000, 0.0000],
473 [1.0000, 0.0000],
474 ],
475 [
476 [5.4030e-01, 8.4148e-01],
477 [9.9500e-01, 9.9833e-02],
478 [9.9995e-01, 9.9998e-03],
479 [1.0000, 2.1361e-04],
480 ],
481 ],
482 &device,
483 )
484 .unsqueeze_dim::<4>(2)
485 .repeat_dim(2, 2)
486 .reshape([2, 8, 2]);
487
488 rotary_encoding
489 .freq_complex
490 .to_data()
491 .assert_approx_eq::<FT>(&expected_freqs.to_data(), Tolerance::default());
492 }
493
494 #[test]
495 fn test_rotary_encoding_shift_full() {
496 let device = Default::default();
497 let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
498
499 let input = Tensor::<TestBackend, 3>::from_floats(
501 [
502 [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
503 [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
504 ],
505 &device,
506 )
507 .unsqueeze::<4>();
508
509 let expected_output = rotary_encoding.apply(input.clone(), 6);
513
514 let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
515 rotary_encoding.shift(6); let output = rotary_encoding.apply(input, 0);
518
519 output
520 .into_data()
521 .assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());
522 }
523
524 #[test]
525 fn test_rotary_encoding_shift() {
526 let device = Default::default();
527 let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
528
529 let input = Tensor::<TestBackend, 3>::from_floats(
531 [
532 [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
533 [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
534 ],
535 &device,
536 )
537 .unsqueeze::<4>();
538
539 let expected_output = rotary_encoding.apply(input.clone(), 2);
543
544 let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
545 rotary_encoding.shift(2); let output = rotary_encoding.apply(input, 0);
548
549 output
550 .into_data()
551 .assert_approx_eq::<FT>(&expected_output.into_data(), Tolerance::default());
552 }
553
554 #[test]
555 fn test_rotary_encoding_shift_multiple() {
556 let device = Default::default();
557 let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
558 rotary_encoding.shift(2);
559 rotary_encoding.shift(5);
560 }
561
562 #[test]
563 #[should_panic = "Shift start position must be monotonically increasing"]
564 fn test_rotary_encoding_shift_should_increase() {
565 let device = Default::default();
566 let mut rotary_encoding = RotaryEncodingConfig::new(4, 4).init::<TestBackend>(&device);
567 rotary_encoding.shift(6);
568 rotary_encoding.shift(4); }
570
571 #[test]
572 fn display() {
573 let config = RotaryEncodingConfig::new(10, 4);
574 let pe = config.init::<TestBackend>(&Default::default());
575
576 assert_eq!(
577 alloc::format!("{pe}"),
578 "RotaryEncoding {d_model: 4, max_sequence_length: 10}"
579 );
580 }
581}