ghostflow_nn/
embedding.rs

1//! Embedding layers
2
3use ghostflow_core::Tensor;
4use crate::module::Module;
5use crate::init;
6
7/// Embedding layer - lookup table for discrete tokens
8pub struct Embedding {
9    weight: Tensor,
10    num_embeddings: usize,
11    embedding_dim: usize,
12    padding_idx: Option<usize>,
13    training: bool,
14}
15
16impl Embedding {
17    /// Create new embedding layer
18    pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
19        let weight = init::normal(&[num_embeddings, embedding_dim], 0.0, 1.0);
20        
21        Embedding {
22            weight,
23            num_embeddings,
24            embedding_dim,
25            padding_idx: None,
26            training: true,
27        }
28    }
29
30    /// Create embedding with padding index (embedding at padding_idx will be zeros)
31    pub fn with_padding(num_embeddings: usize, embedding_dim: usize, padding_idx: usize) -> Self {
32        let mut emb = Self::new(num_embeddings, embedding_dim);
33        emb.padding_idx = Some(padding_idx);
34        
35        // Zero out padding embedding
36        let mut weight_data = emb.weight.data_f32();
37        let start = padding_idx * embedding_dim;
38        for i in 0..embedding_dim {
39            weight_data[start + i] = 0.0;
40        }
41        emb.weight = Tensor::from_slice(&weight_data, &[num_embeddings, embedding_dim]).unwrap();
42        
43        emb
44    }
45
46    /// Create embedding from pretrained weights
47    pub fn from_pretrained(weight: Tensor, freeze: bool) -> Self {
48        let dims = weight.dims();
49        let num_embeddings = dims[0];
50        let embedding_dim = dims[1];
51        
52        let emb = Embedding {
53            weight,
54            num_embeddings,
55            embedding_dim,
56            padding_idx: None,
57            training: !freeze,
58        };
59        
60        emb
61    }
62
63    /// Get embedding dimension
64    pub fn embedding_dim(&self) -> usize {
65        self.embedding_dim
66    }
67
68    /// Get number of embeddings
69    pub fn num_embeddings(&self) -> usize {
70        self.num_embeddings
71    }
72
73    /// Forward pass with integer indices
74    pub fn forward_indices(&self, indices: &[usize]) -> Tensor {
75        let weight_data = self.weight.data_f32();
76        let mut output = Vec::with_capacity(indices.len() * self.embedding_dim);
77        
78        for &idx in indices {
79            let start = idx * self.embedding_dim;
80            output.extend_from_slice(&weight_data[start..start + self.embedding_dim]);
81        }
82        
83        Tensor::from_slice(&output, &[indices.len(), self.embedding_dim]).unwrap()
84    }
85}
86
87impl Module for Embedding {
88    fn forward(&self, input: &Tensor) -> Tensor {
89        // Input should be integer indices (stored as f32)
90        let indices: Vec<usize> = input.data_f32()
91            .iter()
92            .map(|&x| x as usize)
93            .collect();
94        
95        let input_shape = input.dims();
96        let batch_dims: Vec<usize> = input_shape.to_vec();
97        
98        let weight_data = self.weight.data_f32();
99        let mut output = Vec::with_capacity(indices.len() * self.embedding_dim);
100        
101        for &idx in &indices {
102            if idx >= self.num_embeddings {
103                // Out of bounds - use zeros
104                output.extend(vec![0.0f32; self.embedding_dim]);
105            } else {
106                let start = idx * self.embedding_dim;
107                output.extend_from_slice(&weight_data[start..start + self.embedding_dim]);
108            }
109        }
110        
111        // Output shape: input_shape + [embedding_dim]
112        let mut output_shape = batch_dims;
113        output_shape.push(self.embedding_dim);
114        
115        Tensor::from_slice(&output, &output_shape).unwrap()
116    }
117
118    fn parameters(&self) -> Vec<Tensor> {
119        if self.training {
120            vec![self.weight.clone()]
121        } else {
122            vec![] // Frozen
123        }
124    }
125
126    fn train(&mut self) { self.training = true; }
127    fn eval(&mut self) { self.training = false; }
128    fn is_training(&self) -> bool { self.training }
129}
130
131/// Token + Position Embedding (common in transformers)
132pub struct TokenPositionEmbedding {
133    token_embedding: Embedding,
134    position_embedding: Embedding,
135    #[allow(dead_code)]
136    dropout_p: f32,
137    #[allow(dead_code)]
138    max_seq_len: usize,
139}
140
141impl TokenPositionEmbedding {
142    pub fn new(vocab_size: usize, embed_dim: usize, max_seq_len: usize, dropout: f32) -> Self {
143        TokenPositionEmbedding {
144            token_embedding: Embedding::new(vocab_size, embed_dim),
145            position_embedding: Embedding::new(max_seq_len, embed_dim),
146            dropout_p: dropout,
147            max_seq_len,
148        }
149    }
150}
151
152impl Module for TokenPositionEmbedding {
153    fn forward(&self, input: &Tensor) -> Tensor {
154        let seq_len = input.dims()[input.ndim() - 1];
155        
156        // Token embeddings
157        let token_emb = self.token_embedding.forward(input);
158        
159        // Position indices
160        let positions: Vec<f32> = (0..seq_len).map(|i| i as f32).collect();
161        let pos_tensor = Tensor::from_slice(&positions, &[seq_len]).unwrap();
162        let pos_emb = self.position_embedding.forward(&pos_tensor);
163        
164        // Add token and position embeddings
165        token_emb.add(&pos_emb).unwrap()
166    }
167
168    fn parameters(&self) -> Vec<Tensor> {
169        let mut params = self.token_embedding.parameters();
170        params.extend(self.position_embedding.parameters());
171        params
172    }
173
174    fn train(&mut self) {
175        self.token_embedding.train();
176        self.position_embedding.train();
177    }
178
179    fn eval(&mut self) {
180        self.token_embedding.eval();
181        self.position_embedding.eval();
182    }
183
184    fn is_training(&self) -> bool {
185        self.token_embedding.is_training()
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_embedding() {
195        let emb = Embedding::new(100, 64);
196        let indices = Tensor::from_slice(&[0.0f32, 5.0, 10.0], &[3]).unwrap();
197        let output = emb.forward(&indices);
198        
199        assert_eq!(output.dims(), &[3, 64]);
200    }
201
202    #[test]
203    fn test_embedding_batch() {
204        let emb = Embedding::new(100, 64);
205        let indices = Tensor::from_slice(&[0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0], &[2, 3]).unwrap();
206        let output = emb.forward(&indices);
207        
208        assert_eq!(output.dims(), &[2, 3, 64]);
209    }
210}