use burn::prelude::*;
#[derive(Debug)]
pub struct RotaryEmbedding<B: Backend> {
pub cos: Tensor<B, 2>,
pub sin: Tensor<B, 2>,
pub attention_scaling: f64,
}
impl<B: Backend> RotaryEmbedding<B> {
pub fn new_yarn(
head_dim: usize,
max_seq_len: usize,
rope_theta: f64,
factor: f64,
beta_fast: f64,
beta_slow: f64,
original_max_pos: usize,
truncate: bool,
device: &B::Device,
) -> Self {
let dim = head_dim;
let half_dim = dim / 2;
let attention_scaling = if factor <= 1.0 {
1.0
} else {
0.1 * factor.ln() + 1.0
};
let mut inv_freq_extrapolation = vec![0f32; half_dim];
let mut inv_freq_interpolation = vec![0f32; half_dim];
for i in 0..half_dim {
let freq = rope_theta.powf(2.0 * i as f64 / dim as f64);
inv_freq_extrapolation[i] = 1.0 / freq as f32;
inv_freq_interpolation[i] = 1.0 / (factor * freq) as f32;
}
let find_correction_dim = |num_rotations: f64| -> f64 {
(dim as f64
* (original_max_pos as f64 / (num_rotations * 2.0 * std::f64::consts::PI)).ln())
/ (2.0 * rope_theta.ln())
};
let low_raw = find_correction_dim(beta_fast);
let high_raw = find_correction_dim(beta_slow);
let (low, high) = if truncate {
(low_raw.floor(), high_raw.ceil())
} else {
(low_raw, high_raw)
};
let low = low.max(0.0);
let high = high.min((dim - 1) as f64);
let max_val = if (high - low).abs() < 1e-9 {
high + 0.001
} else {
high
};
let mut ramp = vec![0f32; half_dim];
for i in 0..half_dim {
let linear = (i as f64 - low) / (max_val - low);
ramp[i] = linear.clamp(0.0, 1.0) as f32;
}
let mut inv_freq = vec![0f32; half_dim];
for i in 0..half_dim {
inv_freq[i] = inv_freq_interpolation[i] * ramp[i]
+ inv_freq_extrapolation[i] * (1.0 - ramp[i]);
}
let mut cos_data = vec![0f32; max_seq_len * half_dim];
let mut sin_data = vec![0f32; max_seq_len * half_dim];
let scale = attention_scaling as f32;
for pos in 0..max_seq_len {
for i in 0..half_dim {
let angle = pos as f32 * inv_freq[i];
cos_data[pos * half_dim + i] = angle.cos() * scale;
sin_data[pos * half_dim + i] = angle.sin() * scale;
}
}
let cos = Tensor::<B, 2>::from_data(
TensorData::new(cos_data, [max_seq_len, half_dim]),
device,
);
let sin = Tensor::<B, 2>::from_data(
TensorData::new(sin_data, [max_seq_len, half_dim]),
device,
);
Self {
cos,
sin,
attention_scaling,
}
}
pub fn get(&self, seq_len: usize) -> (Tensor<B, 3>, Tensor<B, 3>) {
let cos = self.cos.clone().slice([0..seq_len]).unsqueeze_dim::<3>(0);
let sin = self.sin.clone().slice([0..seq_len]).unsqueeze_dim::<3>(0);
(cos, sin)
}
}
pub fn apply_rotary_emb<B: Backend>(
q: Tensor<B, 4>,
k: Tensor<B, 4>,
cos: &Tensor<B, 3>,
sin: &Tensor<B, 3>,
) -> (Tensor<B, 4>, Tensor<B, 4>) {
let q_rot = apply_rotary_emb_single(q, cos, sin);
let k_rot = apply_rotary_emb_single(k, cos, sin);
(q_rot, k_rot)
}
fn apply_rotary_emb_single<B: Backend>(
x: Tensor<B, 4>,
cos: &Tensor<B, 3>,
sin: &Tensor<B, 3>,
) -> Tensor<B, 4> {
let [batch, heads, seq_len, head_dim] = x.dims();
let half_dim = head_dim / 2;
let x_pairs = x.reshape([batch, heads, seq_len, half_dim, 2]);
let first_half = x_pairs.clone().slice([0..batch, 0..heads, 0..seq_len, 0..half_dim, 0..1])
.reshape([batch, heads, seq_len, half_dim]);
let second_half = x_pairs.slice([0..batch, 0..heads, 0..seq_len, 0..half_dim, 1..2])
.reshape([batch, heads, seq_len, half_dim]);
let cos = cos.clone().unsqueeze_dim::<4>(1);
let sin = sin.clone().unsqueeze_dim::<4>(1);
let rotated_first = first_half.clone() * cos.clone() - second_half.clone() * sin.clone();
let rotated_second = second_half * cos + first_half * sin;
let rf = rotated_first.unsqueeze_dim::<5>(4);
let rs = rotated_second.unsqueeze_dim::<5>(4);
let interleaved = Tensor::cat(vec![rf, rs], 4);
interleaved.reshape([batch, heads, seq_len, head_dim])
}