Skip to main content

god_graph/transformer/layers/
embedding.rs

1//! Rotary Positional Embedding (RoPE) implementation
2
3use crate::tensor::DenseTensor;
4use crate::tensor::traits::TensorBase;
5
6/// Rotary Positional Embedding (RoPE)
7///
8/// RoPE encodes positional information by rotating the query and key vectors
9/// based on their positions. This is used in LLaMA, Mistral, and other modern LLMs.
10///
11/// Formula: RoPE(x, pos) = x * cos(pos * theta) + rotate_half(x) * sin(pos * theta)
12///
13/// where theta = base^(-2i/dim) for i = 0, 1, ..., dim/2-1
14#[derive(Debug, Clone)]
15pub struct RoPE {
16    /// Pre-computed cosine cache [max_seq_len, dim/2]
17    pub cos_cache: DenseTensor,
18    /// Pre-computed sine cache [max_seq_len, dim/2]
19    pub sin_cache: DenseTensor,
20    /// Rotation dimension (typically head_dim)
21    pub dim: usize,
22    /// Maximum sequence length
23    pub max_seq_len: usize,
24    /// Base frequency (LLaMA uses 10000)
25    pub base: f64,
26}
27
28impl RoPE {
29    /// Create a new RoPE module
30    ///
31    /// # Arguments
32    /// * `dim` - Rotation dimension (typically head_dim)
33    /// * `max_seq_len` - Maximum sequence length to support
34    /// * `base` - Base frequency for theta calculation (default: 10000)
35    pub fn new(dim: usize, max_seq_len: usize, base: f64) -> Self {
36        // Pre-compute theta frequencies: theta = base^(-2i/dim)
37        let theta = Self::compute_theta(dim, base);
38        
39        // Compute positions: [0, 1, ..., max_seq_len-1]
40        let positions: Vec<f64> = (0..max_seq_len).map(|i| i as f64).collect();
41        
42        // Compute outer product: positions * theta
43        let freqs = Self::compute_freqs(&positions, &theta);
44        
45        // Compute cos and sin caches
46        let cos_cache = freqs.cos();
47        let sin_cache = freqs.sin();
48        
49        Self {
50            cos_cache,
51            sin_cache,
52            dim,
53            max_seq_len,
54            base,
55        }
56    }
57
58    /// Create RoPE with default base (10000)
59    pub fn default(dim: usize, max_seq_len: usize) -> Self {
60        Self::new(dim, max_seq_len, 10000.0)
61    }
62
63    /// Compute theta frequencies
64    fn compute_theta(dim: usize, base: f64) -> Vec<f64> {
65        let half_dim = dim / 2;
66        let mut theta = Vec::with_capacity(half_dim);
67        
68        for i in 0..half_dim {
69            let exponent = -2.0 * i as f64 / dim as f64;
70            theta.push(base.powf(exponent));
71        }
72        
73        theta
74    }
75
76    /// Compute frequency matrix (outer product of positions and theta)
77    fn compute_freqs(positions: &[f64], theta: &[f64]) -> DenseTensor {
78        let max_seq_len = positions.len();
79        let half_dim = theta.len();
80        
81        let mut data = Vec::with_capacity(max_seq_len * half_dim);
82        
83        for &pos in positions {
84            for &t in theta {
85                data.push(pos * t);
86            }
87        }
88        
89        DenseTensor::new(data, vec![max_seq_len, half_dim])
90    }
91
92    /// Apply RoPE to input tensor
93    ///
94    /// # Arguments
95    /// * `x` - Input tensor [batch_size, seq_len, dim]
96    /// * `positions` - Optional position indices [seq_len] or [batch_size * seq_len]
97    ///
98    /// # Returns
99    /// Rotated tensor [batch_size, seq_len, dim]
100    pub fn forward(&self, x: &DenseTensor, positions: Option<&[usize]>) -> DenseTensor {
101        let batch_size = x.shape()[0];
102        let seq_len = x.shape()[1];
103        
104        // Default to sequential positions if not provided
105        let default_positions: Vec<usize> = (0..seq_len).collect();
106        let positions = positions.unwrap_or(&default_positions);
107        
108        let mut output = Vec::with_capacity(batch_size * seq_len * self.dim);
109        let half_dim = self.dim / 2;
110        
111        for b in 0..batch_size {
112            for s in 0..seq_len {
113                let pos = positions[s % positions.len()];
114                
115                // Get cos/sin for this position
116                let cos = self.cos_cache.get_row(pos.min(self.max_seq_len - 1));
117                let sin = self.sin_cache.get_row(pos.min(self.max_seq_len - 1));
118                
119                // Get input slice for this position
120                let x_start = (b * seq_len + s) * self.dim;
121                let x_slice = &x.data()[x_start..x_start + self.dim];
122                
123                // Apply RoPE: x * cos + rotate_half(x) * sin
124                for i in 0..half_dim {
125                    let x1 = x_slice[i];
126                    let x2 = x_slice[i + half_dim];
127                    
128                    // rotate_half: [-x2, x1]
129                    let rotated_x1 = -x2;
130                    let rotated_x2 = x1;
131                    
132                    // Apply rotation
133                    let out1 = x1 * cos.data()[i] + rotated_x1 * sin.data()[i];
134                    let out2 = x2 * cos.data()[i] + rotated_x2 * sin.data()[i];
135                    
136                    output.push(out1);
137                    output.push(out2);
138                }
139            }
140        }
141        
142        DenseTensor::new(output, vec![batch_size, seq_len, self.dim])
143    }
144
145    /// Apply RoPE to query and key tensors
146    ///
147    /// # Arguments
148    /// * `q` - Query tensor [batch_size, seq_len, dim]
149    /// * `k` - Key tensor [batch_size, seq_len, dim]
150    /// * `positions` - Optional position indices
151    ///
152    /// # Returns
153    /// Tuple of (rotated_q, rotated_k)
154    pub fn forward_qk(&self, q: &DenseTensor, k: &DenseTensor, positions: Option<&[usize]>) -> (DenseTensor, DenseTensor) {
155        let rotated_q = self.forward(q, positions);
156        let rotated_k = self.forward(k, positions);
157        (rotated_q, rotated_k)
158    }
159
160    /// Get the dimension
161    pub fn dim(&self) -> usize {
162        self.dim
163    }
164
165    /// Get the maximum sequence length
166    pub fn max_seq_len(&self) -> usize {
167        self.max_seq_len
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn test_rope_creation() {
177        let dim = 8;
178        let max_seq_len = 512;
179        let rope = RoPE::default(dim, max_seq_len);
180
181        assert_eq!(rope.dim(), dim);
182        assert_eq!(rope.max_seq_len(), max_seq_len);
183        assert_eq!(rope.cos_cache.shape(), &[max_seq_len, dim / 2]);
184        assert_eq!(rope.sin_cache.shape(), &[max_seq_len, dim / 2]);
185    }
186
187    #[test]
188    fn test_rope_forward() {
189        let dim = 8;
190        let max_seq_len = 512;
191        let rope = RoPE::default(dim, max_seq_len);
192
193        let batch_size = 2;
194        let seq_len = 4;
195        let x = DenseTensor::ones(vec![batch_size, seq_len, dim]);
196
197        let output = rope.forward(&x, None);
198
199        assert_eq!(output.shape(), &[batch_size, seq_len, dim]);
200    }
201
202    #[test]
203    fn test_rope_with_positions() {
204        let dim = 8;
205        let max_seq_len = 512;
206        let rope = RoPE::default(dim, max_seq_len);
207
208        let batch_size = 1;
209        let seq_len = 3;
210        let x = DenseTensor::ones(vec![batch_size, seq_len, dim]);
211        let positions = vec![0, 2, 4];
212
213        let output = rope.forward(&x, Some(&positions));
214
215        assert_eq!(output.shape(), &[batch_size, seq_len, dim]);
216    }
217
218    #[test]
219    fn test_rope_preserves_norm() {
220        // RoPE is a rotation, so it should preserve the L2 norm
221        let dim = 8;
222        let max_seq_len = 512;
223        let rope = RoPE::default(dim, max_seq_len);
224
225        let batch_size = 1;
226        let seq_len = 1;
227        let x = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![batch_size, seq_len, dim]);
228
229        let output = rope.forward(&x, None);
230
231        // Compute L2 norm of input and output
232        let input_norm: f64 = x.data().iter().map(|v| v * v).sum::<f64>().sqrt();
233        let output_norm: f64 = output.data().iter().map(|v| v * v).sum::<f64>().sqrt();
234
235        // Norms should be equal (within numerical precision)
236        assert!((input_norm - output_norm).abs() < 1e-5);
237    }
238}