Skip to main content

Module quantize

Module quantize 

Source
Expand description

Quantization op family — Category P.

Phase 8 splits across two parallel milestones:

  • Milestone 8.1 (sibling): per-tensor + per-channel quantize / dequantize plus fake_quantize. Owns crates/baracuda-kernels-sys/kernels/quantize/per_tensor.cu / per_channel.cu / fake_quantize.cu and the Rust plans for those ops in this quantize/ module.

  • Milestone 8.2 (this work): per-token + per-group quantize / dequantize plus their STE backwards. Used by LLM activation (W8A8 per-row) and weight (GPTQ-style INT4 per-group, g=128) quant. Owns crates/baracuda-kernels-sys/kernels/quantize/per_token.cu / per_group.cu and the plans in this module.

The two milestones share append-only edits to this file, to crate::lib’s re-exports, and to baracuda-kernels-sys/src/lib.rs. No existing entry is rewritten.

Trailblazer dtype coverage: input FP ∈ {f32, f64, f16, bf16}; output int ∈ {s8, u8}. Sub-byte packed types (s4 / u4) are deferred to a later milestone.

Backward convention is the Straight-Through Estimator (STE): dx = (dy / scale) * 1[qmin < q < qmax]. The in-range mask is recomputed inside the BW kernel from the saved input — callers must retain the input tensor for autograd (which they would do anyway).

Re-exports§

pub use dequantize_per_channel::DequantizePerChannelArgs;
pub use dequantize_per_channel::DequantizePerChannelDescriptor;
pub use dequantize_per_channel::DequantizePerChannelPlan;
pub use dequantize_per_channel_backward::DequantizePerChannelBackwardArgs;
pub use dequantize_per_channel_backward::DequantizePerChannelBackwardDescriptor;
pub use dequantize_per_channel_backward::DequantizePerChannelBackwardPlan;
pub use dequantize_per_tensor::DequantizePerTensorArgs;
pub use dequantize_per_tensor::DequantizePerTensorDescriptor;
pub use dequantize_per_tensor::DequantizePerTensorPlan;
pub use dequantize_per_tensor_backward::DequantizePerTensorBackwardArgs;
pub use dequantize_per_tensor_backward::DequantizePerTensorBackwardDescriptor;
pub use dequantize_per_tensor_backward::DequantizePerTensorBackwardPlan;
pub use fake_quantize::FakeQuantizeArgs;
pub use fake_quantize::FakeQuantizeDescriptor;
pub use fake_quantize::FakeQuantizePlan;
pub use fake_quantize_backward::FakeQuantizeBackwardArgs;
pub use fake_quantize_backward::FakeQuantizeBackwardDescriptor;
pub use fake_quantize_backward::FakeQuantizeBackwardPlan;
pub use per_channel::QuantizePerChannelArgs;
pub use per_channel::QuantizePerChannelDescriptor;
pub use per_channel::QuantizePerChannelPlan;
pub use per_channel_backward::QuantizePerChannelBackwardArgs;
pub use per_channel_backward::QuantizePerChannelBackwardDescriptor;
pub use per_channel_backward::QuantizePerChannelBackwardPlan;
pub use per_tensor::QuantizePerTensorArgs;
pub use per_tensor::QuantizePerTensorDescriptor;
pub use per_tensor::QuantizePerTensorPlan;
pub use per_tensor_backward::QuantizePerTensorBackwardArgs;
pub use per_tensor_backward::QuantizePerTensorBackwardDescriptor;
pub use per_tensor_backward::QuantizePerTensorBackwardPlan;
pub use per_token::QuantizePerTokenArgs;
pub use per_token::QuantizePerTokenDescriptor;
pub use per_token::QuantizePerTokenPlan;
pub use per_token_backward::QuantizePerTokenBackwardArgs;
pub use per_token_backward::QuantizePerTokenBackwardDescriptor;
pub use per_token_backward::QuantizePerTokenBackwardPlan;
pub use dequantize_per_token::DequantizePerTokenArgs;
pub use dequantize_per_token::DequantizePerTokenDescriptor;
pub use dequantize_per_token::DequantizePerTokenPlan;
pub use dequantize_per_token_backward::DequantizePerTokenBackwardArgs;
pub use dequantize_per_token_backward::DequantizePerTokenBackwardDescriptor;
pub use dequantize_per_token_backward::DequantizePerTokenBackwardPlan;
pub use per_group::QuantizePerGroupArgs;
pub use per_group::QuantizePerGroupDescriptor;
pub use per_group::QuantizePerGroupPlan;
pub use per_group_backward::QuantizePerGroupBackwardArgs;
pub use per_group_backward::QuantizePerGroupBackwardDescriptor;
pub use per_group_backward::QuantizePerGroupBackwardPlan;
pub use dequantize_per_group::DequantizePerGroupArgs;
pub use dequantize_per_group::DequantizePerGroupDescriptor;
pub use dequantize_per_group::DequantizePerGroupPlan;
pub use dequantize_per_group_backward::DequantizePerGroupBackwardArgs;
pub use dequantize_per_group_backward::DequantizePerGroupBackwardDescriptor;
pub use dequantize_per_group_backward::DequantizePerGroupBackwardPlan;
pub use dynamic_range::DynamicRangeMode;
pub use dynamic_range::DynamicRangeQuantizeArgs;
pub use dynamic_range::DynamicRangeQuantizeDescriptor;
pub use dynamic_range::DynamicRangeQuantizePlan;
pub use dynamic_range::DynamicRangeScope;
pub use quantized_linear::QuantizedLinearArgs;
pub use quantized_linear::QuantizedLinearDescriptor;
pub use quantized_linear::QuantizedLinearPlan;
pub use smoothquant::SmoothQuantLinearArgs;
pub use smoothquant::SmoothQuantLinearDescriptor;
pub use smoothquant::SmoothQuantLinearPlan;
pub use gguf::BlockQ2K;
pub use gguf::BlockQ3K;
pub use gguf::BlockQ4_0;
pub use gguf::BlockQ4_1;
pub use gguf::BlockQ4K;
pub use gguf::BlockQ5_0;
pub use gguf::BlockQ5_1;
pub use gguf::BlockQ5K;
pub use gguf::BlockQ6K;
pub use gguf::BlockQ8_0;
pub use gguf::BlockQ8K;
pub use gguf::GgufDequantizeArgs;
pub use gguf::GgufDequantizeDescriptor;
pub use gguf::GgufDequantizePlan;
pub use gguf::GgufMmvqArgs;
pub use gguf::GgufMmvqDescriptor;
pub use gguf::GgufMmvqPlan;
pub use gguf::GgufMmvqBatchedActivation;
pub use gguf::GgufMmvqBatchedArgs;
pub use gguf::GgufMmvqBatchedDescriptor;
pub use gguf::GgufMmvqBatchedFormat;
pub use gguf::GgufMmvqBatchedPlan;
pub use gguf::GgufMmvqMultiMArgs;
pub use gguf::GgufMmvqMultiMDescriptor;
pub use gguf::GgufMmvqMultiMPlan;
pub use nf4::Nf4Activation;
pub use nf4::Nf4DequantizeArgs;
pub use nf4::Nf4DequantizePlan;
pub use nf4::Nf4Descriptor;
pub use nf4::Nf4MmvqArgs;
pub use nf4::Nf4MmvqMultiMArgs;
pub use nf4::Nf4MmvqMultiMDescriptor;
pub use nf4::Nf4MmvqMultiMPlan;
pub use nf4::Nf4MmvqPlan;
pub use nf4::NF4_CODEBOOK;

