1use std::any::Any;
24use std::collections::HashMap;
25
26use axonml_autograd::{GradFn, GradientFunction, Variable};
27use axonml_tensor::Tensor;
28
29use crate::init::normal;
30use crate::module::Module;
31use crate::parameter::Parameter;
32
33#[derive(Debug)]
42struct EmbeddingBackward {
43 next_fns: Vec<Option<GradFn>>,
44 indices: Vec<usize>,
46 num_embeddings: usize,
47 embedding_dim: usize,
48}
49
50impl GradientFunction for EmbeddingBackward {
51 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
52 #[cfg(feature = "cuda")]
54 if grad_output.device().is_gpu() {
55 let indices_u32: Vec<u32> = self.indices.iter().map(|&i| i as u32).collect();
56 let grad_tensor = grad_output.embedding_scatter_add_cuda(
57 &indices_u32,
58 self.num_embeddings,
59 self.embedding_dim,
60 );
61 return vec![Some(grad_tensor)];
62 }
63
64 let grad_data = grad_output.to_vec();
66 let mut weight_grad = vec![0.0f32; self.num_embeddings * self.embedding_dim];
67
68 for (i, &idx) in self.indices.iter().enumerate() {
70 if idx < self.num_embeddings {
71 let src_offset = i * self.embedding_dim;
72 let dst_offset = idx * self.embedding_dim;
73 for d in 0..self.embedding_dim {
74 weight_grad[dst_offset + d] += grad_data[src_offset + d];
75 }
76 }
77 }
78
79 let grad_tensor = Tensor::from_vec(weight_grad, &[self.num_embeddings, self.embedding_dim])
80 .expect("tensor creation failed");
81 vec![Some(grad_tensor)]
82 }
83
84 fn name(&self) -> &'static str {
85 "EmbeddingBackward"
86 }
87
88 fn next_functions(&self) -> &[Option<GradFn>] {
89 &self.next_fns
90 }
91
92 fn as_any(&self) -> &dyn Any {
93 self
94 }
95}
96
97pub struct Embedding {
110 pub weight: Parameter,
112 num_embeddings: usize,
114 embedding_dim: usize,
116 padding_idx: Option<usize>,
118}
119
120impl Embedding {
121 pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
123 Self::with_options(num_embeddings, embedding_dim, None)
124 }
125
126 pub fn with_options(
128 num_embeddings: usize,
129 embedding_dim: usize,
130 padding_idx: Option<usize>,
131 ) -> Self {
132 let mut weight_data = normal(&[num_embeddings, embedding_dim], 0.0, 1.0);
134
135 if let Some(pad_idx) = padding_idx {
137 let mut data = weight_data.to_vec();
138 for i in 0..embedding_dim {
139 data[pad_idx * embedding_dim + i] = 0.0;
140 }
141 weight_data = Tensor::from_vec(data, &[num_embeddings, embedding_dim])
142 .expect("tensor creation failed");
143 }
144
145 Self {
146 weight: Parameter::named("weight", weight_data, true),
147 num_embeddings,
148 embedding_dim,
149 padding_idx,
150 }
151 }
152
153 pub fn from_pretrained(weights: Tensor<f32>, freeze: bool) -> Self {
155 let shape = weights.shape();
156 let num_embeddings = shape[0];
157 let embedding_dim = shape[1];
158
159 Self {
160 weight: Parameter::named("weight", weights, !freeze),
161 num_embeddings,
162 embedding_dim,
163 padding_idx: None,
164 }
165 }
166
167 pub fn num_embeddings(&self) -> usize {
169 self.num_embeddings
170 }
171
172 pub fn embedding_dim(&self) -> usize {
174 self.embedding_dim
175 }
176
177 pub fn lookup(&self, indices: &Variable) -> Variable {
185 let indices_data = indices.data();
186 let indices_vec = indices_data.to_vec();
188 let indices_shape = indices_data.shape().to_vec();
189
190 let mut output_shape = indices_shape.clone();
192 output_shape.push(self.embedding_dim);
193 let output_size: usize = output_shape.iter().product();
194
195 let mut safe_indices = Vec::with_capacity(indices_vec.len());
197 let mut gather_idx = Vec::with_capacity(output_size);
199
200 for &idx_f in &indices_vec {
201 let idx = idx_f as usize;
202 let safe_idx = if idx >= self.num_embeddings {
203 #[cfg(debug_assertions)]
204 eprintln!(
205 "Warning: embedding index {} out of range (max {}), using padding index 0",
206 idx,
207 self.num_embeddings - 1
208 );
209 0
210 } else {
211 idx
212 };
213 safe_indices.push(safe_idx);
214 let base = safe_idx * self.embedding_dim;
216 for d in 0..self.embedding_dim {
217 gather_idx.push((base + d) as u32);
218 }
219 }
220
221 let weight_data = self.weight.data();
222 #[cfg(feature = "cuda")]
223 let weight_device = weight_data.device();
224
225 #[cfg(feature = "cuda")]
227 let output_tensor = if weight_device.is_gpu() {
228 weight_data.embedding_gather_cuda(&gather_idx, &output_shape)
229 } else {
230 let weight_vec = weight_data.to_vec();
231 let output_data: Vec<f32> =
232 gather_idx.iter().map(|&i| weight_vec[i as usize]).collect();
233 Tensor::from_vec(output_data, &output_shape).expect("tensor creation failed")
234 };
235
236 #[cfg(not(feature = "cuda"))]
237 let output_tensor = {
238 let weight_vec = weight_data.to_vec();
239 let output_data: Vec<f32> =
240 gather_idx.iter().map(|&i| weight_vec[i as usize]).collect();
241 Tensor::from_vec(output_data, &output_shape).expect("tensor creation failed")
242 };
243
244 if self.weight.requires_grad() {
245 let grad_fn = GradFn::new(EmbeddingBackward {
246 next_fns: vec![self.weight.variable().grad_fn().cloned()],
247 indices: safe_indices,
248 num_embeddings: self.num_embeddings,
249 embedding_dim: self.embedding_dim,
250 });
251 Variable::from_operation(output_tensor, grad_fn, true)
252 } else {
253 Variable::new(output_tensor, false)
254 }
255 }
256}
257
258impl Module for Embedding {
259 fn forward(&self, input: &Variable) -> Variable {
260 self.lookup(input)
261 }
262
263 fn parameters(&self) -> Vec<Parameter> {
264 vec![self.weight.clone()]
265 }
266
267 fn named_parameters(&self) -> HashMap<String, Parameter> {
268 let mut params = HashMap::new();
269 params.insert("weight".to_string(), self.weight.clone());
270 params
271 }
272
273 fn name(&self) -> &'static str {
274 "Embedding"
275 }
276}
277
278impl std::fmt::Debug for Embedding {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 f.debug_struct("Embedding")
281 .field("num_embeddings", &self.num_embeddings)
282 .field("embedding_dim", &self.embedding_dim)
283 .field("padding_idx", &self.padding_idx)
284 .finish()
285 }
286}
287
288#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_embedding_creation() {
298 let emb = Embedding::new(1000, 128);
299 assert_eq!(emb.num_embeddings(), 1000);
300 assert_eq!(emb.embedding_dim(), 128);
301 }
302
303 #[test]
304 fn test_embedding_lookup() {
305 let emb = Embedding::new(10, 4);
306 let indices = Variable::new(
307 Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).expect("tensor creation failed"),
308 false,
309 );
310 let output = emb.forward(&indices);
311 assert_eq!(output.shape(), vec![3, 4]);
312 }
313
314 #[test]
315 fn test_embedding_batch() {
316 let emb = Embedding::new(10, 4);
317 let indices = Variable::new(
318 Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &[2, 3])
319 .expect("tensor creation failed"),
320 false,
321 );
322 let output = emb.forward(&indices);
323 assert_eq!(output.shape(), vec![2, 3, 4]);
324 }
325
326 #[test]
327 fn test_embedding_parameters() {
328 let emb = Embedding::new(100, 64);
329 assert_eq!(emb.parameters().len(), 1);
330 assert_eq!(emb.num_parameters(), 100 * 64);
331 }
332
333 #[test]
334 fn test_embedding_with_padding() {
335 let emb = Embedding::with_options(10, 4, Some(0));
336 let indices = Variable::new(
338 Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
339 false,
340 );
341 let output = emb.forward(&indices);
342 let output_vec = output.data().to_vec();
343 assert!(output_vec.iter().all(|&x| x == 0.0));
344 }
345}