use candle_core::{Result, Tensor, Device, DType, D};
fn rotate_half(x: &Tensor) -> Result<Tensor> {
let last_dim = x.dims().len() - 1;
let dim_size = x.dim(last_dim)?;
let half = dim_size / 2;
let x1 = x.narrow(last_dim, 0, half)?;
let x2 = x.narrow(last_dim, half, half)?;
Tensor::cat(&[&x2.neg()?, &x1], last_dim)
}
pub fn apply_rotary_pos_emb(
q: &Tensor,
k: &Tensor,
cos: &Tensor,
sin: &Tensor,
) -> Result<(Tensor, Tensor)> {
let orig_dtype = q.dtype();
let q = if q.dtype() != cos.dtype() {
q.to_dtype(cos.dtype())?
} else {
q.clone()
};
let k = if k.dtype() != cos.dtype() {
k.to_dtype(cos.dtype())?
} else {
k.clone()
};
let cos = cos.unsqueeze(0)?.unsqueeze(2)?;
let sin = sin.unsqueeze(0)?.unsqueeze(2)?;
let q_rotated = rotate_half(&q)?;
let q_embed = q.broadcast_mul(&cos)?.add(&q_rotated.broadcast_mul(&sin)?)?;
let k_rotated = rotate_half(&k)?;
let k_embed = k.broadcast_mul(&cos)?.add(&k_rotated.broadcast_mul(&sin)?)?;
let q_embed = if q_embed.dtype() != orig_dtype {
q_embed.to_dtype(orig_dtype)?
} else {
q_embed
};
let k_embed = if k_embed.dtype() != orig_dtype {
k_embed.to_dtype(orig_dtype)?
} else {
k_embed
};
Ok((q_embed, k_embed))
}
pub struct RotaryEmbedding {
cos_cached: Tensor,
sin_cached: Tensor,
}
impl RotaryEmbedding {
pub fn new(
dim: usize,
max_position_embeddings: usize,
base: f32,
device: &Device,
) -> Result<Self> {
let inv_freq: Vec<f32> = (0..dim)
.step_by(2)
.map(|i| {
let exponent = i as f32 / dim as f32;
1.0 / base.powf(exponent)
})
.collect();
let inv_freq = Tensor::new(inv_freq.as_slice(), device)?;
let t: Vec<f32> = (0..max_position_embeddings).map(|i| i as f32).collect();
let t = Tensor::new(t.as_slice(), device)?;
let freqs = t.unsqueeze(1)?.broadcast_mul(&inv_freq.unsqueeze(0)?)?;
let emb = Tensor::cat(&[&freqs, &freqs], 1)?;
let cos_cached = emb.cos()?;
let sin_cached = emb.sin()?;
Ok(Self {
cos_cached,
sin_cached,
})
}
pub fn forward(&self) -> Result<(Tensor, Tensor)> {
Ok((self.cos_cached.clone(), self.sin_cached.clone()))
}
pub fn forward_with_len(&self, seq_len: usize) -> Result<(Tensor, Tensor)> {
let cos = self.cos_cached.narrow(0, 0, seq_len)?;
let sin = self.sin_cached.narrow(0, 0, seq_len)?;
Ok((cos, sin))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rotate_half() -> Result<()> {
let device = Device::Cpu;
let x = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)?.reshape((1, 4))?;
let rotated = rotate_half(&x)?;
let expected = Tensor::new(&[-3.0f32, -4.0, 1.0, 2.0], &device)?.reshape((1, 4))?;
let diff = rotated.sub(&expected)?.abs()?.sum_all()?.to_scalar::<f32>()?;
assert!(diff < 1e-6, "rotate_half failed");
Ok(())
}
#[test]
fn test_rotary_embedding_shape() -> Result<()> {
let device = Device::Cpu;
let rope = RotaryEmbedding::new(64, 512, 10000.0, &device)?;
let (cos, sin) = rope.forward()?;
assert_eq!(cos.dims(), &[512, 64]);
assert_eq!(sin.dims(), &[512, 64]);
Ok(())
}
#[test]
fn test_rotary_embedding_with_len() -> Result<()> {
let device = Device::Cpu;
let rope = RotaryEmbedding::new(64, 512, 10000.0, &device)?;
let (cos, sin) = rope.forward_with_len(128)?;
assert_eq!(cos.dims(), &[128, 64]);
assert_eq!(sin.dims(), &[128, 64]);
Ok(())
}
#[test]
fn test_apply_rotary_pos_emb_shape() -> Result<()> {
let device = Device::Cpu;
let q = Tensor::randn(0f32, 1.0, (2, 16, 8, 64), &device)?;
let k = Tensor::randn(0f32, 1.0, (2, 16, 8, 64), &device)?;
let rope = RotaryEmbedding::new(64, 512, 10000.0, &device)?;
let (cos, sin) = rope.forward_with_len(16)?;
let (q_embed, k_embed) = apply_rotary_pos_emb(&q, &k, &cos, &sin)?;
assert_eq!(q_embed.dims(), q.dims());
assert_eq!(k_embed.dims(), k.dims());
Ok(())
}
}