use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, QuantizeKind, TensorMut, TensorRef, Workspace, U8,
};
use half::{bf16, f16};
use crate::quantize::map_status;
pub trait Nf4Activation: Element + sealed::Sealed {}
mod sealed {
pub trait Sealed {}
impl Sealed for f32 {}
impl Sealed for half::f16 {}
impl Sealed for half::bf16 {}
}
impl Nf4Activation for f32 {}
impl Nf4Activation for f16 {}
impl Nf4Activation for bf16 {}
#[derive(Copy, Clone, Debug)]
pub struct Nf4Descriptor {
pub n: i32,
pub k: i32,
pub block_size: i32,
}
impl Default for Nf4Descriptor {
fn default() -> Self {
Self {
n: 0,
k: 0,
block_size: 64,
}
}
}
#[derive(Copy, Clone, Debug)]
pub struct Nf4MmvqMultiMDescriptor {
pub base: Nf4Descriptor,
pub m: i32,
}
impl Default for Nf4MmvqMultiMDescriptor {
fn default() -> Self {
Self {
base: Nf4Descriptor::default(),
m: 1,
}
}
}
pub struct Nf4DequantizeArgs<'a, T: Nf4Activation> {
pub weight: TensorRef<'a, U8, 1>,
pub absmax: TensorRef<'a, f32, 1>,
pub output: TensorMut<'a, T, 2>,
}
pub struct Nf4MmvqArgs<'a, T: Nf4Activation> {
pub weight: TensorRef<'a, U8, 1>,
pub absmax: TensorRef<'a, f32, 1>,
pub activation: TensorRef<'a, T, 1>,
pub output: TensorMut<'a, T, 1>,
}
pub struct Nf4MmvqMultiMArgs<'a, T: Nf4Activation> {
pub weight: TensorRef<'a, U8, 1>,
pub absmax: TensorRef<'a, f32, 1>,
pub activations: TensorRef<'a, T, 2>,
pub output: TensorMut<'a, T, 2>,
}
pub struct Nf4DequantizePlan<T: Nf4Activation> {
desc: Nf4Descriptor,
sku: KernelSku,
_phantom: PhantomData<T>,
}
impl<T: Nf4Activation> Nf4DequantizePlan<T> {
pub fn select(
_stream: &Stream,
desc: &Nf4Descriptor,
_pref: PlanPreference,
) -> Result<Self> {
validate_desc(desc, "Nf4DequantizePlan")?;
Ok(Self {
desc: *desc,
sku: build_sku(T::KIND, QuantizeKind::DequantizePerGroup),
_phantom: PhantomData,
})
}
pub fn can_implement(&self, args: &Nf4DequantizeArgs<'_, T>) -> Result<()> {
let n = self.desc.n;
let k = self.desc.k;
let bs = self.desc.block_size;
let expected_packed_bytes = ((n / 2) as i64) * (k as i64);
if (args.weight.shape[0] as i64) != expected_packed_bytes {
return Err(Error::InvalidProblem(
"Nf4DequantizePlan: weight bytes != (N/2) * K",
));
}
let expected_absmax = (n as i64) * ((k / bs) as i64);
if (args.absmax.shape[0] as i64) != expected_absmax {
return Err(Error::InvalidProblem(
"Nf4DequantizePlan: absmax length != N * (K / block_size)",
));
}
if args.output.shape != [n, k] {
return Err(Error::InvalidProblem(
"Nf4DequantizePlan: output shape != [N, K]",
));
}
if args.output.stride[1] != 1 {
return Err(Error::InvalidProblem(
"Nf4DequantizePlan: output must be contiguous along K",
));
}
Ok(())
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: Nf4DequantizeArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
if self.desc.n == 0 || self.desc.k == 0 {
return Ok(());
}
let w_ptr = args.weight.data.as_raw().0 as *const c_void;
let amax_ptr = args.absmax.data.as_raw().0 as *const c_void;
let out_ptr = args.output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe {
dispatch_dequant::<T>(
self.desc.n,
self.desc.k,
self.desc.block_size,
w_ptr,
amax_ptr,
out_ptr,
stream_ptr,
)
};
map_status(status)
}
}
pub struct Nf4MmvqPlan<T: Nf4Activation> {
desc: Nf4Descriptor,
sku: KernelSku,
_phantom: PhantomData<T>,
}
impl<T: Nf4Activation> Nf4MmvqPlan<T> {
pub fn select(
_stream: &Stream,
desc: &Nf4Descriptor,
_pref: PlanPreference,
) -> Result<Self> {
validate_desc(desc, "Nf4MmvqPlan")?;
if !matches!(T::KIND, ElementKind::F16 | ElementKind::Bf16) {
return Err(Error::Unsupported(
"Nf4MmvqPlan: activation dtype must be f16 or bf16",
));
}
Ok(Self {
desc: *desc,
sku: build_sku(T::KIND, QuantizeKind::GgufMmvq),
_phantom: PhantomData,
})
}
pub fn can_implement(&self, args: &Nf4MmvqArgs<'_, T>) -> Result<()> {
let n = self.desc.n;
let k = self.desc.k;
let bs = self.desc.block_size;
let expected_packed_bytes = ((n / 2) as i64) * (k as i64);
if (args.weight.shape[0] as i64) != expected_packed_bytes {
return Err(Error::InvalidProblem(
"Nf4MmvqPlan: weight bytes != (N/2) * K",
));
}
let expected_absmax = (n as i64) * ((k / bs) as i64);
if (args.absmax.shape[0] as i64) != expected_absmax {
return Err(Error::InvalidProblem(
"Nf4MmvqPlan: absmax length != N * (K / block_size)",
));
}
if args.activation.shape != [k] {
return Err(Error::InvalidProblem(
"Nf4MmvqPlan: activation shape != [K]",
));
}
if args.output.shape != [n] {
return Err(Error::InvalidProblem("Nf4MmvqPlan: output shape != [N]"));
}
if args.activation.stride[0] != 1 {
return Err(Error::InvalidProblem(
"Nf4MmvqPlan: activation must be contig",
));
}
if args.output.stride[0] != 1 {
return Err(Error::InvalidProblem(
"Nf4MmvqPlan: output must be contig",
));
}
Ok(())
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: Nf4MmvqArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
if self.desc.n == 0 || self.desc.k == 0 {
return Ok(());
}
let w_ptr = args.weight.data.as_raw().0 as *const c_void;
let amax_ptr = args.absmax.data.as_raw().0 as *const c_void;
let y_ptr = args.activation.data.as_raw().0 as *const c_void;
let out_ptr = args.output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe {
dispatch_gemv_m1::<T>(
self.desc.n,
self.desc.k,
self.desc.block_size,
w_ptr,
amax_ptr,
y_ptr,
out_ptr,
stream_ptr,
)
};
map_status(status)
}
}
pub struct Nf4MmvqMultiMPlan<T: Nf4Activation> {
desc: Nf4MmvqMultiMDescriptor,
sku: KernelSku,
_phantom: PhantomData<T>,
}
impl<T: Nf4Activation> Nf4MmvqMultiMPlan<T> {
pub fn select(
_stream: &Stream,
desc: &Nf4MmvqMultiMDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
validate_desc(&desc.base, "Nf4MmvqMultiMPlan")?;
if !matches!(T::KIND, ElementKind::F16 | ElementKind::Bf16) {
return Err(Error::Unsupported(
"Nf4MmvqMultiMPlan: activation dtype must be f16 or bf16",
));
}
if !matches!(desc.m, 1 | 2 | 4 | 8) {
return Err(Error::Unsupported(
"Nf4MmvqMultiMPlan: M must be one of {1, 2, 4, 8}",
));
}
Ok(Self {
desc: *desc,
sku: build_sku(T::KIND, QuantizeKind::GgufMmvq),
_phantom: PhantomData,
})
}
pub fn can_implement(&self, args: &Nf4MmvqMultiMArgs<'_, T>) -> Result<()> {
let n = self.desc.base.n;
let k = self.desc.base.k;
let bs = self.desc.base.block_size;
let m = self.desc.m;
let expected_packed_bytes = ((n / 2) as i64) * (k as i64);
if (args.weight.shape[0] as i64) != expected_packed_bytes {
return Err(Error::InvalidProblem(
"Nf4MmvqMultiMPlan: weight bytes != (N/2) * K",
));
}
let expected_absmax = (n as i64) * ((k / bs) as i64);
if (args.absmax.shape[0] as i64) != expected_absmax {
return Err(Error::InvalidProblem(
"Nf4MmvqMultiMPlan: absmax length != N * (K / block_size)",
));
}
if args.activations.shape != [m, k] {
return Err(Error::InvalidProblem(
"Nf4MmvqMultiMPlan: activations shape != [M, K]",
));
}
if args.output.shape != [m, n] {
return Err(Error::InvalidProblem(
"Nf4MmvqMultiMPlan: output shape != [M, N]",
));
}
if args.activations.stride[1] != 1 {
return Err(Error::InvalidProblem(
"Nf4MmvqMultiMPlan: activations must be contig along K",
));
}
if args.output.stride[1] != 1 {
return Err(Error::InvalidProblem(
"Nf4MmvqMultiMPlan: output must be contig along N",
));
}
Ok(())
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: Nf4MmvqMultiMArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
if self.desc.base.n == 0 || self.desc.base.k == 0 || self.desc.m == 0 {
return Ok(());
}
let w_ptr = args.weight.data.as_raw().0 as *const c_void;
let amax_ptr = args.absmax.data.as_raw().0 as *const c_void;
let y_ptr = args.activations.data.as_raw().0 as *const c_void;
let out_ptr = args.output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe {
dispatch_gemv_multim::<T>(
self.desc.m,
self.desc.base.n,
self.desc.base.k,
self.desc.base.block_size,
w_ptr,
amax_ptr,
y_ptr,
out_ptr,
stream_ptr,
)
};
map_status(status)
}
}
fn validate_desc(desc: &Nf4Descriptor, plan_name: &'static str) -> Result<()> {
if desc.n < 0 || desc.k < 0 || desc.block_size <= 0 {
return Err(Error::InvalidProblem(
match plan_name {
"Nf4DequantizePlan" => "Nf4DequantizePlan: invalid dims",
"Nf4MmvqPlan" => "Nf4MmvqPlan: invalid dims",
_ => "Nf4 plan: invalid dims",
},
));
}
if (desc.n & 1) != 0 {
return Err(Error::InvalidProblem(
"Nf4 plan: N must be even (pair-packed nibbles)",
));
}
if desc.k % desc.block_size != 0 {
return Err(Error::InvalidProblem(
"Nf4 plan: K must be a multiple of block_size",
));
}
Ok(())
}
fn build_sku(act_kind: ElementKind, op: QuantizeKind) -> KernelSku {
KernelSku {
category: OpCategory::Quantization,
op: op as u16,
element: act_kind,
aux_element: Some(ElementKind::U8),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee: PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
},
}
}
#[cfg(feature = "bnb_nf4")]
#[inline]
unsafe fn dispatch_dequant<T: Nf4Activation>(
n: i32,
k: i32,
block_size: i32,
w_ptr: *const c_void,
absmax: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32 {
match T::KIND {
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_nf4_dequantize_f16_run(
n, k, block_size, w_ptr, absmax, out, stream,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_nf4_dequantize_bf16_run(
n, k, block_size, w_ptr, absmax, out, stream,
)
},
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_nf4_dequantize_f32_run(
n, k, block_size, w_ptr, absmax, out, stream,
)
},
_ => 3,
}
}
#[cfg(not(feature = "bnb_nf4"))]
#[inline]
unsafe fn dispatch_dequant<T: Nf4Activation>(
_: i32, _: i32, _: i32, _: *const c_void, _: *const c_void, _: *mut c_void, _: *mut c_void,
) -> i32 {
3
}
#[cfg(feature = "bnb_nf4")]
#[inline]
unsafe fn dispatch_gemv_m1<T: Nf4Activation>(
n: i32,
k: i32,
block_size: i32,
w_ptr: *const c_void,
absmax: *const c_void,
y: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32 {
match T::KIND {
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_nf4_gemv_m1_f16_run(
n, k, block_size, w_ptr, absmax, y, out, stream,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_nf4_gemv_m1_bf16_run(
n, k, block_size, w_ptr, absmax, y, out, stream,
)
},
_ => 3,
}
}
#[cfg(not(feature = "bnb_nf4"))]
#[inline]
unsafe fn dispatch_gemv_m1<T: Nf4Activation>(
_: i32, _: i32, _: i32, _: *const c_void, _: *const c_void, _: *const c_void,
_: *mut c_void, _: *mut c_void,
) -> i32 {
3
}
#[cfg(feature = "bnb_nf4")]
#[inline]
unsafe fn dispatch_gemv_multim<T: Nf4Activation>(
m: i32,
n: i32,
k: i32,
block_size: i32,
w_ptr: *const c_void,
absmax: *const c_void,
y: *const c_void,
out: *mut c_void,
stream: *mut c_void,
) -> i32 {
use baracuda_kernels_sys as sys;
match (T::KIND, m) {
(ElementKind::F16, 1) => unsafe {
sys::baracuda_kernels_nf4_gemv_m1_f16_run(n, k, block_size, w_ptr, absmax, y, out, stream)
},
(ElementKind::F16, 2) => unsafe {
sys::baracuda_kernels_nf4_gemv_m2_f16_run(n, k, block_size, w_ptr, absmax, y, out, stream)
},
(ElementKind::F16, 4) => unsafe {
sys::baracuda_kernels_nf4_gemv_m4_f16_run(n, k, block_size, w_ptr, absmax, y, out, stream)
},
(ElementKind::F16, 8) => unsafe {
sys::baracuda_kernels_nf4_gemv_m8_f16_run(n, k, block_size, w_ptr, absmax, y, out, stream)
},
(ElementKind::Bf16, 1) => unsafe {
sys::baracuda_kernels_nf4_gemv_m1_bf16_run(n, k, block_size, w_ptr, absmax, y, out, stream)
},
(ElementKind::Bf16, 2) => unsafe {
sys::baracuda_kernels_nf4_gemv_m2_bf16_run(n, k, block_size, w_ptr, absmax, y, out, stream)
},
(ElementKind::Bf16, 4) => unsafe {
sys::baracuda_kernels_nf4_gemv_m4_bf16_run(n, k, block_size, w_ptr, absmax, y, out, stream)
},
(ElementKind::Bf16, 8) => unsafe {
sys::baracuda_kernels_nf4_gemv_m8_bf16_run(n, k, block_size, w_ptr, absmax, y, out, stream)
},
_ => 3,
}
}
#[cfg(not(feature = "bnb_nf4"))]
#[inline]
unsafe fn dispatch_gemv_multim<T: Nf4Activation>(
_: i32, _: i32, _: i32, _: i32, _: *const c_void, _: *const c_void, _: *const c_void,
_: *mut c_void, _: *mut c_void,
) -> i32 {
3
}
pub const NF4_CODEBOOK: [f32; 16] = [
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0,
];
pub fn nf4_quantize_value(x: f32, absmax: f32) -> u8 {
if absmax == 0.0 {
return 7;
}
let xn = x / absmax;
let mut best_i = 0u8;
let mut best_d = (xn - NF4_CODEBOOK[0]).abs();
for i in 1..16u8 {
let d = (xn - NF4_CODEBOOK[i as usize]).abs();
if d < best_d {
best_d = d;
best_i = i;
}
}
best_i
}
pub fn nf4_pack_weight(
weight_fp: &[f32],
n: usize,
k: usize,
block_size: usize,
) -> (alloc::vec::Vec<u8>, alloc::vec::Vec<f32>) {
use alloc::vec;
use alloc::vec::Vec;
assert!(n % 2 == 0, "N must be even");
assert!(k % block_size == 0, "K must be a multiple of block_size");
assert_eq!(weight_fp.len(), n * k);
let blocks_per_row = k / block_size;
let num_blocks = n * blocks_per_row;
let mut absmax: Vec<f32> = vec![0.0; num_blocks];
let mut packed: Vec<u8> = vec![0; (n / 2) * k];
for row in 0..n {
for b in 0..blocks_per_row {
let mut a = 0.0f32;
for j in 0..block_size {
let v = weight_fp[row * k + b * block_size + j].abs();
if v > a {
a = v;
}
}
absmax[row * blocks_per_row + b] = a;
}
}
for row in 0..n {
for b in 0..blocks_per_row {
let a = absmax[row * blocks_per_row + b];
for j in 0..block_size {
let kpos = b * block_size + j;
let code = nf4_quantize_value(weight_fp[row * k + kpos], a);
let byte_off = (row / 2) * k + kpos;
let b_ref = &mut packed[byte_off];
if (row & 1) == 0 {
*b_ref = (*b_ref & 0xF0) | (code & 0x0F);
} else {
*b_ref = (*b_ref & 0x0F) | ((code & 0x0F) << 4);
}
}
}
}
(packed, absmax)
}
extern crate alloc;