god_graph/transformer/layers/
embedding.rs1use crate::tensor::DenseTensor;
4use crate::tensor::traits::TensorBase;
5
6#[derive(Debug, Clone)]
15pub struct RoPE {
16 pub cos_cache: DenseTensor,
18 pub sin_cache: DenseTensor,
20 pub dim: usize,
22 pub max_seq_len: usize,
24 pub base: f64,
26}
27
28impl RoPE {
29 pub fn new(dim: usize, max_seq_len: usize, base: f64) -> Self {
36 let theta = Self::compute_theta(dim, base);
38
39 let positions: Vec<f64> = (0..max_seq_len).map(|i| i as f64).collect();
41
42 let freqs = Self::compute_freqs(&positions, &theta);
44
45 let cos_cache = freqs.cos();
47 let sin_cache = freqs.sin();
48
49 Self {
50 cos_cache,
51 sin_cache,
52 dim,
53 max_seq_len,
54 base,
55 }
56 }
57
58 pub fn default(dim: usize, max_seq_len: usize) -> Self {
60 Self::new(dim, max_seq_len, 10000.0)
61 }
62
63 fn compute_theta(dim: usize, base: f64) -> Vec<f64> {
65 let half_dim = dim / 2;
66 let mut theta = Vec::with_capacity(half_dim);
67
68 for i in 0..half_dim {
69 let exponent = -2.0 * i as f64 / dim as f64;
70 theta.push(base.powf(exponent));
71 }
72
73 theta
74 }
75
76 fn compute_freqs(positions: &[f64], theta: &[f64]) -> DenseTensor {
78 let max_seq_len = positions.len();
79 let half_dim = theta.len();
80
81 let mut data = Vec::with_capacity(max_seq_len * half_dim);
82
83 for &pos in positions {
84 for &t in theta {
85 data.push(pos * t);
86 }
87 }
88
89 DenseTensor::new(data, vec![max_seq_len, half_dim])
90 }
91
92 pub fn forward(&self, x: &DenseTensor, positions: Option<&[usize]>) -> DenseTensor {
101 let batch_size = x.shape()[0];
102 let seq_len = x.shape()[1];
103
104 let default_positions: Vec<usize> = (0..seq_len).collect();
106 let positions = positions.unwrap_or(&default_positions);
107
108 let mut output = Vec::with_capacity(batch_size * seq_len * self.dim);
109 let half_dim = self.dim / 2;
110
111 for b in 0..batch_size {
112 for s in 0..seq_len {
113 let pos = positions[s % positions.len()];
114
115 let cos = self.cos_cache.get_row(pos.min(self.max_seq_len - 1));
117 let sin = self.sin_cache.get_row(pos.min(self.max_seq_len - 1));
118
119 let x_start = (b * seq_len + s) * self.dim;
121 let x_slice = &x.data()[x_start..x_start + self.dim];
122
123 for i in 0..half_dim {
125 let x1 = x_slice[i];
126 let x2 = x_slice[i + half_dim];
127
128 let rotated_x1 = -x2;
130 let rotated_x2 = x1;
131
132 let out1 = x1 * cos.data()[i] + rotated_x1 * sin.data()[i];
134 let out2 = x2 * cos.data()[i] + rotated_x2 * sin.data()[i];
135
136 output.push(out1);
137 output.push(out2);
138 }
139 }
140 }
141
142 DenseTensor::new(output, vec![batch_size, seq_len, self.dim])
143 }
144
145 pub fn forward_qk(&self, q: &DenseTensor, k: &DenseTensor, positions: Option<&[usize]>) -> (DenseTensor, DenseTensor) {
155 let rotated_q = self.forward(q, positions);
156 let rotated_k = self.forward(k, positions);
157 (rotated_q, rotated_k)
158 }
159
160 pub fn dim(&self) -> usize {
162 self.dim
163 }
164
165 pub fn max_seq_len(&self) -> usize {
167 self.max_seq_len
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn test_rope_creation() {
177 let dim = 8;
178 let max_seq_len = 512;
179 let rope = RoPE::default(dim, max_seq_len);
180
181 assert_eq!(rope.dim(), dim);
182 assert_eq!(rope.max_seq_len(), max_seq_len);
183 assert_eq!(rope.cos_cache.shape(), &[max_seq_len, dim / 2]);
184 assert_eq!(rope.sin_cache.shape(), &[max_seq_len, dim / 2]);
185 }
186
187 #[test]
188 fn test_rope_forward() {
189 let dim = 8;
190 let max_seq_len = 512;
191 let rope = RoPE::default(dim, max_seq_len);
192
193 let batch_size = 2;
194 let seq_len = 4;
195 let x = DenseTensor::ones(vec![batch_size, seq_len, dim]);
196
197 let output = rope.forward(&x, None);
198
199 assert_eq!(output.shape(), &[batch_size, seq_len, dim]);
200 }
201
202 #[test]
203 fn test_rope_with_positions() {
204 let dim = 8;
205 let max_seq_len = 512;
206 let rope = RoPE::default(dim, max_seq_len);
207
208 let batch_size = 1;
209 let seq_len = 3;
210 let x = DenseTensor::ones(vec![batch_size, seq_len, dim]);
211 let positions = vec![0, 2, 4];
212
213 let output = rope.forward(&x, Some(&positions));
214
215 assert_eq!(output.shape(), &[batch_size, seq_len, dim]);
216 }
217
218 #[test]
219 fn test_rope_preserves_norm() {
220 let dim = 8;
222 let max_seq_len = 512;
223 let rope = RoPE::default(dim, max_seq_len);
224
225 let batch_size = 1;
226 let seq_len = 1;
227 let x = DenseTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![batch_size, seq_len, dim]);
228
229 let output = rope.forward(&x, None);
230
231 let input_norm: f64 = x.data().iter().map(|v| v * v).sum::<f64>().sqrt();
233 let output_norm: f64 = output.data().iter().map(|v| v * v).sum::<f64>().sqrt();
234
235 assert!((input_norm - output_norm).abs() < 1e-5);
237 }
238}