Skip to main content

hanzo_quant/fp8/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use hanzo_ml::{DType, Device, Result, Tensor, D};
9use hanzo_nn::{Linear, Module};
10use quantize::QuantizationResult;
11
12mod quantize;
13
14use crate::{
15    cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
16    utils::{
17        deserialize_tensor, read_dtype, serialize_tensor, version_is_compatible, write_dtype,
18        UQFF_VERSION,
19    },
20    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
21};
22
23#[derive(Debug)]
24pub struct FP8Linear {
25    lin: Linear,
26    dequant_w_scale: Tensor,
27    dequant_x_scale: Tensor,
28    quant_scale: Tensor,
29    /// Quantized type
30    dtype: DType,
31}
32
33impl QuantMethod for FP8Linear {
34    fn new(method: QuantMethodConfig) -> hanzo_ml::Result<Self>
35    where
36        Self: Sized,
37    {
38        match method {
39            QuantMethodConfig::Gguf { .. }
40            | QuantMethodConfig::GptqAwq { .. }
41            | QuantMethodConfig::Hqq { .. }
42            | QuantMethodConfig::Dummy
43            | QuantMethodConfig::Unquantized(_)
44            | QuantMethodConfig::Bnb { .. }
45            | QuantMethodConfig::BlockwiseFP8 { .. }
46            | QuantMethodConfig::PerTensorFP8 { .. }
47            | QuantMethodConfig::Afq { .. }
48            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
49            QuantMethodConfig::FP8 { lin, dtype } => {
50                let QuantizationResult {
51                    qw,
52                    quantize_scale,
53                    dequantize_scale,
54                } = Self::quantize(lin.weight(), dtype)?;
55                Ok(Self {
56                    lin: Linear::new(qw, lin.bias().cloned()),
57                    dequant_x_scale: dequantize_scale.clone(), // This is probably wrong!
58                    dequant_w_scale: dequantize_scale,
59                    quant_scale: quantize_scale,
60                    dtype,
61                })
62            }
63        }
64    }
65    fn dequantize_w(&self) -> Result<hanzo_ml::Tensor> {
66        Ok(self.dequantize(DType::F32)?.weight().clone())
67    }
68
69    fn forward_raw(&self, x: &Tensor) -> Result<Tensor> {
70        // Batch matrix multiplication
71        maybe_init_cublas_lt_wrapper(x.device().clone());
72
73        match CUBLASLT_CONTROLLER.get_for_device(x.device()) {
74            Some(handle) => {
75                let n_dims = x.dims().len();
76                if n_dims < 3 {
77                    hanzo_ml::bail!(
78                        "FP8Linear `matmul` via cuBLASlt expects `x` to have at least 3 dimensions"
79                    );
80                }
81                // Set up target shape
82                let mut tgt_shape = x.dims().to_vec();
83                *tgt_shape.last_mut().unwrap() = self.lin.weight().dim(0)?;
84
85                // Flatten for correct dims
86                let mut x = x.flatten_to(D::Minus(3))?;
87
88                // Prepare the b tensor. If it is not quantized, quantize it
89                let mut dequant_x_scale = self.dequant_x_scale.clone();
90                if !matches!(x.dtype(), DType::F8E4M3) {
91                    let QuantizationResult {
92                        qw,
93                        quantize_scale: _,
94                        dequantize_scale,
95                    } = Self::quantize(&x, DType::F8E4M3)?;
96                    x = qw;
97                    dequant_x_scale = dequantize_scale;
98                }
99
100                // Handle bias
101                let beta = match self.lin.bias().is_some() {
102                    true => Some(1.0),
103                    false => None,
104                };
105
106                // Naming
107                let a = self.lin.weight().unsqueeze(0)?;
108                let b = x;
109
110                handle
111                    .batch_matmul_f8(
112                        &a,
113                        &b,
114                        &self.dequant_w_scale,
115                        &dequant_x_scale,
116                        &self.quant_scale,
117                        self.lin.bias(),
118                        None,
119                        beta,
120                        None,
121                        None,
122                    )?
123                    .reshape(tgt_shape)
124            }
125            None => {
126                // Dequantize matmul
127                let dequant_x = x.clone();
128                let lin = self.dequantize(x.dtype())?;
129                lin.forward(&dequant_x)
130            }
131        }
132    }
133
134    fn quantized_act_type(&self) -> Option<DType> {
135        None
136    }
137
138    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
139        let dequant = self.dequantize(delta.dtype())?;
140        let new = Linear::new((dequant.weight() + delta)?, dequant.bias().cloned());
141        Ok(Arc::new(Self::new(QuantMethodConfig::FP8 {
142            lin: new,
143            dtype: self.dtype,
144        })?))
145    }
146
147    fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
148        (DType::F8E4M3, self.lin.weight().device().clone())
149    }
150
151    fn apply_isq(
152        self: Arc<Self>,
153        dtype: Option<IsqType>,
154        device: Device,
155        _n_quantized: &AtomicUsize,
156        _imatrix_weight: Option<Vec<f32>>,
157        guard: QuantizeOntoGuard,
158    ) -> Result<Arc<dyn QuantMethod>> {
159        match dtype {
160            Some(IsqType::F8Q8) => {
161                let _acquired_quantize_guard = guard.acquire(&device);
162                let dequant = self.dequantize(DType::F32)?;
163                let w = dequant.weight().to_device(&device)?;
164                let b = dequant.bias().map(|b| b.to_device(&device)).transpose()?;
165                Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
166            }
167            _ => todo!(),
168        }
169    }
170}
171
172// Serialization structure:
173//
174// -----------------------
175// UQFF version, u32, little endian
176// -----------------------
177// ISQ type (3 for fp8), u8, little endian
178// -----------------------
179// Whether bias data is included, u8 boolean
180// -----------------------
181// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
182// -----------------------
183// Dequant W scalar, f32, little endian
184// -----------------------
185// Dequant X scalar, f32, little endian
186// -----------------------
187// Quant scalar, f32, little endian
188// -----------------------
189// Quantization type, u32, little endian
190// -----------------------
191// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
192// -----------------------
193
194impl QuantizedSerde for FP8Linear {
195    fn isq_serde_supported(&self) -> bool {
196        true
197    }
198    fn name(&self) -> &'static str {
199        "fp8-linear"
200    }
201    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
202        self.serialize_with_bias(self.lin.bias().cloned())
203    }
204    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
205        let mut buffer = Vec::new();
206
207        // Version is always first!
208        buffer.extend(&UQFF_VERSION.to_le_bytes());
209
210        // ISQ type for fp8 is 3
211        buffer.push(QuantizedSerdeType::Fp8 as u8);
212
213        // Has bias
214        buffer.push(bias.is_some() as u8);
215
216        // Weight
217        serialize_tensor(&mut buffer, self.lin.weight())?;
218
219        // Dequant a scale
220        buffer.extend(self.dequant_w_scale.to_scalar::<f32>()?.to_le_bytes());
221        // Dequant b scale
222        buffer.extend(self.dequant_x_scale.to_scalar::<f32>()?.to_le_bytes());
223        // Quant scale
224        buffer.extend(self.quant_scale.to_scalar::<f32>()?.to_le_bytes());
225
226        // DType
227        write_dtype(self.dtype, &mut buffer);
228
229        if let Some(bias) = &bias {
230            // Bias
231            serialize_tensor(&mut buffer, bias)?;
232        }
233
234        Ok(Cow::from(buffer))
235    }
236
237    fn deserialize(
238        data: Cow<[u8]>,
239        device: &Device,
240        _comm: &Arc<crate::Comm>,
241        guard: QuantizeOntoGuard,
242    ) -> Result<Arc<dyn QuantMethod>>
243    where
244        Self: Sized,
245    {
246        let mut buffer = Cursor::new(data.to_vec());
247
248        let version = buffer.read_u32::<LittleEndian>()?;
249        if let Err(e) = version_is_compatible(version) {
250            return Err(hanzo_ml::Error::wrap(e));
251        }
252
253        let isq_type = buffer.read_u8()? as usize;
254        if isq_type != QuantizedSerdeType::Fp8 as usize {
255            hanzo_ml::bail!(
256                "ISQ type ({isq_type}) doesn't match expected type {}",
257                QuantizedSerdeType::Fp8 as usize
258            );
259        }
260
261        let has_bias = buffer.read_u8()? != 0;
262
263        let w = deserialize_tensor(&mut buffer, device)?;
264
265        let _acquired_load_guard = guard.acquire(device);
266        let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
267        let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
268        let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
269
270        // DType
271        let dtype = read_dtype(&mut buffer)?;
272
273        let b = if has_bias {
274            Some(deserialize_tensor(&mut buffer, device)?)
275        } else {
276            None
277        };
278
279        Ok(Arc::new(Self {
280            lin: Linear::new(w, b),
281            dequant_w_scale,
282            dequant_x_scale,
283            quant_scale,
284            dtype,
285        }))
286    }
287    fn deserialize_ext_bias(
288        data: Cow<[u8]>,
289        device: &Device,
290        guard: QuantizeOntoGuard,
291    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
292    where
293        Self: Sized,
294    {
295        let mut buffer = Cursor::new(data.to_vec());
296
297        let version = buffer.read_u32::<LittleEndian>()?;
298        if let Err(e) = version_is_compatible(version) {
299            return Err(hanzo_ml::Error::wrap(e));
300        }
301
302        let isq_type = buffer.read_u8()? as usize;
303        if isq_type != QuantizedSerdeType::Fp8 as usize {
304            hanzo_ml::bail!(
305                "ISQ type ({isq_type}) doesn't match expected type {}",
306                QuantizedSerdeType::Fp8 as usize
307            );
308        }
309
310        let has_bias = buffer.read_u8()? != 0;
311
312        let _acquired_load_guard = guard.acquire(device);
313        let w = deserialize_tensor(&mut buffer, device)?;
314
315        let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
316        let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
317        let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
318
319        // DType
320        let dtype = read_dtype(&mut buffer)?;
321
322        let b = if has_bias {
323            Some(deserialize_tensor(&mut buffer, device)?)
324        } else {
325            None
326        };
327
328        Ok((
329            Arc::new(Self {
330                lin: Linear::new(w, None),
331                dequant_w_scale,
332                dequant_x_scale,
333                quant_scale,
334                dtype,
335            }),
336            b,
337        ))
338    }
339}