use oxicuda_blas::GpuFloat;
use oxicuda_blas::batched::{GroupedGemmProblem, gemm_grouped};
use oxicuda_blas::types::Transpose;
use oxicuda_memory::DeviceBuffer;
use crate::error::{DnnError, DnnResult};
use crate::handle::DnnHandle;
use crate::types::{TensorDesc, TensorDescMut};
pub fn moe_grouped_gemm<T: GpuFloat>(
handle: &DnnHandle,
tokens: &TensorDesc<T>,
weights: &TensorDesc<T>,
output: &mut TensorDescMut<T>,
expert_offsets: &DeviceBuffer<i32>,
num_experts: u32,
) -> DnnResult<()> {
if tokens.ndim() != 2 {
return Err(DnnError::InvalidDimension(format!(
"tokens must be 2D, got {}D",
tokens.ndim()
)));
}
if weights.ndim() != 3 {
return Err(DnnError::InvalidDimension(format!(
"weights must be 3D [num_experts, K, N], got {}D",
weights.ndim()
)));
}
if output.ndim() != 2 {
return Err(DnnError::InvalidDimension(format!(
"output must be 2D, got {}D",
output.ndim()
)));
}
let total_tokens = tokens.dims[0];
let k_dim = tokens.dims[1]; let n_dim = weights.dims[2];
if weights.dims[0] != num_experts {
return Err(DnnError::InvalidDimension(format!(
"weights dim[0] ({}) != num_experts ({})",
weights.dims[0], num_experts
)));
}
if weights.dims[1] != k_dim {
return Err(DnnError::InvalidDimension(format!(
"weights dim[1] ({}) != tokens dim[1] ({})",
weights.dims[1], k_dim
)));
}
if output.dims[0] != total_tokens {
return Err(DnnError::InvalidDimension(format!(
"output rows ({}) != total_tokens ({})",
output.dims[0], total_tokens
)));
}
if output.dims[1] != n_dim {
return Err(DnnError::InvalidDimension(format!(
"output cols ({}) != weights N dim ({})",
output.dims[1], n_dim
)));
}
let required_offsets = num_experts as usize + 1;
if expert_offsets.len() < required_offsets {
return Err(DnnError::BufferTooSmall {
expected: required_offsets,
actual: expert_offsets.len(),
});
}
let weight_stride = k_dim as usize * n_dim as usize * T::SIZE;
let token_stride = k_dim as usize * T::SIZE;
let output_stride = n_dim as usize * T::SIZE;
let problems = build_expert_problems(
tokens.ptr,
weights.ptr,
output.ptr,
total_tokens,
k_dim,
n_dim,
num_experts,
weight_stride as u64,
token_stride as u64,
output_stride as u64,
);
if problems.is_empty() {
return Ok(());
}
gemm_grouped(handle.blas(), &problems, T::gpu_one(), T::gpu_zero())?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn build_expert_problems(
tokens_ptr: u64,
weights_ptr: u64,
output_ptr: u64,
total_tokens: u32,
k_dim: u32,
n_dim: u32,
num_experts: u32,
weight_stride: u64,
token_row_bytes: u64,
output_row_bytes: u64,
) -> Vec<GroupedGemmProblem> {
if num_experts == 0 || total_tokens == 0 {
return Vec::new();
}
let base_m = total_tokens / num_experts;
let remainder = total_tokens % num_experts;
let mut problems = Vec::with_capacity(num_experts as usize);
let mut row_offset: u64 = 0;
for e in 0..num_experts {
let m_e = if e < remainder { base_m + 1 } else { base_m };
if m_e == 0 {
continue;
}
let a_ptr = tokens_ptr + row_offset * token_row_bytes;
let b_ptr = weights_ptr + (e as u64) * weight_stride;
let d_ptr = output_ptr + row_offset * output_row_bytes;
problems.push(GroupedGemmProblem {
trans_a: Transpose::NoTrans,
trans_b: Transpose::NoTrans,
m: m_e,
n: n_dim,
k: k_dim,
a_ptr,
lda: k_dim,
b_ptr,
ldb: n_dim,
c_ptr: d_ptr, ldc: n_dim,
d_ptr,
ldd: n_dim,
});
row_offset += m_e as u64;
}
problems
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_problems_uniform() {
let problems = build_expert_problems(
0x1000,
0x2000,
0x3000,
16, 64, 128, 4, 64 * 128 * 4, 64 * 4, 128 * 4, );
assert_eq!(problems.len(), 4);
for p in &problems {
assert_eq!(p.m, 4);
assert_eq!(p.n, 128);
assert_eq!(p.k, 64);
}
}
#[test]
fn build_problems_with_remainder() {
let problems = build_expert_problems(
0x1000,
0x2000,
0x3000,
10, 32,
64,
3,
32 * 64 * 4,
32 * 4,
64 * 4,
);
assert_eq!(problems.len(), 3);
assert_eq!(problems[0].m, 4); assert_eq!(problems[1].m, 3);
assert_eq!(problems[2].m, 3);
}
#[test]
fn build_problems_zero_tokens() {
let problems = build_expert_problems(0, 0, 0, 0, 32, 64, 4, 32 * 64 * 4, 32 * 4, 64 * 4);
assert!(problems.is_empty());
}
}