Skip to main content

mlx_nn/
embedding.rs

1//! Embedding layer — looks up rows from a weight matrix by index.
2
3use mlx_core::{Result, Tensor};
4
5use crate::Module;
6
7/// Embedding layer: maps integer indices to dense vectors.
8///
9/// Weight has shape `[vocab_size, embed_dim]`. Input indices are 1D `[seq_len]`
10/// (stored as f32, cast to usize internally). Output is `[seq_len, embed_dim]`.
11pub struct Embedding {
12    weight: Tensor,
13}
14
15impl Embedding {
16    /// Create a new Embedding from a pre-existing weight tensor `[vocab_size, embed_dim]`.
17    pub fn new(weight: Tensor) -> Self {
18        Self { weight }
19    }
20
21    /// Get a reference to the weight tensor.
22    pub fn weight(&self) -> &Tensor {
23        &self.weight
24    }
25}
26
27impl Module for Embedding {
28    fn forward(&self, input: &Tensor) -> Result<Tensor> {
29        self.weight.embedding_lookup(input)
30    }
31}