use burn::prelude::*;
pub struct RotaryEmbedding<B: Backend> {
pub freqs_cis: Tensor<B, 4>,
pub max_seqlen: usize,
pub rope_dim: usize,
pub head_dim: usize,
}
impl<B: Backend> RotaryEmbedding<B> {
pub fn new(
head_dim: usize,
rope_dim: usize,
max_seqlen: usize,
theta: f64,
device: &B::Device,
) -> Self {
assert_eq!(head_dim % rope_dim, 0);
let dim_per_rope = head_dim / rope_dim; let half = dim_per_rope / 2;
let mut table = vec![0f32; max_seqlen * half * 4];
for pos in 0..max_seqlen {
for h in 0..half {
let freq = 1.0 / theta.powf((2 * h) as f64 / dim_per_rope as f64) as f32;
let angle = pos as f32 * freq;
let (s, c) = angle.sin_cos();
let base = (pos * half + h) * 4;
table[base] = c; table[base + 1] = -s; table[base + 2] = s; table[base + 3] = c; }
}
let freqs_cis = Tensor::<B, 1>::from_data(
TensorData::new(table, vec![max_seqlen * half * 4]),
device,
)
.reshape([max_seqlen, half, 2, 2]);
Self { freqs_cis, max_seqlen, rope_dim, head_dim }
}
fn gather_axis(&self, tok_idx_1d: Tensor<B, 1, Int>) -> Tensor<B, 4> {
self.freqs_cis.clone().select(0, tok_idx_1d)
}
pub fn build_freqs_4d(&self, tok_idx: Tensor<B, 2, Int>) -> Tensor<B, 4> {
let s = tok_idx.dims()[0];
let _half = self.freqs_cis.dims()[1];
let parts: Vec<Tensor<B, 4>> = (0..self.rope_dim)
.map(|axis| {
let col = tok_idx
.clone()
.narrow(1, axis, 1) .reshape([s]); self.gather_axis(col) })
.collect();
Tensor::cat(parts, 1)
}
}
pub fn apply_rope<B: Backend>(
xq: Tensor<B, 4>,
xk: Tensor<B, 4>,
freqs: Tensor<B, 4>,
) -> (Tensor<B, 4>, Tensor<B, 4>) {
let [b, s, h, d] = xq.dims();
let half = d / 2;
let cos = freqs
.clone()
.narrow(2, 0, 1) .narrow(3, 0, 1) .reshape([1, s, 1, half]);
let sin = freqs
.narrow(2, 1, 1) .narrow(3, 0, 1) .reshape([1, s, 1, half]);
(
rotate_half(xq, cos.clone(), sin.clone(), b, s, h, half),
rotate_half(xk, cos, sin, b, s, h, half),
)
}
fn rotate_half<B: Backend>(
x: Tensor<B, 4>, cos: Tensor<B, 4>, sin: Tensor<B, 4>, b: usize,
s: usize,
h: usize,
half: usize,
) -> Tensor<B, 4> {
let pairs = x.reshape([b, s, h, half, 2]);
let even = pairs.clone().narrow(4, 0, 1).reshape([b, s, h, half]);
let odd = pairs.narrow(4, 1, 1).reshape([b, s, h, half]);
let out_even = even.clone() * cos.clone() - odd.clone() * sin.clone();
let out_odd = even * sin + odd * cos;
Tensor::stack::<5>(vec![out_even, out_odd], 4)
.reshape([b, s, h, half * 2])
}