oxidized_transformers/layers/embeddings/
rotary_embeddings.rsuse std::sync::RwLock;
use candle_core::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use snafu::{ensure, ResultExt, Snafu};
pub struct RotaryEmbeddingsConfig {
width: usize,
seq_len: usize,
base: usize,
}
impl RotaryEmbeddingsConfig {
pub fn build(&self, vb: VarBuilder) -> Result<RotaryEmbeddings, RotaryEmbeddingsError> {
ensure!(self.width % 2 == 0, WidthNotEvenSnafu { width: self.width });
let theta: Vec<_> = (0..self.width)
.step_by(2)
.map(|i| (self.base as f32).powf(-(i as f32 / self.width as f32)))
.collect();
let theta =
Tensor::from_vec(theta, (self.width / 2,), vb.device()).context(ThetaTensorSnafu)?;
let (cos, sin) =
RotaryEmbeddings::create_rotary_embed(&theta, self.width, self.seq_len, vb.dtype())
.context(CacheSnafu)?;
Ok(RotaryEmbeddings {
cache: RwLock::new(RotaryEmbeddingsCache { cos, sin }),
theta,
})
}
pub fn width(mut self, width: usize) -> Self {
self.width = width;
self
}
pub fn seq_len(mut self, seq_len: usize) -> Self {
self.seq_len = seq_len;
self
}
pub fn base(mut self, base: usize) -> Self {
self.base = base;
self
}
}
impl Default for RotaryEmbeddingsConfig {
fn default() -> Self {
Self {
width: 96,
seq_len: 2048,
base: 10_000,
}
}
}
#[derive(Debug, Snafu)]
pub enum RotaryEmbeddingsError {
#[snafu(display("Cannot apply rotary embeddings to input"))]
ApplyEmbeddings { source: candle_core::Error },
#[snafu(display("Cannot get cache length"))]
CacheLength { source: candle_core::Error },
#[snafu(display("Cannot create rotary embeddings cache"))]
Cache { source: candle_core::Error },
#[snafu(display("Invalid input rank, expected {expected}, got {got}"))]
InvalidRank {
expected: usize,
got: usize,
source: candle_core::Error,
},
#[snafu(display("Cannot rotate input tensor"))]
Rotate { source: candle_core::Error },
#[snafu(display("Cannot slice rotary embeddings cache"))]
SliceCache { source: candle_core::Error },
#[snafu(display("Cannot convert theta to candle tensor"))]
ThetaTensor { source: candle_core::Error },
#[snafu(display("Rotary width must be even, was {width}"))]
WidthNotEven { width: usize },
}
#[derive(Debug)]
struct RotaryEmbeddingsCache {
cos: Tensor,
sin: Tensor,
}
impl RotaryEmbeddingsCache {
fn seq_len(&self) -> Result<usize, RotaryEmbeddingsError> {
let (seq_len, _) = self.cos.shape().dims2().context(CacheLengthSnafu)?;
Ok(seq_len)
}
}
#[derive(Debug)]
pub struct RotaryEmbeddings {
cache: RwLock<RotaryEmbeddingsCache>,
theta: Tensor,
}
impl RotaryEmbeddings {
fn create_rotary_embed(
theta: &Tensor,
width: usize,
length: usize,
dtype: DType,
) -> Result<(Tensor, Tensor), candle_core::Error> {
let device = theta.device();
let position = Tensor::arange(0.0, length as f32, device)?.unsqueeze(1)?;
let m_theta = position.broadcast_mul(&theta.unsqueeze(0)?)?;
let m_theta = Tensor::cat(&[&m_theta, &m_theta], 1)?;
let re_cos = m_theta.cos()?.reshape(&[length, width])?.to_dtype(dtype)?;
let re_sin = m_theta.sin()?.reshape(&[length, width])?.to_dtype(dtype)?;
Ok((re_cos, re_sin))
}
fn resize_rotary_embed(
&self,
width: usize,
len: usize,
dtype: DType,
) -> Result<(), RotaryEmbeddingsError> {
let (re_cos, re_sin) =
Self::create_rotary_embed(&self.theta, width, len, dtype).context(CacheSnafu)?;
let mut cache = self.cache.write().unwrap();
cache.cos = re_cos;
cache.sin = re_sin;
Ok(())
}
fn rotate(input: &Tensor) -> Result<Tensor, RotaryEmbeddingsError> {
let (_batch_size, _n_heads, _seq_len, n_dims) =
input.shape().dims4().context(RotateSnafu)?;
let half_idx = n_dims / 2;
let input_1 = input
.i((.., .., .., half_idx..))
.and_then(|xs| xs.neg())
.context(RotateSnafu)?;
let input_2 = input.i((.., .., .., ..half_idx)).context(RotateSnafu)?;
Tensor::cat(&[&input_1, &input_2], 3).context(RotateSnafu)
}
pub fn forward(
&self,
input: &Tensor,
positions: Option<&Tensor>,
) -> Result<Tensor, RotaryEmbeddingsError> {
let (batch_size, _, seq_len, width) = input.shape().dims4().context(InvalidRankSnafu {
expected: 4usize,
got: input.rank(),
})?;
let dtype = self.cache.read().unwrap().cos.dtype();
let (rot_cos, rot_sin) = match positions {
None => {
if self.cache.read().unwrap().seq_len()? < seq_len {
self.resize_rotary_embed(width, seq_len, dtype)?;
}
let cache = self.cache.read().unwrap();
let rot_cos = cache
.cos
.i((..seq_len, ..))
.and_then(|xs| xs.reshape((1, 1, seq_len, width)))
.context(SliceCacheSnafu)?;
let rot_sin = cache
.sin
.i((..seq_len, ..))
.and_then(|xs| xs.reshape((1, 1, seq_len, width)))
.context(SliceCacheSnafu)?;
(rot_cos, rot_sin)
}
Some(positions) => {
let positions_flat = positions.flatten_all().context(SliceCacheSnafu)?;
let max_len = positions_flat
.max(0)
.and_then(|xs| xs.to_scalar::<u32>())
.context(SliceCacheSnafu)? as usize
+ 1;
if self.cache.read().unwrap().seq_len()? < max_len {
self.resize_rotary_embed(width, max_len, dtype)?;
}
let cache = self.cache.read().unwrap();
let rot_cos = cache
.cos
.index_select(&positions_flat, 0)
.and_then(|xs| xs.reshape((batch_size, 1, seq_len, width)))
.context(SliceCacheSnafu)?;
let rot_sin = cache
.sin
.index_select(&positions_flat, 0)
.and_then(|xs| xs.reshape((batch_size, 1, seq_len, width)))
.context(SliceCacheSnafu)?;
(rot_cos, rot_sin)
}
};
let input_rot_cos = input
.broadcast_mul(&rot_cos)
.context(ApplyEmbeddingsSnafu)?;
let input_rot_sin = Self::rotate(input)?
.broadcast_mul(&rot_sin)
.context(ApplyEmbeddingsSnafu)?;
(input_rot_cos + input_rot_sin).context(ApplyEmbeddingsSnafu)
}
}