1use crate as burn;
2use crate::config::Config;
3use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
4use crate::tensor::backend::Backend;
5use crate::tensor::Int;
6use crate::tensor::Tensor;
7use alloc::vec;
8
9#[cfg(not(feature = "std"))]
10use num_traits::Float;
11
12#[derive(Config, Debug)]
14pub struct RotaryEncodingConfig {
15 pub max_sequence_length: usize,
17
18 pub d_model: usize,
20
21 #[config(default = "10000.0")]
23 pub theta: f32,
24}
25
26impl RotaryEncodingConfig {
27 pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
34 self.initialize(|x| x, device)
35 }
36
37 pub fn init_with_frequency_scaling<B: Backend>(
45 &self,
46 scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
47 device: &B::Device,
48 ) -> RotaryEncoding<B> {
49 self.initialize(scaling, device)
50 }
51
52 fn initialize<B: Backend>(
59 &self,
60 scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
61 device: &B::Device,
62 ) -> RotaryEncoding<B> {
63 assert_eq!(
64 self.d_model % 2,
65 0,
66 "The input embedding dimension must be even"
67 );
68 assert!(
69 self.theta > 0.0,
70 "Theta parameter must be positive (default: 10000)."
71 );
72
73 let exponent = Tensor::<B, 1, Int>::arange_step(0..self.d_model as i64, 2, device)
76 .float()
77 .div_scalar(self.d_model as f32);
78
79 let theta_i = exponent.mul_scalar(self.theta.ln()).exp();
82 let theta_i = theta_i.powf_scalar(-1.0);
83
84 let theta_i = scaling(theta_i);
85
86 let frequencies: Tensor<B, 2> =
88 Tensor::<B, 1, Int>::arange(0..self.max_sequence_length as i64, device)
89 .float()
90 .unsqueeze()
91 .transpose()
92 .repeat_dim(1, self.d_model / 2)
93 * theta_i.unsqueeze();
94
95 let p_cos = frequencies.clone().cos();
97 let p_sin = frequencies.sin();
98
99 let freq_complex: Tensor<B, 3> = Tensor::cat(vec![p_cos, p_sin], 1)
102 .reshape([self.max_sequence_length, 2, self.d_model / 2])
103 .transpose()
104 .unsqueeze_dim::<4>(2)
105 .repeat_dim(2, 2)
106 .reshape([self.max_sequence_length, self.d_model, 2]);
107
108 RotaryEncoding {
109 freq_complex,
110 max_sequence_length: self.max_sequence_length,
111 theta: self.theta,
112 }
113 }
114}
115
116#[derive(Module, Debug)]
125#[module(custom_display)]
126pub struct RotaryEncoding<B: Backend> {
127 pub freq_complex: Tensor<B, 3>,
129 pub max_sequence_length: usize,
131 pub theta: f32,
133}
134
135impl<B: Backend> ModuleDisplay for RotaryEncoding<B> {
136 fn custom_settings(&self) -> Option<DisplaySettings> {
137 DisplaySettings::new()
138 .with_new_line_after_attribute(false)
139 .optional()
140 }
141
142 fn custom_content(&self, content: Content) -> Option<Content> {
143 let [_, _, d_model] = self.freq_complex.shape().dims();
144 content
145 .add("d_model", &d_model)
146 .add("max_sequence_length", &self.max_sequence_length)
147 .add("theta", &self.theta)
148 .optional()
149 }
150}
151
152#[allow(clippy::single_range_in_vec_init)]
153impl<B: Backend> RotaryEncoding<B> {
154 pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
166 self.apply(x, 0)
167 }
168
169 pub fn apply<const D: usize>(&self, x: Tensor<B, D>, start: usize) -> Tensor<B, D> {
182 assert!(
183 D >= 2,
184 "Input tensor must have at least 2 dimensions for sequence length and hidden dimension"
185 );
186
187 let device = x.device();
188 let input_shape = x.shape();
189
190 let (seq_len, d_model) = (x.dims()[D - 2], x.dims()[D - 1]);
193 let dummy_dim_size = input_shape.num_elements() / (seq_len * d_model);
194
195 let sign_tensor =
198 Tensor::<B, 2>::from_floats([[1.0, 0.0, 0.0, 1.0], [0.0, -1.0, 1.0, 0.0]], &device);
199
200 let out: Tensor<B, 4> = x
202 .reshape([dummy_dim_size, seq_len, d_model / 2, 2])
203 .matmul(sign_tensor.unsqueeze())
204 .reshape([dummy_dim_size, seq_len, d_model, 2])
205 * self
206 .freq_complex
207 .clone()
208 .slice([start..start + seq_len])
209 .unsqueeze();
210
211 out.sum_dim(D - 1).reshape(input_shape)
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::TestBackend;
220
221 #[test]
222 fn test_rotary_encoding_forward() {
223 let device = Default::default();
224 let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
225
226 let input = Tensor::<TestBackend, 3>::from_floats(
227 [
228 [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],
229 [[9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]],
230 ],
231 &device,
232 );
233
234 let input = input.unsqueeze::<4>();
236
237 let output = rotary_encoding.forward(input);
238 let expected_output = Tensor::<TestBackend, 3>::from_floats(
239 [
240 [
241 [1.0000, 2.0000, 3.0000, 4.0000],
242 [-2.3473, 7.4492, 6.9197, 8.0696],
243 ],
244 [
245 [9.0000, 10.0000, 11.0000, 12.0000],
246 [-4.7567, 18.5034, 14.8393, 16.1492],
247 ],
248 ],
249 &device,
250 );
251
252 output
253 .squeeze::<3>(0)
254 .to_data()
255 .assert_approx_eq(&expected_output.to_data(), 4);
256 }
257
258 #[test]
259 fn test_zero_input_rotary_encoding_forward() {
260 let device = Default::default();
261 let rotary_encoding = RotaryEncodingConfig::new(10, 4).init::<TestBackend>(&device);
262
263 let input = Tensor::<TestBackend, 4>::zeros([1, 2, 2, 4], &device);
265
266 let output = rotary_encoding.forward(input);
267 let expected_output = Tensor::<TestBackend, 3>::from_floats(
268 [
269 [
270 [0.0000, 0.0000, 0.0000, 0.0000],
271 [0.0000, 0.0000, 0.0000, 0.0000],
272 ],
273 [
274 [0.0000, 0.0000, 0.0000, 0.0000],
275 [0.0000, 0.0000, 0.0000, 0.0000],
276 ],
277 ],
278 &device,
279 );
280
281 output
282 .squeeze::<3>(0)
283 .to_data()
284 .assert_approx_eq(&expected_output.to_data(), 4);
285 }
286
287 #[test]
288 #[should_panic]
289 fn test_valid_input_hidden_dim() {
290 let d_model = 15;
293 let device = Default::default();
294 let pe = RotaryEncodingConfig::new(10, d_model).init::<TestBackend>(&device);
295 let input = Tensor::<TestBackend, 3>::zeros([1, 5, d_model], &device);
296 let _output = pe.forward(input);
297 }
298
299 #[test]
300 fn test_rotary_encoding_frequencies() {
301 let device = Default::default();
302 let rotary_encoding = RotaryEncodingConfig::new(2, 8).init::<TestBackend>(&device);
303
304 let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
305 [
306 [
307 [1.0000, 0.0000],
308 [1.0000, 0.0000],
309 [1.0000, 0.0000],
310 [1.0000, 0.0000],
311 ],
312 [
313 [5.4030e-01, 8.4147e-01],
314 [9.9500e-01, 9.9833e-02],
315 [9.9995e-01, 9.9998e-03],
316 [9.9999e-01, 9.9999e-04],
317 ],
318 ],
319 &device,
320 )
321 .unsqueeze_dim::<4>(2)
322 .repeat_dim(2, 2)
323 .reshape([2, 8, 2]);
324
325 rotary_encoding
326 .freq_complex
327 .to_data()
328 .assert_approx_eq(&expected_freqs.to_data(), 4);
329 }
330
331 fn apply_freq_scaling_by_parts<B: Backend>(freqs: Tensor<B, 1>) -> Tensor<B, 1> {
332 let scale_factor = 8.;
334 let low_freq_factor = 1.;
335 let high_freq_factor = 4.;
336 let old_context_len = 8192.;
337
338 let low_freq_wavelen = old_context_len / low_freq_factor;
339 let high_freq_wavelen = old_context_len / high_freq_factor;
340
341 let wavelen = freqs.clone().recip().mul_scalar(2. * core::f32::consts::PI);
342
343 let cond = wavelen.clone().greater_equal_elem(high_freq_wavelen);
345 let smooth = wavelen
346 .clone()
347 .recip()
348 .mul_scalar(old_context_len)
349 .sub_scalar(low_freq_factor)
350 .div_scalar(high_freq_factor - low_freq_factor);
351 let new_freqs = smooth
353 .clone()
354 .neg()
355 .add_scalar(1.)
356 .mul(freqs.clone().div_scalar(scale_factor))
357 .add(smooth.clone().mul(freqs.clone()));
358 let new_freqs = freqs.clone().mask_where(cond, new_freqs);
359
360 let cond = wavelen.clone().greater_elem(low_freq_wavelen);
362 let new_freqs = new_freqs.mask_where(cond, freqs.clone().div_scalar(scale_factor));
363
364 let cond = wavelen.lower_elem(high_freq_wavelen);
366 new_freqs.mask_where(cond, freqs)
367 }
368
369 #[test]
370 fn test_rotary_encoding_with_frequency_scaling() {
371 let device = Default::default();
372 let rotary_encoding = RotaryEncodingConfig::new(2, 8)
373 .init_with_frequency_scaling::<TestBackend>(apply_freq_scaling_by_parts, &device);
374
375 let expected_freqs = Tensor::<TestBackend, 3>::from_floats(
376 [
377 [
378 [1.0000, 0.0000],
379 [1.0000, 0.0000],
380 [1.0000, 0.0000],
381 [1.0000, 0.0000],
382 ],
383 [
384 [5.4030e-01, 8.4148e-01],
385 [9.9500e-01, 9.9833e-02],
386 [9.9995e-01, 9.9998e-03],
387 [1.0000, 2.1361e-04],
388 ],
389 ],
390 &device,
391 )
392 .unsqueeze_dim::<4>(2)
393 .repeat_dim(2, 2)
394 .reshape([2, 8, 2]);
395
396 rotary_encoding
397 .freq_complex
398 .to_data()
399 .assert_approx_eq(&expected_freqs.to_data(), 4);
400 }
401
402 #[test]
403 fn display() {
404 let config = RotaryEncodingConfig::new(10, 4);
405 let pe = config.init::<TestBackend>(&Default::default());
406
407 assert_eq!(
408 alloc::format!("{}", pe),
409 "RotaryEncoding {d_model: 2, max_sequence_length: 10, theta: 10000}"
410 );
411 }
412}