use zyx::{DType, Tensor, ZyxError};
use zyx_derive::Module;
#[derive(Debug, Module)]
#[cfg_attr(feature = "py", pyo3::pyclass)]
pub struct Embedding {
pub vocab_size: u64,
pub embed_size: u64,
pub weight: Tensor,
pub arange: Tensor,
}
impl Embedding {
pub fn new(vocab_size: u64, embed_size: u64, dtype: DType) -> Result<Embedding, ZyxError> {
Ok(Embedding {
vocab_size,
embed_size,
weight: Tensor::glorot_uniform([vocab_size, embed_size], dtype)?
.reshape([1, 1, vocab_size, embed_size])?,
arange: Tensor::arange(0, vocab_size as i64, 1)?
.reshape([1, 1, vocab_size, 1])?
.cast(dtype),
})
}
pub fn from_params(weight: Tensor) -> Result<Embedding, ZyxError> {
let sh = weight.shape();
assert_eq!(sh.len(), 2);
Ok(Embedding {
vocab_size: sh[0],
embed_size: sh[1],
arange: Tensor::arange(0, sh[0] as i64, 1)?
.reshape([1, 1, sh[0], 1])?
.cast(weight.dtype()),
weight,
})
}
pub fn forward(&self, x: impl Into<Tensor>) -> Result<Tensor, ZyxError> {
let x: Tensor = x.into();
let x_sh = x.shape();
if x.numel() == 0 {
return Ok(Tensor::zeros(
x_sh.iter()
.copied()
.chain([self.embed_size])
.collect::<Vec<u64>>(),
x.dtype(),
));
}
let xdt = x.dtype();
let wdt = self.weight.dtype();
if xdt != wdt {
return Err(ZyxError::DTypeError(
format!("Embedding::forward input x has dtype {xdt} but weight has dtype {wdt}")
.into(),
));
}
let big_shp: Vec<u64> = x_sh
.iter()
.copied()
.chain([self.vocab_size, self.embed_size])
.collect();
let arange = self.arange.expand(big_shp.clone())?;
let idx = x
.reshape(x_sh.into_iter().chain([1, 1]).collect::<Vec<u64>>())?
.expand(big_shp.clone())?;
let vals = self.weight.expand(big_shp)?;
(arange.equal(idx)?.cast(xdt) * vals).sum([2])
}
}