use candle_core::{DType, Device, Tensor};
use crate::error::Result;
pub struct RopeCache {
cos: Tensor,
sin: Tensor,
}
impl RopeCache {
pub fn new(
head_dim: usize,
max_position: usize,
theta: f64,
device: &Device,
dtype: DType,
) -> Result<Self> {
let half_dim = head_dim / 2;
let inv_freq: Vec<f32> = (0..half_dim)
.map(|i| {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let freq = 1.0 / theta.powf(2.0 * i as f64 / head_dim as f64);
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
let freq_f32 = freq as f32;
freq_f32
})
.collect();
let inv_freq_tensor = Tensor::from_vec(inv_freq, (1, half_dim), device)?.to_dtype(dtype)?;
#[allow(clippy::cast_possible_truncation, clippy::as_conversions)]
let pos_tensor = Tensor::arange(0u32, max_position as u32, device)?
.to_dtype(dtype)?
.reshape((max_position, 1))?;
let freqs = pos_tensor.matmul(&inv_freq_tensor)?;
let cos = freqs.cos()?;
let sin = freqs.sin()?;
Ok(Self { cos, sin })
}
pub fn apply(&self, x: &Tensor, start_pos: usize) -> Result<Tensor> {
let (_, _, seq_len, _) = x.dims4()?;
let cos = self.cos.narrow(0, start_pos, seq_len)?;
let sin = self.sin.narrow(0, start_pos, seq_len)?;
Ok(candle_nn::rotary_emb::rope(&x.contiguous()?, &cos, &sin)?)
}
}