1use std::any::Any;
18use std::collections::HashMap;
19
20use axonml_autograd::{GradFn, GradientFunction, Variable};
21use axonml_tensor::Tensor;
22
23use crate::init::normal;
24use crate::module::Module;
25use crate::parameter::Parameter;
26
27#[derive(Debug)]
36struct EmbeddingBackward {
37 next_fns: Vec<Option<GradFn>>,
38 indices: Vec<usize>,
40 num_embeddings: usize,
41 embedding_dim: usize,
42}
43
44impl GradientFunction for EmbeddingBackward {
45 fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
46 #[cfg(feature = "cuda")]
48 if grad_output.device().is_gpu() {
49 let indices_u32: Vec<u32> = self.indices.iter().map(|&i| i as u32).collect();
50 let grad_tensor = grad_output.embedding_scatter_add_cuda(
51 &indices_u32,
52 self.num_embeddings,
53 self.embedding_dim,
54 );
55 return vec![Some(grad_tensor)];
56 }
57
58 let grad_data = grad_output.to_vec();
60 let mut weight_grad = vec![0.0f32; self.num_embeddings * self.embedding_dim];
61
62 for (i, &idx) in self.indices.iter().enumerate() {
64 if idx < self.num_embeddings {
65 let src_offset = i * self.embedding_dim;
66 let dst_offset = idx * self.embedding_dim;
67 for d in 0..self.embedding_dim {
68 weight_grad[dst_offset + d] += grad_data[src_offset + d];
69 }
70 }
71 }
72
73 let grad_tensor =
74 Tensor::from_vec(weight_grad, &[self.num_embeddings, self.embedding_dim]).unwrap();
75 vec![Some(grad_tensor)]
76 }
77
78 fn name(&self) -> &'static str {
79 "EmbeddingBackward"
80 }
81
82 fn next_functions(&self) -> &[Option<GradFn>] {
83 &self.next_fns
84 }
85
86 fn as_any(&self) -> &dyn Any {
87 self
88 }
89}
90
91pub struct Embedding {
104 pub weight: Parameter,
106 num_embeddings: usize,
108 embedding_dim: usize,
110 padding_idx: Option<usize>,
112}
113
114impl Embedding {
115 pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
117 Self::with_options(num_embeddings, embedding_dim, None)
118 }
119
120 pub fn with_options(
122 num_embeddings: usize,
123 embedding_dim: usize,
124 padding_idx: Option<usize>,
125 ) -> Self {
126 let mut weight_data = normal(&[num_embeddings, embedding_dim], 0.0, 1.0);
128
129 if let Some(pad_idx) = padding_idx {
131 let mut data = weight_data.to_vec();
132 for i in 0..embedding_dim {
133 data[pad_idx * embedding_dim + i] = 0.0;
134 }
135 weight_data = Tensor::from_vec(data, &[num_embeddings, embedding_dim]).unwrap();
136 }
137
138 Self {
139 weight: Parameter::named("weight", weight_data, true),
140 num_embeddings,
141 embedding_dim,
142 padding_idx,
143 }
144 }
145
146 pub fn from_pretrained(weights: Tensor<f32>, freeze: bool) -> Self {
148 let shape = weights.shape();
149 let num_embeddings = shape[0];
150 let embedding_dim = shape[1];
151
152 Self {
153 weight: Parameter::named("weight", weights, !freeze),
154 num_embeddings,
155 embedding_dim,
156 padding_idx: None,
157 }
158 }
159
160 pub fn num_embeddings(&self) -> usize {
162 self.num_embeddings
163 }
164
165 pub fn embedding_dim(&self) -> usize {
167 self.embedding_dim
168 }
169
170 pub fn lookup(&self, indices: &Variable) -> Variable {
178 let indices_data = indices.data();
179 let indices_vec = indices_data.to_vec();
181 let indices_shape = indices_data.shape().to_vec();
182
183 let mut output_shape = indices_shape.clone();
185 output_shape.push(self.embedding_dim);
186 let output_size: usize = output_shape.iter().product();
187
188 let mut safe_indices = Vec::with_capacity(indices_vec.len());
190 let mut gather_idx = Vec::with_capacity(output_size);
192
193 for &idx_f in &indices_vec {
194 let idx = idx_f as usize;
195 let safe_idx = if idx >= self.num_embeddings {
196 #[cfg(debug_assertions)]
197 eprintln!(
198 "Warning: embedding index {} out of range (max {}), using padding index 0",
199 idx,
200 self.num_embeddings - 1
201 );
202 0
203 } else {
204 idx
205 };
206 safe_indices.push(safe_idx);
207 let base = safe_idx * self.embedding_dim;
209 for d in 0..self.embedding_dim {
210 gather_idx.push((base + d) as u32);
211 }
212 }
213
214 let weight_data = self.weight.data();
215 #[allow(unused_variables)]
216 let weight_device = weight_data.device();
217
218 #[cfg(feature = "cuda")]
220 let output_tensor = if weight_device.is_gpu() {
221 weight_data.embedding_gather_cuda(&gather_idx, &output_shape)
222 } else {
223 let weight_vec = weight_data.to_vec();
224 let output_data: Vec<f32> =
225 gather_idx.iter().map(|&i| weight_vec[i as usize]).collect();
226 Tensor::from_vec(output_data, &output_shape).unwrap()
227 };
228
229 #[cfg(not(feature = "cuda"))]
230 let output_tensor = {
231 let weight_vec = weight_data.to_vec();
232 let output_data: Vec<f32> =
233 gather_idx.iter().map(|&i| weight_vec[i as usize]).collect();
234 Tensor::from_vec(output_data, &output_shape).unwrap()
235 };
236
237 if self.weight.requires_grad() {
238 let grad_fn = GradFn::new(EmbeddingBackward {
239 next_fns: vec![self.weight.variable().grad_fn().cloned()],
240 indices: safe_indices,
241 num_embeddings: self.num_embeddings,
242 embedding_dim: self.embedding_dim,
243 });
244 Variable::from_operation(output_tensor, grad_fn, true)
245 } else {
246 Variable::new(output_tensor, false)
247 }
248 }
249}
250
251impl Module for Embedding {
252 fn forward(&self, input: &Variable) -> Variable {
253 self.lookup(input)
254 }
255
256 fn parameters(&self) -> Vec<Parameter> {
257 vec![self.weight.clone()]
258 }
259
260 fn named_parameters(&self) -> HashMap<String, Parameter> {
261 let mut params = HashMap::new();
262 params.insert("weight".to_string(), self.weight.clone());
263 params
264 }
265
266 fn name(&self) -> &'static str {
267 "Embedding"
268 }
269}
270
271impl std::fmt::Debug for Embedding {
272 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273 f.debug_struct("Embedding")
274 .field("num_embeddings", &self.num_embeddings)
275 .field("embedding_dim", &self.embedding_dim)
276 .field("padding_idx", &self.padding_idx)
277 .finish()
278 }
279}
280
281#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_embedding_creation() {
291 let emb = Embedding::new(1000, 128);
292 assert_eq!(emb.num_embeddings(), 1000);
293 assert_eq!(emb.embedding_dim(), 128);
294 }
295
296 #[test]
297 fn test_embedding_lookup() {
298 let emb = Embedding::new(10, 4);
299 let indices = Variable::new(Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap(), false);
300 let output = emb.forward(&indices);
301 assert_eq!(output.shape(), vec![3, 4]);
302 }
303
304 #[test]
305 fn test_embedding_batch() {
306 let emb = Embedding::new(10, 4);
307 let indices = Variable::new(
308 Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &[2, 3]).unwrap(),
309 false,
310 );
311 let output = emb.forward(&indices);
312 assert_eq!(output.shape(), vec![2, 3, 4]);
313 }
314
315 #[test]
316 fn test_embedding_parameters() {
317 let emb = Embedding::new(100, 64);
318 assert_eq!(emb.parameters().len(), 1);
319 assert_eq!(emb.num_parameters(), 100 * 64);
320 }
321
322 #[test]
323 fn test_embedding_with_padding() {
324 let emb = Embedding::with_options(10, 4, Some(0));
325 let indices = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
327 let output = emb.forward(&indices);
328 let output_vec = output.data().to_vec();
329 assert!(output_vec.iter().all(|&x| x == 0.0));
330 }
331}