use candle_core::{DType, Device, Result, Tensor};
#[derive(Debug, Clone)]
pub struct SimpleRotaryEmbedding {
cos: Tensor,
sin: Tensor,
}
impl SimpleRotaryEmbedding {
pub fn new(
head_dim: usize,
max_seq_len: usize,
rope_theta: f64,
device: &Device,
dtype: DType,
) -> Result<Self> {
let half_dim = head_dim / 2;
let inv_freq: Vec<f32> = (0..half_dim)
.map(|i| 1.0 / (rope_theta as f32).powf(i as f32 / half_dim as f32))
.collect();
let inv_freq = Tensor::new(inv_freq.as_slice(), device)?.to_dtype(dtype)?;
let positions: Vec<f32> = (0..max_seq_len).map(|i| i as f32).collect();
let positions = Tensor::new(positions.as_slice(), device)?.to_dtype(dtype)?;
let freqs = positions
.unsqueeze(1)?
.broadcast_mul(&inv_freq.unsqueeze(0)?)?;
let cos_half = freqs.cos()?;
let sin_half = freqs.sin()?;
let cos = Tensor::cat(&[&cos_half, &cos_half], 1)?;
let sin = Tensor::cat(&[&sin_half, &sin_half], 1)?;
Ok(Self { cos, sin })
}
pub fn get(&self, seq_len: usize, offset: usize) -> Result<(Tensor, Tensor)> {
let head_dim = self.cos.dim(1)?;
let cos = self
.cos
.narrow(0, offset, seq_len)?
.narrow(1, 0, head_dim)?;
let sin = self
.sin
.narrow(0, offset, seq_len)?
.narrow(1, 0, head_dim)?;
Ok((cos, sin))
}
pub fn forward(&self, seq_len: usize) -> Result<(Tensor, Tensor)> {
let (cos, sin) = self.get(seq_len, 0)?;
Ok((cos.unsqueeze(0)?, sin.unsqueeze(0)?))
}
}