use crate::dtype::{Dtype as WeightDtype, TypedPtr};
use crate::error::{Error, Result};
use crate::kernels::{CUptr, Kernels as GpuKernels};
use cudarc::driver::PushKernelArg;
use std::sync::Arc;
const SPLITM_TN_TARGET_GRID_FACTOR: u32 = 284;
const SPLITM_TN_SCRATCH_CAP: usize = 1 << 23;
const SPLITM_TN_BK_ALIGN: u32 = 16;
#[inline]
fn splitm_tn_partition(batch: usize, n_in: usize, n_out: usize) -> Option<(usize, usize)> {
if !(n_in >= 128 && n_out >= 128 && batch >= 256) {
return None;
}
let k_tiles = (n_in as u32).div_ceil(128);
let n_tiles = (n_out as u32).div_ceil(128);
let base_blocks = k_tiles * n_tiles;
if base_blocks == 0 || base_blocks >= SPLITM_TN_TARGET_GRID_FACTOR {
return None;
}
let f_grid = SPLITM_TN_TARGET_GRID_FACTOR.div_ceil(base_blocks);
let f_scratch_cap = (SPLITM_TN_SCRATCH_CAP / (n_in * n_out)) as u32;
let f = f_grid.min(f_scratch_cap).max(1);
let m_chunk_raw = (batch as u32).div_ceil(f);
let m_chunk = (m_chunk_raw + SPLITM_TN_BK_ALIGN - 1) & !(SPLITM_TN_BK_ALIGN - 1);
let f_final = (batch as u32).div_ceil(m_chunk);
if f_final < 2 || (f_final as usize) * n_in * n_out > SPLITM_TN_SCRATCH_CAP {
return None;
}
Some((m_chunk as usize, f_final as usize))
}
const SGEMM_CUSTOM_MIN: usize = 128;
const SGEMM_SLIM_MAX: usize = 512;
const SGEMM_SLIM_NT_NIN_MAX: usize = 768;
const SGEMM_M_SLIM_FORCE: usize = 512;
pub(crate) const SPLITK_SCRATCH_CAP: usize = 1 << 23;
pub(crate) const NUM_SMS: u32 = 142;
fn dispatch_slim_or_big<'k>(
_kernels: &'k GpuKernels,
m: usize,
n_out: usize,
func_slim: &'k cudarc::driver::CudaFunction,
func_big: &'k cudarc::driver::CudaFunction,
) -> (&'k cudarc::driver::CudaFunction, u32) {
let slim = n_out <= SGEMM_SLIM_MAX || (m < SGEMM_M_SLIM_FORCE && n_out >= SGEMM_CUSTOM_MIN);
let func = if slim { func_slim } else { func_big };
let bn: u32 = if slim { 64 } else { 128 };
(func, bn)
}
pub fn sgemm_bi_forward(
stream: &Arc<cudarc::driver::CudaStream>,
kernels: &GpuKernels,
y_ptr: CUptr,
x_ptr: CUptr,
w_ptr: CUptr,
bias_ptr: CUptr, dims: (usize, usize, usize),
) -> Result<()> {
let (batch, n_in, n_out) = dims;
if (1..32).contains(&batch) && (32..=2048).contains(&n_in) && n_out >= 32 {
let m_i = batch as i32;
let n_i = n_out as i32;
let k_i = n_in as i32;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: ((n_out as u32).div_ceil(32), batch as u32, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: (n_in * std::mem::size_of::<f32>()) as u32,
};
let mut builder = stream.launch_builder(&kernels.sgemm_nn_ultra_thin);
builder.arg(&y_ptr);
builder.arg(&x_ptr);
builder.arg(&w_ptr);
builder.arg(&bias_ptr);
builder.arg(&alpha);
builder.arg(&beta);
builder.arg(&m_i);
builder.arg(&n_i);
builder.arg(&k_i);
builder.arg(&k_i); builder.arg(&n_i); builder.arg(&n_i); unsafe { builder.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nn_ultra_thin forward: {:?}", e)))?;
return Ok(());
}
if (2..=127).contains(&n_out) && (1..=64).contains(&batch) && n_in >= 1 {
let m_i = batch as i32;
let n_i = n_out as i32;
let k_i = n_in as i32;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let post_op: i32 = 0;
let num_pid_m = (batch as u32).div_ceil(16);
let num_pid_n = (n_out as u32).div_ceil(16);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_pid_m * num_pid_n, 1, 1),
block_dim: (64, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&kernels.sgemm_nn_narrow_small);
builder.arg(&y_ptr);
builder.arg(&x_ptr);
builder.arg(&w_ptr);
builder.arg(&bias_ptr);
builder.arg(&alpha);
builder.arg(&beta);
builder.arg(&m_i);
builder.arg(&n_i);
builder.arg(&k_i);
builder.arg(&k_i);
builder.arg(&n_i);
builder.arg(&n_i);
builder.arg(&post_op);
unsafe { builder.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nn_narrow_small forward: {:?}", e)))?;
return Ok(());
}
if (2..=127).contains(&n_out) && batch >= 1 && n_in >= 1 {
let m_i = batch as i32;
let n_i = n_out as i32;
let k_i = n_in as i32;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let post_op: i32 = 0;
let num_pid_m = (batch as u32).div_ceil(64);
let num_pid_n = (n_out as u32).div_ceil(32);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_pid_m * num_pid_n, 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&kernels.sgemm_nn_narrow);
builder.arg(&y_ptr);
builder.arg(&x_ptr);
builder.arg(&w_ptr);
builder.arg(&bias_ptr);
builder.arg(&alpha);
builder.arg(&beta);
builder.arg(&m_i);
builder.arg(&n_i);
builder.arg(&k_i);
builder.arg(&k_i);
builder.arg(&n_i);
builder.arg(&n_i);
builder.arg(&post_op);
unsafe { builder.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nn_narrow forward: {:?}", e)))?;
return Ok(());
}
if n_out == 1 && batch >= 1 && n_in >= 32 {
let m_i = batch as i32;
let k_i = n_in as i32;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let lda_i = n_in as i32;
let ldy_i: i32 = 1;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: ((batch as u32).div_ceil(4), 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&kernels.sgemm_nn_gemv);
builder.arg(&y_ptr);
builder.arg(&x_ptr);
builder.arg(&w_ptr);
builder.arg(&bias_ptr);
builder.arg(&alpha);
builder.arg(&beta);
builder.arg(&m_i);
builder.arg(&k_i);
builder.arg(&lda_i);
builder.arg(&ldy_i);
unsafe { builder.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nn_gemv forward: {:?}", e)))?;
return Ok(());
}
let plain_slim_blocks_nn_ktail = (batch as u32).div_ceil(128) * (n_out as u32).div_ceil(64);
let underfill_nn_ktail = plain_slim_blocks_nn_ktail < NUM_SMS;
if (32..=1024).contains(&batch)
&& (64..=2048).contains(&n_out)
&& n_out.is_multiple_of(4)
&& n_in >= 33
&& !n_in.is_multiple_of(32)
&& underfill_nn_ktail
{
let k_tail = n_in % 32;
let k_main = n_in - k_tail;
let partial_size = (k_main / 32) * batch * n_out;
if k_main >= 32 && partial_size <= SPLITK_SCRATCH_CAP {
let m_i = batch as i32;
let n_i = n_out as i32;
let k_chunks = (k_main / 32) as i32;
let lda_i = n_in as i32; let alpha: f32 = 1.0;
let num_pid_m = (batch as u32).div_ceil(32);
let num_pid_n = (n_out as u32).div_ceil(64);
let partial_cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_pid_m * num_pid_n * k_chunks as u32, 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let partial_ptr = kernels.splitk_scratch_ptr;
let mut pb = stream.launch_builder(&kernels.sgemm_nn_splitk32_partial);
pb.arg(&partial_ptr);
pb.arg(&x_ptr);
pb.arg(&w_ptr);
pb.arg(&m_i);
pb.arg(&n_i);
pb.arg(&k_chunks);
pb.arg(&lda_i);
unsafe { pb.launch(partial_cfg) }.map_err(|e| {
Error::Cuda(format!(
"sgemm_bi_nn_splitk32_partial (K-tail main): {:?}",
e
))
})?;
let total = (batch * n_out) as u32;
let reduce_cfg = cudarc::driver::LaunchConfig {
grid_dim: (total.div_ceil(256), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let zero_i32: i32 = 0;
let tail_cnt_i = k_tail as i32;
let x_base_ptr = x_ptr;
let x_tail_ptr: u64 = x_base_ptr + (k_main as u64) * 4; let w_tail_ptr: u64 = w_ptr + ((k_main * n_out) as u64) * 4; let x_tail_stride_i = n_in as i32; let mut rb = stream.launch_builder(&kernels.sgemm_splitk_reduce);
rb.arg(&y_ptr);
rb.arg(&partial_ptr);
rb.arg(&bias_ptr);
rb.arg(&x_tail_ptr);
rb.arg(&w_tail_ptr);
rb.arg(&alpha);
rb.arg(&m_i);
rb.arg(&n_i);
rb.arg(&k_chunks);
rb.arg(&x_tail_stride_i);
rb.arg(&zero_i32); rb.arg(&tail_cnt_i);
unsafe { rb.launch(reduce_cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_splitk_reduce (K-tail): {:?}", e)))?;
return Ok(());
}
}
let partial_size = (n_in / 32) * batch * n_out;
let plain_slim_blocks_nn_main = (batch as u32).div_ceil(128) * (n_out as u32).div_ceil(64);
let underfill_nn_main = plain_slim_blocks_nn_main < NUM_SMS;
if (32..=1024).contains(&batch)
&& (64..=2048).contains(&n_out)
&& n_out.is_multiple_of(4)
&& n_in >= 32
&& n_in.is_multiple_of(32)
&& partial_size <= SPLITK_SCRATCH_CAP
&& underfill_nn_main
{
let m_i = batch as i32;
let n_i = n_out as i32;
let k_chunks = (n_in / 32) as i32;
let alpha: f32 = 1.0;
let num_pid_m = (batch as u32).div_ceil(32);
let num_pid_n = (n_out as u32).div_ceil(64);
let partial_cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_pid_m * num_pid_n * k_chunks as u32, 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let partial_ptr = {
kernels.splitk_scratch_ptr
};
let lda_i = n_in as i32; let mut pb = stream.launch_builder(&kernels.sgemm_nn_splitk32_partial);
pb.arg(&partial_ptr);
pb.arg(&x_ptr);
pb.arg(&w_ptr);
pb.arg(&m_i);
pb.arg(&n_i);
pb.arg(&k_chunks);
pb.arg(&lda_i);
unsafe { pb.launch(partial_cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nn_splitk32_partial: {:?}", e)))?;
let total = (batch * n_out) as u32;
let reduce_cfg = cudarc::driver::LaunchConfig {
grid_dim: (total.div_ceil(256), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let null_tail: u64 = 0;
let zero_i32: i32 = 0;
let mut rb = stream.launch_builder(&kernels.sgemm_splitk_reduce);
rb.arg(&y_ptr);
rb.arg(&partial_ptr);
rb.arg(&bias_ptr);
rb.arg(&null_tail); rb.arg(&null_tail); rb.arg(&alpha);
rb.arg(&m_i);
rb.arg(&n_i);
rb.arg(&k_chunks);
rb.arg(&zero_i32); rb.arg(&zero_i32); rb.arg(&zero_i32); unsafe { rb.launch(reduce_cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_splitk_reduce: {:?}", e)))?;
return Ok(());
}
const SPLITK_SLIM_K_CHUNK: u32 = 64; if batch > 1024
&& (128..=SGEMM_SLIM_MAX).contains(&n_out)
&& n_in >= SPLITK_SLIM_K_CHUNK as usize && n_in.is_multiple_of(32)
{
let f_final = (n_in as u32).div_ceil(SPLITK_SLIM_K_CHUNK);
if f_final >= 6 && (f_final as usize) * batch * n_out <= SPLITK_SCRATCH_CAP {
let m_tiles = (batch as u32).div_ceil(128);
let n_tiles = (n_out as u32).div_ceil(64);
let base_blocks = m_tiles * n_tiles;
if base_blocks > 0 && base_blocks < 3 * NUM_SMS {
let k_chunk = SPLITK_SLIM_K_CHUNK;
let m_i = batch as i32;
let n_i = n_out as i32;
let k_i = n_in as i32;
let lda_i = n_in as i32; let ldb_i = n_out as i32; let k_chunk_i = k_chunk as i32;
let alpha: f32 = 1.0;
let partial_ptr = kernels.splitk_scratch_ptr;
let partial_cfg = cudarc::driver::LaunchConfig {
grid_dim: (base_blocks, 1, f_final),
block_dim: (128, 1, 1), shared_mem_bytes: 0, };
let mut pb = stream.launch_builder(&kernels.sgemm_nn_splitk_slim_partial);
pb.arg(&partial_ptr);
pb.arg(&x_ptr);
pb.arg(&w_ptr);
pb.arg(&m_i);
pb.arg(&n_i);
pb.arg(&k_i);
pb.arg(&lda_i);
pb.arg(&ldb_i);
pb.arg(&k_chunk_i);
unsafe { pb.launch(partial_cfg) }.map_err(|e| {
Error::Cuda(format!("sgemm_bi_nn_splitk_slim_partial: {:?}", e))
})?;
let total = (batch * n_out) as u32;
let reduce_cfg = cudarc::driver::LaunchConfig {
grid_dim: (total.div_ceil(256), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let null_tail: u64 = 0;
let zero_i32_local: i32 = 0;
let f_i = f_final as i32;
let mut rb = stream.launch_builder(&kernels.sgemm_splitk_reduce);
rb.arg(&y_ptr);
rb.arg(&partial_ptr);
rb.arg(&bias_ptr);
rb.arg(&null_tail); rb.arg(&null_tail); rb.arg(&alpha);
rb.arg(&m_i);
rb.arg(&n_i);
rb.arg(&f_i); rb.arg(&zero_i32_local); rb.arg(&zero_i32_local); rb.arg(&zero_i32_local); unsafe { rb.launch(reduce_cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_splitk_reduce (slim): {:?}", e)))?;
return Ok(());
}
}
}
if batch < 128 && n_out >= 128 && n_in >= 1 {
let m_i = batch as i32;
let n_i = n_out as i32;
let k_i = n_in as i32;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let post_op: i32 = 0;
let num_pid_m = (batch as u32).div_ceil(64);
let num_pid_n = (n_out as u32).div_ceil(32);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_pid_m * num_pid_n, 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&kernels.sgemm_nn_narrow);
builder.arg(&y_ptr);
builder.arg(&x_ptr);
builder.arg(&w_ptr);
builder.arg(&bias_ptr);
builder.arg(&alpha);
builder.arg(&beta);
builder.arg(&m_i);
builder.arg(&n_i);
builder.arg(&k_i);
builder.arg(&k_i);
builder.arg(&n_i);
builder.arg(&n_i);
builder.arg(&post_op);
unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Cuda(format!(
"sgemm_bi_nn_narrow (gap-fill thin-M wide-N): {:?}",
e
))
})?;
return Ok(());
}
if batch >= SGEMM_CUSTOM_MIN && n_out >= SGEMM_CUSTOM_MIN && n_in >= 1 {
let m_i = batch as i32;
let n_i = n_out as i32;
let k_i = n_in as i32;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let (func, bn) = dispatch_slim_or_big(
kernels,
batch,
n_out,
&kernels.sgemm_nn_slim,
&kernels.sgemm_nn,
);
let slim = bn == 64;
let threads = if slim { 128u32 } else { 256u32 };
let smem_bytes: u32 = if slim { 0 } else { 34 * 1024 };
let total_tiles = (batch as u32).div_ceil(128) * (n_out as u32).div_ceil(bn);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (total_tiles, 1, 1),
block_dim: (threads, 1, 1),
shared_mem_bytes: smem_bytes,
};
let mut builder = stream.launch_builder(func);
builder.arg(&y_ptr);
builder.arg(&x_ptr);
builder.arg(&w_ptr);
builder.arg(&bias_ptr);
builder.arg(&alpha);
builder.arg(&beta);
builder.arg(&m_i);
builder.arg(&n_i);
builder.arg(&k_i);
builder.arg(&k_i); builder.arg(&n_i); builder.arg(&n_i); unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Cuda(format!(
"sgemm_bi_nn{} forward: {:?}",
if slim { "_slim" } else { "" },
e
))
})?;
return Ok(());
}
Err(Error::Uncovered {
op: "sgemm_bi_forward (f32)",
m: batch,
k: n_in,
n: n_out,
})
}
pub fn sgemm_bi_backward_dw(
stream: &Arc<cudarc::driver::CudaStream>,
kernels: &GpuKernels,
dw_ptr: CUptr, dy_ptr: CUptr,
x_saved_ptr: CUptr,
dims: (usize, usize, usize),
) -> Result<()> {
let (batch, n_in, n_out) = dims;
if n_out == 1 && n_in >= 4 && batch >= 32 {
let m_i = batch as i32;
let k_i = n_in as i32;
let alpha: f32 = 1.0;
let lda_i = n_in as i32;
let ldy_i: i32 = 1;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: ((n_in as u32).div_ceil(4), 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&kernels.sgemm_tn_gemv);
builder.arg(&dw_ptr);
builder.arg(&x_saved_ptr);
builder.arg(&dy_ptr);
builder.arg(&alpha);
builder.arg(&m_i);
builder.arg(&k_i);
builder.arg(&lda_i);
builder.arg(&ldy_i);
unsafe { builder.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_tn_gemv backward_dw: {:?}", e)))?;
return Ok(());
}
if (2..=127).contains(&n_out) && n_in >= 1 && batch >= 1 {
let m_i = batch as i32;
let k_i = n_in as i32;
let n_i = n_out as i32;
let alpha: f32 = 1.0;
let num_pid_m = (n_in as u32).div_ceil(64);
let num_pid_n = (n_out as u32).div_ceil(32);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_pid_m * num_pid_n, 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&kernels.sgemm_tn_narrow);
builder.arg(&dw_ptr);
builder.arg(&x_saved_ptr);
builder.arg(&dy_ptr);
builder.arg(&alpha);
builder.arg(&m_i);
builder.arg(&k_i);
builder.arg(&n_i);
unsafe { builder.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_tn_narrow backward_dw: {:?}", e)))?;
return Ok(());
}
if let Some((m_chunk, f_final)) = splitm_tn_partition(batch, n_in, n_out) {
let base_blocks = (n_in as u32).div_ceil(128) * (n_out as u32).div_ceil(128);
let m_i = batch as i32;
let k_i = n_in as i32;
let n_i = n_out as i32;
let m_chunk_i = m_chunk as i32;
let alpha: f32 = 1.0;
let f_i = f_final as i32;
let f_final_u32 = f_final as u32;
let partial_ptr = kernels.splitk_scratch_ptr;
let partial_cfg = cudarc::driver::LaunchConfig {
grid_dim: (base_blocks, 1, f_final_u32),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let mut pb = stream.launch_builder(&kernels.sgemm_tn_splitm_partial);
pb.arg(&partial_ptr);
pb.arg(&x_saved_ptr);
pb.arg(&dy_ptr);
pb.arg(&m_i);
pb.arg(&k_i);
pb.arg(&n_i);
pb.arg(&m_chunk_i);
unsafe { pb.launch(partial_cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_tn_splitm_partial: {:?}", e)))?;
let total = (n_in * n_out) as u32;
let reduce_cfg = cudarc::driver::LaunchConfig {
grid_dim: (total.div_ceil(256), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let mut rb = stream.launch_builder(&kernels.sgemm_splitm_reduce);
rb.arg(&dw_ptr);
rb.arg(&partial_ptr);
rb.arg(&alpha);
rb.arg(&k_i);
rb.arg(&n_i);
rb.arg(&f_i);
unsafe { rb.launch(reduce_cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_splitm_reduce: {:?}", e)))?;
return Ok(());
}
if n_in >= 1 && n_out >= SGEMM_CUSTOM_MIN {
let m_i = batch as i32;
let k_i = n_in as i32;
let n_i = n_out as i32;
let alpha: f32 = 1.0;
let (func, bn) = dispatch_slim_or_big(
kernels,
n_in, n_out,
&kernels.sgemm_tn_slim,
&kernels.sgemm_tn,
);
let slim = bn == 64;
let threads = if slim { 128u32 } else { 256u32 };
let smem_bytes: u32 = if slim { 0 } else { 34 * 1024 };
let total_tiles = (n_in as u32).div_ceil(128) * (n_out as u32).div_ceil(bn);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (total_tiles, 1, 1),
block_dim: (threads, 1, 1),
shared_mem_bytes: smem_bytes,
};
let mut builder = stream.launch_builder(func);
builder.arg(&dw_ptr);
builder.arg(&x_saved_ptr);
builder.arg(&dy_ptr);
builder.arg(&alpha);
builder.arg(&m_i);
builder.arg(&k_i);
builder.arg(&n_i);
unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Cuda(format!(
"sgemm_bi_tn{} backward_dw: {:?}",
if slim { "_slim" } else { "" },
e
))
})?;
return Ok(());
}
Err(Error::Uncovered {
op: "sgemm_bi_backward_dw (f32)",
m: batch,
k: n_in,
n: n_out,
})
}
pub fn sgemm_bi_backward_dx(
stream: &Arc<cudarc::driver::CudaStream>,
kernels: &GpuKernels,
dx_ptr: CUptr,
dy_ptr: CUptr,
w_ptr: CUptr,
dims: (usize, usize, usize),
) -> Result<()> {
let (batch, n_in, n_out) = dims;
if (2..=127).contains(&n_out) && n_in >= 1 && batch >= 1 {
let m_i = batch as i32;
let n_i = n_out as i32;
let k_i = n_in as i32;
let alpha: f32 = 1.0;
let num_pid_m = (batch as u32).div_ceil(64);
let num_pid_n = (n_in as u32).div_ceil(32);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_pid_m * num_pid_n, 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&kernels.sgemm_nt_narrow);
builder.arg(&dx_ptr);
builder.arg(&dy_ptr);
builder.arg(&w_ptr);
builder.arg(&alpha);
builder.arg(&m_i);
builder.arg(&n_i);
builder.arg(&k_i);
unsafe { builder.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nt_narrow backward_dx: {:?}", e)))?;
return Ok(());
}
if batch < 32 && n_in >= 1 && n_out >= 128 {
let m_i = batch as i32;
let n_i = n_out as i32;
let k_i = n_in as i32;
let alpha: f32 = 1.0;
let num_pid_m = (batch as u32).div_ceil(64);
let num_pid_n = (n_in as u32).div_ceil(32);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_pid_m * num_pid_n, 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&kernels.sgemm_nt_narrow);
builder.arg(&dx_ptr);
builder.arg(&dy_ptr);
builder.arg(&w_ptr);
builder.arg(&alpha);
builder.arg(&m_i);
builder.arg(&n_i);
builder.arg(&k_i);
unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Cuda(format!(
"sgemm_bi_nt_narrow (small-batch wide-N) backward_dx: {:?}",
e
))
})?;
return Ok(());
}
if n_out == 1 && n_in >= 1 && batch >= 1 {
let m_i = batch as i32;
let k_i = n_in as i32;
let alpha: f32 = 1.0;
let ldx_i = n_in as i32;
let ldy_i: i32 = 1;
let total = (batch * n_in) as u32;
let block = 256u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (total.div_ceil(block), 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&kernels.sgemm_nt_gemv);
builder.arg(&dx_ptr);
builder.arg(&dy_ptr);
builder.arg(&w_ptr);
builder.arg(&alpha);
builder.arg(&m_i);
builder.arg(&k_i);
builder.arg(&ldx_i);
builder.arg(&ldy_i);
unsafe { builder.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nt_gemv backward_dx: {:?}", e)))?;
return Ok(());
}
let plain_slim_blocks_nt_ktail = (batch as u32).div_ceil(128) * (n_in as u32).div_ceil(64);
let underfill_nt_ktail = plain_slim_blocks_nt_ktail < NUM_SMS;
if (32..=1024).contains(&batch)
&& (64..=4096).contains(&n_in)
&& n_in >= 33
&& !n_in.is_multiple_of(32)
&& (32..=2048).contains(&n_out)
&& n_out.is_multiple_of(32)
&& underfill_nt_ktail
{
let k_tail_cnt = n_in % 32;
let k_main = n_in - k_tail_cnt;
let w_size_main = k_main * n_out;
let partial_size_main = (n_out / 32) * batch * k_main;
if k_main >= 32
&& w_size_main <= SPLITK_NT_TRANSPOSE_CAP
&& partial_size_main <= SPLITK_SCRATCH_CAP
{
let rows_i = k_main as i32;
let cols_i = n_out as i32;
let t_grid_x = (n_out as u32).div_ceil(32);
let t_grid_y = (k_main as u32).div_ceil(32);
let t_cfg = cudarc::driver::LaunchConfig {
grid_dim: (t_grid_x, t_grid_y, 1),
block_dim: (32, 32, 1),
shared_mem_bytes: 0,
};
let w_t_ptr = kernels.transpose_scratch_ptr;
let mut tb = stream.launch_builder(&kernels.sgemm_transpose_f32_2d);
tb.arg(&w_t_ptr);
tb.arg(&w_ptr);
tb.arg(&rows_i);
tb.arg(&cols_i);
unsafe { tb.launch(t_cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_transpose_f32_2d (K-tail): {:?}", e)))?;
let m_i = batch as i32;
let k_main_i = k_main as i32;
let k_chunks = (n_out / 32) as i32;
let lda_dy_i = n_out as i32;
let num_pid_m = (batch as u32).div_ceil(32);
let num_pid_n = (k_main as u32).div_ceil(64);
let partial_cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_pid_m * num_pid_n * k_chunks as u32, 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let partial_ptr = kernels.splitk_scratch_ptr;
let mut pb = stream.launch_builder(&kernels.sgemm_nn_splitk32_partial);
pb.arg(&partial_ptr);
pb.arg(&dy_ptr);
pb.arg(&w_t_ptr);
pb.arg(&m_i);
pb.arg(&k_main_i);
pb.arg(&k_chunks);
pb.arg(&lda_dy_i);
unsafe { pb.launch(partial_cfg) }.map_err(|e| {
Error::Cuda(format!(
"sgemm_bi_nn_splitk32_partial (NT K-tail main): {:?}",
e
))
})?;
let null_tail: u64 = 0;
let alpha: f32 = 1.0;
let null_bias: u64 = 0;
let zero_i32: i32 = 0;
let out_stride_i = n_in as i32;
let total_main = (batch * k_main) as u32;
let reduce_cfg = cudarc::driver::LaunchConfig {
grid_dim: (total_main.div_ceil(256), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let mut rb = stream.launch_builder(&kernels.sgemm_splitk_reduce);
rb.arg(&dx_ptr);
rb.arg(&partial_ptr);
rb.arg(&null_bias);
rb.arg(&null_tail);
rb.arg(&null_tail);
rb.arg(&alpha);
rb.arg(&m_i);
rb.arg(&k_main_i);
rb.arg(&k_chunks);
rb.arg(&zero_i32);
rb.arg(&out_stride_i); rb.arg(&zero_i32); unsafe { rb.launch(reduce_cfg) }.map_err(|e| {
Error::Cuda(format!("sgemm_bi_splitk_reduce (NT K-tail main): {:?}", e))
})?;
let w_base_ptr = w_ptr;
let n_i = n_out as i32;
let block = 128u32;
let tail_cfg = cudarc::driver::LaunchConfig {
grid_dim: ((batch as u32).div_ceil(block), 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
};
for k in 0..k_tail_cnt {
let k_tail_col = k_main + k;
let w_tail_row_ptr: u64 = w_base_ptr + (k_tail_col * n_out) as u64 * 4;
let col_idx_i = k_tail_col as i32;
let mut gb = stream.launch_builder(&kernels.sgemm_dx_col_gemv);
gb.arg(&dx_ptr);
gb.arg(&dy_ptr);
gb.arg(&w_tail_row_ptr);
gb.arg(&m_i);
gb.arg(&n_i);
gb.arg(&col_idx_i);
gb.arg(&out_stride_i);
unsafe { gb.launch(tail_cfg) }.map_err(|e| {
Error::Cuda(format!(
"sgemm_bi_dx_col_gemv (NT K-tail col={}): {:?}",
k, e
))
})?;
}
return Ok(());
}
}
const SPLITK_NT_TRANSPOSE_CAP: usize = 1 << 22; let n_tail_nt = n_out % 32;
let n_main_nt = n_out - n_tail_nt;
let w_size_nt = n_in * n_out;
let partial_size_nt = if n_main_nt > 0 {
(n_main_nt / 32) * batch * n_in
} else {
0
};
let plain_slim_blocks_nt_main = (batch as u32).div_ceil(128) * (n_in as u32).div_ceil(64);
let underfill_nt_main = plain_slim_blocks_nt_main < NUM_SMS;
if (32..=1024).contains(&batch)
&& (64..=4096).contains(&n_in)
&& n_in.is_multiple_of(4)
&& n_in.is_multiple_of(32)
&& (32..=2048).contains(&n_out)
&& n_main_nt >= 32
&& w_size_nt <= SPLITK_NT_TRANSPOSE_CAP
&& partial_size_nt <= SPLITK_SCRATCH_CAP
&& underfill_nt_main
{
let rows_i = n_in as i32;
let cols_i = n_out as i32;
let t_grid_x = (n_out as u32).div_ceil(32);
let t_grid_y = (n_in as u32).div_ceil(32);
let t_cfg = cudarc::driver::LaunchConfig {
grid_dim: (t_grid_x, t_grid_y, 1),
block_dim: (32, 32, 1),
shared_mem_bytes: 0,
};
let w_t_ptr = kernels.transpose_scratch_ptr;
let mut tb = stream.launch_builder(&kernels.sgemm_transpose_f32_2d);
tb.arg(&w_t_ptr);
tb.arg(&w_ptr);
tb.arg(&rows_i);
tb.arg(&cols_i);
unsafe { tb.launch(t_cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_transpose_f32_2d: {:?}", e)))?;
let m_i = batch as i32;
let k_out_i = n_in as i32;
let k_chunks = (n_main_nt / 32) as i32;
let num_pid_m = (batch as u32).div_ceil(32);
let num_pid_n = (n_in as u32).div_ceil(64);
let partial_cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_pid_m * num_pid_n * k_chunks as u32, 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let partial_ptr = kernels.splitk_scratch_ptr;
let lda_i = n_out as i32; let mut pb = stream.launch_builder(&kernels.sgemm_nn_splitk32_partial);
pb.arg(&partial_ptr);
pb.arg(&dy_ptr);
pb.arg(&w_t_ptr);
pb.arg(&m_i);
pb.arg(&k_out_i);
pb.arg(&k_chunks);
pb.arg(&lda_i);
unsafe { pb.launch(partial_cfg) }.map_err(|e| {
Error::Cuda(format!(
"sgemm_bi_nn_splitk32_partial (NT-via-T N-tail): {:?}",
e
))
})?;
let alpha: f32 = 1.0;
let null_bias: u64 = 0;
let total = (batch * n_in) as u32;
let reduce_cfg = cudarc::driver::LaunchConfig {
grid_dim: (total.div_ceil(256), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let zero_i32: i32 = 0;
let tail_cnt_i = n_tail_nt as i32;
let dy_tail_stride_i = n_out as i32; let (dy_tail_ptr, wt_tail_ptr): (u64, u64) = if n_tail_nt > 0 {
let dy_base = dy_ptr;
let dyp = dy_base + (n_main_nt as u64) * 4;
let wtp = w_t_ptr + (n_main_nt as u64 * n_in as u64) * 4;
(dyp, wtp)
} else {
(0, 0)
};
let mut rb = stream.launch_builder(&kernels.sgemm_splitk_reduce);
rb.arg(&dx_ptr);
rb.arg(&partial_ptr);
rb.arg(&null_bias);
rb.arg(&dy_tail_ptr);
rb.arg(&wt_tail_ptr);
rb.arg(&alpha);
rb.arg(&m_i);
rb.arg(&k_out_i);
rb.arg(&k_chunks);
rb.arg(&dy_tail_stride_i);
rb.arg(&zero_i32); rb.arg(&tail_cnt_i);
unsafe { rb.launch(reduce_cfg) }.map_err(|e| {
Error::Cuda(format!("sgemm_bi_splitk_reduce (NT-via-T N-tail): {:?}", e))
})?;
return Ok(());
}
const SLIM_NT_K_CHUNK: u32 = 64;
if batch > 1024
&& (128..=SGEMM_SLIM_NT_NIN_MAX).contains(&n_in)
&& n_out >= SLIM_NT_K_CHUNK as usize
&& n_out.is_multiple_of(32)
&& (n_in * n_out) <= SPLITK_NT_TRANSPOSE_CAP
{
let f_final = (n_out as u32).div_ceil(SLIM_NT_K_CHUNK);
if f_final >= 2 && (f_final as usize) * batch * n_in <= SPLITK_SCRATCH_CAP {
let m_tiles = (batch as u32).div_ceil(128);
let k_out_tiles = (n_in as u32).div_ceil(64);
let base_blocks = m_tiles * k_out_tiles;
if base_blocks > 0 && base_blocks < 3 * NUM_SMS {
let k_chunk = SLIM_NT_K_CHUNK;
let rows_i = n_in as i32;
let cols_i = n_out as i32;
let t_grid_x = (n_out as u32).div_ceil(32);
let t_grid_y = (n_in as u32).div_ceil(32);
let t_cfg = cudarc::driver::LaunchConfig {
grid_dim: (t_grid_x, t_grid_y, 1),
block_dim: (32, 32, 1),
shared_mem_bytes: 0,
};
let w_t_ptr = kernels.transpose_scratch_ptr;
let mut tb = stream.launch_builder(&kernels.sgemm_transpose_f32_2d);
tb.arg(&w_t_ptr);
tb.arg(&w_ptr);
tb.arg(&rows_i);
tb.arg(&cols_i);
unsafe { tb.launch(t_cfg) }.map_err(|e| {
Error::Cuda(format!("sgemm_transpose_f32_2d (slim NT): {:?}", e))
})?;
let m_i = batch as i32;
let k_out_i = n_in as i32; let k_full_i = n_out as i32; let lda_i = n_out as i32; let ldb_i = n_in as i32; let k_chunk_i = k_chunk as i32;
let partial_ptr = kernels.splitk_scratch_ptr;
let partial_cfg = cudarc::driver::LaunchConfig {
grid_dim: (base_blocks, 1, f_final),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let mut pb = stream.launch_builder(&kernels.sgemm_nn_splitk_slim_partial);
pb.arg(&partial_ptr);
pb.arg(&dy_ptr);
pb.arg(&w_t_ptr);
pb.arg(&m_i);
pb.arg(&k_out_i);
pb.arg(&k_full_i);
pb.arg(&lda_i);
pb.arg(&ldb_i);
pb.arg(&k_chunk_i);
unsafe { pb.launch(partial_cfg) }.map_err(|e| {
Error::Cuda(format!(
"sgemm_bi_nn_splitk_slim_partial (slim NT): {:?}",
e
))
})?;
let alpha: f32 = 1.0;
let null_bias: u64 = 0;
let null_tail: u64 = 0;
let zero_i32_nt: i32 = 0;
let f_i = f_final as i32;
let total = (batch * n_in) as u32;
let reduce_cfg = cudarc::driver::LaunchConfig {
grid_dim: (total.div_ceil(256), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let mut rb = stream.launch_builder(&kernels.sgemm_splitk_reduce);
rb.arg(&dx_ptr);
rb.arg(&partial_ptr);
rb.arg(&null_bias);
rb.arg(&null_tail);
rb.arg(&null_tail);
rb.arg(&alpha);
rb.arg(&m_i);
rb.arg(&k_out_i);
rb.arg(&f_i);
rb.arg(&zero_i32_nt);
rb.arg(&zero_i32_nt);
rb.arg(&zero_i32_nt);
unsafe { rb.launch(reduce_cfg) }.map_err(|e| {
Error::Cuda(format!("sgemm_bi_splitk_reduce (slim NT): {:?}", e))
})?;
return Ok(());
}
}
}
if (32..128).contains(&batch) && n_in >= 1 && n_out >= 128 {
let m_i = batch as i32;
let n_i = n_out as i32;
let k_i = n_in as i32;
let alpha: f32 = 1.0;
let num_pid_m = (batch as u32).div_ceil(64);
let num_pid_n = (n_in as u32).div_ceil(32);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (num_pid_m * num_pid_n, 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&kernels.sgemm_nt_narrow);
builder.arg(&dx_ptr);
builder.arg(&dy_ptr);
builder.arg(&w_ptr);
builder.arg(&alpha);
builder.arg(&m_i);
builder.arg(&n_i);
builder.arg(&k_i);
unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Cuda(format!(
"sgemm_bi_nt_narrow (gap-fill mid-batch wide-N): {:?}",
e
))
})?;
return Ok(());
}
if batch >= SGEMM_CUSTOM_MIN && n_in >= 1 {
let m_i = batch as i32;
let n_i = n_out as i32;
let k_i = n_in as i32;
let alpha: f32 = 1.0;
let (func, bn) = dispatch_slim_or_big(
kernels,
batch,
n_in, &kernels.sgemm_nt_slim,
&kernels.sgemm_nt,
);
let slim = bn == 64;
let threads = if slim { 128u32 } else { 256u32 };
let smem_bytes: u32 = if slim { 0 } else { 34 * 1024 };
let total_tiles = (batch as u32).div_ceil(128) * (n_in as u32).div_ceil(bn);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (total_tiles, 1, 1),
block_dim: (threads, 1, 1),
shared_mem_bytes: smem_bytes,
};
let mut builder = stream.launch_builder(func);
builder.arg(&dx_ptr);
builder.arg(&dy_ptr);
builder.arg(&w_ptr);
builder.arg(&alpha);
builder.arg(&m_i);
builder.arg(&n_i);
builder.arg(&k_i);
unsafe { builder.launch(cfg) }.map_err(|e| {
Error::Cuda(format!(
"sgemm_bi_nt{} backward_dx: {:?}",
if slim { "_slim" } else { "" },
e
))
})?;
return Ok(());
}
Err(Error::Uncovered {
op: "sgemm_bi_backward_dx (f32)",
m: batch,
k: n_in,
n: n_out,
})
}
pub(crate) fn nn_routes_to_big(batch: usize, n_in: usize, n_out: usize) -> bool {
if n_out == 1 {
return false; }
if (1..32).contains(&batch) && (32..=2048).contains(&n_in) && n_out >= 32 {
return false; }
if (2..=127).contains(&n_out) {
return false; }
let plain_slim_blocks = (batch as u32).div_ceil(128) * (n_out as u32).div_ceil(64);
let underfill = plain_slim_blocks < NUM_SMS;
if (32..=1024).contains(&batch)
&& (64..=2048).contains(&n_out)
&& n_out.is_multiple_of(4)
&& n_in >= 33
&& !n_in.is_multiple_of(32)
&& underfill
{
let k_main = n_in - n_in % 32;
if k_main >= 32 && (k_main / 32) * batch * n_out <= SPLITK_SCRATCH_CAP {
return false;
}
}
if (32..=1024).contains(&batch)
&& (64..=2048).contains(&n_out)
&& n_out.is_multiple_of(4)
&& n_in >= 32
&& n_in.is_multiple_of(32)
&& (n_in / 32) * batch * n_out <= SPLITK_SCRATCH_CAP
&& underfill
{
return false;
}
if batch > 1024
&& (128..=SGEMM_SLIM_MAX).contains(&n_out)
&& n_in >= 64
&& n_in.is_multiple_of(32)
{
let f_final = (n_in as u32).div_ceil(64);
if f_final >= 6 && (f_final as usize) * batch * n_out <= SPLITK_SCRATCH_CAP {
let base_blocks = (batch as u32).div_ceil(128) * (n_out as u32).div_ceil(64);
if base_blocks > 0 && base_blocks < 3 * NUM_SMS {
return false;
}
}
}
if batch < 128 {
return false; }
if !(batch >= SGEMM_CUSTOM_MIN && n_out >= SGEMM_CUSTOM_MIN) {
return false;
}
let slim = n_out <= SGEMM_SLIM_MAX || (batch < SGEMM_M_SLIM_FORCE && n_out >= SGEMM_CUSTOM_MIN);
!slim
}
pub(crate) fn tn_routes_to_big(batch: usize, n_in: usize, n_out: usize) -> bool {
if n_out == 1 || (2..=127).contains(&n_out) {
return false; }
if splitm_tn_partition(batch, n_in, n_out).is_some() {
return false;
}
if !(n_in >= 1 && n_out >= SGEMM_CUSTOM_MIN) {
return false;
}
let slim = n_out <= SGEMM_SLIM_MAX || (n_in < SGEMM_M_SLIM_FORCE && n_out >= SGEMM_CUSTOM_MIN);
!slim
}
pub(crate) fn nt_routes_to_big(batch: usize, n_in: usize, n_out: usize) -> bool {
if (2..=127).contains(&n_out) {
return false; }
if batch < 32 && n_out >= 128 {
return false; }
if n_out == 1 {
return false; }
const SPLITK_NT_TRANSPOSE_CAP: usize = 1 << 22;
let plain_slim_blocks = (batch as u32).div_ceil(128) * (n_in as u32).div_ceil(64);
let underfill = plain_slim_blocks < NUM_SMS;
if (32..=1024).contains(&batch)
&& (64..=4096).contains(&n_in)
&& n_in >= 33
&& !n_in.is_multiple_of(32)
&& (32..=2048).contains(&n_out)
&& n_out.is_multiple_of(32)
&& underfill
{
let k_main = n_in - n_in % 32;
if k_main >= 32
&& k_main * n_out <= SPLITK_NT_TRANSPOSE_CAP
&& (n_out / 32) * batch * k_main <= SPLITK_SCRATCH_CAP
{
return false;
}
}
let n_main = n_out - n_out % 32;
if (32..=1024).contains(&batch)
&& (64..=4096).contains(&n_in)
&& n_in.is_multiple_of(4)
&& n_in.is_multiple_of(32)
&& (32..=2048).contains(&n_out)
&& n_main >= 32
&& n_in * n_out <= SPLITK_NT_TRANSPOSE_CAP
&& (n_main / 32) * batch * n_in <= SPLITK_SCRATCH_CAP
&& underfill
{
return false;
}
if batch > 1024
&& (128..=SGEMM_SLIM_NT_NIN_MAX).contains(&n_in)
&& n_out >= 64
&& n_out.is_multiple_of(32)
&& n_in * n_out <= SPLITK_NT_TRANSPOSE_CAP
{
let f_final = (n_out as u32).div_ceil(64);
if f_final >= 2 && (f_final as usize) * batch * n_in <= SPLITK_SCRATCH_CAP {
let base_blocks = (batch as u32).div_ceil(128) * (n_in as u32).div_ceil(64);
if base_blocks > 0 && base_blocks < 3 * NUM_SMS {
return false;
}
}
}
if (32..128).contains(&batch) {
return false; }
if !(batch >= SGEMM_CUSTOM_MIN && n_in >= 1) {
return false;
}
let slim = n_in <= SGEMM_SLIM_MAX || (batch < SGEMM_M_SLIM_FORCE && n_in >= SGEMM_CUSTOM_MIN);
!slim
}
fn require_half(dt: WeightDtype, what: &str) -> Result<()> {
if dt == WeightDtype::F32 {
let _ = what;
return Err(Error::DtypeMismatch(
"operand is f32 — use the f32 entry points",
));
}
Ok(())
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum TcTile {
Tile128,
Tile64,
}
const TC64_PREFER_MAX_TILES128: u32 = 72;
fn tc_pick_tile(rows: usize, cols: usize) -> Option<TcTile> {
if rows >= 128 && cols >= 128 {
let tiles128 = (rows as u32).div_ceil(128) * (cols as u32).div_ceil(128);
if tiles128 >= TC64_PREFER_MAX_TILES128 {
return Some(TcTile::Tile128);
}
return Some(TcTile::Tile64);
}
if rows >= 64 && cols >= 64 {
return Some(TcTile::Tile64);
}
None
}
impl TcTile {
fn edge(self) -> u32 {
match self {
TcTile::Tile128 => 128,
TcTile::Tile64 => 64,
}
}
fn launch_cfg(
self,
rows: usize,
cols: usize,
dyn_bytes128: u32,
) -> cudarc::driver::LaunchConfig {
let e = self.edge();
let total_tiles = (rows as u32).div_ceil(e) * (cols as u32).div_ceil(e);
cudarc::driver::LaunchConfig {
grid_dim: (total_tiles, 1, 1),
block_dim: (e * 2, 1, 1), shared_mem_bytes: match self {
TcTile::Tile128 => dyn_bytes128,
TcTile::Tile64 => 0,
},
}
}
}
pub fn sgemm_bi_forward_tc(
stream: &Arc<cudarc::driver::CudaStream>,
kernels: &GpuKernels,
y: TypedPtr,
x: TypedPtr,
w: TypedPtr,
bias_ptr: CUptr,
dims: (usize, usize, usize),
) -> Result<()> {
let (batch, n_in, n_out) = dims;
require_half(y.dtype, "output")?;
if x.dtype != y.dtype || w.dtype != y.dtype {
return Err(Error::DtypeMismatch("sgemm_bi_forward_tc: mixed dtypes"));
}
let Some(tile) = tc_pick_tile(batch, n_out) else {
return Err(Error::Uncovered {
op: "sgemm_bi_forward_tc",
m: batch,
k: n_in,
n: n_out,
});
};
let dt = y.dtype;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let m_i = batch as i32;
let n_i = n_out as i32;
let k_i = n_in as i32;
let cfg = tile.launch_cfg(batch, n_out, 71_680);
let func = match tile {
TcTile::Tile128 => kernels.sgemm_nn_tc_typed.get(dt),
TcTile::Tile64 => kernels.sgemm_nn_tc64_typed.get(dt),
};
let mut b = stream.launch_builder(func);
b.arg(&y.ptr);
b.arg(&x.ptr);
b.arg(&w.ptr);
b.arg(&bias_ptr);
b.arg(&alpha);
b.arg(&beta);
b.arg(&m_i);
b.arg(&n_i);
b.arg(&k_i);
b.arg(&k_i);
b.arg(&n_i);
b.arg(&n_i);
unsafe { b.launch(cfg) }.map_err(|e| Error::Cuda(format!("sgemm_bi_nn_tc: {e:?}")))?;
Ok(())
}
pub fn sgemm_bi_backward_dw_tc(
stream: &Arc<cudarc::driver::CudaStream>,
kernels: &GpuKernels,
dw_ptr: CUptr,
dy: TypedPtr,
x_saved: TypedPtr,
dims: (usize, usize, usize),
) -> Result<()> {
let (batch, n_in, n_out) = dims;
require_half(dy.dtype, "dY")?;
if dy.dtype != x_saved.dtype {
return Err(Error::DtypeMismatch(
"sgemm_bi_backward_dw_tc: mixed dtypes",
));
}
let Some(tile) = tc_pick_tile(n_in, n_out) else {
return Err(Error::Uncovered {
op: "sgemm_bi_backward_dw_tc",
m: batch,
k: n_in,
n: n_out,
});
};
let dt = dy.dtype;
let alpha: f32 = 1.0;
let m_red_i = batch as i32;
let k_out_i = n_in as i32;
let n_i = n_out as i32;
let cfg = tile.launch_cfg(n_in, n_out, 69_632);
let func = match tile {
TcTile::Tile128 => kernels.sgemm_tn_tc_typed.get(dt),
TcTile::Tile64 => kernels.sgemm_tn_tc64_typed.get(dt),
};
let mut b = stream.launch_builder(func);
b.arg(&dw_ptr);
b.arg(&x_saved.ptr);
b.arg(&dy.ptr);
b.arg(&alpha);
b.arg(&m_red_i);
b.arg(&k_out_i);
b.arg(&n_i);
unsafe { b.launch(cfg) }.map_err(|e| Error::Cuda(format!("sgemm_bi_tn_tc: {e:?}")))?;
Ok(())
}
pub fn sgemm_bi_backward_dx_tc(
stream: &Arc<cudarc::driver::CudaStream>,
kernels: &GpuKernels,
dx: TypedPtr,
dy: TypedPtr,
w: TypedPtr,
dims: (usize, usize, usize),
) -> Result<()> {
let (batch, n_in, n_out) = dims;
require_half(dx.dtype, "dX")?;
if dx.dtype != dy.dtype || dy.dtype != w.dtype {
return Err(Error::DtypeMismatch(
"sgemm_bi_backward_dx_tc: mixed dtypes",
));
}
let Some(tile) = tc_pick_tile(batch, n_in) else {
return Err(Error::Uncovered {
op: "sgemm_bi_backward_dx_tc",
m: batch,
k: n_in,
n: n_out,
});
};
let dt = dx.dtype;
let alpha: f32 = 1.0;
let m_i = batch as i32;
let n_i = n_out as i32;
let k_out_i = n_in as i32;
let cfg = tile.launch_cfg(batch, n_in, 73_728);
let func = match tile {
TcTile::Tile128 => kernels.sgemm_nt_tc_typed.get(dt),
TcTile::Tile64 => kernels.sgemm_nt_tc64_typed.get(dt),
};
let mut b = stream.launch_builder(func);
b.arg(&dx.ptr);
b.arg(&dy.ptr);
b.arg(&w.ptr);
b.arg(&alpha);
b.arg(&m_i);
b.arg(&n_i);
b.arg(&k_out_i);
unsafe { b.launch(cfg) }.map_err(|e| Error::Cuda(format!("sgemm_bi_nt_tc: {e:?}")))?;
Ok(())
}
pub fn sgemm_bi_forward_typed(
stream: &Arc<cudarc::driver::CudaStream>,
kernels: &GpuKernels,
y: TypedPtr,
x: TypedPtr,
w: TypedPtr,
bias_ptr: CUptr, dims: (usize, usize, usize),
) -> Result<()> {
let (batch, n_in, n_out) = dims;
require_half(y.dtype, "output")?;
if x.dtype != y.dtype || w.dtype != y.dtype {
return Err(Error::DtypeMismatch("sgemm_bi_forward_typed: mixed dtypes"));
}
let dt = y.dtype;
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
let m_i = batch as i32;
let n_i = n_out as i32;
let k_i = n_in as i32;
if n_out == 1 && batch >= 1 && n_in >= 32 {
let lda_i = n_in as i32;
let ldy_i: i32 = 1;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: ((batch as u32).div_ceil(4), 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let mut b = stream.launch_builder(kernels.sgemm_nn_gemv_typed.get(dt));
b.arg(&y.ptr);
b.arg(&x.ptr);
b.arg(&w.ptr);
b.arg(&bias_ptr);
b.arg(&alpha);
b.arg(&beta);
b.arg(&m_i);
b.arg(&k_i);
b.arg(&lda_i);
b.arg(&ldy_i);
unsafe { b.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nn_gemv typed: {e:?}")))?;
return Ok(());
}
if (1..32).contains(&batch) && (32..=2048).contains(&n_in) && n_out >= 32 {
let cfg = cudarc::driver::LaunchConfig {
grid_dim: ((n_out as u32).div_ceil(32), batch as u32, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: (n_in * std::mem::size_of::<f32>()) as u32,
};
let mut b = stream.launch_builder(kernels.sgemm_nn_ultra_thin_typed.get(dt));
b.arg(&y.ptr);
b.arg(&x.ptr);
b.arg(&w.ptr);
b.arg(&bias_ptr);
b.arg(&alpha);
b.arg(&beta);
b.arg(&m_i);
b.arg(&n_i);
b.arg(&k_i);
b.arg(&k_i);
b.arg(&n_i);
b.arg(&n_i);
unsafe { b.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nn_ultra_thin typed: {e:?}")))?;
return Ok(());
}
if (2..=127).contains(&n_out) && batch >= 1 && n_in >= 1 {
let post_op: i32 = 0;
let small = batch <= 64;
let (grid, block, func) = if small {
(
(batch as u32).div_ceil(16) * (n_out as u32).div_ceil(16),
64u32,
kernels.sgemm_nn_narrow_small_typed.get(dt),
)
} else {
(
(batch as u32).div_ceil(64) * (n_out as u32).div_ceil(32),
128u32,
kernels.sgemm_nn_narrow_typed.get(dt),
)
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
};
let mut b = stream.launch_builder(func);
b.arg(&y.ptr);
b.arg(&x.ptr);
b.arg(&w.ptr);
b.arg(&bias_ptr);
b.arg(&alpha);
b.arg(&beta);
b.arg(&m_i);
b.arg(&n_i);
b.arg(&k_i);
b.arg(&k_i);
b.arg(&n_i);
b.arg(&n_i);
b.arg(&post_op);
unsafe { b.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nn_narrow typed: {e:?}")))?;
return Ok(());
}
if nn_routes_to_big(batch, n_in, n_out) {
let total_tiles = (batch as u32).div_ceil(128) * (n_out as u32).div_ceil(128);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (total_tiles, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 34 * 1024,
};
let mut b = stream.launch_builder(kernels.sgemm_nn_big_typed.get(dt));
b.arg(&y.ptr);
b.arg(&x.ptr);
b.arg(&w.ptr);
b.arg(&bias_ptr);
b.arg(&alpha);
b.arg(&beta);
b.arg(&m_i);
b.arg(&n_i);
b.arg(&k_i);
b.arg(&k_i);
b.arg(&n_i);
b.arg(&n_i);
unsafe { b.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nn_big typed: {e:?}")))?;
return Ok(());
}
Err(Error::Uncovered {
op: "sgemm_bi_forward_typed",
m: batch,
k: n_in,
n: n_out,
})
}
pub fn sgemm_bi_backward_dw_typed(
stream: &Arc<cudarc::driver::CudaStream>,
kernels: &GpuKernels,
dw_ptr: CUptr, dy: TypedPtr,
x_saved: TypedPtr,
dims: (usize, usize, usize),
) -> Result<()> {
let (batch, n_in, n_out) = dims;
require_half(dy.dtype, "dY")?;
if x_saved.dtype != dy.dtype {
return Err(Error::DtypeMismatch(
"sgemm_bi_backward_dw_typed: mixed dtypes",
));
}
let dt = dy.dtype;
let alpha: f32 = 1.0;
if n_out == 1 && n_in >= 4 && batch >= 32 {
let m_i = batch as i32;
let k_i = n_in as i32;
let lda_i = n_in as i32;
let ldy_i: i32 = 1;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: ((n_in as u32).div_ceil(4), 1, 1),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let mut b = stream.launch_builder(kernels.sgemm_tn_gemv_typed.get(dt));
b.arg(&dw_ptr);
b.arg(&x_saved.ptr);
b.arg(&dy.ptr);
b.arg(&alpha);
b.arg(&m_i);
b.arg(&k_i);
b.arg(&lda_i);
b.arg(&ldy_i);
unsafe { b.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_tn_gemv typed: {e:?}")))?;
return Ok(());
}
if (2..=127).contains(&n_out) && batch >= 1 && n_in >= 1 {
let m_red_i = batch as i32;
let k_out_i = n_in as i32;
let n_i = n_out as i32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (
(n_in as u32).div_ceil(64) * (n_out as u32).div_ceil(32),
1,
1,
),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let mut b = stream.launch_builder(kernels.sgemm_tn_narrow_typed.get(dt));
b.arg(&dw_ptr);
b.arg(&x_saved.ptr);
b.arg(&dy.ptr);
b.arg(&alpha);
b.arg(&m_red_i);
b.arg(&k_out_i);
b.arg(&n_i);
unsafe { b.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_tn_narrow typed: {e:?}")))?;
return Ok(());
}
if tn_routes_to_big(batch, n_in, n_out) {
let alpha: f32 = 1.0;
let m_red_i = batch as i32;
let k_out_i = n_in as i32;
let n_i = n_out as i32;
let total_tiles = (n_in as u32).div_ceil(128) * (n_out as u32).div_ceil(128);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (total_tiles, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 34 * 1024,
};
let mut b = stream.launch_builder(kernels.sgemm_tn_big_typed.get(dt));
b.arg(&dw_ptr);
b.arg(&x_saved.ptr);
b.arg(&dy.ptr);
b.arg(&alpha);
b.arg(&m_red_i);
b.arg(&k_out_i);
b.arg(&n_i);
unsafe { b.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_tn_big typed: {e:?}")))?;
return Ok(());
}
Err(Error::Uncovered {
op: "sgemm_bi_backward_dw_typed",
m: batch,
k: n_in,
n: n_out,
})
}
pub fn sgemm_bi_backward_dx_typed(
stream: &Arc<cudarc::driver::CudaStream>,
kernels: &GpuKernels,
dx: TypedPtr,
dy: TypedPtr,
w: TypedPtr,
dims: (usize, usize, usize),
) -> Result<()> {
let (batch, n_in, n_out) = dims;
require_half(dx.dtype, "dX")?;
if dy.dtype != dx.dtype || w.dtype != dx.dtype {
return Err(Error::DtypeMismatch(
"sgemm_bi_backward_dx_typed: mixed dtypes",
));
}
let dt = dx.dtype;
let alpha: f32 = 1.0;
if n_out == 1 && batch >= 1 && n_in >= 1 {
let m_i = batch as i32;
let k_i = n_in as i32;
let ldx_i = n_in as i32;
let ldy_i: i32 = 1;
let total = (batch * n_in) as u32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (total.div_ceil(256), 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
let mut b = stream.launch_builder(kernels.sgemm_nt_gemv_typed.get(dt));
b.arg(&dx.ptr);
b.arg(&dy.ptr);
b.arg(&w.ptr);
b.arg(&alpha);
b.arg(&m_i);
b.arg(&k_i);
b.arg(&ldx_i);
b.arg(&ldy_i);
unsafe { b.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nt_gemv typed: {e:?}")))?;
return Ok(());
}
if (2..=127).contains(&n_out) && batch >= 1 && n_in >= 1 {
let m_i = batch as i32;
let n_i = n_out as i32;
let k_out_i = n_in as i32;
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (
(batch as u32).div_ceil(64) * (n_in as u32).div_ceil(32),
1,
1,
),
block_dim: (128, 1, 1),
shared_mem_bytes: 0,
};
let mut b = stream.launch_builder(kernels.sgemm_nt_narrow_typed.get(dt));
b.arg(&dx.ptr);
b.arg(&dy.ptr);
b.arg(&w.ptr);
b.arg(&alpha);
b.arg(&m_i);
b.arg(&n_i);
b.arg(&k_out_i);
unsafe { b.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nt_narrow typed: {e:?}")))?;
return Ok(());
}
if nt_routes_to_big(batch, n_in, n_out) {
let alpha: f32 = 1.0;
let m_i = batch as i32;
let n_i = n_out as i32;
let k_out_i = n_in as i32;
let total_tiles = (batch as u32).div_ceil(128) * (n_in as u32).div_ceil(128);
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (total_tiles, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 34 * 1024,
};
let mut b = stream.launch_builder(kernels.sgemm_nt_big_typed.get(dt));
b.arg(&dx.ptr);
b.arg(&dy.ptr);
b.arg(&w.ptr);
b.arg(&alpha);
b.arg(&m_i);
b.arg(&n_i);
b.arg(&k_out_i);
unsafe { b.launch(cfg) }
.map_err(|e| Error::Cuda(format!("sgemm_bi_nt_big typed: {e:?}")))?;
return Ok(());
}
Err(Error::Uncovered {
op: "sgemm_bi_backward_dx_typed",
m: batch,
k: n_in,
n: n_out,
})
}