use std::sync::{atomic::AtomicUsize, Arc};
use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor};
use candle_nn::Linear;
mod ops;
pub use ops::{fp8_blockwise_dequantize, fp8_blockwise_quantize};
#[cfg(feature = "cuda")]
#[allow(unused_imports)]
pub(crate) use ops::{fp8_blockwise_matmul, fp8_indexed_moe_gemm};
#[cfg(feature = "cuda")]
mod ffi;
use crate::{
generate_isq, generate_isq_imatrix,
hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
AfqBits, AfqGroupSize, AfqLayer, DummyLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits,
HqqConfig, HqqLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
QuantizedConfig, QuantizedSerde, Shard, ShardedVarBuilder, UnquantLinear,
};
#[derive(Debug)]
pub struct BlockwiseFP8Linear {
weight: Tensor,
weight_scale_inv: Tensor,
bias: Option<Tensor>,
dequant_dtype: DType,
weight_block_size: Vec<usize>,
}
impl QuantMethod for BlockwiseFP8Linear {
fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
where
Self: Sized,
{
match method {
QuantMethodConfig::Gguf { .. }
| QuantMethodConfig::GptqAwq { .. }
| QuantMethodConfig::Hqq { .. }
| QuantMethodConfig::Dummy
| QuantMethodConfig::Unquantized(_)
| QuantMethodConfig::Bnb { .. }
| QuantMethodConfig::FP8 { .. }
| QuantMethodConfig::PerTensorFP8 { .. }
| QuantMethodConfig::Afq { .. }
| QuantMethodConfig::MXFP4 { .. } => unreachable!(),
QuantMethodConfig::BlockwiseFP8 {
weight,
weight_scale_inv,
bias,
dequant_dtype,
weight_block_size,
} => Ok(Self {
weight,
weight_scale_inv,
bias,
dequant_dtype,
weight_block_size,
}),
}
}
fn dequantize_w(&self) -> Result<candle_core::Tensor> {
ops::fp8_blockwise_dequantize(
&self.weight,
&self.weight_scale_inv,
self.weight_block_size.to_vec(),
self.dequant_dtype,
)
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
#[cfg(feature = "cuda")]
{
if matches!(x.device(), candle_core::Device::Cuda(_))
&& ffi::HAVE_BLOCKWISE_GEMM_KERNELS
{
let orig_dims = x.dims().to_vec();
let x_2d = if orig_dims.len() > 2 {
let features = orig_dims[orig_dims.len() - 1];
let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
x.reshape((batch_size, features))?
} else {
x.clone()
};
let result = ops::fp8_blockwise_matmul(
&x_2d,
&self.weight,
&self.weight_scale_inv,
&self.weight_block_size,
)?;
let result = if orig_dims.len() > 2 {
let out_features = result.dim(1)?;
let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
new_dims.push(out_features);
result.reshape(new_dims)?
} else {
result
};
if let Some(ref bias) = self.bias {
return result.broadcast_add(bias);
}
return Ok(result);
}
}
let weight = self.dequantize_w()?;
let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
weight,
self.bias.clone(),
)))?;
unquant.forward(x)
}
fn gather_forward(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
#[cfg(feature = "cuda")]
{
if matches!(x.device(), candle_core::Device::Cuda(_))
&& ffi::HAVE_BLOCKWISE_GEMM_KERNELS
{
let result = ops::fp8_indexed_moe_gemm(
x,
&self.weight,
&self.weight_scale_inv,
indices,
&self.weight_block_size,
)?;
if let Some(ref bias) = self.bias {
return result.broadcast_add(bias);
}
return Ok(result);
}
}
let weight = self.dequantize_w()?;
let (n_tokens, n_experts_per_tok) = indices.dims2()?;
let (_n_experts, out_features, _in_features) = weight.dims3()?;
let flat_indices = indices.flatten_all()?;
let weight_selected = weight.index_select(&flat_indices, 0)?;
let x_expanded = if x.dims().len() == 3 && x.dim(1)? == 1 {
x.squeeze(1)?
.unsqueeze(1)?
.broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
.contiguous()?
} else if x.dims().len() == 3 {
x.reshape((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
} else {
x.unsqueeze(1)?
.broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(1)?))?
.contiguous()?
};
let weight_t = weight_selected.transpose(1, 2)?;
let result = x_expanded.matmul(&weight_t)?;
let result = result.reshape((n_tokens, n_experts_per_tok, out_features))?;
if let Some(ref bias) = self.bias {
result.broadcast_add(bias)
} else {
Ok(result)
}
}
fn quantized_act_type(&self) -> Option<DType> {
None
}
fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
candle_core::bail!("BlockwiseFP8Linear does not support add_delta_w")
}
fn dtype_and_device(&self) -> (DType, candle_core::Device) {
(DType::F8E4M3, self.weight.device().clone())
}
fn apply_isq(
self: Arc<Self>,
dtype: Option<IsqType>,
device: Device,
n_quantized: &AtomicUsize,
imatrix_weight: Option<Vec<f32>>,
guard: QuantizeOntoGuard,
) -> Result<Arc<dyn QuantMethod>> {
let weight = ops::fp8_blockwise_dequantize(
&self.weight,
&self.weight_scale_inv,
self.weight_block_size.to_vec(),
self.dequant_dtype,
)?;
match dtype {
Some(IsqType::HQQ4 | IsqType::HQQ8) => {
let _acquired_quantize_guard = guard.acquire(&device);
if imatrix_weight.is_some() {
candle_core::bail!("HQQ does not support imatrix.");
}
n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let bits = match dtype.unwrap() {
IsqType::HQQ8 => HqqBits::Eight,
IsqType::HQQ4 => HqqBits::Four,
_ => unreachable!(),
};
let cfg = HqqConfig {
bits,
group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
axis: HqqAxis::Zero,
optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
round_zeros: false,
channel_wise: true,
};
let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
if let Some(bias) = &self.bias {
let bias = bias
.to_device(&device)?
.to_dtype(res.dtype_and_device().0)?;
Ok(Arc::new(res.with_bias(bias)))
} else {
Ok(Arc::new(res))
}
}
Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
let _acquired_quantize_guard = guard.acquire(&device);
if imatrix_weight.is_some() {
candle_core::bail!("AFQ does not support imatrix.");
}
n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let bits = match dtype.unwrap() {
IsqType::AFQ8 => AfqBits::Eight,
IsqType::AFQ6 => AfqBits::Six,
IsqType::AFQ4 => AfqBits::Four,
IsqType::AFQ3 => AfqBits::Three,
IsqType::AFQ2 => AfqBits::Two,
_ => unreachable!(),
};
Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
weight: weight.to_device(&device)?,
bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
bits,
group_size: AfqGroupSize::default(),
})?))
}
Some(
IsqType::Q2K
| IsqType::Q3K
| IsqType::Q4K
| IsqType::Q4_0
| IsqType::Q4_1
| IsqType::Q5K
| IsqType::Q5_0
| IsqType::Q5_1
| IsqType::Q6K
| IsqType::Q8K
| IsqType::Q8_0
| IsqType::Q8_1,
) => {
let dtype: GgmlDType = dtype.unwrap().try_into()?;
let res = if let Some(imatrix_weight) = imatrix_weight {
generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
} else {
generate_isq!(weight, device, dtype, n_quantized, guard)
};
Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
q_weight: res,
b: self
.bias
.as_ref()
.map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
})?))
}
Some(IsqType::F8E4M3) => {
let _acquired_quantize_guard = guard.acquire(&device);
if imatrix_weight.is_some() {
candle_core::bail!("F8E4M3 does not support imatrix.");
}
let w = weight.to_device(&device)?;
let b = if let Some(b) = &self.bias {
Some(b.to_device(&device)?)
} else {
None
};
Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
lin: Linear::new(w, b),
dtype: DType::F8E4M3,
})?))
}
Some(IsqType::F8Q8) => {
let _acquired_quantize_guard = guard.acquire(&device);
if imatrix_weight.is_some() {
candle_core::bail!("F8Q8 does not support imatrix.");
}
let w = weight.to_device(&device)?;
let b = if let Some(b) = &self.bias {
Some(b.to_device(&device)?)
} else {
None
};
Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
}
Some(IsqType::MXFP4) => {
let _acquired_quantize_guard = guard.acquire(&device);
if imatrix_weight.is_some() {
candle_core::bail!("MXFP4 does not support imatrix.");
}
n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let w = weight.to_device(&device)?;
let b = self
.bias
.as_ref()
.map(|b| b.to_device(&device))
.transpose()?;
crate::MXFP4Layer::quantize(&w, b, &device)
}
None => {
let _acquired_quantize_guard = guard.acquire(&device);
let w = weight.to_device(&device)?;
let b = if let Some(b) = &self.bias {
Some(b.to_device(&device)?)
} else {
None
};
Ok(Arc::new(UnquantLinear::new(
QuantMethodConfig::Unquantized(Linear::new(w, b)),
)?))
}
}
}
}
impl QuantizedSerde for BlockwiseFP8Linear {
fn isq_serde_supported(&self) -> bool {
false
}
fn name(&self) -> &'static str {
"blockwise-fp8-linear"
}
}
pub fn blockwise_fp8_moe(
weight: Tensor,
weight_scale_inv: Tensor,
weight_block_size: Vec<usize>,
dequant_dtype: DType,
) -> Result<Arc<dyn QuantMethod>> {
Ok(Arc::new(BlockwiseFP8Linear {
weight,
weight_scale_inv,
bias: None,
dequant_dtype,
weight_block_size,
}))
}
pub fn blockwise_fp8_linear_b(
in_dim: usize,
out_dim: usize,
config: &QuantizedConfig,
bias: bool,
hints: Shard,
vb: ShardedVarBuilder,
) -> Result<Arc<dyn QuantMethod>> {
let QuantizedConfig::Fp8 { weight_block_size } = config else {
candle_core::bail!("Unexpected quantization config.")
};
if vb.contains_tensor("weight") && !vb.contains_tensor("weight_scale_inv") {
return crate::linear_b(in_dim, out_dim, bias, &None, vb);
}
if !(vb.contains_tensor("weight") && vb.contains_tensor("weight_scale_inv")) {
let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
}
let Some(weight_block_size) = weight_block_size else {
candle_core::bail!("Blockwise FP8 requires weight_block_size to be set. Use per-tensor FP8 for models without block sizes.")
};
if weight_block_size.len() != 2 {
candle_core::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}")
}
let weight = vb.get_with_hints_dtype((out_dim, in_dim), "weight", hints, DType::F8E4M3)?;
let weight_scale_inv = vb.get_with_hints_dtype(
(
out_dim.div_ceil(weight_block_size[0]),
in_dim.div_ceil(weight_block_size[1]),
),
"weight_scale_inv",
hints,
DType::F32,
)?;
let bias = if bias {
Some(vb.get((out_dim,), "bias")?)
} else {
None
};
Ok(Arc::new(BlockwiseFP8Linear {
weight,
weight_block_size: weight_block_size.clone(),
weight_scale_inv,
bias,
dequant_dtype: vb.dtype(),
}))
}