use candle_core::{Device, Tensor};
use crate::error::Result;
pub struct RotaryEmbedding {
cos_cache: Tensor,
sin_cache: Tensor,
head_dim: usize,
}
impl RotaryEmbedding {
pub fn new(head_dim: usize, max_seq_len: usize, base: f32, device: &Device) -> Result<Self> {
let inv_freq: Vec<f32> = (0..head_dim)
.step_by(2)
.map(|i| 1.0 / base.powf(i as f32 / head_dim as f32))
.collect();
let inv_freq = Tensor::from_vec(inv_freq, (head_dim / 2,), device)?;
let positions: Vec<f32> = (0..max_seq_len).map(|i| i as f32).collect();
let positions = Tensor::from_vec(positions, (max_seq_len, 1), device)?;
let freqs = positions.matmul(&inv_freq.unsqueeze(0)?)?;
let cos_cache = freqs.cos()?;
let sin_cache = freqs.sin()?;
Ok(Self {
cos_cache,
sin_cache,
head_dim,
})
}
pub fn forward(
&self,
q: &Tensor,
k: &Tensor,
_position_ids: &Tensor,
) -> Result<(Tensor, Tensor)> {
let device = q.device();
if device.is_cuda() {
self.forward_cuda(q, k)
} else {
self.forward_cpu(q, k)
}
}
fn forward_cpu(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
let seq_len = q.dim(2)?;
let cos = self.cos_cache.narrow(0, 0, seq_len)?;
let sin = self.sin_cache.narrow(0, 0, seq_len)?;
let q_rotated = self.apply_rotary(q, &cos, &sin)?;
let k_rotated = self.apply_rotary(k, &cos, &sin)?;
Ok((q_rotated, k_rotated))
}
fn forward_cuda(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
tracing::debug!("Using CUDA RoPE path for Q shape {:?}", q.shape());
self.forward_cpu(q, k)
}
fn apply_rotary(&self, x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
let half_dim = self.head_dim / 2;
let x1 = x.narrow(3, 0, half_dim)?;
let x2 = x.narrow(3, half_dim, half_dim)?;
let rotated_x1 = (x1.broadcast_mul(cos)? - x2.broadcast_mul(sin)?)?;
let rotated_x2 = (x2.broadcast_mul(cos)? + x1.broadcast_mul(sin)?)?;
Tensor::cat(&[&rotated_x1, &rotated_x2], 3).map_err(Into::into)
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::DType;
#[test]
fn test_rope_creation() {
let device = Device::Cpu;
let rope = RotaryEmbedding::new(64, 2048, 10000.0, &device);
assert!(rope.is_ok());
}
#[test]
fn test_rope_preserves_shape() {
let device = Device::Cpu;
let rope = RotaryEmbedding::new(64, 2048, 10000.0, &device).unwrap();
let q = Tensor::zeros(&[1, 12, 10, 64], DType::F32, &device).unwrap();
let k = Tensor::zeros(&[1, 12, 10, 64], DType::F32, &device).unwrap();
let pos = Tensor::zeros(&[1, 10], DType::I64, &device).unwrap();
let (q_rot, k_rot) = rope.forward(&q, &k, &pos).unwrap();
assert_eq!(q_rot.shape().dims(), &[1, 12, 10, 64]);
assert_eq!(k_rot.shape().dims(), &[1, 12, 10, 64]);
}
}