ghostflow_nn/
embedding.rs1use ghostflow_core::Tensor;
4use crate::module::Module;
5use crate::init;
6
7pub struct Embedding {
9 weight: Tensor,
10 num_embeddings: usize,
11 embedding_dim: usize,
12 padding_idx: Option<usize>,
13 training: bool,
14}
15
16impl Embedding {
17 pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
19 let weight = init::normal(&[num_embeddings, embedding_dim], 0.0, 1.0);
20
21 Embedding {
22 weight,
23 num_embeddings,
24 embedding_dim,
25 padding_idx: None,
26 training: true,
27 }
28 }
29
30 pub fn with_padding(num_embeddings: usize, embedding_dim: usize, padding_idx: usize) -> Self {
32 let mut emb = Self::new(num_embeddings, embedding_dim);
33 emb.padding_idx = Some(padding_idx);
34
35 let mut weight_data = emb.weight.data_f32();
37 let start = padding_idx * embedding_dim;
38 for i in 0..embedding_dim {
39 weight_data[start + i] = 0.0;
40 }
41 emb.weight = Tensor::from_slice(&weight_data, &[num_embeddings, embedding_dim]).unwrap();
42
43 emb
44 }
45
46 pub fn from_pretrained(weight: Tensor, freeze: bool) -> Self {
48 let dims = weight.dims();
49 let num_embeddings = dims[0];
50 let embedding_dim = dims[1];
51
52 Embedding {
53 weight,
54 num_embeddings,
55 embedding_dim,
56 padding_idx: None,
57 training: !freeze,
58 }
59 }
60
61 pub fn embedding_dim(&self) -> usize {
63 self.embedding_dim
64 }
65
66 pub fn num_embeddings(&self) -> usize {
68 self.num_embeddings
69 }
70
71 pub fn forward_indices(&self, indices: &[usize]) -> Tensor {
73 let weight_data = self.weight.data_f32();
74 let mut output = Vec::with_capacity(indices.len() * self.embedding_dim);
75
76 for &idx in indices {
77 let start = idx * self.embedding_dim;
78 output.extend_from_slice(&weight_data[start..start + self.embedding_dim]);
79 }
80
81 Tensor::from_slice(&output, &[indices.len(), self.embedding_dim]).unwrap()
82 }
83}
84
85impl Module for Embedding {
86 fn forward(&self, input: &Tensor) -> Tensor {
87 let indices: Vec<usize> = input.data_f32()
89 .iter()
90 .map(|&x| x as usize)
91 .collect();
92
93 let input_shape = input.dims();
94 let batch_dims: Vec<usize> = input_shape.to_vec();
95
96 let weight_data = self.weight.data_f32();
97 let mut output = Vec::with_capacity(indices.len() * self.embedding_dim);
98
99 for &idx in &indices {
100 if idx >= self.num_embeddings {
101 output.extend(vec![0.0f32; self.embedding_dim]);
103 } else {
104 let start = idx * self.embedding_dim;
105 output.extend_from_slice(&weight_data[start..start + self.embedding_dim]);
106 }
107 }
108
109 let mut output_shape = batch_dims;
111 output_shape.push(self.embedding_dim);
112
113 Tensor::from_slice(&output, &output_shape).unwrap()
114 }
115
116 fn parameters(&self) -> Vec<Tensor> {
117 if self.training {
118 vec![self.weight.clone()]
119 } else {
120 vec![] }
122 }
123
124 fn train(&mut self) { self.training = true; }
125 fn eval(&mut self) { self.training = false; }
126 fn is_training(&self) -> bool { self.training }
127}
128
129pub struct TokenPositionEmbedding {
131 token_embedding: Embedding,
132 position_embedding: Embedding,
133 #[allow(dead_code)]
134 dropout_p: f32,
135 #[allow(dead_code)]
136 max_seq_len: usize,
137}
138
139impl TokenPositionEmbedding {
140 pub fn new(vocab_size: usize, embed_dim: usize, max_seq_len: usize, dropout: f32) -> Self {
141 TokenPositionEmbedding {
142 token_embedding: Embedding::new(vocab_size, embed_dim),
143 position_embedding: Embedding::new(max_seq_len, embed_dim),
144 dropout_p: dropout,
145 max_seq_len,
146 }
147 }
148}
149
150impl Module for TokenPositionEmbedding {
151 fn forward(&self, input: &Tensor) -> Tensor {
152 let seq_len = input.dims()[input.ndim() - 1];
153
154 let token_emb = self.token_embedding.forward(input);
156
157 let positions: Vec<f32> = (0..seq_len).map(|i| i as f32).collect();
159 let pos_tensor = Tensor::from_slice(&positions, &[seq_len]).unwrap();
160 let pos_emb = self.position_embedding.forward(&pos_tensor);
161
162 token_emb.add(&pos_emb).unwrap()
164 }
165
166 fn parameters(&self) -> Vec<Tensor> {
167 let mut params = self.token_embedding.parameters();
168 params.extend(self.position_embedding.parameters());
169 params
170 }
171
172 fn train(&mut self) {
173 self.token_embedding.train();
174 self.position_embedding.train();
175 }
176
177 fn eval(&mut self) {
178 self.token_embedding.eval();
179 self.position_embedding.eval();
180 }
181
182 fn is_training(&self) -> bool {
183 self.token_embedding.is_training()
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_embedding() {
193 let emb = Embedding::new(100, 64);
194 let indices = Tensor::from_slice(&[0.0f32, 5.0, 10.0], &[3]).unwrap();
195 let output = emb.forward(&indices);
196
197 assert_eq!(output.dims(), &[3, 64]);
198 }
199
200 #[test]
201 fn test_embedding_batch() {
202 let emb = Embedding::new(100, 64);
203 let indices = Tensor::from_slice(&[0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0], &[2, 3]).unwrap();
204 let output = emb.forward(&indices);
205
206 assert_eq!(output.dims(), &[2, 3, 64]);
207 }
208}