mod argmax;
mod attention;
pub mod backward;
mod bias_activation;
mod conv1d;
mod elementwise;
mod fused;
mod gemm;
mod gemv;
mod layernorm;
pub mod lz4;
#[cfg(test)]
mod lz4_hash_store_test;
mod megakernel;
pub mod optimizer;
mod parity_impls;
mod persistent;
mod quantize;
mod softmax;
pub use argmax::{ArgMaxFinalKernel, ArgMaxKernel};
pub use attention::{
AttentionKernel, BatchedIncrementalAttentionKernel, FlashDecodingChunkKernel,
FlashDecodingReduceKernel, IncrementalAttentionKernel, MultiWarpIncrementalAttentionKernel,
FLASH_DECODE_CHUNK_SIZE,
};
pub use bias_activation::{Activation, BiasActivationKernel};
pub use conv1d::{Conv1dKernel, TiledConv1dKernel};
pub use elementwise::{
BatchedResidualAddKernel,
BatchedRopeKernel,
BatchedScaleKernel,
BatchedSoftmaxKernel,
BatchedSwigluKernel,
BatchedToInterleavedKernel,
BatchedTransposeKernel,
CopySingleHeadKernel,
ElementwiseMulKernel,
ExtractSingleHeadKernel,
FusedResidualRmsNormKernel,
FusedSwigluKernel,
GeluKernel,
InterleavedToBatchedKernel, KvCacheScatterIndirectKernel,
KvCacheScatterKernel,
PreciseRopeIndirectKernel,
PreciseRopeKernel, ReluKernel, ResidualAddKernel,
RopeIndirectKernel,
RopeKernel,
RopeNeoxIndirectKernel,
RopeNeoxKernel,
ScaleKernel,
SiluKernel,
TransposeKernel, };
pub use fused::{FusedGateUpKernel, FusedGemmBiasGeluKernel, FusedQKVKernel};
pub use gemm::{
Batched4DGemmConfig, Batched4DGemmKernel, BatchedGemmConfig, BatchedGemmKernel, GemmConfig,
GemmKernel,
};
pub use gemv::{CoalescedGemvKernel, GemvKernel};
pub use layernorm::{
BatchedVectorizedRmsNormKernel, LayerNormKernel, PerHeadRmsNormKernel, PreciseRmsNormKernel,
RmsNormKernel, VectorizedRmsNormKernel,
};
pub use lz4::{Lz4WarpCompressKernel, Lz4WarpDecompressKernel};
pub use megakernel::TransformerBlockMegakernel;
pub use optimizer::{
AdamStepKernel, AdamWStepKernel, ClipScaleReduceKernel, GradientClipGpuScaleKernel,
GradientClipKernel, SquaredSumKernel,
};
pub use persistent::PersistentDecoderKernel;
pub use quantize::{
dequantize_nf4, pack_nf4_for_gpu, quantize_nf4, unpack_nf4_from_gpu, BatchedQ4KGemvKernel,
BatchedQ6KGemvKernel, ChunkedTiledQ4KGemvKernel, CoalescedQ4KGemvKernel,
CoalescedQ6KGemvKernel, Dp4aQ4KGemvKernel, Dp4aQ6KGemvKernel, Fp16Q4KGemvKernel,
FusedGateUpQ4KGemvKernel, FusedGateUpSwigluHwDp4aQ4KGemvKernel,
FusedRmsNormGateUpSwigluQ4KKernel, FusedRmsNormQ4KGemvKernel, HalfWarpDp4aQ4KGemvKernel,
HalfWarpDp4aQ6KGemvKernel, MultiWarpQ6KGemvKernel, MultiWarpVectorizedQ4KGemvKernel,
MwvDp4aQ4KGemvKernel, Nf4GemmKernel, Nf4Quantized, PackedDp4aQ4KQ8Kernel, Q4KDequantKernel,
Q4KGemvKernel, Q4KQ8DotKernel, Q4_0GemvKernel, Q4_1GemvKernel, Q5KGemvKernel, Q5KKernel,
Q5_0GemvKernel, Q6KDequantKernel, Q6KGemvKernel, Q6KKernel, Q8QuantizeKernel, Q8_0GemvKernel,
QuantizeKernel, TensorCoreQ4KGemmKernel, TiledQ4KGemvKernel, TrueDp4aQ4KGemvKernel,
VectorizedQ4KGemvKernel, WideQ4KGemvKernel, NF4_BLOCK_BYTES, NF4_BLOCK_SIZE, NF4_LUT,
};
pub use softmax::{LongRowSoftmaxKernel, SoftmaxKernel};
use crate::ptx::optimize::barrier_safety::{self, BarrierSafetyResult};
use crate::ptx::parity::{self, ParityResult};
use crate::ptx::{PtxKernel, PtxModule};
pub trait Kernel {
fn name(&self) -> &str;
fn build_ptx(&self) -> PtxKernel;
fn as_module_for_target(&self, target: &str) -> PtxModule {
PtxModule::new().version(8, 0).target(target).address_size(64).add_kernel(self.build_ptx())
}
fn as_module(&self) -> PtxModule {
self.as_module_for_target("sm_70")
}
fn emit_ptx_for_target(&self, target: &str) -> String {
self.as_module_for_target(target).emit()
}
fn emit_ptx(&self) -> String {
self.as_module().emit()
}
fn analyze_barrier_safety(&self) -> BarrierSafetyResult {
let ptx = self.emit_ptx();
barrier_safety::analyze(&ptx)
}
fn validate_barrier_safety(&self) -> Result<(), String> {
let ptx = self.emit_ptx();
barrier_safety::validate(&ptx)
}
fn emit_ptx_validated(&self) -> String {
let ptx = self.emit_ptx();
if let Err(e) = barrier_safety::validate(&ptx) {
panic!("PARITY-114: Barrier safety violation in kernel '{}': {}", self.name(), e);
}
ptx
}
}
pub trait KernelParity: Kernel {
type SingleVector: Kernel;
fn single_vector_reference(&self) -> Self::SingleVector;
fn validate_parity(&self) -> ParityResult {
let single = self.single_vector_reference();
let single_ptx = single.emit_ptx();
let batched_ptx = self.emit_ptx();
parity::validate_parity(&single_ptx, &batched_ptx, single.name(), self.name())
}
fn validate_batch_dispatch(&self) -> ParityResult {
let ptx = self.emit_ptx();
parity::validate_batched_kernel(&ptx, self.name())
}
}
#[cfg(test)]
mod tests;