use std::sync::Arc;
use oxicuda_blas::GpuFloat;
use oxicuda_driver::Module;
use oxicuda_launch::{Dim3, Kernel, LaunchParams};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::prelude::*;
use crate::error::{DnnError, DnnResult};
use crate::handle::DnnHandle;
use crate::ptx_helpers;
use crate::types::{TensorDesc, TensorDescMut};
const PERM_BLOCK_X: u32 = 256;
fn generate_permute_ptx<T: GpuFloat>(sm: SmVersion) -> DnnResult<String> {
let kernel_name = format!("moe_permute_tokens_{}", T::NAME);
let elem_bytes = T::SIZE as u32;
let ptx = KernelBuilder::new(&kernel_name)
.target(sm)
.param("input_ptr", PtxType::U64)
.param("perm_ptr", PtxType::U64)
.param("output_ptr", PtxType::U64)
.param("num_rows", PtxType::U32)
.param("hidden_dim", PtxType::U32)
.body(move |b| {
let col = b.global_thread_id_x();
let row = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {row}, %ctaid.y;"));
let num_rows = b.load_param_u32("num_rows");
let hidden = b.load_param_u32("hidden_dim");
let exit_lbl = b.fresh_label("exit");
let pred_row = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.u32 {pred_row}, {row}, {num_rows};"));
b.branch_if(pred_row, &exit_lbl);
let pred_col = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.u32 {pred_col}, {col}, {hidden};"));
b.branch_if(pred_col, &exit_lbl);
let input_ptr = b.load_param_u64("input_ptr");
let perm_ptr = b.load_param_u64("perm_ptr");
let output_ptr = b.load_param_u64("output_ptr");
let perm_addr = b.byte_offset_addr(perm_ptr, row.clone(), 4);
let dest_row = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("ld.global.u32 {dest_row}, [{perm_addr}];"));
let row_off = b.mul_lo_u32(row, hidden.clone());
let src_idx = b.add_u32(row_off, col.clone());
let src_addr = b.byte_offset_addr(input_ptr, src_idx, elem_bytes);
let dst_off = b.mul_lo_u32(dest_row, hidden);
let dst_idx = b.add_u32(dst_off, col);
let dst_addr = b.byte_offset_addr(output_ptr, dst_idx, elem_bytes);
let val = ptx_helpers::load_global_float::<T>(b, src_addr);
ptx_helpers::store_global_float::<T>(b, dst_addr, val);
b.label(&exit_lbl);
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
pub fn permute_tokens<T: GpuFloat>(
handle: &DnnHandle,
input: &TensorDesc<T>,
permutation: &DeviceBuffer<i32>,
output: &mut TensorDescMut<T>,
) -> DnnResult<()> {
if input.ndim() != 2 {
return Err(DnnError::InvalidDimension(format!(
"input must be 2D, got {}D",
input.ndim()
)));
}
if output.ndim() != 2 {
return Err(DnnError::InvalidDimension(format!(
"output must be 2D, got {}D",
output.ndim()
)));
}
let num_rows = input.dims[0];
let hidden_dim = input.dims[1];
if output.dims[0] != num_rows || output.dims[1] != hidden_dim {
return Err(DnnError::InvalidDimension(format!(
"output dims [{}, {}] must match input dims [{}, {}]",
output.dims[0], output.dims[1], num_rows, hidden_dim
)));
}
if permutation.len() < num_rows as usize {
return Err(DnnError::BufferTooSmall {
expected: num_rows as usize,
actual: permutation.len(),
});
}
let ptx = generate_permute_ptx::<T>(handle.sm_version())?;
let kernel_name = format!("moe_permute_tokens_{}", T::NAME);
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &kernel_name)?;
let grid_x = hidden_dim.div_ceil(PERM_BLOCK_X);
let grid = Dim3::new(grid_x, num_rows, 1);
let block = Dim3::new(PERM_BLOCK_X, 1, 1);
let params = LaunchParams::new(grid, block);
let args = (
input.ptr,
permutation.as_device_ptr(),
output.ptr,
num_rows,
hidden_dim,
);
kernel.launch(¶ms, handle.stream(), &args)?;
Ok(())
}
fn generate_unpermute_ptx<T: GpuFloat>(sm: SmVersion, top_k: u32) -> DnnResult<String> {
let kernel_name = format!("moe_unpermute_tokens_{}", T::NAME);
let elem_bytes = T::SIZE as u32;
let ptx = KernelBuilder::new(&kernel_name)
.target(sm)
.param("expert_out_ptr", PtxType::U64)
.param("perm_ptr", PtxType::U64)
.param("weights_ptr", PtxType::U64)
.param("output_ptr", PtxType::U64)
.param("num_tokens", PtxType::U32)
.param("hidden_dim", PtxType::U32)
.param("top_k", PtxType::U32)
.body(move |b| {
let col = b.global_thread_id_x();
let token = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {token}, %ctaid.y;"));
let num_tok = b.load_param_u32("num_tokens");
let hidden = b.load_param_u32("hidden_dim");
let p_topk = b.load_param_u32("top_k");
let exit_lbl = b.fresh_label("exit");
let pred_tok = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.u32 {pred_tok}, {token}, {num_tok};"));
b.branch_if(pred_tok, &exit_lbl);
let pred_col = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.u32 {pred_col}, {col}, {hidden};"));
b.branch_if(pred_col, &exit_lbl);
let expert_out = b.load_param_u64("expert_out_ptr");
let perm_ptr = b.load_param_u64("perm_ptr");
let weights_ptr = b.load_param_u64("weights_ptr");
let output_ptr = b.load_param_u64("output_ptr");
let acc = ptx_helpers::load_float_imm::<T>(b, 0.0);
let base_slot = b.mul_lo_u32(token.clone(), p_topk);
b.unroll(top_k, |b, j| {
let j_reg = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {j_reg}, {j};"));
let slot = b.add_u32(base_slot.clone(), j_reg);
let wt_addr = b.byte_offset_addr(weights_ptr.clone(), slot.clone(), elem_bytes);
let weight = ptx_helpers::load_global_float::<T>(b, wt_addr);
let perm_addr = b.byte_offset_addr(perm_ptr.clone(), slot, 4);
let src_row = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("ld.global.u32 {src_row}, [{perm_addr}];"));
let row_off = b.mul_lo_u32(src_row, hidden.clone());
let idx = b.add_u32(row_off, col.clone());
let src_addr = b.byte_offset_addr(expert_out.clone(), idx, elem_bytes);
let val = ptx_helpers::load_global_float::<T>(b, src_addr);
let contribution = ptx_helpers::fma_float::<T>(b, weight, val, acc.clone());
let ty_str = if T::PTX_TYPE == PtxType::F32 {
"f32"
} else {
"f64"
};
b.raw_ptx(&format!("mov.{ty_str} {}, {contribution};", acc.clone()));
});
let out_off = b.mul_lo_u32(token, hidden);
let out_idx = b.add_u32(out_off, col);
let out_addr = b.byte_offset_addr(output_ptr, out_idx, elem_bytes);
ptx_helpers::store_global_float::<T>(b, out_addr, acc);
b.label(&exit_lbl);
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
pub fn unpermute_tokens<T: GpuFloat>(
handle: &DnnHandle,
expert_output: &TensorDesc<T>,
permutation: &DeviceBuffer<i32>,
expert_weights: &DeviceBuffer<T>,
output: &mut TensorDescMut<T>,
top_k: u32,
) -> DnnResult<()> {
if expert_output.ndim() != 2 {
return Err(DnnError::InvalidDimension(format!(
"expert_output must be 2D, got {}D",
expert_output.ndim()
)));
}
if output.ndim() != 2 {
return Err(DnnError::InvalidDimension(format!(
"output must be 2D, got {}D",
output.ndim()
)));
}
if top_k == 0 {
return Err(DnnError::InvalidArgument("top_k must be positive".into()));
}
let num_tokens = output.dims[0];
let hidden_dim = output.dims[1];
if expert_output.dims[1] != hidden_dim {
return Err(DnnError::InvalidDimension(format!(
"expert_output hidden_dim ({}) != output hidden_dim ({})",
expert_output.dims[1], hidden_dim
)));
}
let total_slots = num_tokens as usize * top_k as usize;
if permutation.len() < total_slots {
return Err(DnnError::BufferTooSmall {
expected: total_slots,
actual: permutation.len(),
});
}
if expert_weights.len() < total_slots {
return Err(DnnError::BufferTooSmall {
expected: total_slots,
actual: expert_weights.len(),
});
}
let ptx = generate_unpermute_ptx::<T>(handle.sm_version(), top_k)?;
let kernel_name = format!("moe_unpermute_tokens_{}", T::NAME);
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &kernel_name)?;
let grid_x = hidden_dim.div_ceil(PERM_BLOCK_X);
let grid = Dim3::new(grid_x, num_tokens, 1);
let block = Dim3::new(PERM_BLOCK_X, 1, 1);
let params = LaunchParams::new(grid, block);
let args = (
expert_output.ptr,
permutation.as_device_ptr(),
expert_weights.as_device_ptr(),
output.ptr,
num_tokens,
hidden_dim,
top_k,
);
kernel.launch(¶ms, handle.stream(), &args)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn permute_ptx_gen_f32() {
let ptx = generate_permute_ptx::<f32>(SmVersion::Sm80);
assert!(ptx.is_ok());
let text = ptx.unwrap_or_default();
assert!(text.contains(".entry moe_permute_tokens_f32"));
}
#[test]
fn unpermute_ptx_gen_f32() {
let ptx = generate_unpermute_ptx::<f32>(SmVersion::Sm80, 2);
assert!(ptx.is_ok());
let text = ptx.unwrap_or_default();
assert!(text.contains(".entry moe_unpermute_tokens_f32"));
}
}