use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::ir::PtxType;
use oxicuda_ptx::prelude::*;
use crate::error::DnnError;
const QUANT_BLOCK_SIZE: u32 = 256;
const MIN_DAMP: f32 = 1e-6;
#[derive(Debug, Clone)]
pub enum WeightQuantMethod {
Gptq(GptqConfig),
Awq(AwqConfig),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QuantMethodTag {
Gptq,
Awq,
}
#[derive(Debug, Clone)]
pub struct GptqConfig {
pub bits: u32,
pub group_size: usize,
pub block_size: usize,
pub damp_percent: f32,
pub symmetric: bool,
pub act_order: bool,
pub true_sequential: bool,
}
impl Default for GptqConfig {
fn default() -> Self {
Self {
bits: 4,
group_size: 128,
block_size: 128,
damp_percent: 0.01,
symmetric: true,
act_order: false,
true_sequential: true,
}
}
}
impl GptqConfig {
pub fn validate(&self) -> Result<(), DnnError> {
if !matches!(self.bits, 2 | 3 | 4 | 8) {
return Err(DnnError::InvalidArgument(format!(
"GPTQ bits must be 2, 3, 4, or 8; got {}",
self.bits
)));
}
if self.group_size == 0 {
return Err(DnnError::InvalidArgument(
"GPTQ group_size must be non-zero".into(),
));
}
if self.block_size == 0 {
return Err(DnnError::InvalidArgument(
"GPTQ block_size must be non-zero".into(),
));
}
if self.damp_percent < 0.0 {
return Err(DnnError::InvalidArgument(format!(
"GPTQ damp_percent must be non-negative; got {}",
self.damp_percent
)));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct AwqConfig {
pub bits: u32,
pub group_size: usize,
pub zero_point: bool,
pub search_alpha_min: f32,
pub search_alpha_max: f32,
pub search_steps: usize,
}
impl Default for AwqConfig {
fn default() -> Self {
Self {
bits: 4,
group_size: 128,
zero_point: true,
search_alpha_min: 0.0,
search_alpha_max: 1.0,
search_steps: 20,
}
}
}
impl AwqConfig {
pub fn validate(&self) -> Result<(), DnnError> {
if !matches!(self.bits, 4 | 8) {
return Err(DnnError::InvalidArgument(format!(
"AWQ bits must be 4 or 8; got {}",
self.bits
)));
}
if self.group_size == 0 {
return Err(DnnError::InvalidArgument(
"AWQ group_size must be non-zero".into(),
));
}
if self.search_alpha_min > self.search_alpha_max {
return Err(DnnError::InvalidArgument(format!(
"AWQ search_alpha_min ({}) must be <= search_alpha_max ({})",
self.search_alpha_min, self.search_alpha_max
)));
}
if self.search_steps == 0 {
return Err(DnnError::InvalidArgument(
"AWQ search_steps must be non-zero".into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct QuantizedWeight {
pub method: QuantMethodTag,
pub bits: u32,
pub group_size: usize,
pub rows: usize,
pub cols: usize,
pub packed_weight_elements: usize,
pub scale_elements: usize,
pub zero_point_elements: usize,
pub has_zero_point: bool,
}
#[derive(Debug)]
pub struct WeightQuantPlan {
method: WeightQuantMethod,
weight_rows: usize,
weight_cols: usize,
}
impl WeightQuantPlan {
pub fn new(method: WeightQuantMethod, rows: usize, cols: usize) -> Result<Self, DnnError> {
if rows == 0 || cols == 0 {
return Err(DnnError::InvalidDimension(
"weight dimensions must be non-zero".into(),
));
}
let group_size = match &method {
WeightQuantMethod::Gptq(cfg) => {
cfg.validate()?;
cfg.group_size
}
WeightQuantMethod::Awq(cfg) => {
cfg.validate()?;
cfg.group_size
}
};
if cols % group_size != 0 {
return Err(DnnError::InvalidArgument(format!(
"cols ({cols}) must be divisible by group_size ({group_size})"
)));
}
Ok(Self {
method,
weight_rows: rows,
weight_cols: cols,
})
}
fn bits(&self) -> u32 {
match &self.method {
WeightQuantMethod::Gptq(cfg) => cfg.bits,
WeightQuantMethod::Awq(cfg) => cfg.bits,
}
}
fn group_size(&self) -> usize {
match &self.method {
WeightQuantMethod::Gptq(cfg) => cfg.group_size,
WeightQuantMethod::Awq(cfg) => cfg.group_size,
}
}
pub fn elements_per_u32(&self) -> u32 {
32 / self.bits()
}
fn groups_per_row(&self) -> usize {
self.weight_cols / self.group_size()
}
pub fn generate_gptq_quantize_ptx(&self) -> Result<String, DnnError> {
let cfg = match &self.method {
WeightQuantMethod::Gptq(c) => c,
WeightQuantMethod::Awq(_) => {
return Err(DnnError::InvalidArgument(
"generate_gptq_quantize_ptx called on AWQ plan".into(),
));
}
};
let bits = cfg.bits;
let symmetric = cfg.symmetric;
let qmax = (1u32 << bits) - 1;
let qmax_f = qmax as f64;
let half_range = if symmetric {
(1i64 << (bits - 1)) as f64
} else {
0.0
};
let kernel_name = "dnn_gptq_quantize_f32";
let ptx = KernelBuilder::new(kernel_name)
.target(SmVersion::Sm75)
.max_threads_per_block(QUANT_BLOCK_SIZE)
.param("weight_ptr", PtxType::U64)
.param("hessian_diag_ptr", PtxType::U64)
.param("out_packed_ptr", PtxType::U64)
.param("scale_ptr", PtxType::U64)
.param("zero_ptr", PtxType::U64)
.param("rows", PtxType::U32)
.param("cols", PtxType::U32)
.param("group_size", PtxType::U32)
.param("block_size", PtxType::U32)
.body(move |b| {
let tid = b.global_thread_id_x();
let num_rows = b.load_param_u32("rows");
b.if_lt_u32(tid.clone(), num_rows, move |b| {
let weight_ptr = b.load_param_u64("weight_ptr");
let hess_ptr = b.load_param_u64("hessian_diag_ptr");
let out_ptr = b.load_param_u64("out_packed_ptr");
let scale_ptr = b.load_param_u64("scale_ptr");
let cols_reg = b.load_param_u32("cols");
let group_size_reg = b.load_param_u32("group_size");
let row_off = b.mul_lo_u32(tid.clone(), cols_reg.clone());
let col = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {col}, 0;"));
let loop_lbl = b.fresh_label("gptq_col");
let end_lbl = b.fresh_label("gptq_end");
b.label(&loop_lbl);
let p_done = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.u32 {p_done}, {col}, {cols_reg};"));
b.branch_if(p_done, &end_lbl);
let elem_idx = b.add_u32(row_off.clone(), col.clone());
let w_addr = b.byte_offset_addr(weight_ptr.clone(), elem_idx.clone(), 4u32);
let w_val = b.load_global_f32(w_addr);
let h_addr = b.byte_offset_addr(hess_ptr.clone(), col.clone(), 4u32);
let h_diag = b.load_global_f32(h_addr);
let group_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {group_idx}, {col}, {group_size_reg};"));
let gpr = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {gpr}, {cols_reg}, {group_size_reg};"));
let scale_row_off = b.mul_lo_u32(tid.clone(), gpr);
let scale_idx = b.add_u32(scale_row_off, group_idx);
let s_addr = b.byte_offset_addr(scale_ptr.clone(), scale_idx, 4u32);
let damp_val = b.alloc_reg(PtxType::F32);
let damp_imm = MIN_DAMP;
b.raw_ptx(&format!(
"mov.f32 {damp_val}, 0f{:08X};",
damp_imm.to_bits()
));
let safe_h = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("max.f32 {safe_h}, {h_diag}, {damp_val};"));
let scaled_w = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("div.rn.f32 {scaled_w}, {w_val}, {safe_h};"));
let rounded = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rni.f32.f32 {rounded}, {scaled_w};"));
let zero_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {zero_f}, 0f00000000;"));
let qmax_f_reg = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!(
"mov.f32 {qmax_f_reg}, 0f{:08X};",
(qmax_f as f32).to_bits()
));
if symmetric {
let half_reg = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!(
"mov.f32 {half_reg}, 0f{:08X};",
(half_range as f32).to_bits()
));
let shifted = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("add.rn.f32 {shifted}, {rounded}, {half_reg};"));
b.raw_ptx(&format!("max.f32 {rounded}, {shifted}, {zero_f};"));
}
let clamped = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("min.f32 {clamped}, {rounded}, {qmax_f_reg};"));
b.raw_ptx(&format!("max.f32 {clamped}, {clamped}, {zero_f};"));
let q_int = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("cvt.rzi.u32.f32 {q_int}, {clamped};"));
let scale_val = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("div.rn.f32 {scale_val}, {safe_h}, {qmax_f_reg};"));
b.store_global_f32(s_addr, scale_val);
let epu32 = 32 / bits;
let pack_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {pack_idx}, {elem_idx}, {epu32};"));
let out_addr = b.byte_offset_addr(out_ptr.clone(), pack_idx, 4u32);
let sub_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {sub_idx}, {elem_idx}, {epu32};"));
let bit_off = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {bit_off}, {sub_idx}, {bits};"));
let shifted_q = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("shl.b32 {shifted_q}, {q_int}, {bit_off};"));
b.raw_ptx(&format!(
"atom.global.or.b32 {q_int}, [{out_addr}], {shifted_q};"
));
b.raw_ptx(&format!("add.u32 {col}, {col}, 1;"));
b.branch(&loop_lbl);
b.label(&end_lbl);
});
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(format!("GPTQ quantize: {e}")))?;
Ok(ptx)
}
pub fn generate_gptq_dequantize_ptx(&self) -> Result<String, DnnError> {
let bits = self.bits();
let mask = (1u32 << bits) - 1;
let epu32 = 32 / bits;
let symmetric = match &self.method {
WeightQuantMethod::Gptq(c) => c.symmetric,
WeightQuantMethod::Awq(_) => {
return Err(DnnError::InvalidArgument(
"generate_gptq_dequantize_ptx called on AWQ plan".into(),
));
}
};
let kernel_name = "dnn_gptq_dequantize_f32";
let ptx = KernelBuilder::new(kernel_name)
.target(SmVersion::Sm75)
.max_threads_per_block(QUANT_BLOCK_SIZE)
.param("packed_ptr", PtxType::U64)
.param("scale_ptr", PtxType::U64)
.param("zero_ptr", PtxType::U64)
.param("out_ptr", PtxType::U64)
.param("total_elements", PtxType::U32)
.param("cols", PtxType::U32)
.param("group_size", PtxType::U32)
.body(move |b| {
let tid = b.global_thread_id_x();
let total = b.load_param_u32("total_elements");
b.if_lt_u32(tid.clone(), total, move |b| {
let packed_ptr = b.load_param_u64("packed_ptr");
let scale_ptr = b.load_param_u64("scale_ptr");
let out_ptr = b.load_param_u64("out_ptr");
let cols_reg = b.load_param_u32("cols");
let gs_reg = b.load_param_u32("group_size");
let pack_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {pack_idx}, {tid}, {epu32};"));
let pack_addr = b.byte_offset_addr(packed_ptr, pack_idx, 4u32);
let packed_word = b.load_global_u32(pack_addr);
let sub_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {sub_idx}, {tid}, {epu32};"));
let bit_off = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {bit_off}, {sub_idx}, {bits};"));
let shifted = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("shr.b32 {shifted}, {packed_word}, {bit_off};"));
let q_val = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("and.b32 {q_val}, {shifted}, {mask};"));
let q_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {q_f}, {q_val};"));
let col_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {col_idx}, {tid}, {cols_reg};"));
let group_in_row = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {group_in_row}, {col_idx}, {gs_reg};"));
let row = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {row}, {tid}, {cols_reg};"));
let gpr = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {gpr}, {cols_reg}, {gs_reg};"));
let s_off = b.mul_lo_u32(row, gpr);
let s_idx = b.add_u32(s_off, group_in_row);
let s_idx_clone = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {s_idx_clone}, {s_idx};"));
let s_addr = b.byte_offset_addr(scale_ptr, s_idx, 4u32);
let scale = b.load_global_f32(s_addr);
let result = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {result}, {q_f}, {scale};"));
if !symmetric {
let zero_ptr_r = b.load_param_u64("zero_ptr");
let z_addr = b.byte_offset_addr(zero_ptr_r, s_idx_clone, 4u32);
let zero = b.load_global_f32(z_addr);
let shifted_q = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("sub.rn.f32 {shifted_q}, {q_f}, {zero};"));
b.raw_ptx(&format!("mul.rn.f32 {result}, {shifted_q}, {scale};"));
}
let out_addr = b.byte_offset_addr(out_ptr, tid, 4u32);
b.store_global_f32(out_addr, result);
});
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(format!("GPTQ dequantize: {e}")))?;
Ok(ptx)
}
pub fn generate_awq_scale_search_ptx(&self) -> Result<String, DnnError> {
let cfg = match &self.method {
WeightQuantMethod::Awq(c) => c,
WeightQuantMethod::Gptq(_) => {
return Err(DnnError::InvalidArgument(
"generate_awq_scale_search_ptx called on GPTQ plan".into(),
));
}
};
let alpha_min_bits = cfg.search_alpha_min.to_bits();
let alpha_max_bits = cfg.search_alpha_max.to_bits();
let steps = cfg.search_steps as u32;
let kernel_name = "dnn_awq_scale_search_f32";
let ptx = KernelBuilder::new(kernel_name)
.target(SmVersion::Sm75)
.max_threads_per_block(QUANT_BLOCK_SIZE)
.param("act_stats_ptr", PtxType::U64)
.param("weight_ptr", PtxType::U64)
.param("scale_out_ptr", PtxType::U64)
.param("best_alpha_ptr", PtxType::U64)
.param("num_channels", PtxType::U32)
.param("cols", PtxType::U32)
.body(move |b| {
let tid = b.global_thread_id_x();
let nch = b.load_param_u32("num_channels");
b.if_lt_u32(tid.clone(), nch, move |b| {
let act_ptr = b.load_param_u64("act_stats_ptr");
let scale_out = b.load_param_u64("scale_out_ptr");
let a_addr = b.byte_offset_addr(act_ptr, tid.clone(), 4u32);
let act_val = b.load_global_f32(a_addr);
let alpha_min = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {alpha_min}, 0f{alpha_min_bits:08X};"));
let alpha_max = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {alpha_max}, 0f{alpha_max_bits:08X};"));
let alpha_range = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!(
"sub.rn.f32 {alpha_range}, {alpha_max}, {alpha_min};"
));
let steps_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!(
"mov.f32 {steps_f}, 0f{:08X};",
(steps as f32).to_bits()
));
let step_size = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!(
"div.rn.f32 {step_size}, {alpha_range}, {steps_f};"
));
let abs_act = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("abs.f32 {abs_act}, {act_val};"));
let eps = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {eps}, 0f{:08X};", 1e-7_f32.to_bits()));
let safe_act = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("max.f32 {safe_act}, {abs_act}, {eps};"));
let rsq = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("rsqrt.approx.f32 {rsq}, {safe_act};"));
let scale = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {scale}, {safe_act}, {rsq};"));
let s_addr = b.byte_offset_addr(scale_out, tid, 4u32);
b.store_global_f32(s_addr, scale);
});
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(format!("AWQ scale search: {e}")))?;
Ok(ptx)
}
pub fn generate_awq_quantize_ptx(&self) -> Result<String, DnnError> {
let cfg = match &self.method {
WeightQuantMethod::Awq(c) => c,
WeightQuantMethod::Gptq(_) => {
return Err(DnnError::InvalidArgument(
"generate_awq_quantize_ptx called on GPTQ plan".into(),
));
}
};
let bits = cfg.bits;
let qmax = (1u32 << bits) - 1;
let epu32 = 32 / bits;
let has_zp = cfg.zero_point;
let kernel_name = "dnn_awq_quantize_f32";
let ptx = KernelBuilder::new(kernel_name)
.target(SmVersion::Sm75)
.max_threads_per_block(QUANT_BLOCK_SIZE)
.param("weight_ptr", PtxType::U64)
.param("channel_scales_ptr", PtxType::U64)
.param("out_packed_ptr", PtxType::U64)
.param("scale_ptr", PtxType::U64)
.param("zero_ptr", PtxType::U64)
.param("total_elements", PtxType::U32)
.param("cols", PtxType::U32)
.param("group_size", PtxType::U32)
.body(move |b| {
let tid = b.global_thread_id_x();
let total = b.load_param_u32("total_elements");
b.if_lt_u32(tid.clone(), total, move |b| {
let weight_ptr = b.load_param_u64("weight_ptr");
let ch_scale_ptr = b.load_param_u64("channel_scales_ptr");
let out_ptr = b.load_param_u64("out_packed_ptr");
let scale_ptr = b.load_param_u64("scale_ptr");
let cols_reg = b.load_param_u32("cols");
let gs_reg = b.load_param_u32("group_size");
let w_addr = b.byte_offset_addr(weight_ptr, tid.clone(), 4u32);
let w_val = b.load_global_f32(w_addr);
let ch_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {ch_idx}, {tid}, {cols_reg};"));
let ch_idx_copy = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {ch_idx_copy}, {ch_idx};"));
let cs_addr = b.byte_offset_addr(ch_scale_ptr, ch_idx, 4u32);
let ch_scale = b.load_global_f32(cs_addr);
let w_scaled = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {w_scaled}, {w_val}, {ch_scale};"));
let group_in_row = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {group_in_row}, {ch_idx_copy}, {gs_reg};"));
let row = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {row}, {tid}, {cols_reg};"));
let gpr = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {gpr}, {cols_reg}, {gs_reg};"));
let s_off = b.mul_lo_u32(row, gpr);
let s_idx = b.add_u32(s_off, group_in_row);
let s_addr = b.byte_offset_addr(scale_ptr.clone(), s_idx.clone(), 4u32);
let g_scale = b.load_global_f32(s_addr);
let eps = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {eps}, 0f{:08X};", 1e-12_f32.to_bits()));
let safe_scale = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("max.f32 {safe_scale}, {g_scale}, {eps};"));
let normed = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("div.rn.f32 {normed}, {w_scaled}, {safe_scale};"));
let rounded = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rni.f32.f32 {rounded}, {normed};"));
let zero_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {zero_f}, 0f00000000;"));
let qmax_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!(
"mov.f32 {qmax_f}, 0f{:08X};",
(qmax as f32).to_bits()
));
let clamped = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("max.f32 {clamped}, {rounded}, {zero_f};"));
b.raw_ptx(&format!("min.f32 {clamped}, {clamped}, {qmax_f};"));
let q_int = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("cvt.rzi.u32.f32 {q_int}, {clamped};"));
let pack_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {pack_idx}, {tid}, {epu32};"));
let out_addr = b.byte_offset_addr(out_ptr, pack_idx, 4u32);
let sub_idx = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("rem.u32 {sub_idx}, {tid}, {epu32};"));
let bit_off = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {bit_off}, {sub_idx}, {bits};"));
let shifted_q = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("shl.b32 {shifted_q}, {q_int}, {bit_off};"));
let _dummy = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!(
"atom.global.or.b32 {_dummy}, [{out_addr}], {shifted_q};"
));
if has_zp {
let zero_ptr = b.load_param_u64("zero_ptr");
let z_addr = b.byte_offset_addr(zero_ptr, s_idx, 4u32);
b.store_global_f32(z_addr, zero_f);
}
});
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(format!("AWQ quantize: {e}")))?;
Ok(ptx)
}
pub fn generate_fused_dequant_gemv_ptx(&self) -> Result<String, DnnError> {
let bits = self.bits();
let mask = (1u32 << bits) - 1;
let epu32 = 32 / bits;
let kernel_name = "dnn_fused_dequant_gemv_f32";
let ptx = KernelBuilder::new(kernel_name)
.target(SmVersion::Sm75)
.max_threads_per_block(QUANT_BLOCK_SIZE)
.param("packed_w_ptr", PtxType::U64)
.param("scale_ptr", PtxType::U64)
.param("zero_ptr", PtxType::U64)
.param("x_ptr", PtxType::U64)
.param("y_ptr", PtxType::U64)
.param("rows", PtxType::U32)
.param("cols", PtxType::U32)
.param("group_size", PtxType::U32)
.body(move |b| {
let tid = b.global_thread_id_x();
let num_rows = b.load_param_u32("rows");
b.if_lt_u32(tid.clone(), num_rows, move |b| {
let packed_ptr = b.load_param_u64("packed_w_ptr");
let scale_ptr = b.load_param_u64("scale_ptr");
let x_ptr = b.load_param_u64("x_ptr");
let y_ptr = b.load_param_u64("y_ptr");
let cols_reg = b.load_param_u32("cols");
let gs_reg = b.load_param_u32("group_size");
let acc = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mov.f32 {acc}, 0f00000000;"));
let packed_cols = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {packed_cols}, {cols_reg}, {epu32};"));
let row_pack_off = b.mul_lo_u32(tid.clone(), packed_cols.clone());
let gpr = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {gpr}, {cols_reg}, {gs_reg};"));
let scale_row_off = b.mul_lo_u32(tid.clone(), gpr.clone());
let pw = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mov.u32 {pw}, 0;"));
let loop_lbl = b.fresh_label("gemv_pw");
let end_lbl = b.fresh_label("gemv_end");
b.label(&loop_lbl);
let p_done = b.alloc_reg(PtxType::Pred);
b.raw_ptx(&format!("setp.ge.u32 {p_done}, {pw}, {packed_cols};"));
b.branch_if(p_done, &end_lbl);
let w_idx = b.add_u32(row_pack_off.clone(), pw.clone());
let w_addr = b.byte_offset_addr(packed_ptr.clone(), w_idx, 4u32);
let packed_word = b.load_global_u32(w_addr);
let col_base = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("mul.lo.u32 {col_base}, {pw}, {epu32};"));
for sub in 0..epu32 {
let bit_shift = sub * bits;
let q_val = b.alloc_reg(PtxType::U32);
if bit_shift == 0 {
b.raw_ptx(&format!("and.b32 {q_val}, {packed_word}, {mask};"));
} else {
let tmp = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("shr.b32 {tmp}, {packed_word}, {bit_shift};"));
b.raw_ptx(&format!("and.b32 {q_val}, {tmp}, {mask};"));
}
let q_f = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("cvt.rn.f32.u32 {q_f}, {q_val};"));
let col = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("add.u32 {col}, {col_base}, {sub};"));
let grp = b.alloc_reg(PtxType::U32);
b.raw_ptx(&format!("div.u32 {grp}, {col}, {gs_reg};"));
let s_idx = b.add_u32(scale_row_off.clone(), grp);
let s_addr = b.byte_offset_addr(scale_ptr.clone(), s_idx, 4u32);
let scale = b.load_global_f32(s_addr);
let dq = b.alloc_reg(PtxType::F32);
b.raw_ptx(&format!("mul.rn.f32 {dq}, {q_f}, {scale};"));
let x_addr = b.byte_offset_addr(x_ptr.clone(), col, 4u32);
let x_val = b.load_global_f32(x_addr);
b.raw_ptx(&format!("fma.rn.f32 {acc}, {dq}, {x_val}, {acc};"));
}
b.raw_ptx(&format!("add.u32 {pw}, {pw}, 1;"));
b.branch(&loop_lbl);
b.label(&end_lbl);
let y_addr = b.byte_offset_addr(y_ptr, tid, 4u32);
b.store_global_f32(y_addr, acc);
});
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(format!("fused dequant GEMV: {e}")))?;
Ok(ptx)
}
pub fn quantized_weight_info(&self) -> QuantizedWeight {
let bits = self.bits();
let group_size = self.group_size();
let total_elements = self.weight_rows * self.weight_cols;
let epu32 = 32 / bits;
let packed_weight_elements = total_elements.div_ceil(epu32 as usize);
let groups_per_row = self.groups_per_row();
let scale_elements = self.weight_rows * groups_per_row;
let (has_zero_point, method_tag) = match &self.method {
WeightQuantMethod::Gptq(c) => (!c.symmetric, QuantMethodTag::Gptq),
WeightQuantMethod::Awq(c) => (c.zero_point, QuantMethodTag::Awq),
};
let zero_point_elements = if has_zero_point { scale_elements } else { 0 };
QuantizedWeight {
method: method_tag,
bits,
group_size,
rows: self.weight_rows,
cols: self.weight_cols,
packed_weight_elements,
scale_elements,
zero_point_elements,
has_zero_point,
}
}
pub fn workspace_bytes(&self) -> usize {
match &self.method {
WeightQuantMethod::Gptq(cfg) => {
let hess_bytes = self.weight_cols * 4;
let reorder_bytes = if cfg.act_order {
self.weight_cols * 4 } else {
0
};
let residual_bytes = self.weight_rows * cfg.block_size * 4;
hess_bytes + reorder_bytes + residual_bytes
}
WeightQuantMethod::Awq(cfg) => {
let act_stats = self.weight_cols * 4;
let ch_scales = self.weight_cols * 4;
let alpha_buf = cfg.search_steps * 4;
act_stats + ch_scales + alpha_buf
}
}
}
pub fn shared_memory_bytes(&self) -> usize {
match &self.method {
WeightQuantMethod::Gptq(cfg) => {
cfg.block_size * 4
}
WeightQuantMethod::Awq(cfg) => {
cfg.group_size * 4
}
}
}
pub fn launch_params(&self) -> (usize, usize) {
let block = QUANT_BLOCK_SIZE as usize;
let total_threads = match &self.method {
WeightQuantMethod::Gptq(_) => self.weight_rows,
WeightQuantMethod::Awq(_) => self.weight_rows * self.weight_cols,
};
let grid = total_threads.div_ceil(block);
(grid, block)
}
}
pub struct GptqState {
pub rows: usize,
pub cols: usize,
pub group_size: usize,
pub bits: u32,
pub columns_processed: usize,
pub quantization_error: f64,
}
impl GptqState {
pub fn new(rows: usize, cols: usize, config: &GptqConfig) -> Self {
Self {
rows,
cols,
group_size: config.group_size,
bits: config.bits,
columns_processed: 0,
quantization_error: 0.0,
}
}
pub fn num_groups(&self) -> usize {
self.cols.div_ceil(self.group_size)
}
pub fn num_column_blocks(&self) -> usize {
self.cols.div_ceil(self.group_size)
}
pub fn elements_per_packed_word(&self) -> u32 {
32 / self.bits
}
pub fn is_complete(&self) -> bool {
self.columns_processed >= self.cols
}
}
pub struct AwqChannelScales {
pub num_channels: usize,
pub scales: Vec<f32>,
pub best_alpha: f32,
}
impl AwqChannelScales {
pub fn new(num_channels: usize) -> Self {
Self {
num_channels,
scales: vec![1.0; num_channels],
best_alpha: 0.5,
}
}
pub fn set_scale(&mut self, channel: usize, scale: f32) -> Result<(), DnnError> {
if channel >= self.num_channels {
return Err(DnnError::InvalidArgument(format!(
"channel index {channel} out of range (num_channels = {})",
self.num_channels
)));
}
self.scales[channel] = scale;
Ok(())
}
pub fn apply_alpha(&mut self, alpha: f32) {
self.best_alpha = alpha;
for s in &mut self.scales {
*s = s.powf(alpha);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gptq_config_valid() {
let cfg = GptqConfig {
bits: 4,
group_size: 128,
block_size: 128,
damp_percent: 0.01,
symmetric: true,
act_order: false,
true_sequential: true,
};
assert!(cfg.validate().is_ok());
}
#[test]
fn test_gptq_config_invalid_bits() {
let cfg = GptqConfig {
bits: 5,
..GptqConfig::default()
};
let err = cfg.validate().unwrap_err();
assert!(
err.to_string().contains("bits must be 2, 3, 4, or 8"),
"unexpected error: {err}"
);
}
#[test]
fn test_gptq_config_invalid_group_size() {
let cfg = GptqConfig {
group_size: 0,
..GptqConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_awq_config_valid() {
let cfg = AwqConfig {
bits: 4,
group_size: 128,
zero_point: true,
search_alpha_min: 0.0,
search_alpha_max: 1.0,
search_steps: 20,
};
assert!(cfg.validate().is_ok());
}
#[test]
fn test_awq_config_invalid_bits() {
let cfg = AwqConfig {
bits: 3,
..AwqConfig::default()
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_awq_config_invalid_alpha_range() {
let cfg = AwqConfig {
search_alpha_min: 1.0,
search_alpha_max: 0.0,
..AwqConfig::default()
};
let err = cfg.validate().unwrap_err();
assert!(err.to_string().contains("alpha_min"), "unexpected: {err}");
}
#[test]
fn test_quantized_weight_info_4bit_gptq() {
let plan = WeightQuantPlan::new(
WeightQuantMethod::Gptq(GptqConfig {
bits: 4,
group_size: 128,
symmetric: true,
..GptqConfig::default()
}),
256,
1024,
)
.expect("plan creation should succeed");
let info = plan.quantized_weight_info();
assert_eq!(info.bits, 4);
assert_eq!(info.rows, 256);
assert_eq!(info.cols, 1024);
assert_eq!(info.packed_weight_elements, 32768);
assert_eq!(info.scale_elements, 2048);
assert!(!info.has_zero_point);
assert_eq!(info.zero_point_elements, 0);
assert_eq!(info.method, QuantMethodTag::Gptq);
}
#[test]
fn test_quantized_weight_info_8bit_awq() {
let plan = WeightQuantPlan::new(
WeightQuantMethod::Awq(AwqConfig {
bits: 8,
group_size: 128,
zero_point: true,
..AwqConfig::default()
}),
128,
512,
)
.expect("plan creation should succeed");
let info = plan.quantized_weight_info();
assert_eq!(info.bits, 8);
assert_eq!(info.packed_weight_elements, 16384);
assert_eq!(info.scale_elements, 512);
assert!(info.has_zero_point);
assert_eq!(info.zero_point_elements, 512);
assert_eq!(info.method, QuantMethodTag::Awq);
}
#[test]
fn test_gptq_state_init() {
let cfg = GptqConfig::default();
let state = GptqState::new(256, 1024, &cfg);
assert_eq!(state.rows, 256);
assert_eq!(state.cols, 1024);
assert_eq!(state.columns_processed, 0);
assert!((state.quantization_error - 0.0).abs() < f64::EPSILON);
assert!(!state.is_complete());
}
#[test]
fn test_gptq_state_groups_and_blocks() {
let cfg = GptqConfig {
group_size: 64,
..GptqConfig::default()
};
let state = GptqState::new(128, 512, &cfg);
assert_eq!(state.num_groups(), 8); assert_eq!(state.num_column_blocks(), 8);
}
#[test]
fn test_gptq_state_complete() {
let cfg = GptqConfig::default();
let mut state = GptqState::new(64, 256, &cfg);
state.columns_processed = 256;
assert!(state.is_complete());
}
#[test]
fn test_gptq_elements_per_packed_word() {
for (bits, expected) in [(2, 16), (3, 10), (4, 8), (8, 4)] {
let cfg = GptqConfig {
bits,
..GptqConfig::default()
};
let state = GptqState::new(64, 128, &cfg);
assert_eq!(state.elements_per_packed_word(), expected, "bits={bits}");
}
}
#[test]
fn test_awq_channel_scales_new() {
let scales = AwqChannelScales::new(128);
assert_eq!(scales.num_channels, 128);
assert_eq!(scales.scales.len(), 128);
assert!((scales.scales[0] - 1.0).abs() < f32::EPSILON);
}
#[test]
fn test_awq_channel_scales_set() {
let mut scales = AwqChannelScales::new(4);
assert!(scales.set_scale(2, 3.19).is_ok());
assert!((scales.scales[2] - 3.19).abs() < 1e-6);
}
#[test]
fn test_awq_channel_scales_set_oob() {
let mut scales = AwqChannelScales::new(4);
assert!(scales.set_scale(4, 1.0).is_err());
}
#[test]
fn test_awq_channel_scales_apply_alpha() {
let mut scales = AwqChannelScales::new(3);
scales.scales = vec![4.0, 9.0, 16.0];
scales.apply_alpha(0.5);
assert!((scales.best_alpha - 0.5).abs() < f32::EPSILON);
assert!((scales.scales[0] - 2.0).abs() < 1e-5);
assert!((scales.scales[1] - 3.0).abs() < 1e-5);
assert!((scales.scales[2] - 4.0).abs() < 1e-5);
}
#[test]
fn test_gptq_quantize_ptx_generation() {
let plan = WeightQuantPlan::new(WeightQuantMethod::Gptq(GptqConfig::default()), 128, 256)
.expect("plan");
let ptx = plan.generate_gptq_quantize_ptx().expect("ptx");
assert!(ptx.contains(".entry dnn_gptq_quantize_f32"));
assert!(ptx.contains(".version"));
}
#[test]
fn test_gptq_dequantize_ptx_generation() {
let plan = WeightQuantPlan::new(WeightQuantMethod::Gptq(GptqConfig::default()), 64, 128)
.expect("plan");
let ptx = plan.generate_gptq_dequantize_ptx().expect("ptx");
assert!(ptx.contains(".entry dnn_gptq_dequantize_f32"));
}
#[test]
fn test_awq_scale_search_ptx_generation() {
let plan = WeightQuantPlan::new(WeightQuantMethod::Awq(AwqConfig::default()), 64, 128)
.expect("plan");
let ptx = plan.generate_awq_scale_search_ptx().expect("ptx");
assert!(ptx.contains(".entry dnn_awq_scale_search_f32"));
}
#[test]
fn test_awq_quantize_ptx_generation() {
let plan = WeightQuantPlan::new(WeightQuantMethod::Awq(AwqConfig::default()), 64, 128)
.expect("plan");
let ptx = plan.generate_awq_quantize_ptx().expect("ptx");
assert!(ptx.contains(".entry dnn_awq_quantize_f32"));
}
#[test]
fn test_fused_dequant_gemv_ptx_generation() {
let plan = WeightQuantPlan::new(WeightQuantMethod::Gptq(GptqConfig::default()), 64, 128)
.expect("plan");
let ptx = plan.generate_fused_dequant_gemv_ptx().expect("ptx");
assert!(ptx.contains(".entry dnn_fused_dequant_gemv_f32"));
}
#[test]
fn test_workspace_bytes_gptq() {
let plan = WeightQuantPlan::new(
WeightQuantMethod::Gptq(GptqConfig {
block_size: 128,
act_order: true,
..GptqConfig::default()
}),
256,
1024,
)
.expect("plan");
let ws = plan.workspace_bytes();
assert_eq!(ws, 1024 * 4 + 1024 * 4 + 256 * 128 * 4);
}
#[test]
fn test_shared_memory_bytes() {
let gptq_plan = WeightQuantPlan::new(
WeightQuantMethod::Gptq(GptqConfig {
block_size: 128,
..GptqConfig::default()
}),
64,
128,
)
.expect("plan");
assert_eq!(gptq_plan.shared_memory_bytes(), 128 * 4);
let awq_plan = WeightQuantPlan::new(
WeightQuantMethod::Awq(AwqConfig {
group_size: 64,
..AwqConfig::default()
}),
64,
128,
)
.expect("plan");
assert_eq!(awq_plan.shared_memory_bytes(), 64 * 4);
}
#[test]
fn test_launch_params() {
let plan = WeightQuantPlan::new(WeightQuantMethod::Gptq(GptqConfig::default()), 512, 1024)
.expect("plan");
let (grid, block) = plan.launch_params();
assert_eq!(block, 256);
assert_eq!(grid, 2);
}
#[test]
fn test_group_size_must_divide_cols() {
let result = WeightQuantPlan::new(
WeightQuantMethod::Gptq(GptqConfig {
group_size: 100,
..GptqConfig::default()
}),
64,
256,
);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("divisible"), "unexpected: {err}");
}
}