unsloth_rs/kernels/
rope.rs1use candle_core::{Device, Tensor};
23
24use crate::error::Result;
25
26pub struct RotaryEmbedding {
30 cos_cache: Tensor,
32 sin_cache: Tensor,
34 head_dim: usize,
36}
37
38impl RotaryEmbedding {
39 pub fn new(head_dim: usize, max_seq_len: usize, base: f32, device: &Device) -> Result<Self> {
47 let inv_freq: Vec<f32> = (0..head_dim)
49 .step_by(2)
50 .map(|i| 1.0 / base.powf(i as f32 / head_dim as f32))
51 .collect();
52
53 let inv_freq = Tensor::from_vec(inv_freq, (head_dim / 2,), device)?;
54
55 let positions: Vec<f32> = (0..max_seq_len).map(|i| i as f32).collect();
57 let positions = Tensor::from_vec(positions, (max_seq_len, 1), device)?;
58
59 let freqs = positions.matmul(&inv_freq.unsqueeze(0)?)?;
61
62 let cos_cache = freqs.cos()?;
64 let sin_cache = freqs.sin()?;
65
66 Ok(Self {
67 cos_cache,
68 sin_cache,
69 head_dim,
70 })
71 }
72
73 pub fn forward(
83 &self,
84 q: &Tensor,
85 k: &Tensor,
86 _position_ids: &Tensor,
87 ) -> Result<(Tensor, Tensor)> {
88 let device = q.device();
89
90 if device.is_cuda() {
91 self.forward_cuda(q, k)
92 } else {
93 self.forward_cpu(q, k)
94 }
95 }
96
97 fn forward_cpu(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
99 let seq_len = q.dim(2)?;
100
101 let cos = self.cos_cache.narrow(0, 0, seq_len)?;
103 let sin = self.sin_cache.narrow(0, 0, seq_len)?;
104
105 let q_rotated = self.apply_rotary(q, &cos, &sin)?;
106 let k_rotated = self.apply_rotary(k, &cos, &sin)?;
107
108 Ok((q_rotated, k_rotated))
109 }
110
111 fn forward_cuda(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
116 tracing::debug!("Using CUDA RoPE path for Q shape {:?}", q.shape());
117 self.forward_cpu(q, k)
118 }
119
120 fn apply_rotary(&self, x: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
121 let half_dim = self.head_dim / 2;
122
123 let x1 = x.narrow(3, 0, half_dim)?;
125 let x2 = x.narrow(3, half_dim, half_dim)?;
126
127 let rotated_x1 = (x1.broadcast_mul(cos)? - x2.broadcast_mul(sin)?)?;
129 let rotated_x2 = (x2.broadcast_mul(cos)? + x1.broadcast_mul(sin)?)?;
130
131 Tensor::cat(&[&rotated_x1, &rotated_x2], 3).map_err(Into::into)
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use candle_core::DType;
140
141 #[test]
142 fn test_rope_creation() {
143 let device = Device::Cpu;
144 let rope = RotaryEmbedding::new(64, 2048, 10000.0, &device);
145 assert!(rope.is_ok());
146 }
147
148 #[test]
149 fn test_rope_preserves_shape() {
150 let device = Device::Cpu;
151 let rope = RotaryEmbedding::new(64, 2048, 10000.0, &device).unwrap();
152
153 let q = Tensor::zeros(&[1, 12, 10, 64], DType::F32, &device).unwrap();
154 let k = Tensor::zeros(&[1, 12, 10, 64], DType::F32, &device).unwrap();
155 let pos = Tensor::zeros(&[1, 10], DType::I64, &device).unwrap();
156
157 let (q_rot, k_rot) = rope.forward(&q, &k, &pos).unwrap();
158
159 assert_eq!(q_rot.shape().dims(), &[1, 12, 10, 64]);
160 assert_eq!(k_rot.shape().dims(), &[1, 12, 10, 64]);
161 }
162}