1use axonml_autograd::Variable;
6use axonml_nn::{Module, Embedding, Parameter, Dropout};
7use axonml_tensor::Tensor;
8use axonml_tensor::creation::{zeros, ones};
9
10#[derive(Debug)]
12pub struct TokenEmbedding {
13 pub embedding: Embedding,
15}
16
17impl TokenEmbedding {
18 pub fn new(vocab_size: usize, embed_dim: usize) -> Self {
20 Self {
21 embedding: Embedding::new(vocab_size, embed_dim),
22 }
23 }
24
25 pub fn forward_ids(&self, input_ids: &Tensor<u32>) -> Variable {
27 let batch_size = input_ids.shape()[0];
29 let seq_len = input_ids.shape()[1];
30 let embed_dim = self.embedding.embedding_dim();
31
32 let ids_vec = input_ids.to_vec();
33 let mut output_data = vec![0.0f32; batch_size * seq_len * embed_dim];
34
35 let weight = &self.embedding.weight;
36 let weight_data = weight.data().to_vec();
37
38 for b in 0..batch_size {
39 for s in 0..seq_len {
40 let idx = ids_vec[b * seq_len + s] as usize;
41 let src_offset = idx * embed_dim;
42 let dst_offset = (b * seq_len + s) * embed_dim;
43
44 for e in 0..embed_dim {
45 output_data[dst_offset + e] = weight_data[src_offset + e];
46 }
47 }
48 }
49
50 let output_tensor = Tensor::from_vec(output_data, &[batch_size, seq_len, embed_dim]).unwrap();
51 Variable::new(output_tensor, weight.requires_grad())
52 }
53}
54
55impl Module for TokenEmbedding {
56 fn forward(&self, input: &Variable) -> Variable {
57 self.embedding.forward(input)
58 }
59
60 fn parameters(&self) -> Vec<Parameter> {
61 self.embedding.parameters()
62 }
63}
64
65#[derive(Debug)]
67pub struct PositionalEmbedding {
68 pub embedding: Embedding,
70 pub max_len: usize,
72}
73
74impl PositionalEmbedding {
75 pub fn new(max_len: usize, embed_dim: usize) -> Self {
77 Self {
78 embedding: Embedding::new(max_len, embed_dim),
79 max_len,
80 }
81 }
82
83 pub fn forward_positions(&self, seq_len: usize, batch_size: usize) -> Variable {
85 let embed_dim = self.embedding.embedding_dim();
86
87 let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
89 let position_tensor = Tensor::from_vec(positions.clone(), &[1, seq_len]).unwrap();
90 let position_var = Variable::new(position_tensor, false);
91
92 let pos_embeds = self.embedding.forward(&position_var);
94
95 if batch_size > 1 {
97 pos_embeds.expand(&[batch_size, seq_len, embed_dim])
98 } else {
99 pos_embeds
100 }
101 }
102}
103
104impl Module for PositionalEmbedding {
105 fn forward(&self, input: &Variable) -> Variable {
106 self.embedding.forward(input)
107 }
108
109 fn parameters(&self) -> Vec<Parameter> {
110 self.embedding.parameters()
111 }
112}
113
114#[derive(Debug)]
116pub struct SinusoidalPositionalEncoding {
117 pub encodings: Tensor<f32>,
119 pub max_len: usize,
121 pub embed_dim: usize,
123}
124
125impl SinusoidalPositionalEncoding {
126 pub fn new(max_len: usize, embed_dim: usize) -> Self {
128 let mut encodings = vec![0.0f32; max_len * embed_dim];
129
130 for pos in 0..max_len {
131 for i in 0..embed_dim / 2 {
132 let div_term = (10000.0f32).powf(2.0 * i as f32 / embed_dim as f32);
133 let angle = pos as f32 / div_term;
134
135 encodings[pos * embed_dim + 2 * i] = angle.sin();
136 encodings[pos * embed_dim + 2 * i + 1] = angle.cos();
137 }
138 }
139
140 Self {
141 encodings: Tensor::from_vec(encodings, &[max_len, embed_dim]).unwrap(),
142 max_len,
143 embed_dim,
144 }
145 }
146
147 pub fn forward_seq(&self, seq_len: usize) -> Variable {
149 let sliced = self.encodings.slice(&[0..seq_len, 0..self.embed_dim]);
150 Variable::new(sliced, false)
151 }
152}
153
154#[derive(Debug)]
156pub struct BertEmbedding {
157 pub word_embeddings: Embedding,
159 pub position_embeddings: Embedding,
161 pub token_type_embeddings: Embedding,
163 pub layer_norm: LayerNorm,
165 pub dropout: Dropout,
167 pub embed_dim: usize,
169}
170
171#[derive(Debug)]
173pub struct LayerNorm {
174 weight: Parameter,
175 bias: Parameter,
176 eps: f32,
177}
178
179impl LayerNorm {
180 fn new(dim: usize, eps: f32) -> Self {
181 let weight = Parameter::new(ones::<f32>(&[dim]), true);
182 let bias = Parameter::new(zeros::<f32>(&[dim]), true);
183 Self { weight, bias, eps }
184 }
185
186 fn forward(&self, x: &Variable) -> Variable {
187 let mean = x.mean_dim(-1, true);
189 let variance = x.var_dim(-1, true);
190
191 let x_normalized = x.sub(&mean).div(&variance.add_scalar(self.eps).sqrt());
192
193 let weight_var = Variable::from_tensor_with_grad(self.weight.data().clone(), self.weight.requires_grad());
195 let bias_var = Variable::from_tensor_with_grad(self.bias.data().clone(), self.bias.requires_grad());
196
197 x_normalized.mul(&weight_var).add(&bias_var)
198 }
199
200 fn parameters(&self) -> Vec<Parameter> {
201 vec![self.weight.clone(), self.bias.clone()]
202 }
203}
204
205impl BertEmbedding {
206 pub fn new(
208 vocab_size: usize,
209 max_position_embeddings: usize,
210 type_vocab_size: usize,
211 hidden_size: usize,
212 layer_norm_eps: f32,
213 dropout_prob: f32,
214 ) -> Self {
215 Self {
216 word_embeddings: Embedding::new(vocab_size, hidden_size),
217 position_embeddings: Embedding::new(max_position_embeddings, hidden_size),
218 token_type_embeddings: Embedding::new(type_vocab_size, hidden_size),
219 layer_norm: LayerNorm::new(hidden_size, layer_norm_eps),
220 dropout: Dropout::new(dropout_prob),
221 embed_dim: hidden_size,
222 }
223 }
224
225 pub fn forward_with_ids(
227 &self,
228 input_ids: &Tensor<u32>,
229 token_type_ids: Option<&Tensor<u32>>,
230 position_ids: Option<&Tensor<u32>>,
231 ) -> Variable {
232 let batch_size = input_ids.shape()[0];
233 let seq_len = input_ids.shape()[1];
234
235 let input_ids_f32 = Self::u32_to_f32_tensor(input_ids);
237 let word_embeds = self.word_embeddings.forward(&Variable::new(input_ids_f32, false));
238
239 let pos_ids = if let Some(ids) = position_ids {
241 Self::u32_to_f32_tensor(ids)
242 } else {
243 let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
244 let pos_data: Vec<f32> = (0..batch_size).flat_map(|_| positions.iter().cloned()).collect();
245 Tensor::from_vec(pos_data, &[batch_size, seq_len]).unwrap()
246 };
247 let position_embeds = self.position_embeddings.forward(&Variable::new(pos_ids, false));
248
249 let type_ids = if let Some(ids) = token_type_ids {
251 Self::u32_to_f32_tensor(ids)
252 } else {
253 zeros::<f32>(&[batch_size, seq_len])
254 };
255 let token_type_embeds = self.token_type_embeddings.forward(&Variable::new(type_ids, false));
256
257 let embeddings = word_embeds.add(&position_embeds).add(&token_type_embeds);
259
260 let embeddings = self.layer_norm.forward(&embeddings);
262 self.dropout.forward(&embeddings)
263 }
264
265 fn u32_to_f32_tensor(t: &Tensor<u32>) -> Tensor<f32> {
266 let data: Vec<f32> = t.to_vec().iter().map(|&x| x as f32).collect();
267 Tensor::from_vec(data, t.shape()).unwrap()
268 }
269}
270
271impl Module for BertEmbedding {
272 fn forward(&self, input: &Variable) -> Variable {
273 let input_data = input.data();
275 let shape = input_data.shape();
276 let batch_size = shape[0];
277 let seq_len = shape[1];
278
279 let word_embeds = self.word_embeddings.forward(input);
280
281 let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
283 let pos_data: Vec<f32> = (0..batch_size).flat_map(|_| positions.iter().cloned()).collect();
284 let pos_tensor = Tensor::from_vec(pos_data, &[batch_size, seq_len]).unwrap();
285 let position_embeds = self.position_embeddings.forward(&Variable::new(pos_tensor, false));
286
287 let type_tensor = zeros::<f32>(&[batch_size, seq_len]);
289 let token_type_embeds = self.token_type_embeddings.forward(&Variable::new(type_tensor, false));
290
291 let embeddings = word_embeds.add(&position_embeds).add(&token_type_embeds);
292 let embeddings = self.layer_norm.forward(&embeddings);
293 self.dropout.forward(&embeddings)
294 }
295
296 fn parameters(&self) -> Vec<Parameter> {
297 let mut params = Vec::new();
298 params.extend(self.word_embeddings.parameters());
299 params.extend(self.position_embeddings.parameters());
300 params.extend(self.token_type_embeddings.parameters());
301 params.extend(self.layer_norm.parameters());
302 params
303 }
304
305 fn train(&mut self) {
306 self.dropout.train();
307 }
308
309 fn eval(&mut self) {
310 self.dropout.eval();
311 }
312}
313
314#[derive(Debug)]
316pub struct GPT2Embedding {
317 pub wte: Embedding,
319 pub wpe: Embedding,
321 pub dropout: Dropout,
323 pub n_embd: usize,
325}
326
327impl GPT2Embedding {
328 pub fn new(vocab_size: usize, n_ctx: usize, n_embd: usize, dropout: f32) -> Self {
330 Self {
331 wte: Embedding::new(vocab_size, n_embd),
332 wpe: Embedding::new(n_ctx, n_embd),
333 dropout: Dropout::new(dropout),
334 n_embd,
335 }
336 }
337
338 pub fn forward_ids(&self, input_ids: &Tensor<u32>) -> Variable {
340 let batch_size = input_ids.shape()[0];
341 let seq_len = input_ids.shape()[1];
342
343 let input_ids_f32 = Self::u32_to_f32_tensor(input_ids);
345 let token_embeds = self.wte.forward(&Variable::new(input_ids_f32, false));
346
347 let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
349 let pos_data: Vec<f32> = (0..batch_size).flat_map(|_| positions.iter().cloned()).collect();
350 let pos_tensor = Tensor::from_vec(pos_data, &[batch_size, seq_len]).unwrap();
351 let position_embeds = self.wpe.forward(&Variable::new(pos_tensor, false));
352
353 let embeddings = token_embeds.add(&position_embeds);
355 self.dropout.forward(&embeddings)
356 }
357
358 fn u32_to_f32_tensor(t: &Tensor<u32>) -> Tensor<f32> {
359 let data: Vec<f32> = t.to_vec().iter().map(|&x| x as f32).collect();
360 Tensor::from_vec(data, t.shape()).unwrap()
361 }
362}
363
364impl Module for GPT2Embedding {
365 fn forward(&self, input: &Variable) -> Variable {
366 let input_data = input.data();
367 let shape = input_data.shape();
368 let batch_size = shape[0];
369 let seq_len = shape[1];
370
371 let token_embeds = self.wte.forward(input);
372
373 let positions: Vec<f32> = (0..seq_len).map(|p| p as f32).collect();
375 let pos_data: Vec<f32> = (0..batch_size).flat_map(|_| positions.iter().cloned()).collect();
376 let pos_tensor = Tensor::from_vec(pos_data, &[batch_size, seq_len]).unwrap();
377 let position_embeds = self.wpe.forward(&Variable::new(pos_tensor, false));
378
379 let embeddings = token_embeds.add(&position_embeds);
380 self.dropout.forward(&embeddings)
381 }
382
383 fn parameters(&self) -> Vec<Parameter> {
384 let mut params = Vec::new();
385 params.extend(self.wte.parameters());
386 params.extend(self.wpe.parameters());
387 params
388 }
389
390 fn train(&mut self) {
391 self.dropout.train();
392 }
393
394 fn eval(&mut self) {
395 self.dropout.eval();
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_token_embedding() {
405 let embed = TokenEmbedding::new(1000, 64);
406 let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4], &[2, 2]).unwrap();
407 let output = embed.forward_ids(&input_ids);
408
409 assert_eq!(output.data().shape(), &[2, 2, 64]);
410 }
411
412 #[test]
413 fn test_positional_embedding() {
414 let embed = PositionalEmbedding::new(128, 64);
415 let output = embed.forward_positions(16, 2);
416
417 assert_eq!(output.data().shape(), &[2, 16, 64]);
418 }
419
420 #[test]
421 fn test_sinusoidal_encoding() {
422 let encoding = SinusoidalPositionalEncoding::new(100, 64);
423 let output = encoding.forward_seq(16);
424
425 assert_eq!(output.data().shape(), &[16, 64]);
426 }
427
428 #[test]
429 fn test_gpt2_embedding() {
430 let embed = GPT2Embedding::new(1000, 128, 64, 0.0);
431 let input_ids = Tensor::from_vec(vec![1u32, 2, 3, 4], &[2, 2]).unwrap();
432 let output = embed.forward_ids(&input_ids);
433
434 assert_eq!(output.data().shape(), &[2, 2, 64]);
435 }
436}