#![allow(clippy::similar_names)]
mod naive;
mod tiled;
mod tiled_unrolled;
mod wmma_fp16;
use crate::kernels::Kernel;
use crate::ptx::PtxKernel;
#[derive(Debug, Clone)]
pub struct BatchedGemmConfig {
pub batch: u32,
pub m: u32,
pub n: u32,
pub k: u32,
pub tile_size: u32,
}
impl Default for BatchedGemmConfig {
fn default() -> Self {
Self { batch: 1, m: 1024, n: 1024, k: 1024, tile_size: 16 }
}
}
#[derive(Debug, Clone)]
pub struct BatchedGemmKernel {
pub config: BatchedGemmConfig,
variant: BatchedGemmVariant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum BatchedGemmVariant {
Naive,
Tiled,
TiledUnrolled,
WmmaFp16,
}
impl BatchedGemmKernel {
#[must_use]
pub fn naive(batch: u32, m: u32, n: u32, k: u32) -> Self {
Self {
config: BatchedGemmConfig { batch, m, n, k, ..Default::default() },
variant: BatchedGemmVariant::Naive,
}
}
#[must_use]
pub fn tiled(batch: u32, m: u32, n: u32, k: u32, tile_size: u32) -> Self {
Self {
config: BatchedGemmConfig { batch, m, n, k, tile_size },
variant: BatchedGemmVariant::Tiled,
}
}
#[must_use]
pub fn tiled_unrolled(batch: u32, m: u32, n: u32, k: u32, tile_size: u32) -> Self {
Self {
config: BatchedGemmConfig { batch, m, n, k, tile_size },
variant: BatchedGemmVariant::TiledUnrolled,
}
}
#[must_use]
pub fn wmma_fp16(batch: u32, m: u32, n: u32, k: u32) -> Self {
Self {
config: BatchedGemmConfig {
batch,
m,
n,
k,
tile_size: 16, },
variant: BatchedGemmVariant::WmmaFp16,
}
}
}
impl Kernel for BatchedGemmKernel {
fn name(&self) -> &str {
match self.variant {
BatchedGemmVariant::Naive => "batched_gemm_naive",
BatchedGemmVariant::Tiled => "batched_gemm_tiled",
BatchedGemmVariant::TiledUnrolled => "batched_gemm_tiled_unrolled",
BatchedGemmVariant::WmmaFp16 => "batched_gemm_wmma_fp16",
}
}
fn build_ptx(&self) -> PtxKernel {
match self.variant {
BatchedGemmVariant::Naive => self.build_naive(),
BatchedGemmVariant::Tiled => self.build_tiled(),
BatchedGemmVariant::TiledUnrolled => self.build_tiled_unrolled(),
BatchedGemmVariant::WmmaFp16 => self.build_wmma_fp16(),
}
}
}
#[cfg(test)]
mod tests;