mod argmax;
mod attention;
pub mod backward;
mod bias_activation;
mod conv1d;
mod elementwise;
mod fused;
pub(crate) mod gemm;
mod gemv;
mod layernorm;
pub mod lz4;
#[cfg(test)]
mod lz4_hash_store_test;
mod megakernel;
pub mod optimizer;
mod parity_impls;
mod persistent;
mod quantize;
mod softmax;
pub use argmax::{ArgMaxFinalKernel, ArgMaxKernel};
pub use attention::{
AttentionKernel, BatchedIncrementalAttentionKernel, FlashDecodingChunkKernel,
FlashDecodingChunkKernel2Warp, FlashDecodingReduceKernel, IncrementalAttentionKernel,
MultiWarpIncrementalAttentionKernel, PrefillAttentionKernel, FLASH_DECODE_CHUNK_SIZE,
};
pub use bias_activation::{Activation, BiasActivationKernel};
pub use conv1d::{Conv1dKernel, TiledConv1dKernel};
pub use elementwise::{
BatchedFusedResidualRmsNormKernel,
BatchedResidualAddKernel,
BatchedRopeBackwardKernel,
BatchedRopeKernel,
BatchedScaleKernel,
BatchedSoftmaxKernel,
BatchedSwigluKernel,
BatchedToInterleavedKernel,
BatchedTransposeKernel,
CastF32ToF16Kernel,
CopySingleHeadKernel,
ElementwiseMulKernel,
ExtractSingleHeadKernel,
FusedResidualRmsNormKernel,
FusedSwigluKernel,
GeluKernel,
InterleavedToBatchedKernel, KvCacheScatterIndirectKernel,
KvCacheScatterKernel,
PreciseRopeIndirectKernel,
PreciseRopeKernel, ReluKernel, ResidualAddKernel,
RopeIndirectKernel,
RopeKernel,
RopeNeoxIndirectKernel,
RopeNeoxKernel,
ScaleKernel,
SiluKernel,
TransposeKernel, };
pub use fused::{FusedGateUpKernel, FusedGemmBiasGeluKernel, FusedQKVKernel};
pub use gemm::basic::tensor_core::cta64_wmma::build_cta64x128_mma_pipeline_fp16;
pub use gemm::{
Batched4DGemmConfig, Batched4DGemmKernel, BatchedGemmConfig, BatchedGemmKernel, GemmConfig,
GemmKernel,
};
pub use gemv::{CoalescedGemvKernel, GemvKernel};
pub use layernorm::{
BatchedVectorizedRmsNormKernel, LayerNormKernel, PerHeadRmsNormKernel, PreciseRmsNormKernel,
RmsNormKernel, VectorizedRmsNormKernel,
};
pub use lz4::{Lz4WarpCompressKernel, Lz4WarpDecompressKernel};
pub use megakernel::TransformerBlockMegakernel;
pub use optimizer::{
AdamStepKernel, AdamWStepKernel, ClipScaleReduceKernel, GradientClipGpuScaleKernel,
GradientClipKernel, SquaredSumKernel,
};
pub use persistent::PersistentDecoderKernel;
pub use quantize::fused_kv_scatter::FusedKvScatterKernel;
pub use quantize::{
dequantize_nf4, pack_nf4_for_gpu, quantize_nf4, repack_q4k_interleaved, repack_q4k_w4a16,
unpack_nf4_from_gpu, BatchedHwDp4aQ4KGemvKernel, BatchedQ4KGemvKernel, BatchedQ6KGemvKernel,
ChunkedTiledQ4KGemvKernel, CoalescedQ4KGemvKernel, CoalescedQ6KGemvKernel, Dp4aQ4KGemmKernel,
Dp4aQ4KGemvKernel, Dp4aQ6KGemvKernel, Fp16Q4KGemvKernel, FusedFp32Q4KGemvKernel,
FusedGateUpQ4KGemvKernel, FusedGateUpSwigluHwDp4aQ4KGemvKernel, FusedNf4GateUpGemmKernel,
FusedQKVHwDp4aQ4KGemvKernel, FusedRmsNormGateUpSwigluQ4KKernel, FusedRmsNormNf4GemvKernel,
FusedRmsNormQ4KGemvKernel, HalfWarpDp4aQ4KGemvKernel, HalfWarpDp4aQ6KGemvKernel,
InlineQ8Dp4aQ4KGemvKernel, InterleavedWmmaQ4KGemmKernel, MultiWarpQ6KGemvKernel,
MultiWarpTensorCoreQ4KGemmKernel, MultiWarpVectorizedQ4KGemvKernel, MwvDp4aQ4KGemvKernel,
Nf4GemmKernel, Nf4GemmTransposeKernel, Nf4Quantized, Nf4TensorCoreGemmKernel,
PackedDp4aQ4KQ8Kernel, Q4KDequantFp16Kernel, Q4KDequantKernel, Q4KGemvKernel, Q4KQ8DotKernel,
Q4_0GemvKernel, Q4_1GemvKernel, Q5KGemvKernel, Q5KKernel, Q5_0GemvKernel, Q6KDequantKernel,
Q6KGemvKernel, Q6KKernel, Q8QuantizeKernel, Q8_0GemvKernel, QuantizeKernel,
TensorCoreQ4KGemmKernel, TiledQ4KGemvKernel, TrueDp4aQ4KGemvKernel, VectorizedQ4KGemvKernel,
W4a16WmmaQ4KGemmKernel, WideQ4KGemvKernel, NF4_BLOCK_BYTES, NF4_BLOCK_SIZE, NF4_LUT,
};
pub use softmax::{LongRowSoftmaxKernel, SoftmaxKernel};
use crate::ptx::optimize::barrier_safety::{self, BarrierSafetyResult};
use crate::ptx::parity::{self, ParityResult};
use crate::ptx::{PtxKernel, PtxModule};
pub fn ptx_version_for_target(target: &str) -> (u32, u32) {
let sm_num: u32 = target
.strip_prefix("sm_")
.and_then(|s| s.parse().ok())
.unwrap_or(70);
if sm_num >= 100 {
(8, 8)
} else {
(8, 0)
}
}
pub trait Kernel {
fn name(&self) -> &str;
fn build_ptx(&self) -> PtxKernel;
fn as_module_for_target(&self, target: &str) -> PtxModule {
let (ptx_major, ptx_minor) = ptx_version_for_target(target);
PtxModule::new()
.version(ptx_major, ptx_minor)
.target(target)
.address_size(64)
.add_kernel(self.build_ptx())
}
fn as_module(&self) -> PtxModule {
self.as_module_for_target("sm_70")
}
fn emit_ptx_for_target(&self, target: &str) -> String {
self.as_module_for_target(target).emit()
}
fn emit_ptx(&self) -> String {
self.as_module().emit()
}
fn analyze_barrier_safety(&self) -> BarrierSafetyResult {
let ptx = self.emit_ptx();
barrier_safety::analyze(&ptx)
}
fn validate_barrier_safety(&self) -> Result<(), String> {
let ptx = self.emit_ptx();
barrier_safety::validate(&ptx)
}
fn emit_ptx_validated(&self) -> String {
let ptx = self.emit_ptx();
if let Err(e) = barrier_safety::validate(&ptx) {
panic!(
"PARITY-114: Barrier safety violation in kernel '{}': {}",
self.name(),
e
);
}
ptx
}
}
pub trait KernelParity: Kernel {
type SingleVector: Kernel;
fn single_vector_reference(&self) -> Self::SingleVector;
fn validate_parity(&self) -> ParityResult {
let single = self.single_vector_reference();
let single_ptx = single.emit_ptx();
let batched_ptx = self.emit_ptx();
parity::validate_parity(&single_ptx, &batched_ptx, single.name(), self.name())
}
fn validate_batch_dispatch(&self) -> ParityResult {
let ptx = self.emit_ptx();
parity::validate_batched_kernel(&ptx, self.name())
}
}
#[cfg(test)]
mod tests;