#![allow(clippy::similar_names)]
#![allow(clippy::too_many_lines)]
use super::Kernel;
use crate::ptx::PtxKernel;
mod dot;
mod fp16_tensor;
mod fused;
mod fused_gemm;
mod legacy;
mod nf4;
mod nf4_cpu;
mod q4k;
mod q5k;
mod q6k;
mod q8;
pub use dot::{PackedDp4aQ4KQ8Kernel, Q4KQ8DotKernel};
pub use fp16_tensor::{Fp16Q4KGemvKernel, TensorCoreQ4KGemmKernel};
pub use fused::{
FusedGateUpQ4KGemvKernel, FusedRmsNormGateUpSwigluQ4KKernel, FusedRmsNormQ4KGemvKernel,
};
pub use legacy::{Q4_0GemvKernel, Q4_1GemvKernel, Q5_0GemvKernel, Q8_0GemvKernel};
pub use nf4::Nf4GemmKernel;
pub use nf4_cpu::{
dequantize_nf4, pack_nf4_for_gpu, quantize_nf4, unpack_nf4_from_gpu, Nf4Quantized,
NF4_BLOCK_BYTES, NF4_BLOCK_SIZE, NF4_LUT,
};
pub use q4k::{
BatchedQ4KGemvKernel, ChunkedTiledQ4KGemvKernel, CoalescedQ4KGemvKernel, Dp4aQ4KGemvKernel,
FusedGateUpSwigluHwDp4aQ4KGemvKernel, HalfWarpDp4aQ4KGemvKernel,
MultiWarpVectorizedQ4KGemvKernel, MwvDp4aQ4KGemvKernel, Q4KDequantKernel, Q4KGemvKernel,
TiledQ4KGemvKernel, TrueDp4aQ4KGemvKernel, VectorizedQ4KGemvKernel, WideQ4KGemvKernel,
};
pub use q5k::{Q5KGemvKernel, Q5KKernel};
pub use q6k::{
BatchedQ6KGemvKernel, CoalescedQ6KGemvKernel, Dp4aQ6KGemvKernel, HalfWarpDp4aQ6KGemvKernel,
MultiWarpQ6KGemvKernel, Q6KDequantKernel, Q6KGemvKernel, Q6KKernel,
};
pub use q8::Q8QuantizeKernel;
const Q4K_BLOCK_SIZE: u32 = 32;
pub(crate) const Q4K_SUPER_BLOCK_SIZE: u32 = 256;
pub(crate) const Q4K_SUPER_BLOCK_BYTES: u32 = 144;
const Q4K_BLOCK_BYTES: u32 = 18;
pub(crate) const Q5K_SUPER_BLOCK_SIZE: u32 = 256;
pub(crate) const Q5K_SUPER_BLOCK_BYTES: u32 = 176;
pub(crate) const Q6K_SUPER_BLOCK_SIZE: u32 = 256;
pub(crate) const Q6K_SUPER_BLOCK_BYTES: u32 = 210;
const Q8_0_BLOCK_SIZE: u32 = 32;
const Q8_0_BLOCK_BYTES: u32 = 34;
const Q5_0_BLOCK_SIZE: u32 = 32;
const Q5_0_BLOCK_BYTES: u32 = 22;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Q4KFormat {
Simplified,
GgmlSuperBlock,
}
#[derive(Debug, Clone)]
pub struct QuantizeKernel {
pub m: u32,
pub n: u32,
pub k: u32,
pub tile_size: u32,
pub block_size: u32,
pub format: Q4KFormat,
}
impl QuantizeKernel {
#[must_use]
pub fn new(m: u32, n: u32, k: u32) -> Self {
Self { m, n, k, tile_size: 32, block_size: Q4K_BLOCK_SIZE, format: Q4KFormat::Simplified }
}
#[must_use]
pub fn ggml(m: u32, n: u32, k: u32) -> Self {
Self {
m,
n,
k,
tile_size: 32,
block_size: Q4K_SUPER_BLOCK_SIZE,
format: Q4KFormat::GgmlSuperBlock,
}
}
#[must_use]
pub const fn with_tile_size(mut self, tile_size: u32) -> Self {
self.tile_size = tile_size;
self
}
#[must_use]
pub const fn num_blocks_per_row(&self) -> u32 {
self.k / self.block_size
}
#[must_use]
pub const fn num_super_blocks_per_row(&self) -> u32 {
self.k / Q4K_SUPER_BLOCK_SIZE
}
}
impl Kernel for QuantizeKernel {
fn name(&self) -> &str {
match self.format {
Q4KFormat::Simplified => "q4k_gemm_fused",
Q4KFormat::GgmlSuperBlock => "q4k_gemm_ggml",
}
}
fn build_ptx(&self) -> PtxKernel {
match self.format {
Q4KFormat::Simplified => self.build_fused_gemm_simplified(),
Q4KFormat::GgmlSuperBlock => self.build_fused_gemm_ggml(),
}
}
}
#[cfg(test)]
mod tests;