1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use hanzo_ml::{DType, Device, Result, Tensor, D};
9use hanzo_nn::{Linear, Module};
10use quantize::QuantizationResult;
11
12mod quantize;
13
14use crate::{
15 cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
16 utils::{
17 deserialize_tensor, read_dtype, serialize_tensor, version_is_compatible, write_dtype,
18 UQFF_VERSION,
19 },
20 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
21};
22
23#[derive(Debug)]
24pub struct FP8Linear {
25 lin: Linear,
26 dequant_w_scale: Tensor,
27 dequant_x_scale: Tensor,
28 quant_scale: Tensor,
29 dtype: DType,
31}
32
33impl QuantMethod for FP8Linear {
34 fn new(method: QuantMethodConfig) -> hanzo_ml::Result<Self>
35 where
36 Self: Sized,
37 {
38 match method {
39 QuantMethodConfig::Gguf { .. }
40 | QuantMethodConfig::GptqAwq { .. }
41 | QuantMethodConfig::Hqq { .. }
42 | QuantMethodConfig::Dummy
43 | QuantMethodConfig::Unquantized(_)
44 | QuantMethodConfig::Bnb { .. }
45 | QuantMethodConfig::BlockwiseFP8 { .. }
46 | QuantMethodConfig::PerTensorFP8 { .. }
47 | QuantMethodConfig::Afq { .. }
48 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
49 QuantMethodConfig::FP8 { lin, dtype } => {
50 let QuantizationResult {
51 qw,
52 quantize_scale,
53 dequantize_scale,
54 } = Self::quantize(lin.weight(), dtype)?;
55 Ok(Self {
56 lin: Linear::new(qw, lin.bias().cloned()),
57 dequant_x_scale: dequantize_scale.clone(), dequant_w_scale: dequantize_scale,
59 quant_scale: quantize_scale,
60 dtype,
61 })
62 }
63 }
64 }
65 fn dequantize_w(&self) -> Result<hanzo_ml::Tensor> {
66 Ok(self.dequantize(DType::F32)?.weight().clone())
67 }
68
69 fn forward_raw(&self, x: &Tensor) -> Result<Tensor> {
70 maybe_init_cublas_lt_wrapper(x.device().clone());
72
73 match CUBLASLT_CONTROLLER.get_for_device(x.device()) {
74 Some(handle) => {
75 let n_dims = x.dims().len();
76 if n_dims < 3 {
77 hanzo_ml::bail!(
78 "FP8Linear `matmul` via cuBLASlt expects `x` to have at least 3 dimensions"
79 );
80 }
81 let mut tgt_shape = x.dims().to_vec();
83 *tgt_shape.last_mut().unwrap() = self.lin.weight().dim(0)?;
84
85 let mut x = x.flatten_to(D::Minus(3))?;
87
88 let mut dequant_x_scale = self.dequant_x_scale.clone();
90 if !matches!(x.dtype(), DType::F8E4M3) {
91 let QuantizationResult {
92 qw,
93 quantize_scale: _,
94 dequantize_scale,
95 } = Self::quantize(&x, DType::F8E4M3)?;
96 x = qw;
97 dequant_x_scale = dequantize_scale;
98 }
99
100 let beta = match self.lin.bias().is_some() {
102 true => Some(1.0),
103 false => None,
104 };
105
106 let a = self.lin.weight().unsqueeze(0)?;
108 let b = x;
109
110 handle
111 .batch_matmul_f8(
112 &a,
113 &b,
114 &self.dequant_w_scale,
115 &dequant_x_scale,
116 &self.quant_scale,
117 self.lin.bias(),
118 None,
119 beta,
120 None,
121 None,
122 )?
123 .reshape(tgt_shape)
124 }
125 None => {
126 let dequant_x = x.clone();
128 let lin = self.dequantize(x.dtype())?;
129 lin.forward(&dequant_x)
130 }
131 }
132 }
133
134 fn quantized_act_type(&self) -> Option<DType> {
135 None
136 }
137
138 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
139 let dequant = self.dequantize(delta.dtype())?;
140 let new = Linear::new((dequant.weight() + delta)?, dequant.bias().cloned());
141 Ok(Arc::new(Self::new(QuantMethodConfig::FP8 {
142 lin: new,
143 dtype: self.dtype,
144 })?))
145 }
146
147 fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
148 (DType::F8E4M3, self.lin.weight().device().clone())
149 }
150
151 fn apply_isq(
152 self: Arc<Self>,
153 dtype: Option<IsqType>,
154 device: Device,
155 _n_quantized: &AtomicUsize,
156 _imatrix_weight: Option<Vec<f32>>,
157 guard: QuantizeOntoGuard,
158 ) -> Result<Arc<dyn QuantMethod>> {
159 match dtype {
160 Some(IsqType::F8Q8) => {
161 let _acquired_quantize_guard = guard.acquire(&device);
162 let dequant = self.dequantize(DType::F32)?;
163 let w = dequant.weight().to_device(&device)?;
164 let b = dequant.bias().map(|b| b.to_device(&device)).transpose()?;
165 Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
166 }
167 _ => todo!(),
168 }
169 }
170}
171
172impl QuantizedSerde for FP8Linear {
195 fn isq_serde_supported(&self) -> bool {
196 true
197 }
198 fn name(&self) -> &'static str {
199 "fp8-linear"
200 }
201 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
202 self.serialize_with_bias(self.lin.bias().cloned())
203 }
204 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
205 let mut buffer = Vec::new();
206
207 buffer.extend(&UQFF_VERSION.to_le_bytes());
209
210 buffer.push(QuantizedSerdeType::Fp8 as u8);
212
213 buffer.push(bias.is_some() as u8);
215
216 serialize_tensor(&mut buffer, self.lin.weight())?;
218
219 buffer.extend(self.dequant_w_scale.to_scalar::<f32>()?.to_le_bytes());
221 buffer.extend(self.dequant_x_scale.to_scalar::<f32>()?.to_le_bytes());
223 buffer.extend(self.quant_scale.to_scalar::<f32>()?.to_le_bytes());
225
226 write_dtype(self.dtype, &mut buffer);
228
229 if let Some(bias) = &bias {
230 serialize_tensor(&mut buffer, bias)?;
232 }
233
234 Ok(Cow::from(buffer))
235 }
236
237 fn deserialize(
238 data: Cow<[u8]>,
239 device: &Device,
240 _comm: &Arc<crate::Comm>,
241 guard: QuantizeOntoGuard,
242 ) -> Result<Arc<dyn QuantMethod>>
243 where
244 Self: Sized,
245 {
246 let mut buffer = Cursor::new(data.to_vec());
247
248 let version = buffer.read_u32::<LittleEndian>()?;
249 if let Err(e) = version_is_compatible(version) {
250 return Err(hanzo_ml::Error::wrap(e));
251 }
252
253 let isq_type = buffer.read_u8()? as usize;
254 if isq_type != QuantizedSerdeType::Fp8 as usize {
255 hanzo_ml::bail!(
256 "ISQ type ({isq_type}) doesn't match expected type {}",
257 QuantizedSerdeType::Fp8 as usize
258 );
259 }
260
261 let has_bias = buffer.read_u8()? != 0;
262
263 let w = deserialize_tensor(&mut buffer, device)?;
264
265 let _acquired_load_guard = guard.acquire(device);
266 let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
267 let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
268 let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
269
270 let dtype = read_dtype(&mut buffer)?;
272
273 let b = if has_bias {
274 Some(deserialize_tensor(&mut buffer, device)?)
275 } else {
276 None
277 };
278
279 Ok(Arc::new(Self {
280 lin: Linear::new(w, b),
281 dequant_w_scale,
282 dequant_x_scale,
283 quant_scale,
284 dtype,
285 }))
286 }
287 fn deserialize_ext_bias(
288 data: Cow<[u8]>,
289 device: &Device,
290 guard: QuantizeOntoGuard,
291 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
292 where
293 Self: Sized,
294 {
295 let mut buffer = Cursor::new(data.to_vec());
296
297 let version = buffer.read_u32::<LittleEndian>()?;
298 if let Err(e) = version_is_compatible(version) {
299 return Err(hanzo_ml::Error::wrap(e));
300 }
301
302 let isq_type = buffer.read_u8()? as usize;
303 if isq_type != QuantizedSerdeType::Fp8 as usize {
304 hanzo_ml::bail!(
305 "ISQ type ({isq_type}) doesn't match expected type {}",
306 QuantizedSerdeType::Fp8 as usize
307 );
308 }
309
310 let has_bias = buffer.read_u8()? != 0;
311
312 let _acquired_load_guard = guard.acquire(device);
313 let w = deserialize_tensor(&mut buffer, device)?;
314
315 let dequant_w_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
316 let dequant_x_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
317 let quant_scale = Tensor::new(buffer.read_f32::<LittleEndian>()?, device)?;
318
319 let dtype = read_dtype(&mut buffer)?;
321
322 let b = if has_bias {
323 Some(deserialize_tensor(&mut buffer, device)?)
324 } else {
325 None
326 };
327
328 Ok((
329 Arc::new(Self {
330 lin: Linear::new(w, None),
331 dequant_w_scale,
332 dequant_x_scale,
333 quant_scale,
334 dtype,
335 }),
336 b,
337 ))
338 }
339}