Skip to main content

hanzo_quant/pertensor_fp8/
mod.rs

1use std::{
2    borrow::Cow,
3    sync::{atomic::AtomicUsize, Arc},
4};
5
6use hanzo_ml::{quantized::GgmlDType, DType, Device, Result, Tensor};
7use hanzo_nn::Linear;
8
9mod ops;
10
11use crate::{
12    generate_isq, generate_isq_imatrix, has_missing_required_tensors,
13    hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
14    make_dummy_or_error,
15    utils::{serialize_tensor, UQFF_VERSION},
16    AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits, HqqConfig, HqqLayer,
17    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
18    QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
19};
20
21/// Per-tensor FP8 Linear layer with static activation scaling.
22///
23/// This is used for models that have per-tensor FP8 quantization (weight_block_size = null)
24/// with static activation scales. Each linear layer has:
25/// - `<layer>.weight` (FP8 E4M3)
26/// - `<layer>.weight_scale_inv` (F32 scalar) - dequantization scale for weights
27/// - `<layer>.activation_scale` (F32 scalar) - quantization scale for activations
28#[derive(Debug)]
29pub struct PerTensorFP8Linear {
30    weight: Tensor,
31    #[allow(dead_code)]
32    weight_scale_inv: Tensor,
33    #[allow(dead_code)]
34    activation_scale: Option<Tensor>,
35    bias: Option<Tensor>,
36    #[allow(dead_code)]
37    dequant_dtype: DType,
38}
39
40impl QuantMethod for PerTensorFP8Linear {
41    fn new(method: QuantMethodConfig) -> hanzo_ml::Result<Self>
42    where
43        Self: Sized,
44    {
45        match method {
46            QuantMethodConfig::PerTensorFP8 {
47                weight,
48                weight_scale_inv,
49                activation_scale,
50                bias,
51                dequant_dtype,
52            } => {
53                // Dequantize immediately since Hanzo FP8 is storage-only (no ops)
54                let dequant_weight =
55                    ops::fp8_pertensor_dequantize(&weight, &weight_scale_inv, dequant_dtype)?;
56                Ok(Self {
57                    weight: dequant_weight,
58                    weight_scale_inv,
59                    activation_scale,
60                    bias,
61                    dequant_dtype,
62                })
63            }
64            _ => unreachable!(),
65        }
66    }
67
68    fn dequantize_w(&self) -> Result<Tensor> {
69        // Weight is already dequantized on load
70        Ok(self.weight.clone())
71    }
72
73    fn forward_raw(&self, x: &Tensor) -> Result<Tensor> {
74        // Weight is already dequantized, use standard matmul
75        let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
76            self.weight.clone(),
77            self.bias.clone(),
78        )))?;
79        unquant.forward(x)
80    }
81
82    fn quantized_act_type(&self) -> Option<DType> {
83        None
84    }
85
86    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
87        hanzo_ml::bail!("PerTensorFP8Linear does not support add_delta_w")
88    }
89
90    fn dtype_and_device(&self) -> (DType, Device) {
91        (DType::F8E4M3, self.weight.device().clone())
92    }
93
94    fn apply_isq(
95        self: Arc<Self>,
96        dtype: Option<IsqType>,
97        device: Device,
98        n_quantized: &AtomicUsize,
99        imatrix_weight: Option<Vec<f32>>,
100        guard: QuantizeOntoGuard,
101    ) -> Result<Arc<dyn QuantMethod>> {
102        let weight = self.dequantize_w()?;
103        match dtype {
104            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
105                let _acquired_quantize_guard = guard.acquire(&device);
106                if imatrix_weight.is_some() {
107                    hanzo_ml::bail!("HQQ does not support imatrix.");
108                }
109
110                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
111                let bits = match dtype.unwrap() {
112                    IsqType::HQQ8 => HqqBits::Eight,
113                    IsqType::HQQ4 => HqqBits::Four,
114                    _ => unreachable!(),
115                };
116                let cfg = HqqConfig {
117                    bits,
118                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
119                    axis: HqqAxis::Zero,
120                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
121                    round_zeros: false,
122                    channel_wise: true,
123                };
124                let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
125                if let Some(bias) = &self.bias {
126                    let bias = bias
127                        .to_device(&device)?
128                        .to_dtype(res.dtype_and_device().0)?;
129                    Ok(Arc::new(res.with_bias(bias)))
130                } else {
131                    Ok(Arc::new(res))
132                }
133            }
134            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
135                let _acquired_quantize_guard = guard.acquire(&device);
136                if imatrix_weight.is_some() {
137                    hanzo_ml::bail!("AFQ does not support imatrix.");
138                }
139
140                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
141                let bits = match dtype.unwrap() {
142                    IsqType::AFQ8 => AfqBits::Eight,
143                    IsqType::AFQ6 => AfqBits::Six,
144                    IsqType::AFQ4 => AfqBits::Four,
145                    IsqType::AFQ3 => AfqBits::Three,
146                    IsqType::AFQ2 => AfqBits::Two,
147                    _ => unreachable!(),
148                };
149
150                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
151                    weight: weight.to_device(&device)?,
152                    bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
153                    bits,
154                    group_size: AfqGroupSize::default(),
155                })?))
156            }
157            Some(
158                IsqType::Q2K
159                | IsqType::Q3K
160                | IsqType::Q4K
161                | IsqType::Q4_0
162                | IsqType::Q4_1
163                | IsqType::Q5K
164                | IsqType::Q5_0
165                | IsqType::Q5_1
166                | IsqType::Q6K
167                | IsqType::Q8K
168                | IsqType::Q8_0
169                | IsqType::Q8_1,
170            ) => {
171                let dtype: GgmlDType = dtype.unwrap().try_into()?;
172                let res = if let Some(imatrix_weight) = imatrix_weight {
173                    generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
174                } else {
175                    generate_isq!(weight, device, dtype, n_quantized, guard)
176                };
177                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
178                    q_weight: res,
179                    b: self
180                        .bias
181                        .as_ref()
182                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
183                })?))
184            }
185            Some(IsqType::F8E4M3) => {
186                let _acquired_quantize_guard = guard.acquire(&device);
187                if imatrix_weight.is_some() {
188                    hanzo_ml::bail!("F8E4M3 does not support imatrix.");
189                }
190
191                let w = weight.to_device(&device)?;
192                let b = if let Some(b) = &self.bias {
193                    Some(b.to_device(&device)?)
194                } else {
195                    None
196                };
197                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
198                    lin: Linear::new(w, b),
199                    dtype: DType::F8E4M3,
200                })?))
201            }
202            Some(IsqType::F8Q8) => {
203                let _acquired_quantize_guard = guard.acquire(&device);
204                if imatrix_weight.is_some() {
205                    hanzo_ml::bail!("F8Q8 does not support imatrix.");
206                }
207
208                let w = weight.to_device(&device)?;
209                let b = if let Some(b) = &self.bias {
210                    Some(b.to_device(&device)?)
211                } else {
212                    None
213                };
214                Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
215            }
216            Some(IsqType::MXFP4) => {
217                let _acquired_quantize_guard = guard.acquire(&device);
218                if imatrix_weight.is_some() {
219                    hanzo_ml::bail!("MXFP4 does not support imatrix.");
220                }
221
222                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
223                let w = weight.to_device(&device)?;
224                let b = self
225                    .bias
226                    .as_ref()
227                    .map(|b| b.to_device(&device))
228                    .transpose()?;
229                crate::MXFP4Layer::quantize(&w, b, &device)
230            }
231            None => {
232                let _acquired_quantize_guard = guard.acquire(&device);
233
234                let w = weight.to_device(&device)?;
235                let b = if let Some(b) = &self.bias {
236                    Some(b.to_device(&device)?)
237                } else {
238                    None
239                };
240                Ok(Arc::new(UnquantLinear::new(
241                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
242                )?))
243            }
244        }
245    }
246}
247
248// Serialization structure (same as UnquantLinear):
249//
250// -----------------------
251// UQFF version, u32, little endian
252// -----------------------
253// ISQ type (1 for unquantized), u8, little endian
254// -----------------------
255// Whether bias data is included, u8 boolean
256// -----------------------
257// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
258// -----------------------
259// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
260// -----------------------
261
262impl QuantizedSerde for PerTensorFP8Linear {
263    fn isq_serde_supported(&self) -> bool {
264        true
265    }
266    fn name(&self) -> &'static str {
267        "pertensor-fp8-linear"
268    }
269    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
270        self.serialize_with_bias(self.bias.clone())
271    }
272    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
273        // Serialize as unquantized since weights are already dequantized
274        let mut buffer = Vec::new();
275
276        // Version is always first!
277        buffer.extend(&UQFF_VERSION.to_le_bytes());
278
279        // ISQ type for unquant is 1 (same as UnquantLinear)
280        buffer.push(QuantizedSerdeType::Unquant as u8);
281
282        // Has bias
283        buffer.push(bias.is_some() as u8);
284
285        // Weight (already dequantized)
286        serialize_tensor(&mut buffer, &self.weight)?;
287
288        if let Some(bias) = &bias {
289            // Bias
290            serialize_tensor(&mut buffer, bias)?;
291        }
292
293        Ok(Cow::from(buffer))
294    }
295}
296
297/// Load a per-tensor FP8 linear layer from the VarBuilder.
298///
299/// This handles models with per-tensor FP8 quantization where:
300/// - `weight_block_size` is null (per-tensor, not blockwise)
301/// - Each layer has: weight (FP8), weight_scale_inv (F32), activation_scale (F32)
302pub fn pertensor_fp8_linear_b(
303    in_dim: usize,
304    out_dim: usize,
305    _config: &QuantizedConfig,
306    bias: bool,
307    _hints: Shard,
308    vb: ShardedVarBuilder,
309) -> Result<Arc<dyn QuantMethod>> {
310    // Handle the case where we actually have unquantized weights
311    if vb.contains_tensor("weight") && !vb.contains_tensor("weight_scale_inv") {
312        return crate::linear_b(in_dim, out_dim, bias, &None, vb);
313    }
314
315    if has_missing_required_tensors(&vb, &["weight", "weight_scale_inv"]) {
316        return make_dummy_or_error("pertensor_fp8_linear", &vb, &["weight", "weight_scale_inv"]);
317    }
318
319    // Load FP8 weight tensor
320    let weight = vb.get_with_hints_dtype(
321        (out_dim, in_dim),
322        "weight",
323        Default::default(),
324        DType::F8E4M3,
325    )?;
326
327    // Load per-tensor weight scale (scalar)
328    let weight_scale_inv =
329        vb.get_with_hints_dtype((), "weight_scale_inv", Default::default(), DType::F32)?;
330
331    // Load activation scale if present (optional - some models may not have it)
332    let activation_scale = if vb.contains_tensor("activation_scale") {
333        Some(vb.get_with_hints_dtype((), "activation_scale", Default::default(), DType::F32)?)
334    } else {
335        None
336    };
337
338    let bias = if bias && vb.contains_tensor("bias") {
339        Some(vb.get((out_dim,), "bias")?)
340    } else {
341        None
342    };
343
344    // Determine the output dtype for dequantization.
345    // We can't use vb.dtype() as that returns F8E4M3 (the storage type).
346    // Use the bias dtype if available, otherwise default to BF16.
347    let dequant_dtype = bias.as_ref().map(|b| b.dtype()).unwrap_or(DType::BF16);
348
349    // Use new() which handles dequantization (Hanzo FP8 is storage-only)
350    Ok(Arc::new(PerTensorFP8Linear::new(
351        QuantMethodConfig::PerTensorFP8 {
352            weight,
353            weight_scale_inv,
354            activation_scale,
355            bias,
356            dequant_dtype,
357        },
358    )?))
359}