Skip to main content

oxicuda_lm/layer/
embedding.rs

1//! Embedding layers: token embedding, learned positional embedding, and
2//! Rotary Positional Embedding (RoPE).
3
4use crate::error::{LmError, LmResult};
5use crate::weights::WeightTensor;
6
7// ─── TokenEmbedding ──────────────────────────────────────────────────────────
8
9/// Token embedding table: maps token ids to dense vectors.
10///
11/// Weight shape: `[vocab_size × embed_dim]`.
12/// Output shape: `[seq_len × embed_dim]`.
13#[derive(Debug, Clone)]
14pub struct TokenEmbedding {
15    /// Vocabulary size.
16    pub vocab_size: usize,
17    /// Embedding dimension.
18    pub embed_dim: usize,
19    /// Weight table: `[vocab_size × embed_dim]`, row-major.
20    pub weight: WeightTensor,
21}
22
23impl TokenEmbedding {
24    /// Construct with zero-initialised weights.
25    pub fn new(vocab_size: usize, embed_dim: usize) -> LmResult<Self> {
26        if vocab_size == 0 || embed_dim == 0 {
27            return Err(LmError::InvalidConfig {
28                msg: "TokenEmbedding: vocab_size and embed_dim must be > 0".into(),
29            });
30        }
31        let weight = WeightTensor::zeros(&[vocab_size, embed_dim]);
32        Ok(Self {
33            vocab_size,
34            embed_dim,
35            weight,
36        })
37    }
38
39    /// Construct from an existing weight tensor.
40    pub fn from_weight(weight: WeightTensor) -> LmResult<Self> {
41        if weight.shape.len() != 2 {
42            return Err(LmError::DimensionMismatch {
43                expected: 2,
44                got: weight.shape.len(),
45            });
46        }
47        let vocab_size = weight.shape[0];
48        let embed_dim = weight.shape[1];
49        if vocab_size == 0 || embed_dim == 0 {
50            return Err(LmError::InvalidConfig {
51                msg: "TokenEmbedding weight must be non-empty".into(),
52            });
53        }
54        Ok(Self {
55            vocab_size,
56            embed_dim,
57            weight,
58        })
59    }
60
61    /// Lookup embeddings for `token_ids`.
62    ///
63    /// Returns a flat buffer of shape `[token_ids.len() × embed_dim]`.
64    pub fn forward(&self, token_ids: &[u32]) -> LmResult<Vec<f32>> {
65        if token_ids.is_empty() {
66            return Err(LmError::EmptyInput {
67                context: "token_ids",
68            });
69        }
70        let mut out = vec![0.0_f32; token_ids.len() * self.embed_dim];
71        for (pos, &tid) in token_ids.iter().enumerate() {
72            if tid as usize >= self.vocab_size {
73                return Err(LmError::OutOfVocab { token: tid });
74            }
75            let src_start = tid as usize * self.embed_dim;
76            let dst_start = pos * self.embed_dim;
77            out[dst_start..dst_start + self.embed_dim]
78                .copy_from_slice(&self.weight.data[src_start..src_start + self.embed_dim]);
79        }
80        Ok(out)
81    }
82}
83
84// ─── LearnedPositionalEmbedding ───────────────────────────────────────────────
85
86/// Learned positional embedding table (GPT-2 style).
87///
88/// Weight shape: `[max_positions × embed_dim]`.
89#[derive(Debug, Clone)]
90pub struct LearnedPositionalEmbedding {
91    /// Maximum number of positions.
92    pub max_positions: usize,
93    /// Embedding dimension.
94    pub embed_dim: usize,
95    /// Weight table.
96    pub weight: WeightTensor,
97}
98
99impl LearnedPositionalEmbedding {
100    /// Construct with zero-initialised weights.
101    pub fn new(max_positions: usize, embed_dim: usize) -> LmResult<Self> {
102        if max_positions == 0 || embed_dim == 0 {
103            return Err(LmError::InvalidConfig {
104                msg: "LearnedPositionalEmbedding: max_positions and embed_dim must be > 0".into(),
105            });
106        }
107        let weight = WeightTensor::zeros(&[max_positions, embed_dim]);
108        Ok(Self {
109            max_positions,
110            embed_dim,
111            weight,
112        })
113    }
114
115    /// Construct from an existing weight tensor.
116    pub fn from_weight(weight: WeightTensor) -> LmResult<Self> {
117        if weight.shape.len() != 2 {
118            return Err(LmError::DimensionMismatch {
119                expected: 2,
120                got: weight.shape.len(),
121            });
122        }
123        let max_positions = weight.shape[0];
124        let embed_dim = weight.shape[1];
125        Ok(Self {
126            max_positions,
127            embed_dim,
128            weight,
129        })
130    }
131
132    /// Return positional embeddings for positions `[offset, offset + seq_len)`.
133    ///
134    /// Returns flat buffer of shape `[seq_len × embed_dim]`.
135    pub fn forward(&self, seq_len: usize, offset: usize) -> LmResult<Vec<f32>> {
136        if offset + seq_len > self.max_positions {
137            return Err(LmError::SequenceTooLong {
138                total_len: offset + seq_len,
139                max_pos: self.max_positions,
140            });
141        }
142        let mut out = vec![0.0_f32; seq_len * self.embed_dim];
143        for i in 0..seq_len {
144            let pos = offset + i;
145            let src = pos * self.embed_dim;
146            let dst = i * self.embed_dim;
147            out[dst..dst + self.embed_dim]
148                .copy_from_slice(&self.weight.data[src..src + self.embed_dim]);
149        }
150        Ok(out)
151    }
152}
153
154// ─── RotaryEmbedding ─────────────────────────────────────────────────────────
155
156/// Rotary Positional Embedding (RoPE).
157///
158/// Precomputes `cos` and `sin` tables for all positions up to `max_positions`.
159/// The rotation applies to pairs of dimensions `(x_{2i}, x_{2i+1})` as:
160///
161/// ```text
162/// x_out[2i]   = x[2i]*cos(θ_i*pos) − x[2i+1]*sin(θ_i*pos)
163/// x_out[2i+1] = x[2i]*sin(θ_i*pos) + x[2i+1]*cos(θ_i*pos)
164/// ```
165///
166/// where `θ_i = theta ^ (-2i / head_dim)`.
167///
168/// This embeds position information directly into the attention dot product
169/// without requiring separate positional embeddings.
170#[derive(Debug, Clone)]
171pub struct RotaryEmbedding {
172    /// Head dimension (must be even).
173    pub head_dim: usize,
174    /// Maximum sequence length for which tables are precomputed.
175    pub max_positions: usize,
176    /// RoPE base frequency (typically 10 000 for LLaMA-2, 500 000 for LLaMA-3).
177    pub theta: f32,
178    /// Cos table: `[max_positions × head_dim/2]`, row-major.
179    cos_table: Vec<f32>,
180    /// Sin table: `[max_positions × head_dim/2]`, row-major.
181    sin_table: Vec<f32>,
182}
183
184impl RotaryEmbedding {
185    /// Build RoPE tables for the given configuration.
186    pub fn new(head_dim: usize, max_positions: usize, theta: f32) -> LmResult<Self> {
187        if head_dim == 0 || head_dim % 2 != 0 {
188            return Err(LmError::InvalidConfig {
189                msg: format!("RotaryEmbedding: head_dim={head_dim} must be even and > 0"),
190            });
191        }
192        if max_positions == 0 {
193            return Err(LmError::InvalidConfig {
194                msg: "RotaryEmbedding: max_positions must be > 0".into(),
195            });
196        }
197        if theta <= 0.0 {
198            return Err(LmError::InvalidConfig {
199                msg: "RotaryEmbedding: theta must be > 0".into(),
200            });
201        }
202
203        let half_dim = head_dim / 2;
204        let n = max_positions * half_dim;
205        let mut cos_table = Vec::with_capacity(n);
206        let mut sin_table = Vec::with_capacity(n);
207
208        for pos in 0..max_positions {
209            for i in 0..half_dim {
210                // θ_i = theta ^ (-2i / head_dim)
211                let freq = theta.powf(-((2 * i) as f32) / head_dim as f32);
212                let angle = pos as f32 * freq;
213                cos_table.push(angle.cos());
214                sin_table.push(angle.sin());
215            }
216        }
217
218        Ok(Self {
219            head_dim,
220            max_positions,
221            theta,
222            cos_table,
223            sin_table,
224        })
225    }
226
227    /// Apply RoPE in-place to a QKV projection.
228    ///
229    /// `x` has shape `[n_tokens × n_heads × head_dim]`.
230    /// `offset` is the absolute position of the first token (for KV-cache decode).
231    pub fn apply(
232        &self,
233        x: &mut [f32],
234        n_heads: usize,
235        n_tokens: usize,
236        offset: usize,
237    ) -> LmResult<()> {
238        // Check positional bounds before buffer size so callers get a more
239        // informative error when both conditions are violated simultaneously.
240        if offset + n_tokens > self.max_positions {
241            return Err(LmError::SequenceTooLong {
242                total_len: offset + n_tokens,
243                max_pos: self.max_positions,
244            });
245        }
246        let expected = n_tokens * n_heads * self.head_dim;
247        if x.len() != expected {
248            return Err(LmError::DimensionMismatch {
249                expected,
250                got: x.len(),
251            });
252        }
253
254        let half_dim = self.head_dim / 2;
255
256        for t in 0..n_tokens {
257            let abs_pos = offset + t;
258            let cos_row_start = abs_pos * half_dim;
259            for h in 0..n_heads {
260                let base = (t * n_heads + h) * self.head_dim;
261                for i in 0..half_dim {
262                    let cos = self.cos_table[cos_row_start + i];
263                    let sin = self.sin_table[cos_row_start + i];
264                    let x0 = x[base + 2 * i];
265                    let x1 = x[base + 2 * i + 1];
266                    x[base + 2 * i] = x0 * cos - x1 * sin;
267                    x[base + 2 * i + 1] = x0 * sin + x1 * cos;
268                }
269            }
270        }
271        Ok(())
272    }
273
274    /// Cosine value for `(position, half_dim_index)`.
275    pub fn cos_at(&self, pos: usize, i: usize) -> f32 {
276        self.cos_table[pos * (self.head_dim / 2) + i]
277    }
278
279    /// Sine value for `(position, half_dim_index)`.
280    pub fn sin_at(&self, pos: usize, i: usize) -> f32 {
281        self.sin_table[pos * (self.head_dim / 2) + i]
282    }
283}
284
285// ─── Tests ───────────────────────────────────────────────────────────────────
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    // ── TokenEmbedding ────────────────────────────────────────────────────
292
293    #[test]
294    fn token_embedding_lookup() {
295        let mut emb = TokenEmbedding::new(4, 3).expect("vocab_size=4 embed_dim=3 should be valid");
296        // Set row 2 to [1,2,3]
297        emb.weight.data[6] = 1.0;
298        emb.weight.data[7] = 2.0;
299        emb.weight.data[8] = 3.0;
300        let out = emb
301            .forward(&[2])
302            .expect("token id 2 within vocab_size=4 should succeed");
303        assert_eq!(out, vec![1.0_f32, 2.0, 3.0]);
304    }
305
306    #[test]
307    fn token_embedding_multi_token() {
308        let mut emb = TokenEmbedding::new(3, 2).expect("vocab_size=3 embed_dim=2 should be valid");
309        emb.weight.data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
310        // tokens [0, 2] → [[1,2], [5,6]]
311        let out = emb
312            .forward(&[0, 2])
313            .expect("token ids 0 and 2 within vocab_size=3 should succeed");
314        assert_eq!(out, vec![1.0_f32, 2.0, 5.0, 6.0]);
315    }
316
317    #[test]
318    fn token_embedding_out_of_vocab_error() {
319        let emb = TokenEmbedding::new(4, 3).expect("vocab_size=4 embed_dim=3 should be valid");
320        assert!(matches!(
321            emb.forward(&[5]),
322            Err(LmError::OutOfVocab { token: 5 })
323        ));
324    }
325
326    #[test]
327    fn token_embedding_empty_error() {
328        let emb = TokenEmbedding::new(4, 3).expect("vocab_size=4 embed_dim=3 should be valid");
329        assert!(matches!(emb.forward(&[]), Err(LmError::EmptyInput { .. })));
330    }
331
332    #[test]
333    fn token_embedding_from_weight() {
334        let w = WeightTensor::zeros(&[10, 4]);
335        let emb = TokenEmbedding::from_weight(w)
336            .expect("2-D weight tensor [10,4] should be valid for TokenEmbedding");
337        assert_eq!(emb.vocab_size, 10);
338        assert_eq!(emb.embed_dim, 4);
339    }
340
341    // ── LearnedPositionalEmbedding ────────────────────────────────────────
342
343    #[test]
344    fn pos_embedding_lookup() {
345        let mut pe = LearnedPositionalEmbedding::new(4, 2)
346            .expect("max_positions=4 embed_dim=2 should be valid");
347        // Set position 1 to [3.0, 4.0]
348        pe.weight.data[2] = 3.0;
349        pe.weight.data[3] = 4.0;
350        let out = pe
351            .forward(2, 0)
352            .expect("seq_len=2 offset=0 within max_positions=4 should succeed");
353        // pos 0: [0,0], pos 1: [3,4]
354        assert_eq!(out, vec![0.0_f32, 0.0, 3.0, 4.0]);
355    }
356
357    #[test]
358    fn pos_embedding_with_offset() {
359        let mut pe = LearnedPositionalEmbedding::new(8, 2)
360            .expect("max_positions=8 embed_dim=2 should be valid");
361        // positions 4,5 get value 10
362        for i in 8..12 {
363            pe.weight.data[i] = 10.0;
364        }
365        let out = pe
366            .forward(2, 4)
367            .expect("seq_len=2 offset=4 within max_positions=8 should succeed"); // positions 4..6
368        assert!(out.iter().all(|&v| v == 10.0));
369    }
370
371    #[test]
372    fn pos_embedding_too_long_error() {
373        let pe = LearnedPositionalEmbedding::new(4, 2)
374            .expect("max_positions=4 embed_dim=2 should be valid");
375        assert!(matches!(
376            pe.forward(5, 0),
377            Err(LmError::SequenceTooLong { .. })
378        ));
379    }
380
381    // ── RotaryEmbedding ───────────────────────────────────────────────────
382
383    #[test]
384    fn rope_pos0_is_identity() {
385        // At position 0, angle = 0, cos=1, sin=0 → rotation is identity
386        let rope = RotaryEmbedding::new(4, 16, 10_000.0)
387            .expect("even head_dim=4 max_pos=16 should be valid");
388        let mut x = vec![1.0_f32, 2.0, 3.0, 4.0]; // 1 token, 1 head, head_dim=4
389        rope.apply(&mut x, 1, 1, 0)
390            .expect("1 token at offset 0 within max_positions=16 should succeed");
391        // All cos=1 at pos=0 for i=0, so x[0]=1*1-2*sin=1-2*0=1, x[1]=1*0+2*1=2
392        assert!((x[0] - 1.0).abs() < 1e-5, "x[0]={}", x[0]);
393        assert!((x[1] - 2.0).abs() < 1e-5, "x[1]={}", x[1]);
394        assert!((x[2] - 3.0).abs() < 1e-5, "x[2]={}", x[2]);
395        assert!((x[3] - 4.0).abs() < 1e-5, "x[3]={}", x[3]);
396    }
397
398    #[test]
399    fn rope_rotation_preserves_norm() {
400        // Rotation is orthogonal → norm preserved.
401        let rope = RotaryEmbedding::new(4, 32, 10_000.0)
402            .expect("even head_dim=4 max_pos=32 should be valid");
403        let original = vec![1.0_f32, 2.0, 3.0, 4.0];
404        let mut x = original.clone();
405        rope.apply(&mut x, 1, 1, 5)
406            .expect("1 token at offset 5 within max_positions=32 should succeed"); // pos=5
407        let norm_before: f32 = original.iter().map(|&v| v * v).sum::<f32>().sqrt();
408        let norm_after: f32 = x.iter().map(|&v| v * v).sum::<f32>().sqrt();
409        assert!(
410            (norm_before - norm_after).abs() < 1e-4,
411            "norm {norm_before} ≠ {norm_after}"
412        );
413    }
414
415    #[test]
416    fn rope_multi_head_multi_token() {
417        // Just check no error and correct output size.
418        let rope = RotaryEmbedding::new(4, 32, 10_000.0)
419            .expect("even head_dim=4 max_pos=32 for multi-head test should be valid");
420        let mut x = vec![1.0_f32; 2 * 3 * 4]; // 2 tokens, 3 heads, head_dim=4
421        rope.apply(&mut x, 3, 2, 0)
422            .expect("2 tokens 3 heads at offset 0 within max_positions=32 should succeed");
423        assert_eq!(x.len(), 24);
424    }
425
426    #[test]
427    fn rope_odd_head_dim_error() {
428        assert!(RotaryEmbedding::new(3, 16, 10_000.0).is_err());
429    }
430
431    #[test]
432    fn rope_sequence_too_long_error() {
433        let rope = RotaryEmbedding::new(4, 4, 10_000.0)
434            .expect("even head_dim=4 max_pos=4 should be valid");
435        let mut x = vec![0.0_f32; 4];
436        // offset=3 + seq_len=2 = 5 > max_positions=4
437        assert!(matches!(
438            rope.apply(&mut x, 1, 2, 3),
439            Err(LmError::SequenceTooLong { .. })
440        ));
441    }
442
443    #[test]
444    fn rope_cos_sin_tables_at_zero() {
445        let rope = RotaryEmbedding::new(4, 8, 10_000.0)
446            .expect("even head_dim=4 max_pos=8 should be valid");
447        // At position 0, cos=1, sin=0 for all dims
448        assert!((rope.cos_at(0, 0) - 1.0).abs() < 1e-6);
449        assert!(rope.sin_at(0, 0).abs() < 1e-6);
450    }
451
452    #[test]
453    fn rope_tables_have_correct_dimensions() {
454        let head_dim = 8;
455        let max_pos = 16;
456        let rope = RotaryEmbedding::new(head_dim, max_pos, 10_000.0)
457            .expect("even head_dim and positive max_pos should produce valid RoPE");
458        assert_eq!(rope.cos_table.len(), max_pos * (head_dim / 2));
459        assert_eq!(rope.sin_table.len(), max_pos * (head_dim / 2));
460    }
461}