use candle_core::{Result, Tensor};
#[cfg(feature = "cuda")]
pub fn moe_gemm(
input: &Tensor,
weights: &Tensor,
topk_weights: &Option<Tensor>,
sorted_token_ids: &Tensor,
experts_ids: &Tensor,
topk: usize,
is_prefill: bool,
) -> Result<Tensor> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
use candle_core::DType;
use half::{bf16, f16};
fn cuda_fwd<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
input: &Tensor,
weights: &Tensor,
topk_weights: &Option<Tensor>,
sorted_token_ids: &Tensor,
experts_ids: &Tensor,
topk: usize,
is_prefill: bool,
) -> Result<Tensor> {
let (mut size_m, size_k1) = input.dims2()?;
if topk_weights.is_none() {
size_m *= topk;
}
let (num_experts, size_n, size_k) = weights.dims3()?;
assert!(
size_k == size_k1,
"input {:?} and weight {:?} last dim mismatch!",
size_k1,
size_k
);
let dev = input.device().as_cuda_device()?;
let data_type = match input.dtype() {
DType::F16 => 0,
DType::BF16 => 1,
_ => {
candle_core::bail!("moe_gemm_wmma only accept f16/bf16 inputs!")
}
};
let (input, input_l) = input.storage_and_layout();
let input = match &*input {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("input must be a cuda tensor"),
};
let input_offset = input_l.start_offset();
let (weights, weights_l) = weights.storage_and_layout();
let weights = match &*weights {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("weight must be a cuda tensor"),
};
let weights_offset = weights_l.start_offset();
let (sorted_token_ids, sti_l) = sorted_token_ids.storage_and_layout();
let sorted_token_ids = match &*sorted_token_ids {
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?,
_ => candle::bail!("sorted_token_ids must be a cuda tensor"),
};
let sti_offset = sti_l.start_offset();
let (experts_ids, ei_l) = experts_ids.storage_and_layout();
let experts_ids = match &*experts_ids {
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?,
_ => candle::bail!("experts_ids must be a cuda tensor"),
};
let ei_offset = ei_l.start_offset();
let topk_weights_ptr = if let Some(topk_weights) = &topk_weights {
let (topk_weights, tw_l) = topk_weights.storage_and_layout();
let topk_weights = match &*topk_weights {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("topk_weights must be a cuda tensor"),
};
let tw_offset = tw_l.start_offset();
let topk_w_ptr = topk_weights
.slice(tw_offset..)
.device_ptr(topk_weights.stream())
.0 as *const f32;
topk_w_ptr
} else {
std::ptr::null()
};
let output = unsafe { dev.alloc::<T>(size_m * size_n) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
use core::ffi::c_void;
const GEMV_THRESHOLD: i32 = 8;
let num_experts_i32 = i32::try_from(num_experts).expect("num_experts too large for i32");
let topk_i32 = i32::try_from(topk).expect("topk too large for i32");
let size_m_i32 = i32::try_from(size_m).expect("size_m too large for i32");
let size_n_i32 = i32::try_from(size_n).expect("size_n too large for i32");
let size_k_i32 = i32::try_from(size_k).expect("size_k too large for i32");
let moe_func = if is_prefill {
crate::cuda::ffi::moe_gemm_wmma
} else if size_m_i32 <= GEMV_THRESHOLD {
crate::cuda::ffi::moe_gemv
} else {
crate::cuda::ffi::moe_gemm
};
unsafe {
moe_func(
input.slice(input_offset..).device_ptr(input.stream()).0 as *const c_void, weights
.slice(weights_offset..)
.device_ptr(weights.stream())
.0 as *const c_void, sorted_token_ids
.slice(sti_offset..)
.device_ptr(sorted_token_ids.stream())
.0 as *const i32,
experts_ids
.slice(ei_offset..)
.device_ptr(experts_ids.stream())
.0 as *const i32,
topk_weights_ptr,
output.device_ptr(output.stream()).0 as *mut c_void, num_experts_i32,
topk_i32,
size_m_i32,
size_n_i32,
size_k_i32,
data_type as i32, stream,
);
}
let output = candle::CudaStorage::wrap_cuda_slice(output, dev.clone());
let output = Tensor::from((candle::Storage::Cuda(output), (size_m, size_n)));
Ok(output)
}
match input.dtype() {
DType::F16 => cuda_fwd::<f16>(
input,
weights,
topk_weights,
sorted_token_ids,
experts_ids,
topk,
is_prefill,
),
DType::BF16 => cuda_fwd::<bf16>(
input,
weights,
topk_weights,
sorted_token_ids,
experts_ids,
topk,
is_prefill,
),
_ => {
candle_core::bail!("moe_gemm only accept f16/bf16 inputs!")
}
}
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn moe_gemm(
_: &Tensor,
_: &Tensor,
_: &Option<Tensor>,
_: &Tensor,
_: &Tensor,
_: usize,
_: bool,
) -> Result<Tensor> {
candle_core::bail!("moe_gemm is not implemented on this platform!")
}
#[cfg(feature = "cuda")]
pub fn moe_gemm_transposed(
input: &Tensor,
weights: &Tensor,
topk_weights: &Option<Tensor>,
sorted_token_ids: &Tensor,
experts_ids: &Tensor,
topk: usize,
is_prefill: bool,
) -> Result<Tensor> {
use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle_core as candle;
use candle_core::DType;
use half::{bf16, f16};
fn cuda_fwd<
T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr,
>(
input: &Tensor,
weights: &Tensor,
topk_weights: &Option<Tensor>,
sorted_token_ids: &Tensor,
experts_ids: &Tensor,
topk: usize,
is_prefill: bool,
) -> Result<Tensor> {
let (mut size_m, size_k1) = input.dims2()?;
if topk_weights.is_none() {
size_m *= topk;
}
let (num_experts, size_k, size_n) = weights.dims3()?;
assert!(
size_k == size_k1,
"input {:?} and weight {:?} K dim mismatch!",
size_k1,
size_k
);
let dev = input.device().as_cuda_device()?;
let data_type = match input.dtype() {
DType::F16 => 0,
DType::BF16 => 1,
_ => {
candle_core::bail!("moe_gemm_transposed only accept f16/bf16 inputs!")
}
};
let (input, input_l) = input.storage_and_layout();
let input = match &*input {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("input must be a cuda tensor"),
};
let input_offset = input_l.start_offset();
let (weights, weights_l) = weights.storage_and_layout();
let weights = match &*weights {
candle::Storage::Cuda(c) => c.as_cuda_slice::<T>()?,
_ => candle::bail!("weight must be a cuda tensor"),
};
let weights_offset = weights_l.start_offset();
let (sorted_token_ids, sti_l) = sorted_token_ids.storage_and_layout();
let sorted_token_ids = match &*sorted_token_ids {
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?,
_ => candle::bail!("sorted_token_ids must be a cuda tensor"),
};
let sti_offset = sti_l.start_offset();
let (experts_ids, ei_l) = experts_ids.storage_and_layout();
let experts_ids = match &*experts_ids {
candle::Storage::Cuda(c) => c.as_cuda_slice::<u32>()?,
_ => candle::bail!("experts_ids must be a cuda tensor"),
};
let ei_offset = ei_l.start_offset();
let topk_weights_ptr = if let Some(topk_weights) = &topk_weights {
let (topk_weights, tw_l) = topk_weights.storage_and_layout();
let topk_weights = match &*topk_weights {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("topk_weights must be a cuda tensor"),
};
let tw_offset = tw_l.start_offset();
let topk_w_ptr = topk_weights
.slice(tw_offset..)
.device_ptr(topk_weights.stream())
.0 as *const f32;
topk_w_ptr
} else {
std::ptr::null()
};
let output = unsafe { dev.alloc::<T>(size_m * size_n) }?;
let stream = dev.cuda_stream().cu_stream() as i64;
use core::ffi::c_void;
const GEMV_THRESHOLD: i32 = 8;
let num_experts_i32 = i32::try_from(num_experts).expect("num_experts too large for i32");
let topk_i32 = i32::try_from(topk).expect("topk too large for i32");
let size_m_i32 = i32::try_from(size_m).expect("size_m too large for i32");
let size_n_i32 = i32::try_from(size_n).expect("size_n too large for i32");
let size_k_i32 = i32::try_from(size_k).expect("size_k too large for i32");
let moe_func = if is_prefill {
crate::cuda::ffi::moe_gemm_wmma_transposed
} else if size_m_i32 <= GEMV_THRESHOLD {
crate::cuda::ffi::moe_gemv_transposed
} else {
crate::cuda::ffi::moe_gemm_transposed
};
unsafe {
moe_func(
input.slice(input_offset..).device_ptr(input.stream()).0 as *const c_void, weights
.slice(weights_offset..)
.device_ptr(weights.stream())
.0 as *const c_void, sorted_token_ids
.slice(sti_offset..)
.device_ptr(sorted_token_ids.stream())
.0 as *const i32,
experts_ids
.slice(ei_offset..)
.device_ptr(experts_ids.stream())
.0 as *const i32,
topk_weights_ptr,
output.device_ptr(output.stream()).0 as *mut c_void, num_experts_i32,
topk_i32,
size_m_i32,
size_n_i32,
size_k_i32,
data_type as i32, stream,
);
}
let output = candle::CudaStorage::wrap_cuda_slice(output, dev.clone());
let output = Tensor::from((candle::Storage::Cuda(output), (size_m, size_n)));
Ok(output)
}
match input.dtype() {
DType::F16 => cuda_fwd::<f16>(
input,
weights,
topk_weights,
sorted_token_ids,
experts_ids,
topk,
is_prefill,
),
DType::BF16 => cuda_fwd::<bf16>(
input,
weights,
topk_weights,
sorted_token_ids,
experts_ids,
topk,
is_prefill,
),
_ => {
candle_core::bail!("moe_gemm_transposed only accept f16/bf16 inputs!")
}
}
}
#[cfg(not(feature = "cuda"))]
#[allow(unused)]
pub fn moe_gemm_transposed(
_: &Tensor,
_: &Tensor,
_: &Option<Tensor>,
_: &Tensor,
_: &Tensor,
_: usize,
_: bool,
) -> Result<Tensor> {
candle_core::bail!("moe_gemm_transposed is not implemented on this platform!")
}