ghostflow_nn/
embedding.rs

1//! Embedding layers
2
3use ghostflow_core::Tensor;
4use crate::module::Module;
5use crate::init;
6
7/// Embedding layer - lookup table for discrete tokens
8pub struct Embedding {
9    weight: Tensor,
10    num_embeddings: usize,
11    embedding_dim: usize,
12    padding_idx: Option<usize>,
13    training: bool,
14}
15
16impl Embedding {
17    /// Create new embedding layer
18    pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
19        let weight = init::normal(&[num_embeddings, embedding_dim], 0.0, 1.0);
20        
21        Embedding {
22            weight,
23            num_embeddings,
24            embedding_dim,
25            padding_idx: None,
26            training: true,
27        }
28    }
29
30    /// Create embedding with padding index (embedding at padding_idx will be zeros)
31    pub fn with_padding(num_embeddings: usize, embedding_dim: usize, padding_idx: usize) -> Self {
32        let mut emb = Self::new(num_embeddings, embedding_dim);
33        emb.padding_idx = Some(padding_idx);
34        
35        // Zero out padding embedding
36        let mut weight_data = emb.weight.data_f32();
37        let start = padding_idx * embedding_dim;
38        for i in 0..embedding_dim {
39            weight_data[start + i] = 0.0;
40        }
41        emb.weight = Tensor::from_slice(&weight_data, &[num_embeddings, embedding_dim]).unwrap();
42        
43        emb
44    }
45
46    /// Create embedding from pretrained weights
47    pub fn from_pretrained(weight: Tensor, freeze: bool) -> Self {
48        let dims = weight.dims();
49        let num_embeddings = dims[0];
50        let embedding_dim = dims[1];
51        
52        Embedding {
53            weight,
54            num_embeddings,
55            embedding_dim,
56            padding_idx: None,
57            training: !freeze,
58        }
59    }
60
61    /// Get embedding dimension
62    pub fn embedding_dim(&self) -> usize {
63        self.embedding_dim
64    }
65
66    /// Get number of embeddings
67    pub fn num_embeddings(&self) -> usize {
68        self.num_embeddings
69    }
70
71    /// Forward pass with integer indices
72    pub fn forward_indices(&self, indices: &[usize]) -> Tensor {
73        let weight_data = self.weight.data_f32();
74        let mut output = Vec::with_capacity(indices.len() * self.embedding_dim);
75        
76        for &idx in indices {
77            let start = idx * self.embedding_dim;
78            output.extend_from_slice(&weight_data[start..start + self.embedding_dim]);
79        }
80        
81        Tensor::from_slice(&output, &[indices.len(), self.embedding_dim]).unwrap()
82    }
83}
84
85impl Module for Embedding {
86    fn forward(&self, input: &Tensor) -> Tensor {
87        // Input should be integer indices (stored as f32)
88        let indices: Vec<usize> = input.data_f32()
89            .iter()
90            .map(|&x| x as usize)
91            .collect();
92        
93        let input_shape = input.dims();
94        let batch_dims: Vec<usize> = input_shape.to_vec();
95        
96        let weight_data = self.weight.data_f32();
97        let mut output = Vec::with_capacity(indices.len() * self.embedding_dim);
98        
99        for &idx in &indices {
100            if idx >= self.num_embeddings {
101                // Out of bounds - use zeros
102                output.extend(vec![0.0f32; self.embedding_dim]);
103            } else {
104                let start = idx * self.embedding_dim;
105                output.extend_from_slice(&weight_data[start..start + self.embedding_dim]);
106            }
107        }
108        
109        // Output shape: input_shape + [embedding_dim]
110        let mut output_shape = batch_dims;
111        output_shape.push(self.embedding_dim);
112        
113        Tensor::from_slice(&output, &output_shape).unwrap()
114    }
115
116    fn parameters(&self) -> Vec<Tensor> {
117        if self.training {
118            vec![self.weight.clone()]
119        } else {
120            vec![] // Frozen
121        }
122    }
123
124    fn train(&mut self) { self.training = true; }
125    fn eval(&mut self) { self.training = false; }
126    fn is_training(&self) -> bool { self.training }
127}
128
129/// Token + Position Embedding (common in transformers)
130pub struct TokenPositionEmbedding {
131    token_embedding: Embedding,
132    position_embedding: Embedding,
133    #[allow(dead_code)]
134    dropout_p: f32,
135    #[allow(dead_code)]
136    max_seq_len: usize,
137}
138
139impl TokenPositionEmbedding {
140    pub fn new(vocab_size: usize, embed_dim: usize, max_seq_len: usize, dropout: f32) -> Self {
141        TokenPositionEmbedding {
142            token_embedding: Embedding::new(vocab_size, embed_dim),
143            position_embedding: Embedding::new(max_seq_len, embed_dim),
144            dropout_p: dropout,
145            max_seq_len,
146        }
147    }
148}
149
150impl Module for TokenPositionEmbedding {
151    fn forward(&self, input: &Tensor) -> Tensor {
152        let seq_len = input.dims()[input.ndim() - 1];
153        
154        // Token embeddings
155        let token_emb = self.token_embedding.forward(input);
156        
157        // Position indices
158        let positions: Vec<f32> = (0..seq_len).map(|i| i as f32).collect();
159        let pos_tensor = Tensor::from_slice(&positions, &[seq_len]).unwrap();
160        let pos_emb = self.position_embedding.forward(&pos_tensor);
161        
162        // Add token and position embeddings
163        token_emb.add(&pos_emb).unwrap()
164    }
165
166    fn parameters(&self) -> Vec<Tensor> {
167        let mut params = self.token_embedding.parameters();
168        params.extend(self.position_embedding.parameters());
169        params
170    }
171
172    fn train(&mut self) {
173        self.token_embedding.train();
174        self.position_embedding.train();
175    }
176
177    fn eval(&mut self) {
178        self.token_embedding.eval();
179        self.position_embedding.eval();
180    }
181
182    fn is_training(&self) -> bool {
183        self.token_embedding.is_training()
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn test_embedding() {
193        let emb = Embedding::new(100, 64);
194        let indices = Tensor::from_slice(&[0.0f32, 5.0, 10.0], &[3]).unwrap();
195        let output = emb.forward(&indices);
196        
197        assert_eq!(output.dims(), &[3, 64]);
198    }
199
200    #[test]
201    fn test_embedding_batch() {
202        let emb = Embedding::new(100, 64);
203        let indices = Tensor::from_slice(&[0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0], &[2, 3]).unwrap();
204        let output = emb.forward(&indices);
205        
206        assert_eq!(output.dims(), &[2, 3, 64]);
207    }
208}