1use std::any::Any;
18use std::collections::HashMap;
19
20use axonml_autograd::{GradFn, GradientFunction, Variable};
21use axonml_tensor::Tensor;
22
23use crate::init::normal;
24use crate::module::Module;
25use crate::parameter::Parameter;
26
27#[derive(Debug)]
36struct EmbeddingBackward {
37 next_fns: Vec<Option<GradFn>>,
38 indices: Vec<usize>,
40 num_embeddings: usize,
41 embedding_dim: usize,
42}
43
44impl GradientFunction for EmbeddingBackward {
45 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
46 #[cfg(feature = "cuda")]
48 if grad_output.device().is_gpu() {
49 let indices_u32: Vec<u32> = self.indices.iter().map(|&i| i as u32).collect();
50 let grad_tensor = grad_output.embedding_scatter_add_cuda(
51 &indices_u32,
52 self.num_embeddings,
53 self.embedding_dim,
54 );
55 return vec![Some(grad_tensor)];
56 }
57
58 let grad_data = grad_output.to_vec();
60 let mut weight_grad = vec![0.0f32; self.num_embeddings * self.embedding_dim];
61
62 for (i, &idx) in self.indices.iter().enumerate() {
64 if idx < self.num_embeddings {
65 let src_offset = i * self.embedding_dim;
66 let dst_offset = idx * self.embedding_dim;
67 for d in 0..self.embedding_dim {
68 weight_grad[dst_offset + d] += grad_data[src_offset + d];
69 }
70 }
71 }
72
73 let grad_tensor = Tensor::from_vec(weight_grad, &[self.num_embeddings, self.embedding_dim])
74 .expect("tensor creation failed");
75 vec![Some(grad_tensor)]
76 }
77
78 fn name(&self) -> &'static str {
79 "EmbeddingBackward"
80 }
81
82 fn next_functions(&self) -> &[Option<GradFn>] {
83 &self.next_fns
84 }
85
86 fn as_any(&self) -> &dyn Any {
87 self
88 }
89}
90
91pub struct Embedding {
104 pub weight: Parameter,
106 num_embeddings: usize,
108 embedding_dim: usize,
110 padding_idx: Option<usize>,
112}
113
114impl Embedding {
115 pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
117 Self::with_options(num_embeddings, embedding_dim, None)
118 }
119
120 pub fn with_options(
122 num_embeddings: usize,
123 embedding_dim: usize,
124 padding_idx: Option<usize>,
125 ) -> Self {
126 let mut weight_data = normal(&[num_embeddings, embedding_dim], 0.0, 1.0);
128
129 if let Some(pad_idx) = padding_idx {
131 let mut data = weight_data.to_vec();
132 for i in 0..embedding_dim {
133 data[pad_idx * embedding_dim + i] = 0.0;
134 }
135 weight_data = Tensor::from_vec(data, &[num_embeddings, embedding_dim])
136 .expect("tensor creation failed");
137 }
138
139 Self {
140 weight: Parameter::named("weight", weight_data, true),
141 num_embeddings,
142 embedding_dim,
143 padding_idx,
144 }
145 }
146
147 pub fn from_pretrained(weights: Tensor<f32>, freeze: bool) -> Self {
149 let shape = weights.shape();
150 let num_embeddings = shape[0];
151 let embedding_dim = shape[1];
152
153 Self {
154 weight: Parameter::named("weight", weights, !freeze),
155 num_embeddings,
156 embedding_dim,
157 padding_idx: None,
158 }
159 }
160
161 pub fn num_embeddings(&self) -> usize {
163 self.num_embeddings
164 }
165
166 pub fn embedding_dim(&self) -> usize {
168 self.embedding_dim
169 }
170
171 pub fn lookup(&self, indices: &Variable) -> Variable {
179 let indices_data = indices.data();
180 let indices_vec = indices_data.to_vec();
182 let indices_shape = indices_data.shape().to_vec();
183
184 let mut output_shape = indices_shape.clone();
186 output_shape.push(self.embedding_dim);
187 let output_size: usize = output_shape.iter().product();
188
189 let mut safe_indices = Vec::with_capacity(indices_vec.len());
191 let mut gather_idx = Vec::with_capacity(output_size);
193
194 for &idx_f in &indices_vec {
195 let idx = idx_f as usize;
196 let safe_idx = if idx >= self.num_embeddings {
197 #[cfg(debug_assertions)]
198 eprintln!(
199 "Warning: embedding index {} out of range (max {}), using padding index 0",
200 idx,
201 self.num_embeddings - 1
202 );
203 0
204 } else {
205 idx
206 };
207 safe_indices.push(safe_idx);
208 let base = safe_idx * self.embedding_dim;
210 for d in 0..self.embedding_dim {
211 gather_idx.push((base + d) as u32);
212 }
213 }
214
215 let weight_data = self.weight.data();
216 #[cfg(feature = "cuda")]
217 let weight_device = weight_data.device();
218
219 #[cfg(feature = "cuda")]
221 let output_tensor = if weight_device.is_gpu() {
222 weight_data.embedding_gather_cuda(&gather_idx, &output_shape)
223 } else {
224 let weight_vec = weight_data.to_vec();
225 let output_data: Vec<f32> =
226 gather_idx.iter().map(|&i| weight_vec[i as usize]).collect();
227 Tensor::from_vec(output_data, &output_shape).expect("tensor creation failed")
228 };
229
230 #[cfg(not(feature = "cuda"))]
231 let output_tensor = {
232 let weight_vec = weight_data.to_vec();
233 let output_data: Vec<f32> =
234 gather_idx.iter().map(|&i| weight_vec[i as usize]).collect();
235 Tensor::from_vec(output_data, &output_shape).expect("tensor creation failed")
236 };
237
238 if self.weight.requires_grad() {
239 let grad_fn = GradFn::new(EmbeddingBackward {
240 next_fns: vec![self.weight.variable().grad_fn().cloned()],
241 indices: safe_indices,
242 num_embeddings: self.num_embeddings,
243 embedding_dim: self.embedding_dim,
244 });
245 Variable::from_operation(output_tensor, grad_fn, true)
246 } else {
247 Variable::new(output_tensor, false)
248 }
249 }
250}
251
252impl Module for Embedding {
253 fn forward(&self, input: &Variable) -> Variable {
254 self.lookup(input)
255 }
256
257 fn parameters(&self) -> Vec<Parameter> {
258 vec![self.weight.clone()]
259 }
260
261 fn named_parameters(&self) -> HashMap<String, Parameter> {
262 let mut params = HashMap::new();
263 params.insert("weight".to_string(), self.weight.clone());
264 params
265 }
266
267 fn name(&self) -> &'static str {
268 "Embedding"
269 }
270}
271
272impl std::fmt::Debug for Embedding {
273 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274 f.debug_struct("Embedding")
275 .field("num_embeddings", &self.num_embeddings)
276 .field("embedding_dim", &self.embedding_dim)
277 .field("padding_idx", &self.padding_idx)
278 .finish()
279 }
280}
281
282#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_embedding_creation() {
292 let emb = Embedding::new(1000, 128);
293 assert_eq!(emb.num_embeddings(), 1000);
294 assert_eq!(emb.embedding_dim(), 128);
295 }
296
297 #[test]
298 fn test_embedding_lookup() {
299 let emb = Embedding::new(10, 4);
300 let indices = Variable::new(
301 Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).expect("tensor creation failed"),
302 false,
303 );
304 let output = emb.forward(&indices);
305 assert_eq!(output.shape(), vec![3, 4]);
306 }
307
308 #[test]
309 fn test_embedding_batch() {
310 let emb = Embedding::new(10, 4);
311 let indices = Variable::new(
312 Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &[2, 3])
313 .expect("tensor creation failed"),
314 false,
315 );
316 let output = emb.forward(&indices);
317 assert_eq!(output.shape(), vec![2, 3, 4]);
318 }
319
320 #[test]
321 fn test_embedding_parameters() {
322 let emb = Embedding::new(100, 64);
323 assert_eq!(emb.parameters().len(), 1);
324 assert_eq!(emb.num_parameters(), 100 * 64);
325 }
326
327 #[test]
328 fn test_embedding_with_padding() {
329 let emb = Embedding::with_options(10, 4, Some(0));
330 let indices = Variable::new(
332 Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
333 false,
334 );
335 let output = emb.forward(&indices);
336 let output_vec = output.data().to_vec();
337 assert!(output_vec.iter().all(|&x| x == 0.0));
338 }
339}