Skip to main content

axonml_nn/layers/
embedding.rs

1//! Embedding Layer - Lookup Table for Indices
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/embedding.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::any::Any;
18use std::collections::HashMap;
19
20use axonml_autograd::{GradFn, GradientFunction, Variable};
21use axonml_tensor::Tensor;
22
23use crate::init::normal;
24use crate::module::Module;
25use crate::parameter::Parameter;
26
27// =============================================================================
28// EmbeddingBackward
29// =============================================================================
30
31/// Gradient function for Embedding lookup.
32///
33/// Scatters the upstream gradient back into a sparse gradient of shape
34/// `[num_embeddings, embedding_dim]` using the indices from the forward pass.
35#[derive(Debug)]
36struct EmbeddingBackward {
37    next_fns: Vec<Option<GradFn>>,
38    /// The indices used during the forward lookup (as usize).
39    indices: Vec<usize>,
40    num_embeddings: usize,
41    embedding_dim: usize,
42}
43
44impl GradientFunction for EmbeddingBackward {
45    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
46        // GPU path: use CUDA scatter-add kernel
47        #[cfg(feature = "cuda")]
48        if grad_output.device().is_gpu() {
49            let indices_u32: Vec<u32> = self.indices.iter().map(|&i| i as u32).collect();
50            let grad_tensor = grad_output.embedding_scatter_add_cuda(
51                &indices_u32,
52                self.num_embeddings,
53                self.embedding_dim,
54            );
55            return vec![Some(grad_tensor)];
56        }
57
58        // CPU fallback
59        let grad_data = grad_output.to_vec();
60        let mut weight_grad = vec![0.0f32; self.num_embeddings * self.embedding_dim];
61
62        // Scatter-add: accumulate gradients for each index
63        for (i, &idx) in self.indices.iter().enumerate() {
64            if idx < self.num_embeddings {
65                let src_offset = i * self.embedding_dim;
66                let dst_offset = idx * self.embedding_dim;
67                for d in 0..self.embedding_dim {
68                    weight_grad[dst_offset + d] += grad_data[src_offset + d];
69                }
70            }
71        }
72
73        let grad_tensor = Tensor::from_vec(weight_grad, &[self.num_embeddings, self.embedding_dim])
74            .expect("tensor creation failed");
75        vec![Some(grad_tensor)]
76    }
77
78    fn name(&self) -> &'static str {
79        "EmbeddingBackward"
80    }
81
82    fn next_functions(&self) -> &[Option<GradFn>] {
83        &self.next_fns
84    }
85
86    fn as_any(&self) -> &dyn Any {
87        self
88    }
89}
90
91// =============================================================================
92// Embedding
93// =============================================================================
94
95/// A simple lookup table that stores embeddings of a fixed dictionary.
96///
97/// This module is often used to store word embeddings and retrieve them
98/// using indices.
99///
100/// # Shape
101/// - Input: (*) - LongTensor of arbitrary shape containing indices
102/// - Output: (*, H) - where H = embedding_dim
103pub struct Embedding {
104    /// Embedding weights of shape (num_embeddings, embedding_dim).
105    pub weight: Parameter,
106    /// Number of embeddings in the dictionary.
107    num_embeddings: usize,
108    /// Dimension of each embedding vector.
109    embedding_dim: usize,
110    /// Index of padding token (if any).
111    padding_idx: Option<usize>,
112}
113
114impl Embedding {
115    /// Creates a new Embedding layer.
116    pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
117        Self::with_options(num_embeddings, embedding_dim, None)
118    }
119
120    /// Creates an Embedding with padding index.
121    pub fn with_options(
122        num_embeddings: usize,
123        embedding_dim: usize,
124        padding_idx: Option<usize>,
125    ) -> Self {
126        // Initialize weights from N(0, 1)
127        let mut weight_data = normal(&[num_embeddings, embedding_dim], 0.0, 1.0);
128
129        // Set padding index to zeros if specified
130        if let Some(pad_idx) = padding_idx {
131            let mut data = weight_data.to_vec();
132            for i in 0..embedding_dim {
133                data[pad_idx * embedding_dim + i] = 0.0;
134            }
135            weight_data = Tensor::from_vec(data, &[num_embeddings, embedding_dim])
136                .expect("tensor creation failed");
137        }
138
139        Self {
140            weight: Parameter::named("weight", weight_data, true),
141            num_embeddings,
142            embedding_dim,
143            padding_idx,
144        }
145    }
146
147    /// Creates an Embedding from pretrained weights.
148    pub fn from_pretrained(weights: Tensor<f32>, freeze: bool) -> Self {
149        let shape = weights.shape();
150        let num_embeddings = shape[0];
151        let embedding_dim = shape[1];
152
153        Self {
154            weight: Parameter::named("weight", weights, !freeze),
155            num_embeddings,
156            embedding_dim,
157            padding_idx: None,
158        }
159    }
160
161    /// Returns the number of embeddings.
162    pub fn num_embeddings(&self) -> usize {
163        self.num_embeddings
164    }
165
166    /// Returns the embedding dimension.
167    pub fn embedding_dim(&self) -> usize {
168        self.embedding_dim
169    }
170
171    /// Looks up embeddings for the given indices.
172    ///
173    /// # Arguments
174    /// * `indices` - Variable containing integer indices
175    ///
176    /// Note: In a full implementation, indices would be LongTensor.
177    /// Here we use f32 and cast to usize.
178    pub fn lookup(&self, indices: &Variable) -> Variable {
179        let indices_data = indices.data();
180        // Copy indices to CPU (small: batch_size * seq_len values)
181        let indices_vec = indices_data.to_vec();
182        let indices_shape = indices_data.shape().to_vec();
183
184        // Output shape: indices_shape + [embedding_dim]
185        let mut output_shape = indices_shape.clone();
186        output_shape.push(self.embedding_dim);
187        let output_size: usize = output_shape.iter().product();
188
189        // Compute gather indices and validate on CPU (indices are small)
190        let mut safe_indices = Vec::with_capacity(indices_vec.len());
191        // Build flat gather index: for each token index, we need embedding_dim consecutive elements
192        let mut gather_idx = Vec::with_capacity(output_size);
193
194        for &idx_f in &indices_vec {
195            let idx = idx_f as usize;
196            let safe_idx = if idx >= self.num_embeddings {
197                #[cfg(debug_assertions)]
198                eprintln!(
199                    "Warning: embedding index {} out of range (max {}), using padding index 0",
200                    idx,
201                    self.num_embeddings - 1
202                );
203                0
204            } else {
205                idx
206            };
207            safe_indices.push(safe_idx);
208            // Each token maps to embedding_dim elements starting at safe_idx * embedding_dim
209            let base = safe_idx * self.embedding_dim;
210            for d in 0..self.embedding_dim {
211                gather_idx.push((base + d) as u32);
212            }
213        }
214
215        let weight_data = self.weight.data();
216        #[cfg(feature = "cuda")]
217        let weight_device = weight_data.device();
218
219        // GPU path: use gather kernel to avoid copying entire weight matrix
220        #[cfg(feature = "cuda")]
221        let output_tensor = if weight_device.is_gpu() {
222            weight_data.embedding_gather_cuda(&gather_idx, &output_shape)
223        } else {
224            let weight_vec = weight_data.to_vec();
225            let output_data: Vec<f32> =
226                gather_idx.iter().map(|&i| weight_vec[i as usize]).collect();
227            Tensor::from_vec(output_data, &output_shape).expect("tensor creation failed")
228        };
229
230        #[cfg(not(feature = "cuda"))]
231        let output_tensor = {
232            let weight_vec = weight_data.to_vec();
233            let output_data: Vec<f32> =
234                gather_idx.iter().map(|&i| weight_vec[i as usize]).collect();
235            Tensor::from_vec(output_data, &output_shape).expect("tensor creation failed")
236        };
237
238        if self.weight.requires_grad() {
239            let grad_fn = GradFn::new(EmbeddingBackward {
240                next_fns: vec![self.weight.variable().grad_fn().cloned()],
241                indices: safe_indices,
242                num_embeddings: self.num_embeddings,
243                embedding_dim: self.embedding_dim,
244            });
245            Variable::from_operation(output_tensor, grad_fn, true)
246        } else {
247            Variable::new(output_tensor, false)
248        }
249    }
250}
251
252impl Module for Embedding {
253    fn forward(&self, input: &Variable) -> Variable {
254        self.lookup(input)
255    }
256
257    fn parameters(&self) -> Vec<Parameter> {
258        vec![self.weight.clone()]
259    }
260
261    fn named_parameters(&self) -> HashMap<String, Parameter> {
262        let mut params = HashMap::new();
263        params.insert("weight".to_string(), self.weight.clone());
264        params
265    }
266
267    fn name(&self) -> &'static str {
268        "Embedding"
269    }
270}
271
272impl std::fmt::Debug for Embedding {
273    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274        f.debug_struct("Embedding")
275            .field("num_embeddings", &self.num_embeddings)
276            .field("embedding_dim", &self.embedding_dim)
277            .field("padding_idx", &self.padding_idx)
278            .finish()
279    }
280}
281
282// =============================================================================
283// Tests
284// =============================================================================
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_embedding_creation() {
292        let emb = Embedding::new(1000, 128);
293        assert_eq!(emb.num_embeddings(), 1000);
294        assert_eq!(emb.embedding_dim(), 128);
295    }
296
297    #[test]
298    fn test_embedding_lookup() {
299        let emb = Embedding::new(10, 4);
300        let indices = Variable::new(
301            Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).expect("tensor creation failed"),
302            false,
303        );
304        let output = emb.forward(&indices);
305        assert_eq!(output.shape(), vec![3, 4]);
306    }
307
308    #[test]
309    fn test_embedding_batch() {
310        let emb = Embedding::new(10, 4);
311        let indices = Variable::new(
312            Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &[2, 3])
313                .expect("tensor creation failed"),
314            false,
315        );
316        let output = emb.forward(&indices);
317        assert_eq!(output.shape(), vec![2, 3, 4]);
318    }
319
320    #[test]
321    fn test_embedding_parameters() {
322        let emb = Embedding::new(100, 64);
323        assert_eq!(emb.parameters().len(), 1);
324        assert_eq!(emb.num_parameters(), 100 * 64);
325    }
326
327    #[test]
328    fn test_embedding_with_padding() {
329        let emb = Embedding::with_options(10, 4, Some(0));
330        // Padding index 0 should be all zeros
331        let indices = Variable::new(
332            Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
333            false,
334        );
335        let output = emb.forward(&indices);
336        let output_vec = output.data().to_vec();
337        assert!(output_vec.iter().all(|&x| x == 0.0));
338    }
339}