1use std::sync::{atomic::AtomicUsize, Arc};
2
3use hanzo_ml::{quantized::GgmlDType, DType, Device, Result, Tensor};
4use hanzo_nn::Linear;
5
6mod ops;
7pub use ops::{fp8_blockwise_dequantize, fp8_blockwise_quantize};
8#[cfg(feature = "cuda")]
9#[allow(unused_imports)]
10pub(crate) use ops::{fp8_blockwise_matmul, fp8_indexed_moe_gemm};
11
12#[cfg(feature = "cuda")]
13mod ffi;
14
15use crate::{
16 generate_isq, generate_isq_imatrix, has_missing_required_tensors,
17 hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
18 make_dummy_or_error, AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits,
19 HqqConfig, HqqLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
20 QuantizedConfig, QuantizedSerde, Shard, ShardedVarBuilder, UnquantLinear,
21};
22
23#[derive(Debug)]
24pub struct BlockwiseFP8Linear {
25 weight: Tensor,
26 weight_scale_inv: Tensor,
27 bias: Option<Tensor>,
28 dequant_dtype: DType,
29 weight_block_size: Vec<usize>,
30}
31
32impl QuantMethod for BlockwiseFP8Linear {
33 fn new(method: QuantMethodConfig) -> hanzo_ml::Result<Self>
34 where
35 Self: Sized,
36 {
37 match method {
38 QuantMethodConfig::Gguf { .. }
39 | QuantMethodConfig::GptqAwq { .. }
40 | QuantMethodConfig::Hqq { .. }
41 | QuantMethodConfig::Dummy
42 | QuantMethodConfig::Unquantized(_)
43 | QuantMethodConfig::Bnb { .. }
44 | QuantMethodConfig::FP8 { .. }
45 | QuantMethodConfig::PerTensorFP8 { .. }
46 | QuantMethodConfig::Afq { .. }
47 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
48 QuantMethodConfig::BlockwiseFP8 {
49 weight,
50 weight_scale_inv,
51 bias,
52 dequant_dtype,
53 weight_block_size,
54 } => Ok(Self {
55 weight,
56 weight_scale_inv,
57 bias,
58 dequant_dtype,
59 weight_block_size,
60 }),
61 }
62 }
63 fn dequantize_w(&self) -> Result<hanzo_ml::Tensor> {
64 ops::fp8_blockwise_dequantize(
65 &self.weight,
66 &self.weight_scale_inv,
67 self.weight_block_size.to_vec(),
68 self.dequant_dtype,
69 )
70 }
71
72 fn forward_raw(&self, x: &Tensor) -> Result<Tensor> {
73 #[cfg(feature = "cuda")]
75 {
76 if matches!(x.device(), hanzo_ml::Device::Cuda(_)) && ffi::HAVE_BLOCKWISE_GEMM_KERNELS {
77 let orig_dims = x.dims().to_vec();
79 let x_2d = if orig_dims.len() > 2 {
80 let features = orig_dims[orig_dims.len() - 1];
82 let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
83 x.reshape((batch_size, features))?
84 } else {
85 x.clone()
86 };
87
88 let result = ops::fp8_blockwise_matmul(
90 &x_2d,
91 &self.weight,
92 &self.weight_scale_inv,
93 &self.weight_block_size,
94 )?;
95
96 let result = if orig_dims.len() > 2 {
98 let out_features = result.dim(1)?;
99 let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
100 new_dims.push(out_features);
101 result.reshape(new_dims)?
102 } else {
103 result
104 };
105
106 if let Some(ref bias) = self.bias {
108 return result.broadcast_add(bias);
109 }
110 return Ok(result);
111 }
112 }
113
114 let weight = self.dequantize_w()?;
116 let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
118 weight,
119 self.bias.clone(),
120 )))?;
121 unquant.forward(x)
122 }
123
124 fn gather_forward_raw(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
129 #[cfg(feature = "cuda")]
131 {
132 if matches!(x.device(), hanzo_ml::Device::Cuda(_)) && ffi::HAVE_BLOCKWISE_GEMM_KERNELS {
133 let result = ops::fp8_indexed_moe_gemm(
135 x,
136 &self.weight,
137 &self.weight_scale_inv,
138 indices,
139 &self.weight_block_size,
140 )?;
141 if let Some(ref bias) = self.bias {
143 return result.broadcast_add(bias);
144 }
145 return Ok(result);
146 }
147 }
148
149 let weight = self.dequantize_w()?;
151
152 let (n_tokens, n_experts_per_tok) = indices.dims2()?;
158 let (_n_experts, out_features, _in_features) = weight.dims3()?;
159
160 let flat_indices = indices.flatten_all()?;
162
163 let weight_selected = weight.index_select(&flat_indices, 0)?;
166
167 let x_expanded = if x.dims().len() == 3 && x.dim(1)? == 1 {
169 x.squeeze(1)?
171 .unsqueeze(1)?
172 .broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
173 .contiguous()?
174 } else if x.dims().len() == 3 {
175 x.reshape((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
177 } else {
178 x.unsqueeze(1)?
180 .broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(1)?))?
181 .contiguous()?
182 };
183
184 let weight_t = weight_selected.transpose(1, 2)?;
187 let result = x_expanded.matmul(&weight_t)?;
188
189 let result = result.reshape((n_tokens, n_experts_per_tok, out_features))?;
191
192 if let Some(ref bias) = self.bias {
194 result.broadcast_add(bias)
195 } else {
196 Ok(result)
197 }
198 }
199
200 fn quantized_act_type(&self) -> Option<DType> {
201 None
202 }
203
204 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
205 hanzo_ml::bail!("BlockwiseFP8Linear does not support add_delta_w")
206 }
207
208 fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
209 (DType::F8E4M3, self.weight.device().clone())
210 }
211
212 fn apply_isq(
213 self: Arc<Self>,
214 dtype: Option<IsqType>,
215 device: Device,
216 n_quantized: &AtomicUsize,
217 imatrix_weight: Option<Vec<f32>>,
218 guard: QuantizeOntoGuard,
219 ) -> Result<Arc<dyn QuantMethod>> {
220 let weight = ops::fp8_blockwise_dequantize(
221 &self.weight,
222 &self.weight_scale_inv,
223 self.weight_block_size.to_vec(),
224 self.dequant_dtype,
225 )?;
226 match dtype {
227 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
229 let _acquired_quantize_guard = guard.acquire(&device);
230 if imatrix_weight.is_some() {
231 hanzo_ml::bail!("HQQ does not support imatrix.");
233 }
234
235 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
236 let bits = match dtype.unwrap() {
237 IsqType::HQQ8 => HqqBits::Eight,
238 IsqType::HQQ4 => HqqBits::Four,
239 _ => unreachable!(),
243 };
244 let cfg = HqqConfig {
245 bits,
246 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
247 axis: HqqAxis::Zero,
248 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
249 round_zeros: false,
250 channel_wise: true,
251 };
252 let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
253 if let Some(bias) = &self.bias {
254 let bias = bias
255 .to_device(&device)?
256 .to_dtype(res.dtype_and_device().0)?;
257 Ok(Arc::new(res.with_bias(bias)))
258 } else {
259 Ok(Arc::new(res))
260 }
261 }
262 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
263 let _acquired_quantize_guard = guard.acquire(&device);
264 if imatrix_weight.is_some() {
265 hanzo_ml::bail!("AFQ does not support imatrix.");
267 }
268
269 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
270 let bits = match dtype.unwrap() {
271 IsqType::AFQ8 => AfqBits::Eight,
272 IsqType::AFQ6 => AfqBits::Six,
273 IsqType::AFQ4 => AfqBits::Four,
274 IsqType::AFQ3 => AfqBits::Three,
275 IsqType::AFQ2 => AfqBits::Two,
276 _ => unreachable!(),
277 };
278
279 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
280 weight: weight.to_device(&device)?,
281 bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
282 bits,
283 group_size: AfqGroupSize::default(),
284 })?))
285 }
286 Some(
287 IsqType::Q2K
288 | IsqType::Q3K
289 | IsqType::Q4K
290 | IsqType::Q4_0
291 | IsqType::Q4_1
292 | IsqType::Q5K
293 | IsqType::Q5_0
294 | IsqType::Q5_1
295 | IsqType::Q6K
296 | IsqType::Q8K
297 | IsqType::Q8_0
298 | IsqType::Q8_1,
299 ) => {
300 let dtype: GgmlDType = dtype.unwrap().try_into()?;
301 let res = if let Some(imatrix_weight) = imatrix_weight {
302 generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
303 } else {
304 generate_isq!(weight, device, dtype, n_quantized, guard)
305 };
306 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
307 q_weight: res,
308 b: self
309 .bias
310 .as_ref()
311 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
312 })?))
313 }
314 Some(IsqType::F8E4M3) => {
315 let _acquired_quantize_guard = guard.acquire(&device);
316 if imatrix_weight.is_some() {
317 hanzo_ml::bail!("F8E4M3 does not support imatrix.");
319 }
320
321 let w = weight.to_device(&device)?;
322 let b = if let Some(b) = &self.bias {
323 Some(b.to_device(&device)?)
324 } else {
325 None
326 };
327 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
328 lin: Linear::new(w, b),
329 dtype: DType::F8E4M3,
330 })?))
331 }
332 Some(IsqType::F8Q8) => {
333 let _acquired_quantize_guard = guard.acquire(&device);
334 if imatrix_weight.is_some() {
335 hanzo_ml::bail!("F8Q8 does not support imatrix.");
336 }
337
338 let w = weight.to_device(&device)?;
339 let b = if let Some(b) = &self.bias {
340 Some(b.to_device(&device)?)
341 } else {
342 None
343 };
344 Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
345 }
346 Some(IsqType::MXFP4) => {
347 let _acquired_quantize_guard = guard.acquire(&device);
348 if imatrix_weight.is_some() {
349 hanzo_ml::bail!("MXFP4 does not support imatrix.");
350 }
351
352 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
353 let w = weight.to_device(&device)?;
354 let b = self
355 .bias
356 .as_ref()
357 .map(|b| b.to_device(&device))
358 .transpose()?;
359 crate::MXFP4Layer::quantize(&w, b, &device)
360 }
361 None => {
362 let _acquired_quantize_guard = guard.acquire(&device);
363 let w = weight.to_device(&device)?;
366 let b = if let Some(b) = &self.bias {
367 Some(b.to_device(&device)?)
368 } else {
369 None
370 };
371 Ok(Arc::new(UnquantLinear::new(
372 QuantMethodConfig::Unquantized(Linear::new(w, b)),
373 )?))
374 }
375 }
376 }
377}
378
379impl QuantizedSerde for BlockwiseFP8Linear {
402 fn isq_serde_supported(&self) -> bool {
403 false
404 }
405 fn name(&self) -> &'static str {
406 "blockwise-fp8-linear"
407 }
408}
409
410pub fn blockwise_fp8_moe(
413 weight: Tensor,
414 weight_scale_inv: Tensor,
415 weight_block_size: Vec<usize>,
416 dequant_dtype: DType,
417) -> Result<Arc<dyn QuantMethod>> {
418 Ok(Arc::new(BlockwiseFP8Linear {
419 weight,
420 weight_scale_inv,
421 bias: None,
422 dequant_dtype,
423 weight_block_size,
424 }))
425}
426
427pub fn blockwise_fp8_linear_b(
428 in_dim: usize,
429 out_dim: usize,
430 config: &QuantizedConfig,
431 bias: bool,
432 hints: Shard,
433 vb: ShardedVarBuilder,
434) -> Result<Arc<dyn QuantMethod>> {
435 let QuantizedConfig::Fp8 { weight_block_size } = config else {
436 hanzo_ml::bail!("Unexpected quantization config.")
437 };
438
439 if vb.contains_tensor("weight") && !vb.contains_tensor("weight_scale_inv") {
441 return crate::linear_b(in_dim, out_dim, bias, &None, vb);
442 }
443
444 if has_missing_required_tensors(&vb, &["weight", "weight_scale_inv"]) {
445 return make_dummy_or_error("blockwise_fp8_linear", &vb, &["weight", "weight_scale_inv"]);
446 }
447
448 let Some(weight_block_size) = weight_block_size else {
450 hanzo_ml::bail!("Blockwise FP8 requires weight_block_size to be set. Use per-tensor FP8 for models without block sizes.")
451 };
452 if weight_block_size.len() != 2 {
453 hanzo_ml::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}")
454 }
455 let weight = vb.get_with_hints_dtype((out_dim, in_dim), "weight", hints, DType::F8E4M3)?;
456 let weight_scale_inv = vb.get_with_hints_dtype(
457 (
458 out_dim.div_ceil(weight_block_size[0]),
459 in_dim.div_ceil(weight_block_size[1]),
460 ),
461 "weight_scale_inv",
462 hints,
463 DType::F32,
464 )?;
465 let bias = if bias {
466 Some(vb.get((out_dim,), "bias")?)
467 } else {
468 None
469 };
470
471 Ok(Arc::new(BlockwiseFP8Linear {
472 weight,
473 weight_block_size: weight_block_size.clone(),
474 weight_scale_inv,
475 bias,
476 dequant_dtype: vb.dtype(),
477 }))
478}