use crate::device::Device;
use crate::errors::{Result, TrustformersError};
use crate::tensor::Tensor;
use crate::traits::Layer;
use scirs2_core::ndarray::{Array2, Axis};
#[derive(Debug, Clone)]
pub struct Embedding {
weight: Tensor,
num_embeddings: usize,
embedding_dim: usize,
device: Device,
}
impl Embedding {
pub fn new(
num_embeddings: usize,
embedding_dim: usize,
padding_idx: Option<usize>,
) -> Result<Self> {
Self::new_with_device(num_embeddings, embedding_dim, padding_idx, Device::CPU)
}
pub fn new_with_device(
num_embeddings: usize,
embedding_dim: usize,
padding_idx: Option<usize>,
device: Device,
) -> Result<Self> {
let mut weight = Tensor::randn(&[num_embeddings, embedding_dim])?;
if let Some(padding_idx) = padding_idx {
if padding_idx < num_embeddings {
weight = weight.zero_padding_embedding(padding_idx)?;
}
}
Ok(Self {
weight,
num_embeddings,
embedding_dim,
device,
})
}
pub fn set_weight(&mut self, weight: Tensor) -> Result<()> {
self.weight = weight;
Ok(())
}
pub fn forward_ids(&self, input_ids: &[u32]) -> Result<Tensor> {
self.forward(input_ids.to_vec())
}
pub fn device(&self) -> Device {
self.device
}
pub fn to_device(mut self, device: Device) -> Self {
self.device = device;
self
}
pub fn parameter_count(&self) -> usize {
self.num_embeddings * self.embedding_dim
}
#[cfg(all(target_os = "macos", feature = "metal"))]
pub fn weights_to_gpu(&mut self, device: &Device) -> Result<()> {
if !matches!(device, Device::Metal(_)) {
return Ok(());
}
self.device = *device;
self.weight = self.weight.to_device_enum(device)?;
Ok(())
}
#[cfg(all(feature = "cuda", any(target_os = "linux", target_os = "windows")))]
pub fn weights_to_gpu_cuda(&mut self, device: &Device) -> Result<()> {
if !matches!(device, Device::CUDA(_)) {
return Ok(());
}
self.device = *device;
self.weight = self.weight.to_device_enum(device)?;
Ok(())
}
}
impl Layer for Embedding {
type Input = Vec<u32>;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
#[cfg(all(target_os = "macos", feature = "metal"))]
if let Tensor::Metal(_) = &self.weight {
let cpu_weight = self.weight.to_device_enum(&Device::CPU)?;
if let Tensor::F32(weight_arr) = cpu_weight {
let batch_size = input.len();
let mut output = Array2::<f32>::zeros((batch_size, self.embedding_dim));
for (i, &idx) in input.iter().enumerate() {
if idx as usize >= self.num_embeddings {
return Err(TrustformersError::tensor_op_error(
&format!(
"Index {} out of range for embedding table of size {}",
idx, self.num_embeddings
),
"Embedding::forward",
));
}
let embedding = weight_arr.index_axis(Axis(0), idx as usize);
output.row_mut(i).assign(&embedding);
}
let result_tensor = Tensor::F32(output.into_dyn());
#[cfg(debug_assertions)]
{
if matches!(self.device, Device::Metal(_)) {
if let Tensor::F32(ref arr) = result_tensor {
let data: Vec<f32> = arr.iter().cloned().collect();
eprintln!(
"🔍 Embedding lookup (CPU) first 10: {:?}",
&data[..10.min(data.len())]
);
eprintln!(
"🔍 Embedding stats: min={:.4}, max={:.4}, mean={:.4}",
data.iter().fold(f32::INFINITY, |a, &b| a.min(b)),
data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)),
data.iter().sum::<f32>() / data.len() as f32
);
}
}
}
if matches!(self.device, Device::Metal(_)) {
let metal_result = result_tensor.to_device_enum(&self.device)?;
return Ok(metal_result);
}
return Ok(result_tensor);
}
}
match &self.weight {
Tensor::F32(weight_arr) => {
let batch_size = input.len();
let mut output = Array2::<f32>::zeros((batch_size, self.embedding_dim));
for (i, &idx) in input.iter().enumerate() {
if idx as usize >= self.num_embeddings {
return Err(TrustformersError::tensor_op_error(
&format!(
"Index {} out of range for embedding table of size {}",
idx, self.num_embeddings
),
"Embedding::forward",
));
}
let embedding = weight_arr.index_axis(Axis(0), idx as usize);
output.row_mut(i).assign(&embedding);
}
Ok(Tensor::F32(output.into_dyn()))
},
_ => Err(TrustformersError::tensor_op_error(
"Unsupported tensor type for embedding",
"Embedding::forward",
)),
}
}
}