Skip to main content

hanzo_quant/blockwise_fp8/
mod.rs

1use std::sync::{atomic::AtomicUsize, Arc};
2
3use hanzo_ml::{quantized::GgmlDType, DType, Device, Result, Tensor};
4use hanzo_nn::Linear;
5
6mod ops;
7pub use ops::{fp8_blockwise_dequantize, fp8_blockwise_quantize};
8#[cfg(feature = "cuda")]
9#[allow(unused_imports)]
10pub(crate) use ops::{fp8_blockwise_matmul, fp8_indexed_moe_gemm};
11
12#[cfg(feature = "cuda")]
13mod ffi;
14
15use crate::{
16    generate_isq, generate_isq_imatrix, has_missing_required_tensors,
17    hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
18    make_dummy_or_error, AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits,
19    HqqConfig, HqqLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
20    QuantizedConfig, QuantizedSerde, Shard, ShardedVarBuilder, UnquantLinear,
21};
22
23#[derive(Debug)]
24pub struct BlockwiseFP8Linear {
25    weight: Tensor,
26    weight_scale_inv: Tensor,
27    bias: Option<Tensor>,
28    dequant_dtype: DType,
29    weight_block_size: Vec<usize>,
30}
31
32impl QuantMethod for BlockwiseFP8Linear {
33    fn new(method: QuantMethodConfig) -> hanzo_ml::Result<Self>
34    where
35        Self: Sized,
36    {
37        match method {
38            QuantMethodConfig::Gguf { .. }
39            | QuantMethodConfig::GptqAwq { .. }
40            | QuantMethodConfig::Hqq { .. }
41            | QuantMethodConfig::Dummy
42            | QuantMethodConfig::Unquantized(_)
43            | QuantMethodConfig::Bnb { .. }
44            | QuantMethodConfig::FP8 { .. }
45            | QuantMethodConfig::PerTensorFP8 { .. }
46            | QuantMethodConfig::Afq { .. }
47            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
48            QuantMethodConfig::BlockwiseFP8 {
49                weight,
50                weight_scale_inv,
51                bias,
52                dequant_dtype,
53                weight_block_size,
54            } => Ok(Self {
55                weight,
56                weight_scale_inv,
57                bias,
58                dequant_dtype,
59                weight_block_size,
60            }),
61        }
62    }
63    fn dequantize_w(&self) -> Result<hanzo_ml::Tensor> {
64        ops::fp8_blockwise_dequantize(
65            &self.weight,
66            &self.weight_scale_inv,
67            self.weight_block_size.to_vec(),
68            self.dequant_dtype,
69        )
70    }
71
72    fn forward_raw(&self, x: &Tensor) -> Result<Tensor> {
73        // Try to use native FP8 GEMM kernel on CUDA
74        #[cfg(feature = "cuda")]
75        {
76            if matches!(x.device(), hanzo_ml::Device::Cuda(_)) && ffi::HAVE_BLOCKWISE_GEMM_KERNELS {
77                // Handle batched inputs by flattening to 2D
78                let orig_dims = x.dims().to_vec();
79                let x_2d = if orig_dims.len() > 2 {
80                    // Flatten all but last dim: [batch, seq, features] -> [batch*seq, features]
81                    let features = orig_dims[orig_dims.len() - 1];
82                    let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
83                    x.reshape((batch_size, features))?
84                } else {
85                    x.clone()
86                };
87
88                // Use native FP8 GEMM kernel
89                let result = ops::fp8_blockwise_matmul(
90                    &x_2d,
91                    &self.weight,
92                    &self.weight_scale_inv,
93                    &self.weight_block_size,
94                )?;
95
96                // Reshape back to original batch dimensions
97                let result = if orig_dims.len() > 2 {
98                    let out_features = result.dim(1)?;
99                    let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
100                    new_dims.push(out_features);
101                    result.reshape(new_dims)?
102                } else {
103                    result
104                };
105
106                // Apply bias if present
107                if let Some(ref bias) = self.bias {
108                    return result.broadcast_add(bias);
109                }
110                return Ok(result);
111            }
112        }
113
114        // Fallback: dequantize and use unquantized matmul
115        let weight = self.dequantize_w()?;
116        // Dispatch to unquant. This uses some cublaslt for bias & on cuda always, so it is better
117        let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
118            weight,
119            self.bias.clone(),
120        )))?;
121        unquant.forward(x)
122    }
123
124    /// Compute matmul of `self` and `a`. `self` should contain the weights.
125    ///
126    /// If `a` is (n_tokens, 1, cols), `self` weights are (n_experts, rows, cols),
127    /// then the indices are (n_tokens, n_experts_per_tok).
128    fn gather_forward_raw(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
129        // Try to use native FP8 indexed MoE GEMM kernel on CUDA
130        #[cfg(feature = "cuda")]
131        {
132            if matches!(x.device(), hanzo_ml::Device::Cuda(_)) && ffi::HAVE_BLOCKWISE_GEMM_KERNELS {
133                // Use native FP8 indexed MoE GEMM kernel (expects U32 indices)
134                let result = ops::fp8_indexed_moe_gemm(
135                    x,
136                    &self.weight,
137                    &self.weight_scale_inv,
138                    indices,
139                    &self.weight_block_size,
140                )?;
141                // Apply bias if present (broadcast over tokens and topk)
142                if let Some(ref bias) = self.bias {
143                    return result.broadcast_add(bias);
144                }
145                return Ok(result);
146            }
147        }
148
149        // Fallback: dequantize weights and compute manually
150        let weight = self.dequantize_w()?;
151
152        // Expected shapes:
153        // - x: (n_tokens, 1, hidden_dim) or (n_tokens, n_experts_per_tok, hidden_dim)
154        // - indices: (n_tokens, n_experts_per_tok)
155        // - weight: (n_experts, out_features, in_features)
156
157        let (n_tokens, n_experts_per_tok) = indices.dims2()?;
158        let (_n_experts, out_features, _in_features) = weight.dims3()?;
159
160        // Flatten indices to select expert weights
161        let flat_indices = indices.flatten_all()?;
162
163        // Select weights for each (token, expert) pair
164        // weight_selected: (n_tokens * n_experts_per_tok, out_features, in_features)
165        let weight_selected = weight.index_select(&flat_indices, 0)?;
166
167        // Reshape x for batched matmul
168        let x_expanded = if x.dims().len() == 3 && x.dim(1)? == 1 {
169            // x is (n_tokens, 1, hidden_dim) - broadcast to (n_tokens * n_experts_per_tok, 1, hidden_dim)
170            x.squeeze(1)?
171                .unsqueeze(1)?
172                .broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
173                .contiguous()?
174        } else if x.dims().len() == 3 {
175            // x is (n_tokens, n_experts_per_tok, hidden_dim)
176            x.reshape((n_tokens * n_experts_per_tok, 1, x.dim(2)?))?
177        } else {
178            // x is (n_tokens, hidden_dim)
179            x.unsqueeze(1)?
180                .broadcast_as((n_tokens * n_experts_per_tok, 1, x.dim(1)?))?
181                .contiguous()?
182        };
183
184        // Batched matmul: (batch, 1, k) @ (batch, k, n).T = (batch, 1, n)
185        // weight_selected is (batch, n, k), so we need to transpose last two dims
186        let weight_t = weight_selected.transpose(1, 2)?;
187        let result = x_expanded.matmul(&weight_t)?;
188
189        // Reshape result to (n_tokens, n_experts_per_tok, out_features)
190        let result = result.reshape((n_tokens, n_experts_per_tok, out_features))?;
191
192        // Apply bias if present
193        if let Some(ref bias) = self.bias {
194            result.broadcast_add(bias)
195        } else {
196            Ok(result)
197        }
198    }
199
200    fn quantized_act_type(&self) -> Option<DType> {
201        None
202    }
203
204    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
205        hanzo_ml::bail!("BlockwiseFP8Linear does not support add_delta_w")
206    }
207
208    fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
209        (DType::F8E4M3, self.weight.device().clone())
210    }
211
212    fn apply_isq(
213        self: Arc<Self>,
214        dtype: Option<IsqType>,
215        device: Device,
216        n_quantized: &AtomicUsize,
217        imatrix_weight: Option<Vec<f32>>,
218        guard: QuantizeOntoGuard,
219    ) -> Result<Arc<dyn QuantMethod>> {
220        let weight = ops::fp8_blockwise_dequantize(
221            &self.weight,
222            &self.weight_scale_inv,
223            self.weight_block_size.to_vec(),
224            self.dequant_dtype,
225        )?;
226        match dtype {
227            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
228            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
229                let _acquired_quantize_guard = guard.acquire(&device);
230                if imatrix_weight.is_some() {
231                    // TODO just warn?
232                    hanzo_ml::bail!("HQQ does not support imatrix.");
233                }
234
235                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
236                let bits = match dtype.unwrap() {
237                    IsqType::HQQ8 => HqqBits::Eight,
238                    IsqType::HQQ4 => HqqBits::Four,
239                    // IsqType::HQQ3 => HqqBits::Three,
240                    // IsqType::HQQ2 => HqqBits::Two,
241                    // IsqType::HQQ1 => HqqBits::One,
242                    _ => unreachable!(),
243                };
244                let cfg = HqqConfig {
245                    bits,
246                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
247                    axis: HqqAxis::Zero,
248                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
249                    round_zeros: false,
250                    channel_wise: true,
251                };
252                let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
253                if let Some(bias) = &self.bias {
254                    let bias = bias
255                        .to_device(&device)?
256                        .to_dtype(res.dtype_and_device().0)?;
257                    Ok(Arc::new(res.with_bias(bias)))
258                } else {
259                    Ok(Arc::new(res))
260                }
261            }
262            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
263                let _acquired_quantize_guard = guard.acquire(&device);
264                if imatrix_weight.is_some() {
265                    // TODO just warn?
266                    hanzo_ml::bail!("AFQ does not support imatrix.");
267                }
268
269                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
270                let bits = match dtype.unwrap() {
271                    IsqType::AFQ8 => AfqBits::Eight,
272                    IsqType::AFQ6 => AfqBits::Six,
273                    IsqType::AFQ4 => AfqBits::Four,
274                    IsqType::AFQ3 => AfqBits::Three,
275                    IsqType::AFQ2 => AfqBits::Two,
276                    _ => unreachable!(),
277                };
278
279                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
280                    weight: weight.to_device(&device)?,
281                    bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
282                    bits,
283                    group_size: AfqGroupSize::default(),
284                })?))
285            }
286            Some(
287                IsqType::Q2K
288                | IsqType::Q3K
289                | IsqType::Q4K
290                | IsqType::Q4_0
291                | IsqType::Q4_1
292                | IsqType::Q5K
293                | IsqType::Q5_0
294                | IsqType::Q5_1
295                | IsqType::Q6K
296                | IsqType::Q8K
297                | IsqType::Q8_0
298                | IsqType::Q8_1,
299            ) => {
300                let dtype: GgmlDType = dtype.unwrap().try_into()?;
301                let res = if let Some(imatrix_weight) = imatrix_weight {
302                    generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
303                } else {
304                    generate_isq!(weight, device, dtype, n_quantized, guard)
305                };
306                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
307                    q_weight: res,
308                    b: self
309                        .bias
310                        .as_ref()
311                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
312                })?))
313            }
314            Some(IsqType::F8E4M3) => {
315                let _acquired_quantize_guard = guard.acquire(&device);
316                if imatrix_weight.is_some() {
317                    // TODO just warn?
318                    hanzo_ml::bail!("F8E4M3 does not support imatrix.");
319                }
320
321                let w = weight.to_device(&device)?;
322                let b = if let Some(b) = &self.bias {
323                    Some(b.to_device(&device)?)
324                } else {
325                    None
326                };
327                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
328                    lin: Linear::new(w, b),
329                    dtype: DType::F8E4M3,
330                })?))
331            }
332            Some(IsqType::F8Q8) => {
333                let _acquired_quantize_guard = guard.acquire(&device);
334                if imatrix_weight.is_some() {
335                    hanzo_ml::bail!("F8Q8 does not support imatrix.");
336                }
337
338                let w = weight.to_device(&device)?;
339                let b = if let Some(b) = &self.bias {
340                    Some(b.to_device(&device)?)
341                } else {
342                    None
343                };
344                Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
345            }
346            Some(IsqType::MXFP4) => {
347                let _acquired_quantize_guard = guard.acquire(&device);
348                if imatrix_weight.is_some() {
349                    hanzo_ml::bail!("MXFP4 does not support imatrix.");
350                }
351
352                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
353                let w = weight.to_device(&device)?;
354                let b = self
355                    .bias
356                    .as_ref()
357                    .map(|b| b.to_device(&device))
358                    .transpose()?;
359                crate::MXFP4Layer::quantize(&w, b, &device)
360            }
361            None => {
362                let _acquired_quantize_guard = guard.acquire(&device);
363                // Ignore imatrix altogether
364
365                let w = weight.to_device(&device)?;
366                let b = if let Some(b) = &self.bias {
367                    Some(b.to_device(&device)?)
368                } else {
369                    None
370                };
371                Ok(Arc::new(UnquantLinear::new(
372                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
373                )?))
374            }
375        }
376    }
377}
378
379// Serialization structure:
380//
381// -----------------------
382// UQFF version, u32, little endian
383// -----------------------
384// ISQ type (3 for fp8), u8, little endian
385// -----------------------
386// Whether bias data is included, u8 boolean
387// -----------------------
388// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
389// -----------------------
390// Dequant W scalar, f32, little endian
391// -----------------------
392// Dequant X scalar, f32, little endian
393// -----------------------
394// Quant scalar, f32, little endian
395// -----------------------
396// Quantization type, u32, little endian
397// -----------------------
398// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
399// -----------------------
400
401impl QuantizedSerde for BlockwiseFP8Linear {
402    fn isq_serde_supported(&self) -> bool {
403        false
404    }
405    fn name(&self) -> &'static str {
406        "blockwise-fp8-linear"
407    }
408}
409
410/// Create a BlockwiseFP8Linear for MoE with 3D weights [num_experts, N, K].
411/// This is used by FusedExperts to enable gather_forward with native FP8 GEMM.
412pub fn blockwise_fp8_moe(
413    weight: Tensor,
414    weight_scale_inv: Tensor,
415    weight_block_size: Vec<usize>,
416    dequant_dtype: DType,
417) -> Result<Arc<dyn QuantMethod>> {
418    Ok(Arc::new(BlockwiseFP8Linear {
419        weight,
420        weight_scale_inv,
421        bias: None,
422        dequant_dtype,
423        weight_block_size,
424    }))
425}
426
427pub fn blockwise_fp8_linear_b(
428    in_dim: usize,
429    out_dim: usize,
430    config: &QuantizedConfig,
431    bias: bool,
432    hints: Shard,
433    vb: ShardedVarBuilder,
434) -> Result<Arc<dyn QuantMethod>> {
435    let QuantizedConfig::Fp8 { weight_block_size } = config else {
436        hanzo_ml::bail!("Unexpected quantization config.")
437    };
438
439    // Handle the case where we actually have an unquantized layer
440    if vb.contains_tensor("weight") && !vb.contains_tensor("weight_scale_inv") {
441        return crate::linear_b(in_dim, out_dim, bias, &None, vb);
442    }
443
444    if has_missing_required_tensors(&vb, &["weight", "weight_scale_inv"]) {
445        return make_dummy_or_error("blockwise_fp8_linear", &vb, &["weight", "weight_scale_inv"]);
446    }
447
448    // Blockwise FP8 requires weight_block_size to be set
449    let Some(weight_block_size) = weight_block_size else {
450        hanzo_ml::bail!("Blockwise FP8 requires weight_block_size to be set. Use per-tensor FP8 for models without block sizes.")
451    };
452    if weight_block_size.len() != 2 {
453        hanzo_ml::bail!("Expected weight_block_size to have length 2, got {weight_block_size:?}")
454    }
455    let weight = vb.get_with_hints_dtype((out_dim, in_dim), "weight", hints, DType::F8E4M3)?;
456    let weight_scale_inv = vb.get_with_hints_dtype(
457        (
458            out_dim.div_ceil(weight_block_size[0]),
459            in_dim.div_ceil(weight_block_size[1]),
460        ),
461        "weight_scale_inv",
462        hints,
463        DType::F32,
464    )?;
465    let bias = if bias {
466        Some(vb.get((out_dim,), "bias")?)
467    } else {
468        None
469    };
470
471    Ok(Arc::new(BlockwiseFP8Linear {
472        weight,
473        weight_block_size: weight_block_size.clone(),
474        weight_scale_inv,
475        bias,
476        dequant_dtype: vb.dtype(),
477    }))
478}