Modules§

dequantize_per_channel
dequantize_per_channel forward plan.
dequantize_per_channel_backward
dequantize_per_channel backward plan — dq[i] = dy[i] * scale[c].
dequantize_per_group
dequantize_per_group forward plan.
dequantize_per_group_backward
dequantize_per_group backward plan — straight-through.
dequantize_per_tensor
dequantize_per_tensor forward plan.
dequantize_per_tensor_backward
dequantize_per_tensor backward plan — dq = dy * scale.
dequantize_per_token
dequantize_per_token forward plan.
dequantize_per_token_backward
dequantize_per_token backward plan — straight-through (dq = dy * scale[n]).
dynamic_range
dynamic_range_quantize — compose op (Phase 8 Milestone 8.3).
fake_quantize
fake_quantize forward plan — per-tensor, FP roundtrip.
fake_quantize_backward
fake_quantize backward plan via STE.
gguf
GGUF block-format quantization plans — Phase 8 Milestone 8.4.
nf4
NF4 (NormalFloat 4-bit) dequant + GEMV plans — Phase 53.
per_channel
quantize_per_channel forward plan.
per_channel_backward
quantize_per_channel backward plan via STE.
per_group
quantize_per_group forward plan.
per_group_backward
quantize_per_group backward plan (Straight-Through Estimator).
per_tensor
quantize_per_tensor forward plan — Category P FW trailblazer.
per_tensor_backward
quantize_per_tensor backward plan (Straight-Through Estimator).
per_token
quantize_per_token forward plan.
per_token_backward
quantize_per_token backward plan (Straight-Through Estimator).
quantized_linear
quantized_linear — fused W8A8 quantized matmul (Phase 8.3).
smoothquant
SmoothQuantLinearPlan — Phase 45 zero-new-CUDA composition.

Functions§

default_q_range
Default qmin / qmax for an output integer dtype. Today wired for the two trailblazer output kinds — baracuda_kernels_types::S8 ([-128, 127]) and baracuda_kernels_types::U8 ([0, 255]).