use std::sync::Arc;
use crate::distributed::comm_utils::all_reduce_tensor;
use crate::error::{Error, Result};
use numr::autograd::Var;
use numr::dtype::DType;
use numr::ops::{CompareOps, IndexingOps, ScalarOps, TensorOps, TypeConversionOps, UtilityOps};
use numr::runtime::{Communicator, ReduceOp, Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub struct VocabParallelEmbedding<R: Runtime> {
weight: Var<R>,
comm: Arc<dyn Communicator>,
vocab_start: usize,
vocab_end: usize,
embed_dim: usize,
}
impl<R: Runtime<DType = DType>> VocabParallelEmbedding<R> {
pub fn new(
full_weight: &Tensor<R>,
comm: Arc<dyn Communicator>,
trainable: bool,
) -> Result<Self> {
let shape = full_weight.shape();
if shape.len() != 2 {
return Err(Error::DistributedError {
reason: format!(
"VocabParallelEmbedding expects 2D weight, got {}D",
shape.len()
),
});
}
let vocab_size = shape[0];
let embed_dim = shape[1];
let rank = comm.rank();
let world_size = comm.world_size();
if vocab_size % world_size != 0 {
return Err(Error::DistributedError {
reason: format!(
"vocab_size ({}) not divisible by world_size ({})",
vocab_size, world_size
),
});
}
let shard_size = vocab_size / world_size;
let vocab_start = rank * shard_size;
let vocab_end = vocab_start + shard_size;
let shard = full_weight
.narrow(0, vocab_start, shard_size)
.map_err(|e| Error::DistributedError {
reason: format!("embedding narrow failed: {e}"),
})?
.contiguous();
Ok(Self {
weight: Var::new(shard, trainable),
comm,
vocab_start,
vocab_end,
embed_dim,
})
}
pub fn from_shard(
weight: Tensor<R>,
comm: Arc<dyn Communicator>,
vocab_start: usize,
vocab_end: usize,
trainable: bool,
) -> Self {
let embed_dim = weight.shape()[1];
Self {
weight: Var::new(weight, trainable),
comm,
vocab_start,
vocab_end,
embed_dim,
}
}
pub fn forward<C>(&self, client: &C, indices: &Tensor<R>) -> Result<Var<R>>
where
C: RuntimeClient<R>
+ IndexingOps<R>
+ TensorOps<R>
+ CompareOps<R>
+ ScalarOps<R>
+ UtilityOps<R>
+ TypeConversionOps<R>,
R::Client: IndexingOps<R> + TensorOps<R>,
{
let idx_shape = indices.shape().to_vec();
let n: usize = idx_shape.iter().product();
let shard_size = self.vocab_end - self.vocab_start;
let flat_idx = indices.reshape(&[n]).map_err(Error::Numr)?;
let start_tensor = Tensor::<R>::full_scalar(
&[1],
flat_idx.dtype(),
self.vocab_start as f64,
indices.device(),
);
let end_tensor = Tensor::<R>::full_scalar(
&[1],
flat_idx.dtype(),
self.vocab_end as f64,
indices.device(),
);
let ge_start = client.ge(&flat_idx, &start_tensor).map_err(Error::Numr)?;
let lt_end = client.lt(&flat_idx, &end_tensor).map_err(Error::Numr)?;
let mask_i64 = client.mul(&ge_start, <_end).map_err(Error::Numr)?;
let mask_f32 = client.cast(&mask_i64, DType::F32).map_err(Error::Numr)?;
let local_idx = client
.sub_scalar(&flat_idx, self.vocab_start as f64)
.map_err(Error::Numr)?;
let local_idx = client
.clamp(&local_idx, 0.0, (shard_size - 1) as f64)
.map_err(Error::Numr)?;
let expanded = local_idx.unsqueeze(1).map_err(Error::Numr)?;
let expanded = expanded
.broadcast_to(&[n, self.embed_dim])
.map_err(Error::Numr)?;
let gathered =
numr::autograd::var_gather(&self.weight, 0, &expanded, client).map_err(Error::Numr)?;
let mask_2d = mask_f32.unsqueeze(1).map_err(Error::Numr)?;
let mask_broadcast = mask_2d
.broadcast_to(&[n, self.embed_dim])
.map_err(Error::Numr)?;
let mask_var = Var::new(mask_broadcast, false);
let masked = numr::autograd::var_mul(&gathered, &mask_var, client).map_err(Error::Numr)?;
if self.comm.world_size() > 1 {
let tensor = masked.tensor();
all_reduce_tensor(self.comm.as_ref(), tensor, ReduceOp::Sum)?;
self.comm.sync().map_err(|e| Error::DistributedError {
reason: format!("sync after embedding all_reduce failed: {e}"),
})?;
}
let mut out_shape = idx_shape;
out_shape.push(self.embed_dim);
let result = numr::autograd::var_reshape(&masked, &out_shape).map_err(Error::Numr)?;
Ok(result)
}
pub fn weight(&self) -> &Var<R> {
&self.weight
}
pub fn vocab_start(&self) -> usize {
self.vocab_start
}
pub fn vocab_end(&self) -> usize {
self.vocab_end
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nn::Embedding;
use numr::runtime::NoOpCommunicator;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn cpu_setup() -> (CpuClient, CpuDevice) {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
(client, device)
}
#[test]
fn test_vocab_parallel_single_rank_matches_embedding() {
let (client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
#[rustfmt::skip]
let weight = Tensor::<CpuRuntime>::from_slice(
&[
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
&[3, 4],
&device,
);
let plain_emb = Embedding::new(weight.clone(), false);
let par_emb = VocabParallelEmbedding::new(&weight, comm, false).unwrap();
let indices = Tensor::<CpuRuntime>::from_slice(&[0i64, 2, 1], &[3], &device);
let plain_out = plain_emb.forward(&client, &indices).unwrap();
let par_out = par_emb.forward(&client, &indices).unwrap();
assert_eq!(plain_out.shape(), par_out.shape());
let plain_data = plain_out.tensor().to_vec::<f32>();
let par_data = par_out.tensor().to_vec::<f32>();
assert_eq!(plain_data, par_data);
}
#[test]
fn test_vocab_parallel_forward_shape() {
let (client, device) = cpu_setup();
let comm = Arc::new(NoOpCommunicator);
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32; 20], &[4, 5], &device);
let par_emb = VocabParallelEmbedding::new(&weight, comm, false).unwrap();
let indices = Tensor::<CpuRuntime>::from_slice(&[0i64, 1, 2, 3, 0, 1], &[2, 3], &device);
let out = par_emb.forward(&client, &indices).unwrap();
assert_eq!(out.shape(), &[2, 3, 5]);
}
#[test]
fn test_vocab_parallel_not_divisible() {
let (_client, device) = cpu_setup();
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32; 12], &[3, 4], &device);
let comm = Arc::new(NoOpCommunicator);
assert!(VocabParallelEmbedding::new(&weight, comm, false).is_ok());
}
#[test]
fn test_vocab_parallel_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<VocabParallelEmbedding<CpuRuntime>>();
}
}