Skip to main content

axonml_llm/
embedding.rs

1//! Embedding Module
2//!
3//! Token, positional, and combined embeddings for transformer models.
4
5use axonml_autograd::Variable;
6use axonml_nn::{Module, Embedding, Parameter, Dropout};
7use axonml_tensor::Tensor;
8use axonml_tensor::creation::{zeros, ones};
9
10/// Token embedding layer.
11#[derive(Debug)]
12pub struct TokenEmbedding {
13    /// Embedding layer
14    pub embedding: Embedding,
15}
16
17impl TokenEmbedding {
18    /// Creates a new token embedding.
19    pub fn new(vocab_size: usize, embed_dim: usize) -> Self {
20        Self {
21            embedding: Embedding::new(vocab_size, embed_dim),
22        }
23    }
24
25    /// Gets embeddings for token IDs.
26    pub fn forward_ids(&self, input_ids: &Tensor<u32>) -> Variable {
27        // Convert u32 to indices and lookup
28        let batch_size = input_ids.shape()[0];
29        let seq_len = input_ids.shape()[1];
30        let embed_dim = self.embedding.embedding_dim();
31
32        let ids_vec = input_ids.to_vec();
33        let mut output_data = vec![0.0f32; batch_size * seq_len * embed_dim];
34
35        let weight = &self.embedding.weight;
36        let weight_data = weight.data().to_vec();
37
38        for b in 0..batch_size {
39            for s in 0..seq_len {
40                let idx = ids_vec[b * seq_len + s] as usize;
41                let src_offset = idx * embed_dim;
42                let dst_offset = (b * seq_len + s) * embed_dim;
43
44                for e in 0..embed_dim {
45                    output_data[dst_offset + e] = weight_data[src_offset + e];
46                }
47            }
48        }
49
50        let output_tensor = Tensor::from_vec(output_data, &[batch_size, seq_len, embed_dim]).unwrap();
51        Variable::new(output_tensor, weight.requires_grad())
52    }
53}
54
55impl Module for TokenEmbedding {
56    fn forward(&self, input: &Variable) -> Variable {
57        self.embedding.forward(input)
58    }
59
60    fn parameters(&self) -> Vec<Parameter> {
61        self.embedding.parameters()
62    }
63}
64
65/// Learned positional embedding.
66#[derive(Debug)]
67pub struct PositionalEmbedding {
68    /// Position embedding weights
69    pub embedding: Embedding,
70    /// Maximum sequence length
71    pub max_len: usize,
72}
73
74impl PositionalEmbedding {
75    /// Creates a new learned positional embedding.
76    pub fn new(max_len: usize, embed_dim: usize) -> Self {
77        Self {
78            embedding: Embedding::new(max_len, embed_dim),
79            max_len,
80        }
81    }
82
83    /// Gets positional embeddings for a sequence length.
84    pub fn forward_positions(&self, seq_len: usize, batch_size: usize) -> Variable {
85        let embed_dim = self.embedding.embedding_dim();
86
87        // Create position indices [0, 1, 2, ..., seq_len-1]
88        let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
89        let position_tensor = Tensor::from_vec(positions.clone(), &[1, seq_len]).unwrap();
90        let position_var = Variable::new(position_tensor, false);
91
92        // Lookup embeddings
93        let pos_embeds = self.embedding.forward(&position_var);
94
95        // Expand to batch size
96        if batch_size > 1 {
97            pos_embeds.expand(&[batch_size, seq_len, embed_dim])
98        } else {
99            pos_embeds
100        }
101    }
102}
103
104impl Module for PositionalEmbedding {
105    fn forward(&self, input: &Variable) -> Variable {
106        self.embedding.forward(input)
107    }
108
109    fn parameters(&self) -> Vec<Parameter> {
110        self.embedding.parameters()
111    }
112}
113
114/// Sinusoidal positional encoding (fixed, not learned).
115#[derive(Debug)]
116pub struct SinusoidalPositionalEncoding {
117    /// Precomputed positional encodings
118    pub encodings: Tensor<f32>,
119    /// Maximum sequence length
120    pub max_len: usize,
121    /// Embedding dimension
122    pub embed_dim: usize,
123}
124
125impl SinusoidalPositionalEncoding {
126    /// Creates sinusoidal positional encodings.
127    pub fn new(max_len: usize, embed_dim: usize) -> Self {
128        let mut encodings = vec![0.0f32; max_len * embed_dim];
129
130        for pos in 0..max_len {
131            for i in 0..embed_dim / 2 {
132                let div_term = (10000.0f32).powf(2.0 * i as f32 / embed_dim as f32);
133                let angle = pos as f32 / div_term;
134
135                encodings[pos * embed_dim + 2 * i] = angle.sin();
136                encodings[pos * embed_dim + 2 * i + 1] = angle.cos();
137            }
138        }
139
140        Self {
141            encodings: Tensor::from_vec(encodings, &[max_len, embed_dim]).unwrap(),
142            max_len,
143            embed_dim,
144        }
145    }
146
147    /// Gets positional encodings for a sequence.
148    pub fn forward_seq(&self, seq_len: usize) -> Variable {
149        let sliced = self.encodings.slice(&[0..seq_len, 0..self.embed_dim]);
150        Variable::new(sliced, false)
151    }
152}
153
154/// BERT-style embeddings (token + position + segment).
155#[derive(Debug)]
156pub struct BertEmbedding {
157    /// Token embeddings
158    pub word_embeddings: Embedding,
159    /// Position embeddings
160    pub position_embeddings: Embedding,
161    /// Token type embeddings (segment embeddings)
162    pub token_type_embeddings: Embedding,
163    /// Layer normalization
164    pub layer_norm: LayerNorm,
165    /// Dropout
166    pub dropout: Dropout,
167    /// Embedding dimension
168    pub embed_dim: usize,
169}
170
171/// Simple layer norm implementation for embeddings.
172#[derive(Debug)]
173pub struct LayerNorm {
174    weight: Parameter,
175    bias: Parameter,
176    eps: f32,
177}
178
179impl LayerNorm {
180    fn new(dim: usize, eps: f32) -> Self {
181        let weight = Parameter::new(ones::<f32>(&[dim]), true);
182        let bias = Parameter::new(zeros::<f32>(&[dim]), true);
183        Self { weight, bias, eps }
184    }
185
186    fn forward(&self, x: &Variable) -> Variable {
187        // Normalize over last dimension
188        let mean = x.mean_dim(-1, true);
189        let variance = x.var_dim(-1, true);
190
191        let x_normalized = x.sub(&mean).div(&variance.add_scalar(self.eps).sqrt());
192
193        // Scale and shift
194        let weight_var = Variable::from_tensor_with_grad(self.weight.data().clone(), self.weight.requires_grad());
195        let bias_var = Variable::from_tensor_with_grad(self.bias.data().clone(), self.bias.requires_grad());
196
197        x_normalized.mul(&weight_var).add(&bias_var)
198    }
199
200    fn parameters(&self) -> Vec<Parameter> {
201        vec![self.weight.clone(), self.bias.clone()]
202    }
203}
204
205impl BertEmbedding {
206    /// Creates BERT embeddings.
207    pub fn new(
208        vocab_size: usize,
209        max_position_embeddings: usize,
210        type_vocab_size: usize,
211        hidden_size: usize,
212        layer_norm_eps: f32,
213        dropout_prob: f32,
214    ) -> Self {
215        Self {
216            word_embeddings: Embedding::new(vocab_size, hidden_size),
217            position_embeddings: Embedding::new(max_position_embeddings, hidden_size),
218            token_type_embeddings: Embedding::new(type_vocab_size, hidden_size),
219            layer_norm: LayerNorm::new(hidden_size, layer_norm_eps),
220            dropout: Dropout::new(dropout_prob),
221            embed_dim: hidden_size,
222        }
223    }
224
225    /// Forward pass with token IDs, position IDs, and token type IDs.
226    pub fn forward_with_ids(
227        &self,
228        input_ids: &Tensor<u32>,
229        token_type_ids: Option<&Tensor<u32>>,
230        position_ids: Option<&Tensor<u32>>,
231    ) -> Variable {
232        let batch_size = input_ids.shape()[0];
233        let seq_len = input_ids.shape()[1];
234
235        // Token embeddings
236        let input_ids_f32 = Self::u32_to_f32_tensor(input_ids);
237        let word_embeds = self.word_embeddings.forward(&Variable::new(input_ids_f32, false));
238
239        // Position embeddings
240        let pos_ids = if let Some(ids) = position_ids {
241            Self::u32_to_f32_tensor(ids)
242        } else {
243            let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
244            let pos_data: Vec<f32> = (0..batch_size).flat_map(|_| positions.iter().cloned()).collect();
245            Tensor::from_vec(pos_data, &[batch_size, seq_len]).unwrap()
246        };
247        let position_embeds = self.position_embeddings.forward(&Variable::new(pos_ids, false));
248
249        // Token type embeddings
250        let type_ids = if let Some(ids) = token_type_ids {
251            Self::u32_to_f32_tensor(ids)
252        } else {
253            zeros::<f32>(&[batch_size, seq_len])
254        };
255        let token_type_embeds = self.token_type_embeddings.forward(&Variable::new(type_ids, false));
256
257        // Combine embeddings
258        let embeddings = word_embeds.add(&position_embeds).add(&token_type_embeds);
259
260        // Layer norm and dropout
261        let embeddings = self.layer_norm.forward(&embeddings);
262        self.dropout.forward(&embeddings)
263    }
264
265    fn u32_to_f32_tensor(t: &Tensor<u32>) -> Tensor<f32> {
266        let data: Vec<f32> = t.to_vec().iter().map(|&x| x as f32).collect();
267        Tensor::from_vec(data, t.shape()).unwrap()
268    }
269}
270
271impl Module for BertEmbedding {
272    fn forward(&self, input: &Variable) -> Variable {
273        // For Module trait, assume input is already f32 token indices
274        let input_data = input.data();
275        let shape = input_data.shape();
276        let batch_size = shape[0];
277        let seq_len = shape[1];
278
279        let word_embeds = self.word_embeddings.forward(input);
280
281        // Generate position IDs
282        let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
283        let pos_data: Vec<f32> = (0..batch_size).flat_map(|_| positions.iter().cloned()).collect();
284        let pos_tensor = Tensor::from_vec(pos_data, &[batch_size, seq_len]).unwrap();
285        let position_embeds = self.position_embeddings.forward(&Variable::new(pos_tensor, false));
286
287        // Token type embeddings (assume all zeros)
288        let type_tensor = zeros::<f32>(&[batch_size, seq_len]);
289        let token_type_embeds = self.token_type_embeddings.forward(&Variable::new(type_tensor, false));
290
291        let embeddings = word_embeds.add(&position_embeds).add(&token_type_embeds);
292        let embeddings = self.layer_norm.forward(&embeddings);
293        self.dropout.forward(&embeddings)
294    }
295
296    fn parameters(&self) -> Vec<Parameter> {
297        let mut params = Vec::new();
298        params.extend(self.word_embeddings.parameters());
299        params.extend(self.position_embeddings.parameters());
300        params.extend(self.token_type_embeddings.parameters());
301        params.extend(self.layer_norm.parameters());
302        params
303    }
304
305    fn train(&mut self) {
306        self.dropout.train();
307    }
308
309    fn eval(&mut self) {
310        self.dropout.eval();
311    }
312}
313
314/// GPT-2 style embeddings (token + position).
315#[derive(Debug)]
316pub struct GPT2Embedding {
317    /// Token embeddings
318    pub wte: Embedding,
319    /// Position embeddings
320    pub wpe: Embedding,
321    /// Dropout
322    pub dropout: Dropout,
323    /// Embedding dimension
324    pub n_embd: usize,
325}
326
327impl GPT2Embedding {
328    /// Creates GPT-2 embeddings.
329    pub fn new(vocab_size: usize, n_ctx: usize, n_embd: usize, dropout: f32) -> Self {
330        Self {
331            wte: Embedding::new(vocab_size, n_embd),
332            wpe: Embedding::new(n_ctx, n_embd),
333            dropout: Dropout::new(dropout),
334            n_embd,
335        }
336    }
337
338    /// Forward pass with token IDs.
339    pub fn forward_ids(&self, input_ids: &Tensor<u32>) -> Variable {
340        let batch_size = input_ids.shape()[0];
341        let seq_len = input_ids.shape()[1];
342
343        // Token embeddings
344        let input_ids_f32 = Self::u32_to_f32_tensor(input_ids);
345        let token_embeds = self.wte.forward(&Variable::new(input_ids_f32, false));
346
347        // Position embeddings
348        let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
349        let pos_data: Vec<f32> = (0..batch_size).flat_map(|_| positions.iter().cloned()).collect();
350        let pos_tensor = Tensor::from_vec(pos_data, &[batch_size, seq_len]).unwrap();
351        let position_embeds = self.wpe.forward(&Variable::new(pos_tensor, false));
352
353        // Combine and apply dropout
354        let embeddings = token_embeds.add(&position_embeds);
355        self.dropout.forward(&embeddings)
356    }
357
358    fn u32_to_f32_tensor(t: &Tensor<u32>) -> Tensor<f32> {
359        let data: Vec<f32> = t.to_vec().iter().map(|&x| x as f32).collect();
360        Tensor::from_vec(data, t.shape()).unwrap()
361    }
362}
363
364impl Module for GPT2Embedding {
365    fn forward(&self, input: &Variable) -> Variable {
366        let input_data = input.data();
367        let shape = input_data.shape();
368        let batch_size = shape[0];
369        let seq_len = shape[1];
370
371        let token_embeds = self.wte.forward(input);
372
373        // Position embeddings
374        let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
375        let pos_data: Vec<f32> = (0..batch_size).flat_map(|_| positions.iter().cloned()).collect();
376        let pos_tensor = Tensor::from_vec(pos_data, &[batch_size, seq_len]).unwrap();
377        let position_embeds = self.wpe.forward(&Variable::new(pos_tensor, false));
378
379        let embeddings = token_embeds.add(&position_embeds);
380        self.dropout.forward(&embeddings)
381    }
382
383    fn parameters(&self) -> Vec<Parameter> {
384        let mut params = Vec::new();
385        params.extend(self.wte.parameters());
386        params.extend(self.wpe.parameters());
387        params
388    }
389
390    fn train(&mut self) {
391        self.dropout.train();
392    }
393
394    fn eval(&mut self) {
395        self.dropout.eval();
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_token_embedding() {
405        let embed = TokenEmbedding::new(1000, 64);
406        let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4], &[2, 2]).unwrap();
407        let output = embed.forward_ids(&input_ids);
408
409        assert_eq!(output.data().shape(), &[2, 2, 64]);
410    }
411
412    #[test]
413    fn test_positional_embedding() {
414        let embed = PositionalEmbedding::new(128, 64);
415        let output = embed.forward_positions(16, 2);
416
417        assert_eq!(output.data().shape(), &[2, 16, 64]);
418    }
419
420    #[test]
421    fn test_sinusoidal_encoding() {
422        let encoding = SinusoidalPositionalEncoding::new(100, 64);
423        let output = encoding.forward_seq(16);
424
425        assert_eq!(output.data().shape(), &[16, 64]);
426    }
427
428    #[test]
429    fn test_gpt2_embedding() {
430        let embed = GPT2Embedding::new(1000, 128, 64, 0.0);
431        let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4], &[2, 2]).unwrap();
432        let output = embed.forward_ids(&input_ids);
433
434        assert_eq!(output.data().shape(), &[2, 2, 64]);
435    }
436}