mlx_nn/embedding.rs
1//! Embedding layer — looks up rows from a weight matrix by index.
2
3use mlx_core::{Result, Tensor};
4
5use crate::Module;
6
7/// Embedding layer: maps integer indices to dense vectors.
8///
9/// Weight has shape `[vocab_size, embed_dim]`. Input indices are 1D `[seq_len]`
10/// (stored as f32, cast to usize internally). Output is `[seq_len, embed_dim]`.
11pub struct Embedding {
12 weight: Tensor,
13}
14
15impl Embedding {
16 /// Create a new Embedding from a pre-existing weight tensor `[vocab_size, embed_dim]`.
17 pub fn new(weight: Tensor) -> Self {
18 Self { weight }
19 }
20
21 /// Get a reference to the weight tensor.
22 pub fn weight(&self) -> &Tensor {
23 &self.weight
24 }
25}
26
27impl Module for Embedding {
28 fn forward(&self, input: &Tensor) -> Result<Tensor> {
29 self.weight.embedding_lookup(input)
30 }
31}