use crate::error::{Error, Result};
use numr::dtype::DType;
use numr::ops::{
BinaryOps, IndexingOps, ReduceOps, ScalarOps, ShapeOps, SortingOps, TensorOps,
TypeConversionOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn moe_permute_tokens_impl<R, C>(
client: &C,
tokens: &Tensor<R>,
indices: &Tensor<R>,
num_experts: usize,
) -> Result<(Tensor<R>, Tensor<R>, Tensor<R>)>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ BinaryOps<R>
+ IndexingOps<R>
+ ScalarOps<R>
+ SortingOps<R>
+ ShapeOps<R>
+ TensorOps<R>
+ TypeConversionOps<R>,
{
let tok_shape = tokens.shape();
let idx_shape = indices.shape();
if tok_shape.len() != 2 {
return Err(Error::InvalidArgument {
arg: "tokens",
reason: format!("expected 2D [num_tokens, hidden], got {}D", tok_shape.len()),
});
}
if idx_shape.len() != 2 {
return Err(Error::InvalidArgument {
arg: "indices",
reason: format!("expected 2D [num_tokens, k], got {}D", idx_shape.len()),
});
}
let num_tokens = tok_shape[0];
let _hidden_dim = tok_shape[1];
let k = idx_shape[1];
let total = num_tokens * k;
let device = tokens.device();
let flat_indices = indices.reshape(&[total]).map_err(Error::Numr)?.contiguous();
let sort_perm = client
.argsort(&flat_indices, 0, false)
.map_err(Error::Numr)?;
let token_source_data: Vec<i32> = (0..num_tokens)
.flat_map(|t| std::iter::repeat_n(t as i32, k))
.collect();
let token_source = Tensor::<R>::from_slice(&token_source_data, &[total], device);
let sorted_token_indices = client
.index_select(&token_source, 0, &sort_perm)
.map_err(Error::Numr)?;
let permuted = client
.index_select(tokens, 0, &sorted_token_indices)
.map_err(Error::Numr)?;
let sorted_expert_ids = client
.index_select(&flat_indices, 0, &sort_perm)
.map_err(Error::Numr)?;
let counts = client
.bincount(&sorted_expert_ids, None, num_experts)
.map_err(Error::Numr)?;
let counts_i32 = if counts.dtype() == DType::I32 {
counts
} else {
client.cast(&counts, DType::I32).map_err(Error::Numr)?
};
let cumsum = client.cumsum(&counts_i32, 0).map_err(Error::Numr)?;
let zero = Tensor::<R>::zeros(&[1], DType::I32, device);
let expert_offsets = client.cat(&[&zero, &cumsum], 0).map_err(Error::Numr)?;
let sort_perm_i32 = if sort_perm.dtype() != DType::I32 {
client.cast(&sort_perm, DType::I32).map_err(Error::Numr)?
} else {
sort_perm
};
Ok((permuted, expert_offsets, sort_perm_i32))
}
pub fn moe_unpermute_tokens_impl<R, C>(
client: &C,
expert_output: &Tensor<R>,
sort_indices: &Tensor<R>,
weights: &Tensor<R>,
num_tokens: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + IndexingOps<R> + ShapeOps<R> + ScalarOps<R> + ReduceOps<R> + TensorOps<R>,
{
let out_shape = expert_output.shape();
if out_shape.len() != 2 {
return Err(Error::InvalidArgument {
arg: "expert_output",
reason: format!("expected 2D [N*k, hidden], got {}D", out_shape.len()),
});
}
let total = out_shape[0];
let hidden_dim = out_shape[1];
let w_shape = weights.shape();
if w_shape.len() != 2 {
return Err(Error::InvalidArgument {
arg: "weights",
reason: format!("expected 2D [num_tokens, k], got {}D", w_shape.len()),
});
}
let k = w_shape[1];
if num_tokens * k != total {
return Err(Error::InvalidArgument {
arg: "num_tokens",
reason: format!(
"num_tokens({}) * k({}) = {} != expert_output rows({})",
num_tokens,
k,
num_tokens * k,
total
),
});
}
let device = expert_output.device();
let arange_data: Vec<i32> = (0..total as i32).collect();
let values = Tensor::<R>::from_slice(&arange_data, &[total], device);
let inv_perm_base = Tensor::<R>::zeros(&[total], DType::I32, device);
let inv_perm = client
.scatter(&inv_perm_base, 0, sort_indices, &values)
.map_err(Error::Numr)?;
let unsorted = client
.index_select(expert_output, 0, &inv_perm)
.map_err(Error::Numr)?;
let reshaped = unsorted
.reshape(&[num_tokens, k, hidden_dim])
.map_err(Error::Numr)?;
let w_expanded = weights.reshape(&[num_tokens, k, 1]).map_err(Error::Numr)?;
let weighted = client.mul(&reshaped, &w_expanded).map_err(Error::Numr)?;
let output = client.sum(&weighted, &[1], false).map_err(Error::Numr)?;
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_permute_unpermute_roundtrip() {
let (client, device) = cpu_setup();
let num_tokens = 4;
let hidden_dim = 8;
let num_experts = 3;
let k = 2;
let tokens_data: Vec<f32> = (0..num_tokens * hidden_dim)
.map(|i| i as f32 * 0.1)
.collect();
let tokens =
Tensor::<CpuRuntime>::from_slice(&tokens_data, &[num_tokens, hidden_dim], &device);
let indices_data: Vec<i32> = vec![0, 1, 2, 0, 1, 2, 0, 1];
let indices = Tensor::<CpuRuntime>::from_slice(&indices_data, &[num_tokens, k], &device);
let weights_data: Vec<f32> = vec![0.5; num_tokens * k];
let weights = Tensor::<CpuRuntime>::from_slice(&weights_data, &[num_tokens, k], &device);
let (permuted, offsets, sort_indices) =
moe_permute_tokens_impl(&client, &tokens, &indices, num_experts).unwrap();
assert_eq!(permuted.shape(), &[num_tokens * k, hidden_dim]);
assert_eq!(offsets.shape(), &[num_experts + 1]);
let result =
moe_unpermute_tokens_impl(&client, &permuted, &sort_indices, &weights, num_tokens)
.unwrap();
assert_eq!(result.shape(), &[num_tokens, hidden_dim]);
let result_vec = result.to_vec::<f32>();
for (i, (&got, &expected)) in result_vec.iter().zip(tokens_data.iter()).enumerate() {
assert!(
(got - expected).abs() < 1e-5,
"roundtrip mismatch at {}: got {}, expected {}",
i,
got,
expected
);
}
}
}