Skip to main content

axonml_nn/layers/
embedding.rs

1//! Embedding Layer - Lookup Table for Indices
2//!
3//! Maps discrete indices to dense vectors.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::init::normal;
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17// =============================================================================
18// Embedding
19// =============================================================================
20
21/// A simple lookup table that stores embeddings of a fixed dictionary.
22///
23/// This module is often used to store word embeddings and retrieve them
24/// using indices.
25///
26/// # Shape
27/// - Input: (*) - LongTensor of arbitrary shape containing indices
28/// - Output: (*, H) - where H = embedding_dim
29pub struct Embedding {
30    /// Embedding weights of shape (num_embeddings, embedding_dim).
31    pub weight: Parameter,
32    /// Number of embeddings in the dictionary.
33    num_embeddings: usize,
34    /// Dimension of each embedding vector.
35    embedding_dim: usize,
36    /// Index of padding token (if any).
37    padding_idx: Option<usize>,
38}
39
40impl Embedding {
41    /// Creates a new Embedding layer.
42    pub fn new(num_embeddings: usize, embedding_dim: usize) -> Self {
43        Self::with_options(num_embeddings, embedding_dim, None)
44    }
45
46    /// Creates an Embedding with padding index.
47    pub fn with_options(
48        num_embeddings: usize,
49        embedding_dim: usize,
50        padding_idx: Option<usize>,
51    ) -> Self {
52        // Initialize weights from N(0, 1)
53        let mut weight_data = normal(&[num_embeddings, embedding_dim], 0.0, 1.0);
54
55        // Set padding index to zeros if specified
56        if let Some(pad_idx) = padding_idx {
57            let mut data = weight_data.to_vec();
58            for i in 0..embedding_dim {
59                data[pad_idx * embedding_dim + i] = 0.0;
60            }
61            weight_data = Tensor::from_vec(data, &[num_embeddings, embedding_dim]).unwrap();
62        }
63
64        Self {
65            weight: Parameter::named("weight", weight_data, true),
66            num_embeddings,
67            embedding_dim,
68            padding_idx,
69        }
70    }
71
72    /// Creates an Embedding from pretrained weights.
73    pub fn from_pretrained(weights: Tensor<f32>, freeze: bool) -> Self {
74        let shape = weights.shape();
75        let num_embeddings = shape[0];
76        let embedding_dim = shape[1];
77
78        Self {
79            weight: Parameter::named("weight", weights, !freeze),
80            num_embeddings,
81            embedding_dim,
82            padding_idx: None,
83        }
84    }
85
86    /// Returns the number of embeddings.
87    pub fn num_embeddings(&self) -> usize {
88        self.num_embeddings
89    }
90
91    /// Returns the embedding dimension.
92    pub fn embedding_dim(&self) -> usize {
93        self.embedding_dim
94    }
95
96    /// Looks up embeddings for the given indices.
97    ///
98    /// # Arguments
99    /// * `indices` - Variable containing integer indices
100    ///
101    /// Note: In a full implementation, indices would be LongTensor.
102    /// Here we use f32 and cast to usize.
103    pub fn lookup(&self, indices: &Variable) -> Variable {
104        let indices_data = indices.data();
105        let indices_vec = indices_data.to_vec();
106        let indices_shape = indices_data.shape().to_vec();
107
108        let weight_vec = self.weight.data().to_vec();
109
110        // Output shape: indices_shape + [embedding_dim]
111        let mut output_shape = indices_shape.clone();
112        output_shape.push(self.embedding_dim);
113        let output_size: usize = output_shape.iter().product();
114
115        let mut output_data = vec![0.0f32; output_size];
116
117        for (i, &idx_f) in indices_vec.iter().enumerate() {
118            let idx = idx_f as usize;
119            // Clamp out-of-bounds indices to the padding index (0)
120            // This prevents panics while still producing valid output
121            let safe_idx = if idx >= self.num_embeddings {
122                #[cfg(debug_assertions)]
123                eprintln!(
124                    "Warning: embedding index {} out of range (max {}), using padding index 0",
125                    idx,
126                    self.num_embeddings - 1
127                );
128                0
129            } else {
130                idx
131            };
132
133            for d in 0..self.embedding_dim {
134                output_data[i * self.embedding_dim + d] =
135                    weight_vec[safe_idx * self.embedding_dim + d];
136            }
137        }
138
139        Variable::new(
140            Tensor::from_vec(output_data, &output_shape).unwrap(),
141            self.weight.requires_grad(),
142        )
143    }
144}
145
146impl Module for Embedding {
147    fn forward(&self, input: &Variable) -> Variable {
148        self.lookup(input)
149    }
150
151    fn parameters(&self) -> Vec<Parameter> {
152        vec![self.weight.clone()]
153    }
154
155    fn named_parameters(&self) -> HashMap<String, Parameter> {
156        let mut params = HashMap::new();
157        params.insert("weight".to_string(), self.weight.clone());
158        params
159    }
160
161    fn name(&self) -> &'static str {
162        "Embedding"
163    }
164}
165
166impl std::fmt::Debug for Embedding {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        f.debug_struct("Embedding")
169            .field("num_embeddings", &self.num_embeddings)
170            .field("embedding_dim", &self.embedding_dim)
171            .field("padding_idx", &self.padding_idx)
172            .finish()
173    }
174}
175
176// =============================================================================
177// Tests
178// =============================================================================
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_embedding_creation() {
186        let emb = Embedding::new(1000, 128);
187        assert_eq!(emb.num_embeddings(), 1000);
188        assert_eq!(emb.embedding_dim(), 128);
189    }
190
191    #[test]
192    fn test_embedding_lookup() {
193        let emb = Embedding::new(10, 4);
194        let indices = Variable::new(Tensor::from_vec(vec![0.0, 1.0, 2.0], &[3]).unwrap(), false);
195        let output = emb.forward(&indices);
196        assert_eq!(output.shape(), vec![3, 4]);
197    }
198
199    #[test]
200    fn test_embedding_batch() {
201        let emb = Embedding::new(10, 4);
202        let indices = Variable::new(
203            Tensor::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &[2, 3]).unwrap(),
204            false,
205        );
206        let output = emb.forward(&indices);
207        assert_eq!(output.shape(), vec![2, 3, 4]);
208    }
209
210    #[test]
211    fn test_embedding_parameters() {
212        let emb = Embedding::new(100, 64);
213        assert_eq!(emb.parameters().len(), 1);
214        assert_eq!(emb.num_parameters(), 100 * 64);
215    }
216
217    #[test]
218    fn test_embedding_with_padding() {
219        let emb = Embedding::with_options(10, 4, Some(0));
220        // Padding index 0 should be all zeros
221        let indices = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
222        let output = emb.forward(&indices);
223        let output_vec = output.data().to_vec();
224        assert!(output_vec.iter().all(|&x| x == 0.0));
225    }
226}