ghostflow_nn/
embedding.rs1use ghostflow_core::Tensor;
4use crate::module::Module;
5use crate::init;
6
7pub struct Embedding {
9 weight: Tensor,
10 num_embeddings: usize,
11 embedding_dim: usize,
12 padding_idx: Option<usize>,
13 training: bool,
14}
15
16impl Embedding {
17 pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
19 let weight = init::normal(&[num_embeddings, embedding_dim], 0.0, 1.0);
20
21 Embedding {
22 weight,
23 num_embeddings,
24 embedding_dim,
25 padding_idx: None,
26 training: true,
27 }
28 }
29
30 pub fn with_padding(num_embeddings: usize, embedding_dim: usize, padding_idx: usize) -> Self {
32 let mut emb = Self::new(num_embeddings, embedding_dim);
33 emb.padding_idx = Some(padding_idx);
34
35 let mut weight_data = emb.weight.data_f32();
37 let start = padding_idx * embedding_dim;
38 for i in 0..embedding_dim {
39 weight_data[start + i] = 0.0;
40 }
41 emb.weight = Tensor::from_slice(&weight_data, &[num_embeddings, embedding_dim]).unwrap();
42
43 emb
44 }
45
46 pub fn from_pretrained(weight: Tensor, freeze: bool) -> Self {
48 let dims = weight.dims();
49 let num_embeddings = dims[0];
50 let embedding_dim = dims[1];
51
52 let emb = Embedding {
53 weight,
54 num_embeddings,
55 embedding_dim,
56 padding_idx: None,
57 training: !freeze,
58 };
59
60 emb
61 }
62
63 pub fn embedding_dim(&self) -> usize {
65 self.embedding_dim
66 }
67
68 pub fn num_embeddings(&self) -> usize {
70 self.num_embeddings
71 }
72
73 pub fn forward_indices(&self, indices: &[usize]) -> Tensor {
75 let weight_data = self.weight.data_f32();
76 let mut output = Vec::with_capacity(indices.len() * self.embedding_dim);
77
78 for &idx in indices {
79 let start = idx * self.embedding_dim;
80 output.extend_from_slice(&weight_data[start..start + self.embedding_dim]);
81 }
82
83 Tensor::from_slice(&output, &[indices.len(), self.embedding_dim]).unwrap()
84 }
85}
86
87impl Module for Embedding {
88 fn forward(&self, input: &Tensor) -> Tensor {
89 let indices: Vec<usize> = input.data_f32()
91 .iter()
92 .map(|&x| x as usize)
93 .collect();
94
95 let input_shape = input.dims();
96 let batch_dims: Vec<usize> = input_shape.to_vec();
97
98 let weight_data = self.weight.data_f32();
99 let mut output = Vec::with_capacity(indices.len() * self.embedding_dim);
100
101 for &idx in &indices {
102 if idx >= self.num_embeddings {
103 output.extend(vec![0.0f32; self.embedding_dim]);
105 } else {
106 let start = idx * self.embedding_dim;
107 output.extend_from_slice(&weight_data[start..start + self.embedding_dim]);
108 }
109 }
110
111 let mut output_shape = batch_dims;
113 output_shape.push(self.embedding_dim);
114
115 Tensor::from_slice(&output, &output_shape).unwrap()
116 }
117
118 fn parameters(&self) -> Vec<Tensor> {
119 if self.training {
120 vec![self.weight.clone()]
121 } else {
122 vec![] }
124 }
125
126 fn train(&mut self) { self.training = true; }
127 fn eval(&mut self) { self.training = false; }
128 fn is_training(&self) -> bool { self.training }
129}
130
131pub struct TokenPositionEmbedding {
133 token_embedding: Embedding,
134 position_embedding: Embedding,
135 #[allow(dead_code)]
136 dropout_p: f32,
137 #[allow(dead_code)]
138 max_seq_len: usize,
139}
140
141impl TokenPositionEmbedding {
142 pub fn new(vocab_size: usize, embed_dim: usize, max_seq_len: usize, dropout: f32) -> Self {
143 TokenPositionEmbedding {
144 token_embedding: Embedding::new(vocab_size, embed_dim),
145 position_embedding: Embedding::new(max_seq_len, embed_dim),
146 dropout_p: dropout,
147 max_seq_len,
148 }
149 }
150}
151
152impl Module for TokenPositionEmbedding {
153 fn forward(&self, input: &Tensor) -> Tensor {
154 let seq_len = input.dims()[input.ndim() - 1];
155
156 let token_emb = self.token_embedding.forward(input);
158
159 let positions: Vec<f32> = (0..seq_len).map(|i| i as f32).collect();
161 let pos_tensor = Tensor::from_slice(&positions, &[seq_len]).unwrap();
162 let pos_emb = self.position_embedding.forward(&pos_tensor);
163
164 token_emb.add(&pos_emb).unwrap()
166 }
167
168 fn parameters(&self) -> Vec<Tensor> {
169 let mut params = self.token_embedding.parameters();
170 params.extend(self.position_embedding.parameters());
171 params
172 }
173
174 fn train(&mut self) {
175 self.token_embedding.train();
176 self.position_embedding.train();
177 }
178
179 fn eval(&mut self) {
180 self.token_embedding.eval();
181 self.position_embedding.eval();
182 }
183
184 fn is_training(&self) -> bool {
185 self.token_embedding.is_training()
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_embedding() {
195 let emb = Embedding::new(100, 64);
196 let indices = Tensor::from_slice(&[0.0f32, 5.0, 10.0], &[3]).unwrap();
197 let output = emb.forward(&indices);
198
199 assert_eq!(output.dims(), &[3, 64]);
200 }
201
202 #[test]
203 fn test_embedding_batch() {
204 let emb = Embedding::new(100, 64);
205 let indices = Tensor::from_slice(&[0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0], &[2, 3]).unwrap();
206 let output = emb.forward(&indices);
207
208 assert_eq!(output.dims(), &[2, 3, 64]);
209 }
210}