mod argmax;
mod attention;
pub mod backward;
mod bias_activation;
mod conv1d;
mod elementwise;
mod fused;
mod gemm;
mod gemv;
mod layernorm;
pub mod lz4;
#[cfg(test)]
mod lz4_hash_store_test;
mod megakernel;
pub mod optimizer;
mod parity_impls;
mod persistent;
mod quantize;
mod softmax;
pub use argmax::{ArgMaxFinalKernel, ArgMaxKernel};
pub use attention::{
AttentionKernel, BatchedIncrementalAttentionKernel, FlashDecodingChunkKernel,
FlashDecodingReduceKernel, IncrementalAttentionKernel, MultiWarpIncrementalAttentionKernel,
PrefillAttentionKernel, FLASH_DECODE_CHUNK_SIZE,
};
pub use bias_activation::{Activation, BiasActivationKernel};
pub use conv1d::{Conv1dKernel, TiledConv1dKernel};
pub use elementwise::{
BatchedResidualAddKernel,
BatchedRopeBackwardKernel,
BatchedRopeKernel,
BatchedScaleKernel,
BatchedSoftmaxKernel,
BatchedSwigluKernel,
BatchedToInterleavedKernel,
BatchedTransposeKernel,
CopySingleHeadKernel,
ElementwiseMulKernel,
ExtractSingleHeadKernel,
FusedResidualRmsNormKernel,
FusedSwigluKernel,
GeluKernel,
InterleavedToBatchedKernel, KvCacheScatterIndirectKernel,
KvCacheScatterKernel,
PreciseRopeIndirectKernel,
PreciseRopeKernel, ReluKernel, ResidualAddKernel,
RopeIndirectKernel,
RopeKernel,
RopeNeoxIndirectKernel,
RopeNeoxKernel,
ScaleKernel,
SiluKernel,
TransposeKernel, };
pub use fused::{FusedGateUpKernel, FusedGemmBiasGeluKernel, FusedQKVKernel};
pub use gemm::{
Batched4DGemmConfig, Batched4DGemmKernel, BatchedGemmConfig, BatchedGemmKernel, GemmConfig,
GemmKernel,
};
pub use gemv::{CoalescedGemvKernel, GemvKernel};
pub use layernorm::{
BatchedVectorizedRmsNormKernel, LayerNormKernel, PerHeadRmsNormKernel, PreciseRmsNormKernel,
RmsNormKernel, VectorizedRmsNormKernel,
};
pub use lz4::{Lz4WarpCompressKernel, Lz4WarpDecompressKernel};
pub use megakernel::TransformerBlockMegakernel;
pub use optimizer::{
AdamStepKernel, AdamWStepKernel, ClipScaleReduceKernel, GradientClipGpuScaleKernel,
GradientClipKernel, SquaredSumKernel,
};
pub use persistent::PersistentDecoderKernel;
pub use quantize::{
dequantize_nf4, pack_nf4_for_gpu, quantize_nf4, unpack_nf4_from_gpu,
BatchedHwDp4aQ4KGemvKernel, BatchedQ4KGemvKernel, BatchedQ6KGemvKernel,
ChunkedTiledQ4KGemvKernel, CoalescedQ4KGemvKernel, CoalescedQ6KGemvKernel, Dp4aQ4KGemmKernel,
Dp4aQ4KGemvKernel, Dp4aQ6KGemvKernel, Fp16Q4KGemvKernel, FusedGateUpQ4KGemvKernel,
FusedGateUpSwigluHwDp4aQ4KGemvKernel, FusedRmsNormGateUpSwigluQ4KKernel,
FusedRmsNormQ4KGemvKernel, HalfWarpDp4aQ4KGemvKernel, HalfWarpDp4aQ6KGemvKernel,
MultiWarpQ6KGemvKernel, MultiWarpTensorCoreQ4KGemmKernel, MultiWarpVectorizedQ4KGemvKernel,
MwvDp4aQ4KGemvKernel, Nf4GemmKernel, Nf4GemmTransposeKernel, Nf4Quantized,
PackedDp4aQ4KQ8Kernel, Q4KDequantFp16Kernel, Q4KDequantKernel, Q4KGemvKernel, Q4KQ8DotKernel,
Q4_0GemvKernel, Q4_1GemvKernel, Q5KGemvKernel, Q5KKernel, Q5_0GemvKernel, Q6KDequantKernel,
Q6KGemvKernel, Q6KKernel, Q8QuantizeKernel, Q8_0GemvKernel, QuantizeKernel,
TensorCoreQ4KGemmKernel, TiledQ4KGemvKernel, TrueDp4aQ4KGemvKernel, VectorizedQ4KGemvKernel,
WideQ4KGemvKernel, NF4_BLOCK_BYTES, NF4_BLOCK_SIZE, NF4_LUT,
};
pub use softmax::{LongRowSoftmaxKernel, SoftmaxKernel};
use crate::ptx::optimize::barrier_safety::{self, BarrierSafetyResult};
use crate::ptx::parity::{self, ParityResult};
use crate::ptx::{PtxKernel, PtxModule};
pub trait Kernel {
fn name(&self) -> &str;
fn build_ptx(&self) -> PtxKernel;
fn as_module_for_target(&self, target: &str) -> PtxModule {
PtxModule::new().version(8, 0).target(target).address_size(64).add_kernel(self.build_ptx())
}
fn as_module(&self) -> PtxModule {
self.as_module_for_target("sm_70")
}
fn emit_ptx_for_target(&self, target: &str) -> String {
self.as_module_for_target(target).emit()
}
fn emit_ptx(&self) -> String {
self.as_module().emit()
}
fn analyze_barrier_safety(&self) -> BarrierSafetyResult {
let ptx = self.emit_ptx();
barrier_safety::analyze(&ptx)
}
fn validate_barrier_safety(&self) -> Result<(), String> {
let ptx = self.emit_ptx();
barrier_safety::validate(&ptx)
}
fn emit_ptx_validated(&self) -> String {
let ptx = self.emit_ptx();
if let Err(e) = barrier_safety::validate(&ptx) {
panic!("PARITY-114: Barrier safety violation in kernel '{}': {}", self.name(), e);
}
ptx
}
}
pub trait KernelParity: Kernel {
type SingleVector: Kernel;
fn single_vector_reference(&self) -> Self::SingleVector;
fn validate_parity(&self) -> ParityResult {
let single = self.single_vector_reference();
let single_ptx = single.emit_ptx();
let batched_ptx = self.emit_ptx();
parity::validate_parity(&single_ptx, &batched_ptx, single.name(), self.name())
}
fn validate_batch_dispatch(&self) -> ParityResult {
let ptx = self.emit_ptx();
parity::validate_batched_kernel(&ptx, self.name())
}
}
#[cfg(test)]
mod tests;