Skip to main content

hanzo_quant/f8q8/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use float8::F8E4M3;
9use half::f16;
10use hanzo_ml::{DType, Device, Result, Shape, Tensor};
11use hanzo_nn::{Linear, Module};
12
13use crate::{
14    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
15    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
16};
17
18#[cfg(target_feature = "avx")]
19mod avx;
20#[cfg(target_feature = "neon")]
21mod neon;
22#[cfg(target_feature = "simd128")]
23mod simd128;
24
25pub(crate) const QK8_0: usize = 32;
26
27#[derive(Debug, Clone, PartialEq)]
28#[repr(C)]
29pub struct BlockF8Q8 {
30    d: F8E4M3,
31    pub(crate) qs: [i8; QK8_0],
32}
33const _: () = assert!(std::mem::size_of::<BlockF8Q8>() == 33);
34
35impl BlockF8Q8 {
36    pub fn dq_d(&self) -> f32 {
37        self.d.to_f32() / F8E4M3::MAX.to_f32()
38    }
39
40    fn zeros() -> Self {
41        BlockF8Q8 {
42            d: F8E4M3::ZERO,
43            qs: [0i8; QK8_0],
44        }
45    }
46}
47
48// Our own BlockQ8_0 with accessible fields for vec_dot kernels.
49// hanzo_ml's BlockQ8_0 has pub(crate) fields we can't access.
50#[derive(Debug, Clone, PartialEq)]
51#[repr(C)]
52pub struct BlockQ8_0 {
53    pub(crate) d: f16,
54    pub(crate) qs: [i8; QK8_0],
55}
56const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
57
58// ---- GgmlType-like functions ----
59
60fn to_float(xs: &[BlockF8Q8], ys: &mut [f32]) -> Result<()> {
61    let k = ys.len();
62    if !k.is_multiple_of(QK8_0) {
63        hanzo_ml::bail!("dequantize_row_f8q8: {k} is not divisible by {QK8_0}");
64    }
65
66    let nb = k / QK8_0;
67
68    for i in 0..nb {
69        let d = xs[i].dq_d();
70
71        for j in 0..QK8_0 {
72            ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d;
73        }
74    }
75    Ok(())
76}
77
78fn from_float(xs: &[f32], ys: &mut [BlockF8Q8]) -> Result<()> {
79    let k = xs.len();
80    if !k.is_multiple_of(QK8_0) {
81        hanzo_ml::bail!("{k} is not divisible by {QK8_0}");
82    }
83    let nb = k / QK8_0;
84    if ys.len() != nb {
85        hanzo_ml::bail!("size mismatch {} {} {}", xs.len(), ys.len(), QK8_0)
86    }
87    for (i, ys) in ys.iter_mut().enumerate() {
88        let mut amax = 0f32;
89        let xs = &xs[i * QK8_0..(i + 1) * QK8_0];
90        for &x in xs.iter() {
91            amax = amax.max(x.abs())
92        }
93        let d = amax / ((1 << 7) - 1) as f32;
94        let id = if d != 0f32 { 1. / d } else { 0. };
95        ys.d = F8E4M3::from_f32(d * F8E4M3::MAX.to_f32());
96        for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
97            *y = f32::round(x * id) as i8
98        }
99    }
100    Ok(())
101}
102
103#[allow(dead_code)]
104#[allow(unreachable_code)]
105fn vec_dot(n: usize, xs: &[BlockF8Q8], ys: &[BlockQ8_0]) -> Result<f32> {
106    #[cfg(target_feature = "avx")]
107    return avx::vec_dot_f8q8_q8_0(n, xs, ys);
108
109    #[cfg(target_feature = "neon")]
110    return neon::vec_dot_f8q8_q8_0(n, xs, ys);
111
112    #[cfg(target_feature = "simd128")]
113    return simd128::vec_dot_f8q8_q8_0(n, xs, ys);
114
115    vec_dot_unopt(n, xs, ys)
116}
117
118#[allow(dead_code)]
119fn vec_dot_unopt(n: usize, xs: &[BlockF8Q8], ys: &[BlockQ8_0]) -> Result<f32> {
120    let qk = QK8_0;
121    if !n.is_multiple_of(QK8_0) {
122        hanzo_ml::bail!("vec_dot_f8q8_q8_0: {n} is not divisible by {qk}")
123    }
124
125    let mut sumf = 0f32;
126    for (xs, ys) in xs.iter().zip(ys.iter()) {
127        let sum_i = xs
128            .qs
129            .iter()
130            .zip(ys.qs.iter())
131            .map(|(&x, &y)| x as i32 * y as i32)
132            .sum::<i32>();
133        sumf += sum_i as f32 * xs.dq_d() * f16::to_f32(ys.d)
134    }
135    Ok(sumf)
136}
137
138#[allow(dead_code)]
139#[allow(unreachable_code)]
140#[allow(unused)]
141#[cfg(feature = "arm-nightly-feat")]
142fn matmul_i8mm(
143    n: usize,
144    xs_0: &[BlockF8Q8],
145    xs_1: &[BlockF8Q8],
146    ys_0: &[BlockQ8_0],
147    ys_1: &[BlockQ8_0],
148) -> Result<[f32; 4]> {
149    #[cfg(target_feature = "neon")]
150    return neon::i8mm_f8q8_q8_0(n, xs_0, xs_1, ys_0, ys_1);
151
152    hanzo_ml::bail!("Unsupported block type for i8mm");
153}
154
155// ---- F8Q8Linear ----
156
157#[derive(Debug)]
158pub struct F8Q8Linear {
159    data: Vec<BlockF8Q8>,
160    shape: Shape,
161    bias: Option<Tensor>,
162}
163
164impl F8Q8Linear {
165    pub fn from_weight(weight: &Tensor, bias: Option<Tensor>) -> Result<Self> {
166        let shape = weight.shape().clone();
167        let weight_f32 = weight.to_dtype(DType::F32)?.flatten_all()?;
168        let mut weight_data: Vec<f32> = weight_f32.to_vec1()?;
169
170        // Pad to multiple of QK8_0
171        let elem_count = weight_data.len();
172        let padded_count = elem_count.div_ceil(QK8_0) * QK8_0;
173        weight_data.resize(padded_count, 0.0);
174
175        let num_blocks = padded_count / QK8_0;
176        let mut blocks = vec![BlockF8Q8::zeros(); num_blocks];
177        from_float(&weight_data, &mut blocks)?;
178
179        Ok(Self {
180            data: blocks,
181            shape,
182            bias,
183        })
184    }
185
186    fn dequantize(&self, dtype: DType) -> Result<Tensor> {
187        let num_blocks = self.data.len();
188        let total_floats = num_blocks * QK8_0;
189        let mut output = vec![0f32; total_floats];
190        to_float(&self.data, &mut output)?;
191
192        // Trim padding and reshape
193        let n = self.shape.elem_count();
194        let output = &output[..n];
195        Tensor::from_slice(output, &self.shape, &Device::Cpu)?.to_dtype(dtype)
196    }
197}
198
199impl QuantMethod for F8Q8Linear {
200    fn new(method: QuantMethodConfig) -> Result<Self>
201    where
202        Self: Sized,
203    {
204        let _ = method;
205        hanzo_ml::bail!("F8Q8Linear should be constructed via from_weight")
206    }
207
208    fn dequantize_w(&self) -> Result<Tensor> {
209        self.dequantize(DType::F32)
210    }
211
212    fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
213        let dequant_w = self.dequantize(a.dtype())?;
214        let lin = Linear::new(dequant_w, self.bias.clone());
215        lin.forward(a)
216    }
217
218    fn quantized_act_type(&self) -> Option<DType> {
219        None
220    }
221
222    fn dtype_and_device(&self) -> (DType, Device) {
223        (DType::F32, Device::Cpu)
224    }
225
226    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
227        let dequant = self.dequantize(delta.dtype())?;
228        let new_w = (dequant + delta)?;
229        Ok(Arc::new(Self::from_weight(&new_w, self.bias.clone())?))
230    }
231
232    fn apply_isq(
233        self: Arc<Self>,
234        dtype: Option<IsqType>,
235        device: Device,
236        n_quantized: &AtomicUsize,
237        _imatrix_weight: Option<Vec<f32>>,
238        guard: QuantizeOntoGuard,
239    ) -> Result<Arc<dyn QuantMethod>> {
240        match dtype {
241            Some(IsqType::F8Q8) | None => {
242                // Already F8Q8 or no-op, just return self
243                Ok(self)
244            }
245            Some(other) => {
246                // Dequantize and re-quantize to requested type
247                let w = self.dequantize(DType::F32)?;
248                let b = self.bias.clone();
249                let unquant =
250                    crate::UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(w, b)))?;
251                Arc::new(unquant).apply_isq(Some(other), device, n_quantized, None, guard)
252            }
253        }
254    }
255}
256
257// ---- Serialization ----
258//
259// Layout:
260// | UQFF_VERSION (u32) | type=5 (u8) | has_bias (u8) | num_blocks (u32) |
261// | shape_ndims (u32) | shape_dims[] (u32 each) |
262// | raw BlockF8Q8 data (33 * num_blocks bytes) |
263// | [optional bias via serialize_tensor] |
264
265impl QuantizedSerde for F8Q8Linear {
266    fn name(&self) -> &'static str {
267        "f8q8-linear"
268    }
269
270    fn isq_serde_supported(&self) -> bool {
271        true
272    }
273
274    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
275        self.serialize_with_bias(self.bias.clone())
276    }
277
278    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
279        let mut buffer = Vec::new();
280
281        // Version
282        buffer.extend(&UQFF_VERSION.to_le_bytes());
283
284        // ISQ type
285        buffer.push(QuantizedSerdeType::F8Q8 as u8);
286
287        // Has bias
288        buffer.push(bias.is_some() as u8);
289
290        // Num blocks
291        buffer.extend(&(self.data.len() as u32).to_le_bytes());
292
293        // Shape
294        let dims = self.shape.dims();
295        buffer.extend(&(dims.len() as u32).to_le_bytes());
296        for &dim in dims {
297            buffer.extend(&(dim as u32).to_le_bytes());
298        }
299
300        // Raw block data
301        let block_bytes: &[u8] = unsafe {
302            std::slice::from_raw_parts(
303                self.data.as_ptr() as *const u8,
304                self.data.len() * std::mem::size_of::<BlockF8Q8>(),
305            )
306        };
307        buffer.extend(block_bytes);
308
309        // Optional bias
310        if let Some(ref b) = bias {
311            serialize_tensor(&mut buffer, b)?;
312        }
313
314        Ok(Cow::from(buffer))
315    }
316
317    fn deserialize(
318        data: Cow<[u8]>,
319        device: &Device,
320        _comm: &Arc<crate::Comm>,
321        guard: QuantizeOntoGuard,
322    ) -> Result<Arc<dyn QuantMethod>>
323    where
324        Self: Sized,
325    {
326        let mut buffer = Cursor::new(data.to_vec());
327
328        let version = buffer.read_u32::<LittleEndian>()?;
329        if let Err(e) = version_is_compatible(version) {
330            return Err(hanzo_ml::Error::wrap(e));
331        }
332
333        let isq_type = buffer.read_u8()? as usize;
334        if isq_type != QuantizedSerdeType::F8Q8 as usize {
335            hanzo_ml::bail!(
336                "ISQ type ({isq_type}) doesn't match expected type {}",
337                QuantizedSerdeType::F8Q8 as usize
338            );
339        }
340
341        let has_bias = buffer.read_u8()? != 0;
342
343        let num_blocks = buffer.read_u32::<LittleEndian>()? as usize;
344
345        // Shape
346        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
347        let mut dims = Vec::with_capacity(n_dims);
348        for _ in 0..n_dims {
349            dims.push(buffer.read_u32::<LittleEndian>()? as usize);
350        }
351        let shape = Shape::from_dims(&dims);
352
353        // Raw block data
354        let block_byte_count = num_blocks * std::mem::size_of::<BlockF8Q8>();
355        let mut raw_data = vec![0u8; block_byte_count];
356        std::io::Read::read_exact(&mut buffer, &mut raw_data)?;
357
358        // Safety: BlockF8Q8 is #[repr(C)] and 33 bytes
359        let blocks: Vec<BlockF8Q8> = unsafe {
360            let mut blocks = Vec::with_capacity(num_blocks);
361            std::ptr::copy_nonoverlapping(
362                raw_data.as_ptr(),
363                blocks.as_mut_ptr() as *mut u8,
364                block_byte_count,
365            );
366            blocks.set_len(num_blocks);
367            blocks
368        };
369
370        let _acquired_load_guard = guard.acquire(device);
371
372        let bias = if has_bias {
373            Some(deserialize_tensor(&mut buffer, device)?)
374        } else {
375            None
376        };
377
378        Ok(Arc::new(F8Q8Linear {
379            data: blocks,
380            shape,
381            bias,
382        }))
383    }
384
385    fn deserialize_ext_bias(
386        data: Cow<[u8]>,
387        device: &Device,
388        guard: QuantizeOntoGuard,
389    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
390    where
391        Self: Sized,
392    {
393        let mut buffer = Cursor::new(data.to_vec());
394
395        let version = buffer.read_u32::<LittleEndian>()?;
396        if let Err(e) = version_is_compatible(version) {
397            return Err(hanzo_ml::Error::wrap(e));
398        }
399
400        let isq_type = buffer.read_u8()? as usize;
401        if isq_type != QuantizedSerdeType::F8Q8 as usize {
402            hanzo_ml::bail!(
403                "ISQ type ({isq_type}) doesn't match expected type {}",
404                QuantizedSerdeType::F8Q8 as usize
405            );
406        }
407
408        let has_bias = buffer.read_u8()? != 0;
409
410        let num_blocks = buffer.read_u32::<LittleEndian>()? as usize;
411
412        // Shape
413        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
414        let mut dims = Vec::with_capacity(n_dims);
415        for _ in 0..n_dims {
416            dims.push(buffer.read_u32::<LittleEndian>()? as usize);
417        }
418        let shape = Shape::from_dims(&dims);
419
420        // Raw block data
421        let block_byte_count = num_blocks * std::mem::size_of::<BlockF8Q8>();
422        let mut raw_data = vec![0u8; block_byte_count];
423        std::io::Read::read_exact(&mut buffer, &mut raw_data)?;
424
425        let blocks: Vec<BlockF8Q8> = unsafe {
426            let mut blocks = Vec::with_capacity(num_blocks);
427            std::ptr::copy_nonoverlapping(
428                raw_data.as_ptr(),
429                blocks.as_mut_ptr() as *mut u8,
430                block_byte_count,
431            );
432            blocks.set_len(num_blocks);
433            blocks
434        };
435
436        let _acquired_load_guard = guard.acquire(device);
437
438        let bias = if has_bias {
439            Some(deserialize_tensor(&mut buffer, device)?)
440        } else {
441            None
442        };
443
444        Ok((
445            Arc::new(F8Q8Linear {
446                data: blocks,
447                shape,
448                bias: None,
449            }),
450            bias,
451        ))
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    #[test]
460    fn test_f8q8_roundtrip() {
461        let data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 128.0).collect();
462        let weight = Tensor::from_slice(&data, (16, 16), &Device::Cpu).unwrap();
463        let linear = F8Q8Linear::from_weight(&weight, None).unwrap();
464        let dequant = linear.dequantize(DType::F32).unwrap();
465        let dequant_data: Vec<f32> = dequant.flatten_all().unwrap().to_vec1().unwrap();
466
467        let mut max_err = 0f32;
468        for (a, b) in data.iter().zip(dequant_data.iter()) {
469            max_err = max_err.max((a - b).abs());
470        }
471        assert!(
472            max_err < 0.1,
473            "F8Q8 roundtrip max error {max_err} exceeds threshold"
474        );
475    }
476
477    #[test]
478    fn test_f8q8_non_divisible_shape() {
479        let data: Vec<f32> = (0..10000).map(|i| (i as f32 - 5000.0) / 5000.0).collect();
480        let weight = Tensor::from_slice(&data, (100, 100), &Device::Cpu).unwrap();
481        let linear = F8Q8Linear::from_weight(&weight, None).unwrap();
482        let dequant = linear.dequantize(DType::F32).unwrap();
483        assert_eq!(dequant.dims(), &[100, 100]);
484
485        let dequant_data: Vec<f32> = dequant.flatten_all().unwrap().to_vec1().unwrap();
486        let mut max_err = 0f32;
487        for (a, b) in data.iter().zip(dequant_data.iter()) {
488            max_err = max_err.max((a - b).abs());
489        }
490        assert!(
491            max_err < 0.1,
492            "F8Q8 non-divisible shape roundtrip max error {max_err} exceeds threshold"
493        );
494    }
495
496    #[test]
497    fn test_f8q8_block_size() {
498        assert_eq!(std::mem::size_of::<BlockF8Q8>(), 33);
499        assert_eq!(std::mem::size_of::<BlockQ8_0>(), 34);
500    }
501}