Skip to main content

axonml_nn/layers/
embedding.rs

1//! `Embedding` — trainable lookup table mapping integer indices to dense vectors.
2//!
3//! 339 lines. `Embedding::new(num_embeddings, embedding_dim)` with normal-init
4//! weight. `forward` takes `Variable` of integer token IDs and returns
5//! `[batch, seq_len, embedding_dim]`. Optional `padding_idx` zeros out a
6//! specific index. `lookup` and `embedding_gather_cuda` handle CPU/GPU dispatch.
7//!
8//! # File
9//! `crates/axonml-nn/src/layers/embedding.rs`
10//!
11//! # Author
12//! Andrew Jewell Sr. — AutomataNexus LLC
13//! ORCID: 0009-0005-2158-7060
14//!
15//! # Updated
16//! April 14, 2026 11:15 PM EST
17//!
18//! # Disclaimer
19//! Use at own risk. This software is provided "as is", without warranty of any
20//! kind, express or implied. The author and AutomataNexus shall not be held
21//! liable for any damages arising from the use of this software.
22
23use std::any::Any;
24use std::collections::HashMap;
25
26use axonml_autograd::{GradFn, GradientFunction, Variable};
27use axonml_tensor::Tensor;
28
29use crate::init::normal;
30use crate::module::Module;
31use crate::parameter::Parameter;
32
33// =============================================================================
34// EmbeddingBackward
35// =============================================================================
36
37/// Gradient function for Embedding lookup.
38///
39/// Scatters the upstream gradient back into a sparse gradient of shape
40/// `[num_embeddings, embedding_dim]` using the indices from the forward pass.
41#[derive(Debug)]
42struct EmbeddingBackward {
43    next_fns: Vec<Option<GradFn>>,
44    /// The indices used during the forward lookup (as usize).
45    indices: Vec<usize>,
46    num_embeddings: usize,
47    embedding_dim: usize,
48}
49
50impl GradientFunction for EmbeddingBackward {
51    fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
52        // GPU path: use CUDA scatter-add kernel
53        #[cfg(feature = "cuda")]
54        if grad_output.device().is_gpu() {
55            let indices_u32: Vec<u32> = self.indices.iter().map(|&i| i as u32).collect();
56            let grad_tensor = grad_output.embedding_scatter_add_cuda(
57                &indices_u32,
58                self.num_embeddings,
59                self.embedding_dim,
60            );
61            return vec![Some(grad_tensor)];
62        }
63
64        // CPU fallback
65        let grad_data = grad_output.to_vec();
66        let mut weight_grad = vec![0.0f32; self.num_embeddings * self.embedding_dim];
67
68        // Scatter-add: accumulate gradients for each index
69        for (i, &idx) in self.indices.iter().enumerate() {
70            if idx < self.num_embeddings {
71                let src_offset = i * self.embedding_dim;
72                let dst_offset = idx * self.embedding_dim;
73                for d in 0..self.embedding_dim {
74                    weight_grad[dst_offset + d] += grad_data[src_offset + d];
75                }
76            }
77        }
78
79        let grad_tensor = Tensor::from_vec(weight_grad, &[self.num_embeddings, self.embedding_dim])
80            .expect("tensor creation failed");
81        vec![Some(grad_tensor)]
82    }
83
84    fn name(&self) -> &'static str {
85        "EmbeddingBackward"
86    }
87
88    fn next_functions(&self) -> &[Option<GradFn>] {
89        &self.next_fns
90    }
91
92    fn as_any(&self) -> &dyn Any {
93        self
94    }
95}
96
97// =============================================================================
98// Embedding
99// =============================================================================
100
101/// A simple lookup table that stores embeddings of a fixed dictionary.
102///
103/// This module is often used to store word embeddings and retrieve them
104/// using indices.
105///
106/// # Shape
107/// - Input: (*) - LongTensor of arbitrary shape containing indices
108/// - Output: (*, H) - where H = embedding_dim
109pub struct Embedding {
110    /// Embedding weights of shape (num_embeddings, embedding_dim).
111    pub weight: Parameter,
112    /// Number of embeddings in the dictionary.
113    num_embeddings: usize,
114    /// Dimension of each embedding vector.
115    embedding_dim: usize,
116    /// Index of padding token (if any).
117    padding_idx: Option<usize>,
118}
119
120impl Embedding {
121    /// Creates a new Embedding layer.
122    pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
123        Self::with_options(num_embeddings, embedding_dim, None)
124    }
125
126    /// Creates an Embedding with padding index.
127    pub fn with_options(
128        num_embeddings: usize,
129        embedding_dim: usize,
130        padding_idx: Option<usize>,
131    ) -> Self {
132        // Initialize weights from N(0, 1)
133        let mut weight_data = normal(&[num_embeddings, embedding_dim], 0.0, 1.0);
134
135        // Set padding index to zeros if specified
136        if let Some(pad_idx) = padding_idx {
137            let mut data = weight_data.to_vec();
138            for i in 0..embedding_dim {
139                data[pad_idx * embedding_dim + i] = 0.0;
140            }
141            weight_data = Tensor::from_vec(data, &[num_embeddings, embedding_dim])
142                .expect("tensor creation failed");
143        }
144
145        Self {
146            weight: Parameter::named("weight", weight_data, true),
147            num_embeddings,
148            embedding_dim,
149            padding_idx,
150        }
151    }
152
153    /// Creates an Embedding from pretrained weights.
154    pub fn from_pretrained(weights: Tensor<f32>, freeze: bool) -> Self {
155        let shape = weights.shape();
156        let num_embeddings = shape[0];
157        let embedding_dim = shape[1];
158
159        Self {
160            weight: Parameter::named("weight", weights, !freeze),
161            num_embeddings,
162            embedding_dim,
163            padding_idx: None,
164        }
165    }
166
167    /// Returns the number of embeddings.
168    pub fn num_embeddings(&self) -> usize {
169        self.num_embeddings
170    }
171
172    /// Returns the embedding dimension.
173    pub fn embedding_dim(&self) -> usize {
174        self.embedding_dim
175    }
176
177    /// Looks up embeddings for the given indices.
178    ///
179    /// # Arguments
180    /// * `indices` - Variable containing integer indices
181    ///
182    /// Note: In a full implementation, indices would be LongTensor.
183    /// Here we use f32 and cast to usize.
184    pub fn lookup(&self, indices: &Variable) -> Variable {
185        let indices_data = indices.data();
186        // Copy indices to CPU (small: batch_size * seq_len values)
187        let indices_vec = indices_data.to_vec();
188        let indices_shape = indices_data.shape().to_vec();
189
190        // Output shape: indices_shape + [embedding_dim]
191        let mut output_shape = indices_shape.clone();
192        output_shape.push(self.embedding_dim);
193        let output_size: usize = output_shape.iter().product();
194
195        // Compute gather indices and validate on CPU (indices are small)
196        let mut safe_indices = Vec::with_capacity(indices_vec.len());
197        // Build flat gather index: for each token index, we need embedding_dim consecutive elements
198        let mut gather_idx = Vec::with_capacity(output_size);
199
200        for &idx_f in &indices_vec {
201            let idx = idx_f as usize;
202            let safe_idx = if idx >= self.num_embeddings {
203                #[cfg(debug_assertions)]
204                eprintln!(
205                    "Warning: embedding index {} out of range (max {}), using padding index 0",
206                    idx,
207                    self.num_embeddings - 1
208                );
209                0
210            } else {
211                idx
212            };
213            safe_indices.push(safe_idx);
214            // Each token maps to embedding_dim elements starting at safe_idx * embedding_dim
215            let base = safe_idx * self.embedding_dim;
216            for d in 0..self.embedding_dim {
217                gather_idx.push((base + d) as u32);
218            }
219        }
220
221        let weight_data = self.weight.data();
222        #[cfg(feature = "cuda")]
223        let weight_device = weight_data.device();
224
225        // GPU path: use gather kernel to avoid copying entire weight matrix
226        #[cfg(feature = "cuda")]
227        let output_tensor = if weight_device.is_gpu() {
228            weight_data.embedding_gather_cuda(&gather_idx, &output_shape)
229        } else {
230            let weight_vec = weight_data.to_vec();
231            let output_data: Vec<f32> =
232                gather_idx.iter().map(|&i| weight_vec[i as usize]).collect();
233            Tensor::from_vec(output_data, &output_shape).expect("tensor creation failed")
234        };
235
236        #[cfg(not(feature = "cuda"))]
237        let output_tensor = {
238            let weight_vec = weight_data.to_vec();
239            let output_data: Vec<f32> =
240                gather_idx.iter().map(|&i| weight_vec[i as usize]).collect();
241            Tensor::from_vec(output_data, &output_shape).expect("tensor creation failed")
242        };
243
244        if self.weight.requires_grad() {
245            let grad_fn = GradFn::new(EmbeddingBackward {
246                next_fns: vec![self.weight.variable().grad_fn().cloned()],
247                indices: safe_indices,
248                num_embeddings: self.num_embeddings,
249                embedding_dim: self.embedding_dim,
250            });
251            Variable::from_operation(output_tensor, grad_fn, true)
252        } else {
253            Variable::new(output_tensor, false)
254        }
255    }
256}
257
258impl Module for Embedding {
259    fn forward(&self, input: &Variable) -> Variable {
260        self.lookup(input)
261    }
262
263    fn parameters(&self) -> Vec<Parameter> {
264        vec![self.weight.clone()]
265    }
266
267    fn named_parameters(&self) -> HashMap<String, Parameter> {
268        let mut params = HashMap::new();
269        params.insert("weight".to_string(), self.weight.clone());
270        params
271    }
272
273    fn name(&self) -> &'static str {
274        "Embedding"
275    }
276}
277
278impl std::fmt::Debug for Embedding {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        f.debug_struct("Embedding")
281            .field("num_embeddings", &self.num_embeddings)
282            .field("embedding_dim", &self.embedding_dim)
283            .field("padding_idx", &self.padding_idx)
284            .finish()
285    }
286}
287
288// =============================================================================
289// Tests
290// =============================================================================
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_embedding_creation() {
298        let emb = Embedding::new(1000, 128);
299        assert_eq!(emb.num_embeddings(), 1000);
300        assert_eq!(emb.embedding_dim(), 128);
301    }
302
303    #[test]
304    fn test_embedding_lookup() {
305        let emb = Embedding::new(10, 4);
306        let indices = Variable::new(
307            Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).expect("tensor creation failed"),
308            false,
309        );
310        let output = emb.forward(&indices);
311        assert_eq!(output.shape(), vec![3, 4]);
312    }
313
314    #[test]
315    fn test_embedding_batch() {
316        let emb = Embedding::new(10, 4);
317        let indices = Variable::new(
318            Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &[2, 3])
319                .expect("tensor creation failed"),
320            false,
321        );
322        let output = emb.forward(&indices);
323        assert_eq!(output.shape(), vec![2, 3, 4]);
324    }
325
326    #[test]
327    fn test_embedding_parameters() {
328        let emb = Embedding::new(100, 64);
329        assert_eq!(emb.parameters().len(), 1);
330        assert_eq!(emb.num_parameters(), 100 * 64);
331    }
332
333    #[test]
334    fn test_embedding_with_padding() {
335        let emb = Embedding::with_options(10, 4, Some(0));
336        // Padding index 0 should be all zeros
337        let indices = Variable::new(
338            Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
339            false,
340        );
341        let output = emb.forward(&indices);
342        let output_vec = output.data().to_vec();
343        assert!(output_vec.iter().all(|&x| x == 0.0));
344    }
345}