axonml_nn/layers/
embedding.rs1use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::init::normal;
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17pub struct Embedding {
30 pub weight: Parameter,
32 num_embeddings: usize,
34 embedding_dim: usize,
36 padding_idx: Option<usize>,
38}
39
40impl Embedding {
41 pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
43 Self::with_options(num_embeddings, embedding_dim, None)
44 }
45
46 pub fn with_options(
48 num_embeddings: usize,
49 embedding_dim: usize,
50 padding_idx: Option<usize>,
51 ) -> Self {
52 let mut weight_data = normal(&[num_embeddings, embedding_dim], 0.0, 1.0);
54
55 if let Some(pad_idx) = padding_idx {
57 let mut data = weight_data.to_vec();
58 for i in 0..embedding_dim {
59 data[pad_idx * embedding_dim + i] = 0.0;
60 }
61 weight_data = Tensor::from_vec(data, &[num_embeddings, embedding_dim]).unwrap();
62 }
63
64 Self {
65 weight: Parameter::named("weight", weight_data, true),
66 num_embeddings,
67 embedding_dim,
68 padding_idx,
69 }
70 }
71
72 pub fn from_pretrained(weights: Tensor<f32>, freeze: bool) -> Self {
74 let shape = weights.shape();
75 let num_embeddings = shape[0];
76 let embedding_dim = shape[1];
77
78 Self {
79 weight: Parameter::named("weight", weights, !freeze),
80 num_embeddings,
81 embedding_dim,
82 padding_idx: None,
83 }
84 }
85
86 pub fn num_embeddings(&self) -> usize {
88 self.num_embeddings
89 }
90
91 pub fn embedding_dim(&self) -> usize {
93 self.embedding_dim
94 }
95
96 pub fn lookup(&self, indices: &Variable) -> Variable {
104 let indices_data = indices.data();
105 let indices_vec = indices_data.to_vec();
106 let indices_shape = indices_data.shape().to_vec();
107
108 let weight_vec = self.weight.data().to_vec();
109
110 let mut output_shape = indices_shape.clone();
112 output_shape.push(self.embedding_dim);
113 let output_size: usize = output_shape.iter().product();
114
115 let mut output_data = vec![0.0f32; output_size];
116
117 for (i, &idx_f) in indices_vec.iter().enumerate() {
118 let idx = idx_f as usize;
119 let safe_idx = if idx >= self.num_embeddings {
122 #[cfg(debug_assertions)]
123 eprintln!(
124 "Warning: embedding index {} out of range (max {}), using padding index 0",
125 idx,
126 self.num_embeddings - 1
127 );
128 0
129 } else {
130 idx
131 };
132
133 for d in 0..self.embedding_dim {
134 output_data[i * self.embedding_dim + d] =
135 weight_vec[safe_idx * self.embedding_dim + d];
136 }
137 }
138
139 Variable::new(
140 Tensor::from_vec(output_data, &output_shape).unwrap(),
141 self.weight.requires_grad(),
142 )
143 }
144}
145
146impl Module for Embedding {
147 fn forward(&self, input: &Variable) -> Variable {
148 self.lookup(input)
149 }
150
151 fn parameters(&self) -> Vec<Parameter> {
152 vec![self.weight.clone()]
153 }
154
155 fn named_parameters(&self) -> HashMap<String, Parameter> {
156 let mut params = HashMap::new();
157 params.insert("weight".to_string(), self.weight.clone());
158 params
159 }
160
161 fn name(&self) -> &'static str {
162 "Embedding"
163 }
164}
165
166impl std::fmt::Debug for Embedding {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 f.debug_struct("Embedding")
169 .field("num_embeddings", &self.num_embeddings)
170 .field("embedding_dim", &self.embedding_dim)
171 .field("padding_idx", &self.padding_idx)
172 .finish()
173 }
174}
175
176#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_embedding_creation() {
186 let emb = Embedding::new(1000, 128);
187 assert_eq!(emb.num_embeddings(), 1000);
188 assert_eq!(emb.embedding_dim(), 128);
189 }
190
191 #[test]
192 fn test_embedding_lookup() {
193 let emb = Embedding::new(10, 4);
194 let indices = Variable::new(Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap(), false);
195 let output = emb.forward(&indices);
196 assert_eq!(output.shape(), vec![3, 4]);
197 }
198
199 #[test]
200 fn test_embedding_batch() {
201 let emb = Embedding::new(10, 4);
202 let indices = Variable::new(
203 Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &[2, 3]).unwrap(),
204 false,
205 );
206 let output = emb.forward(&indices);
207 assert_eq!(output.shape(), vec![2, 3, 4]);
208 }
209
210 #[test]
211 fn test_embedding_parameters() {
212 let emb = Embedding::new(100, 64);
213 assert_eq!(emb.parameters().len(), 1);
214 assert_eq!(emb.num_parameters(), 100 * 64);
215 }
216
217 #[test]
218 fn test_embedding_with_padding() {
219 let emb = Embedding::with_options(10, 4, Some(0));
220 let indices = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
222 let output = emb.forward(&indices);
223 let output_vec = output.data().to_vec();
224 assert!(output_vec.iter().all(|&x| x == 0.0));
225 }
226}