use std::ops::Range;
use burn::{
Tensor,
config::Config,
module::Module,
prelude::{
Backend,
s,
},
tensor::{
DType,
linalg,
},
};
pub trait RotaryEmbeddingMeta {
fn seq_len(&self) -> usize;
fn head_dim(&self) -> usize;
}
#[derive(Config, Debug)]
pub struct RotaryEmbeddingConfig {
pub seq_len: usize,
pub head_dim: usize,
#[config(default = 10000)]
pub base: usize,
}
impl RotaryEmbeddingMeta for RotaryEmbeddingConfig {
fn seq_len(&self) -> usize {
self.seq_len
}
fn head_dim(&self) -> usize {
self.head_dim
}
}
impl RotaryEmbeddingConfig {
pub fn init<B: Backend>(
self,
device: &B::Device,
) -> RotaryEmbedding<B> {
assert!(
self.head_dim.is_multiple_of(2),
"Head dimension must be even: {}",
self.head_dim
);
let freq_matrix =
positional_frequency_table(self.seq_len, self.base, self.head_dim, device);
let cos = freq_matrix
.clone()
.cos()
.set_require_grad(false)
.unsqueeze_dim::<3>(1)
.unsqueeze_dim(0);
let sin = freq_matrix
.sin()
.set_require_grad(false)
.unsqueeze_dim::<3>(1)
.unsqueeze_dim(0);
RotaryEmbedding {
head_dim: self.head_dim,
cos,
sin,
}
}
}
#[derive(Module, Debug)]
pub struct RotaryEmbedding<B: Backend> {
pub head_dim: usize,
pub cos: Tensor<B, 4>,
pub sin: Tensor<B, 4>,
}
impl<B: Backend> RotaryEmbeddingMeta for RotaryEmbedding<B> {
fn seq_len(&self) -> usize {
self.cos.dims()[1]
}
fn head_dim(&self) -> usize {
self.head_dim
}
}
impl<B: Backend> RotaryEmbedding<B> {
pub fn cast(
self,
dtype: DType,
) -> Self {
Self {
cos: self.cos.cast(dtype),
sin: self.sin.cast(dtype),
..self
}
}
pub fn clip_range(
&self,
range: Range<usize>,
) -> Self {
Self {
head_dim: self.head_dim,
cos: self.cos.clone().slice_dim(1, range.clone()),
sin: self.sin.clone().slice_dim(1, range),
}
}
pub fn apply(
&self,
input: Tensor<B, 4>,
) -> Tensor<B, 4> {
#[cfg(debug_assertions)]
let [b, h] = bimm_contracts::unpack_shape_contract!(
["B", "T", "H", "D"],
&input.dims(),
&["B", "H"],
&[("T", self.seq_len()), ("D", self.head_dim())]
);
let pivot = self.head_dim() / 2;
let x1 = input.clone().slice_dim(3, s![..pivot]);
let x2 = input.clone().slice_dim(3, s![pivot..]);
let y1 = x1.clone() * self.cos.clone() + x2.clone() * self.sin.clone();
let y2 = x1 * (-self.sin.clone()) + x2 * self.cos.clone();
let output = Tensor::cat(vec![y1, y2], 3);
#[cfg(debug_assertions)]
bimm_contracts::assert_shape_contract_periodically!(
["B", "T", "H", "D"],
&output.dims(),
&[
("B", b),
("T", self.seq_len()),
("H", h),
("D", self.head_dim())
]
);
output
}
}
pub fn inverse_frequency_table<B: Backend>(
base: usize,
head_dim: usize,
device: &B::Device,
) -> Tensor<B, 1> {
Tensor::from_data([base as f32], device).powf(
-Tensor::arange_step(0..head_dim as i64, 2, device)
.float()
.div_scalar(head_dim as f32),
)
}
pub fn positional_frequency_table<B: Backend>(
seq_len: usize,
base: usize,
head_dim: usize,
device: &B::Device,
) -> Tensor<B, 2> {
let inv_freq = inverse_frequency_table::<B>(base, head_dim, device);
let t: Tensor<B, 1> = Tensor::arange(0..seq_len as i64, device).float();
linalg::outer::<_, 1, 2, _>(t, inv_freq)
}
#[cfg(test)]
mod tests {
use bimm_contracts::assert_shape_contract;
use burn::{
backend::Wgpu,
tensor::{
Distribution,
Tolerance,
},
};
use super::*;
#[test]
fn test_inverse_frequency_table() {
type B = Wgpu;
let device = Default::default();
let base = 10000;
let head_dim = 4;
let base_f = base as f32;
let head_dim_f = head_dim as f32;
inverse_frequency_table::<B>(base, head_dim, &device)
.to_data()
.assert_approx_eq(
&Tensor::<B, 1>::from_data(
[
1.0 / base_f.powf(0.0 / head_dim_f),
1.0 / base_f.powf(2.0 / head_dim_f),
],
&device,
)
.to_data(),
Tolerance::<f32>::default(),
);
}
#[test]
fn test_frequency_matrix() {
type B = Wgpu;
let device = Default::default();
let base = 10000;
let head_dim = 4;
let base_f = base as f32;
let head_dim_f = head_dim as f32;
positional_frequency_table::<B>(3, base, head_dim, &device)
.to_data()
.assert_approx_eq(
&Tensor::<B, 2>::from_data(
[
[0.0, 0.0],
[
1.0 / base_f.powf(0.0 / head_dim_f),
1.0 / base_f.powf(2.0 / head_dim_f),
],
[
2.0 / base_f.powf(0.0 / head_dim_f),
2.0 / base_f.powf(2.0 / head_dim_f),
],
],
&device,
)
.to_data(),
Tolerance::<f32>::default(),
);
}
#[test]
fn test_clip_range() {
type B = Wgpu;
let device = Default::default();
let config = RotaryEmbeddingConfig::new(1024, 64);
let re: RotaryEmbedding<B> = config.init(&device);
assert_eq!(re.seq_len(), 1024);
assert_eq!(re.head_dim(), 64);
let clip_re = re.clip_range(10..20);
assert_eq!(clip_re.seq_len(), 10);
clip_re
.sin
.clone()
.to_data()
.assert_eq(&re.sin.clone().slice_dim(1, 10..20).to_data(), true);
clip_re
.cos
.clone()
.to_data()
.assert_eq(&re.cos.clone().slice_dim(1, 10..20).to_data(), true);
}
#[test]
fn test_rotary_embedding() {
type B = Wgpu;
let device = Default::default();
let batch = 1;
let heads = 2;
let seq_len = 1024;
let head_dim = 64;
let config = RotaryEmbeddingConfig::new(seq_len, head_dim);
assert_eq!(config.seq_len(), seq_len);
assert_eq!(config.head_dim(), head_dim);
assert_eq!(config.base, 10000);
let re: RotaryEmbedding<B> = config.init(&device);
assert_eq!(re.seq_len(), seq_len);
assert_eq!(re.head_dim(), head_dim);
let input: Tensor<B, 4> = Tensor::random(
[batch, seq_len, heads, head_dim],
Distribution::Default,
&device,
);
let output = re.apply(input.clone());
assert_shape_contract!(
["B", "T", "H", "D"],
&output.dims(),
&[("B", batch), ("T", seq_len), ("H", heads), ("D", head_dim)]
);
let x1 = input.clone().slice_dim(3, s![..head_dim / 2]);
let x2 = input.clone().slice_dim(3, s![head_dim / 2..]);
let y1 = x1.clone() * re.cos.clone() + x2.clone() * re.sin.clone();
let y2 = x1 * (-re.sin.clone()) + x2 * re.cos.clone();
let expected = Tensor::cat(vec![y1, y2], 3);
expected
.to_data()
.assert_approx_eq(&output.to_data(), Tolerance::<f32>::default());
}
}