1use crate::error::{LmError, LmResult};
5use crate::weights::WeightTensor;
6
7#[derive(Debug, Clone)]
14pub struct TokenEmbedding {
15 pub vocab_size: usize,
17 pub embed_dim: usize,
19 pub weight: WeightTensor,
21}
22
23impl TokenEmbedding {
24 pub fn new(vocab_size: usize, embed_dim: usize) -> LmResult<Self> {
26 if vocab_size == 0 || embed_dim == 0 {
27 return Err(LmError::InvalidConfig {
28 msg: "TokenEmbedding: vocab_size and embed_dim must be > 0".into(),
29 });
30 }
31 let weight = WeightTensor::zeros(&[vocab_size, embed_dim]);
32 Ok(Self {
33 vocab_size,
34 embed_dim,
35 weight,
36 })
37 }
38
39 pub fn from_weight(weight: WeightTensor) -> LmResult<Self> {
41 if weight.shape.len() != 2 {
42 return Err(LmError::DimensionMismatch {
43 expected: 2,
44 got: weight.shape.len(),
45 });
46 }
47 let vocab_size = weight.shape[0];
48 let embed_dim = weight.shape[1];
49 if vocab_size == 0 || embed_dim == 0 {
50 return Err(LmError::InvalidConfig {
51 msg: "TokenEmbedding weight must be non-empty".into(),
52 });
53 }
54 Ok(Self {
55 vocab_size,
56 embed_dim,
57 weight,
58 })
59 }
60
61 pub fn forward(&self, token_ids: &[u32]) -> LmResult<Vec<f32>> {
65 if token_ids.is_empty() {
66 return Err(LmError::EmptyInput {
67 context: "token_ids",
68 });
69 }
70 let mut out = vec![0.0_f32; token_ids.len() * self.embed_dim];
71 for (pos, &tid) in token_ids.iter().enumerate() {
72 if tid as usize >= self.vocab_size {
73 return Err(LmError::OutOfVocab { token: tid });
74 }
75 let src_start = tid as usize * self.embed_dim;
76 let dst_start = pos * self.embed_dim;
77 out[dst_start..dst_start + self.embed_dim]
78 .copy_from_slice(&self.weight.data[src_start..src_start + self.embed_dim]);
79 }
80 Ok(out)
81 }
82}
83
84#[derive(Debug, Clone)]
90pub struct LearnedPositionalEmbedding {
91 pub max_positions: usize,
93 pub embed_dim: usize,
95 pub weight: WeightTensor,
97}
98
99impl LearnedPositionalEmbedding {
100 pub fn new(max_positions: usize, embed_dim: usize) -> LmResult<Self> {
102 if max_positions == 0 || embed_dim == 0 {
103 return Err(LmError::InvalidConfig {
104 msg: "LearnedPositionalEmbedding: max_positions and embed_dim must be > 0".into(),
105 });
106 }
107 let weight = WeightTensor::zeros(&[max_positions, embed_dim]);
108 Ok(Self {
109 max_positions,
110 embed_dim,
111 weight,
112 })
113 }
114
115 pub fn from_weight(weight: WeightTensor) -> LmResult<Self> {
117 if weight.shape.len() != 2 {
118 return Err(LmError::DimensionMismatch {
119 expected: 2,
120 got: weight.shape.len(),
121 });
122 }
123 let max_positions = weight.shape[0];
124 let embed_dim = weight.shape[1];
125 Ok(Self {
126 max_positions,
127 embed_dim,
128 weight,
129 })
130 }
131
132 pub fn forward(&self, seq_len: usize, offset: usize) -> LmResult<Vec<f32>> {
136 if offset + seq_len > self.max_positions {
137 return Err(LmError::SequenceTooLong {
138 total_len: offset + seq_len,
139 max_pos: self.max_positions,
140 });
141 }
142 let mut out = vec![0.0_f32; seq_len * self.embed_dim];
143 for i in 0..seq_len {
144 let pos = offset + i;
145 let src = pos * self.embed_dim;
146 let dst = i * self.embed_dim;
147 out[dst..dst + self.embed_dim]
148 .copy_from_slice(&self.weight.data[src..src + self.embed_dim]);
149 }
150 Ok(out)
151 }
152}
153
154#[derive(Debug, Clone)]
171pub struct RotaryEmbedding {
172 pub head_dim: usize,
174 pub max_positions: usize,
176 pub theta: f32,
178 cos_table: Vec<f32>,
180 sin_table: Vec<f32>,
182}
183
184impl RotaryEmbedding {
185 pub fn new(head_dim: usize, max_positions: usize, theta: f32) -> LmResult<Self> {
187 if head_dim == 0 || head_dim % 2 != 0 {
188 return Err(LmError::InvalidConfig {
189 msg: format!("RotaryEmbedding: head_dim={head_dim} must be even and > 0"),
190 });
191 }
192 if max_positions == 0 {
193 return Err(LmError::InvalidConfig {
194 msg: "RotaryEmbedding: max_positions must be > 0".into(),
195 });
196 }
197 if theta <= 0.0 {
198 return Err(LmError::InvalidConfig {
199 msg: "RotaryEmbedding: theta must be > 0".into(),
200 });
201 }
202
203 let half_dim = head_dim / 2;
204 let n = max_positions * half_dim;
205 let mut cos_table = Vec::with_capacity(n);
206 let mut sin_table = Vec::with_capacity(n);
207
208 for pos in 0..max_positions {
209 for i in 0..half_dim {
210 let freq = theta.powf(-((2 * i) as f32) / head_dim as f32);
212 let angle = pos as f32 * freq;
213 cos_table.push(angle.cos());
214 sin_table.push(angle.sin());
215 }
216 }
217
218 Ok(Self {
219 head_dim,
220 max_positions,
221 theta,
222 cos_table,
223 sin_table,
224 })
225 }
226
227 pub fn apply(
232 &self,
233 x: &mut [f32],
234 n_heads: usize,
235 n_tokens: usize,
236 offset: usize,
237 ) -> LmResult<()> {
238 if offset + n_tokens > self.max_positions {
241 return Err(LmError::SequenceTooLong {
242 total_len: offset + n_tokens,
243 max_pos: self.max_positions,
244 });
245 }
246 let expected = n_tokens * n_heads * self.head_dim;
247 if x.len() != expected {
248 return Err(LmError::DimensionMismatch {
249 expected,
250 got: x.len(),
251 });
252 }
253
254 let half_dim = self.head_dim / 2;
255
256 for t in 0..n_tokens {
257 let abs_pos = offset + t;
258 let cos_row_start = abs_pos * half_dim;
259 for h in 0..n_heads {
260 let base = (t * n_heads + h) * self.head_dim;
261 for i in 0..half_dim {
262 let cos = self.cos_table[cos_row_start + i];
263 let sin = self.sin_table[cos_row_start + i];
264 let x0 = x[base + 2 * i];
265 let x1 = x[base + 2 * i + 1];
266 x[base + 2 * i] = x0 * cos - x1 * sin;
267 x[base + 2 * i + 1] = x0 * sin + x1 * cos;
268 }
269 }
270 }
271 Ok(())
272 }
273
274 pub fn cos_at(&self, pos: usize, i: usize) -> f32 {
276 self.cos_table[pos * (self.head_dim / 2) + i]
277 }
278
279 pub fn sin_at(&self, pos: usize, i: usize) -> f32 {
281 self.sin_table[pos * (self.head_dim / 2) + i]
282 }
283}
284
285#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
294 fn token_embedding_lookup() {
295 let mut emb = TokenEmbedding::new(4, 3).expect("vocab_size=4 embed_dim=3 should be valid");
296 emb.weight.data[6] = 1.0;
298 emb.weight.data[7] = 2.0;
299 emb.weight.data[8] = 3.0;
300 let out = emb
301 .forward(&[2])
302 .expect("token id 2 within vocab_size=4 should succeed");
303 assert_eq!(out, vec![1.0_f32, 2.0, 3.0]);
304 }
305
306 #[test]
307 fn token_embedding_multi_token() {
308 let mut emb = TokenEmbedding::new(3, 2).expect("vocab_size=3 embed_dim=2 should be valid");
309 emb.weight.data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
310 let out = emb
312 .forward(&[0, 2])
313 .expect("token ids 0 and 2 within vocab_size=3 should succeed");
314 assert_eq!(out, vec![1.0_f32, 2.0, 5.0, 6.0]);
315 }
316
317 #[test]
318 fn token_embedding_out_of_vocab_error() {
319 let emb = TokenEmbedding::new(4, 3).expect("vocab_size=4 embed_dim=3 should be valid");
320 assert!(matches!(
321 emb.forward(&[5]),
322 Err(LmError::OutOfVocab { token: 5 })
323 ));
324 }
325
326 #[test]
327 fn token_embedding_empty_error() {
328 let emb = TokenEmbedding::new(4, 3).expect("vocab_size=4 embed_dim=3 should be valid");
329 assert!(matches!(emb.forward(&[]), Err(LmError::EmptyInput { .. })));
330 }
331
332 #[test]
333 fn token_embedding_from_weight() {
334 let w = WeightTensor::zeros(&[10, 4]);
335 let emb = TokenEmbedding::from_weight(w)
336 .expect("2-D weight tensor [10,4] should be valid for TokenEmbedding");
337 assert_eq!(emb.vocab_size, 10);
338 assert_eq!(emb.embed_dim, 4);
339 }
340
341 #[test]
344 fn pos_embedding_lookup() {
345 let mut pe = LearnedPositionalEmbedding::new(4, 2)
346 .expect("max_positions=4 embed_dim=2 should be valid");
347 pe.weight.data[2] = 3.0;
349 pe.weight.data[3] = 4.0;
350 let out = pe
351 .forward(2, 0)
352 .expect("seq_len=2 offset=0 within max_positions=4 should succeed");
353 assert_eq!(out, vec![0.0_f32, 0.0, 3.0, 4.0]);
355 }
356
357 #[test]
358 fn pos_embedding_with_offset() {
359 let mut pe = LearnedPositionalEmbedding::new(8, 2)
360 .expect("max_positions=8 embed_dim=2 should be valid");
361 for i in 8..12 {
363 pe.weight.data[i] = 10.0;
364 }
365 let out = pe
366 .forward(2, 4)
367 .expect("seq_len=2 offset=4 within max_positions=8 should succeed"); assert!(out.iter().all(|&v| v == 10.0));
369 }
370
371 #[test]
372 fn pos_embedding_too_long_error() {
373 let pe = LearnedPositionalEmbedding::new(4, 2)
374 .expect("max_positions=4 embed_dim=2 should be valid");
375 assert!(matches!(
376 pe.forward(5, 0),
377 Err(LmError::SequenceTooLong { .. })
378 ));
379 }
380
381 #[test]
384 fn rope_pos0_is_identity() {
385 let rope = RotaryEmbedding::new(4, 16, 10_000.0)
387 .expect("even head_dim=4 max_pos=16 should be valid");
388 let mut x = vec![1.0_f32, 2.0, 3.0, 4.0]; rope.apply(&mut x, 1, 1, 0)
390 .expect("1 token at offset 0 within max_positions=16 should succeed");
391 assert!((x[0] - 1.0).abs() < 1e-5, "x[0]={}", x[0]);
393 assert!((x[1] - 2.0).abs() < 1e-5, "x[1]={}", x[1]);
394 assert!((x[2] - 3.0).abs() < 1e-5, "x[2]={}", x[2]);
395 assert!((x[3] - 4.0).abs() < 1e-5, "x[3]={}", x[3]);
396 }
397
398 #[test]
399 fn rope_rotation_preserves_norm() {
400 let rope = RotaryEmbedding::new(4, 32, 10_000.0)
402 .expect("even head_dim=4 max_pos=32 should be valid");
403 let original = vec![1.0_f32, 2.0, 3.0, 4.0];
404 let mut x = original.clone();
405 rope.apply(&mut x, 1, 1, 5)
406 .expect("1 token at offset 5 within max_positions=32 should succeed"); let norm_before: f32 = original.iter().map(|&v| v * v).sum::<f32>().sqrt();
408 let norm_after: f32 = x.iter().map(|&v| v * v).sum::<f32>().sqrt();
409 assert!(
410 (norm_before - norm_after).abs() < 1e-4,
411 "norm {norm_before} ≠ {norm_after}"
412 );
413 }
414
415 #[test]
416 fn rope_multi_head_multi_token() {
417 let rope = RotaryEmbedding::new(4, 32, 10_000.0)
419 .expect("even head_dim=4 max_pos=32 for multi-head test should be valid");
420 let mut x = vec![1.0_f32; 2 * 3 * 4]; rope.apply(&mut x, 3, 2, 0)
422 .expect("2 tokens 3 heads at offset 0 within max_positions=32 should succeed");
423 assert_eq!(x.len(), 24);
424 }
425
426 #[test]
427 fn rope_odd_head_dim_error() {
428 assert!(RotaryEmbedding::new(3, 16, 10_000.0).is_err());
429 }
430
431 #[test]
432 fn rope_sequence_too_long_error() {
433 let rope = RotaryEmbedding::new(4, 4, 10_000.0)
434 .expect("even head_dim=4 max_pos=4 should be valid");
435 let mut x = vec![0.0_f32; 4];
436 assert!(matches!(
438 rope.apply(&mut x, 1, 2, 3),
439 Err(LmError::SequenceTooLong { .. })
440 ));
441 }
442
443 #[test]
444 fn rope_cos_sin_tables_at_zero() {
445 let rope = RotaryEmbedding::new(4, 8, 10_000.0)
446 .expect("even head_dim=4 max_pos=8 should be valid");
447 assert!((rope.cos_at(0, 0) - 1.0).abs() < 1e-6);
449 assert!(rope.sin_at(0, 0).abs() < 1e-6);
450 }
451
452 #[test]
453 fn rope_tables_have_correct_dimensions() {
454 let head_dim = 8;
455 let max_pos = 16;
456 let rope = RotaryEmbedding::new(head_dim, max_pos, 10_000.0)
457 .expect("even head_dim and positive max_pos should produce valid RoPE");
458 assert_eq!(rope.cos_table.len(), max_pos * (head_dim / 2));
459 assert_eq!(rope.sin_table.len(), max_pos * (head_dim / 2));
460 }
461}