1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use float8::F8E4M3;
9use half::f16;
10use hanzo_ml::{DType, Device, Result, Shape, Tensor};
11use hanzo_nn::{Linear, Module};
12
13use crate::{
14 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
15 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
16};
17
18#[cfg(target_feature = "avx")]
19mod avx;
20#[cfg(target_feature = "neon")]
21mod neon;
22#[cfg(target_feature = "simd128")]
23mod simd128;
24
25pub(crate) const QK8_0: usize = 32;
26
27#[derive(Debug, Clone, PartialEq)]
28#[repr(C)]
29pub struct BlockF8Q8 {
30 d: F8E4M3,
31 pub(crate) qs: [i8; QK8_0],
32}
33const _: () = assert!(std::mem::size_of::<BlockF8Q8>() == 33);
34
35impl BlockF8Q8 {
36 pub fn dq_d(&self) -> f32 {
37 self.d.to_f32() / F8E4M3::MAX.to_f32()
38 }
39
40 fn zeros() -> Self {
41 BlockF8Q8 {
42 d: F8E4M3::ZERO,
43 qs: [0i8; QK8_0],
44 }
45 }
46}
47
48#[derive(Debug, Clone, PartialEq)]
51#[repr(C)]
52pub struct BlockQ8_0 {
53 pub(crate) d: f16,
54 pub(crate) qs: [i8; QK8_0],
55}
56const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
57
58fn to_float(xs: &[BlockF8Q8], ys: &mut [f32]) -> Result<()> {
61 let k = ys.len();
62 if !k.is_multiple_of(QK8_0) {
63 hanzo_ml::bail!("dequantize_row_f8q8: {k} is not divisible by {QK8_0}");
64 }
65
66 let nb = k / QK8_0;
67
68 for i in 0..nb {
69 let d = xs[i].dq_d();
70
71 for j in 0..QK8_0 {
72 ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d;
73 }
74 }
75 Ok(())
76}
77
78fn from_float(xs: &[f32], ys: &mut [BlockF8Q8]) -> Result<()> {
79 let k = xs.len();
80 if !k.is_multiple_of(QK8_0) {
81 hanzo_ml::bail!("{k} is not divisible by {QK8_0}");
82 }
83 let nb = k / QK8_0;
84 if ys.len() != nb {
85 hanzo_ml::bail!("size mismatch {} {} {}", xs.len(), ys.len(), QK8_0)
86 }
87 for (i, ys) in ys.iter_mut().enumerate() {
88 let mut amax = 0f32;
89 let xs = &xs[i * QK8_0..(i + 1) * QK8_0];
90 for &x in xs.iter() {
91 amax = amax.max(x.abs())
92 }
93 let d = amax / ((1 << 7) - 1) as f32;
94 let id = if d != 0f32 { 1. / d } else { 0. };
95 ys.d = F8E4M3::from_f32(d * F8E4M3::MAX.to_f32());
96 for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
97 *y = f32::round(x * id) as i8
98 }
99 }
100 Ok(())
101}
102
103#[allow(dead_code)]
104#[allow(unreachable_code)]
105fn vec_dot(n: usize, xs: &[BlockF8Q8], ys: &[BlockQ8_0]) -> Result<f32> {
106 #[cfg(target_feature = "avx")]
107 return avx::vec_dot_f8q8_q8_0(n, xs, ys);
108
109 #[cfg(target_feature = "neon")]
110 return neon::vec_dot_f8q8_q8_0(n, xs, ys);
111
112 #[cfg(target_feature = "simd128")]
113 return simd128::vec_dot_f8q8_q8_0(n, xs, ys);
114
115 vec_dot_unopt(n, xs, ys)
116}
117
118#[allow(dead_code)]
119fn vec_dot_unopt(n: usize, xs: &[BlockF8Q8], ys: &[BlockQ8_0]) -> Result<f32> {
120 let qk = QK8_0;
121 if !n.is_multiple_of(QK8_0) {
122 hanzo_ml::bail!("vec_dot_f8q8_q8_0: {n} is not divisible by {qk}")
123 }
124
125 let mut sumf = 0f32;
126 for (xs, ys) in xs.iter().zip(ys.iter()) {
127 let sum_i = xs
128 .qs
129 .iter()
130 .zip(ys.qs.iter())
131 .map(|(&x, &y)| x as i32 * y as i32)
132 .sum::<i32>();
133 sumf += sum_i as f32 * xs.dq_d() * f16::to_f32(ys.d)
134 }
135 Ok(sumf)
136}
137
138#[allow(dead_code)]
139#[allow(unreachable_code)]
140#[allow(unused)]
141#[cfg(feature = "arm-nightly-feat")]
142fn matmul_i8mm(
143 n: usize,
144 xs_0: &[BlockF8Q8],
145 xs_1: &[BlockF8Q8],
146 ys_0: &[BlockQ8_0],
147 ys_1: &[BlockQ8_0],
148) -> Result<[f32; 4]> {
149 #[cfg(target_feature = "neon")]
150 return neon::i8mm_f8q8_q8_0(n, xs_0, xs_1, ys_0, ys_1);
151
152 hanzo_ml::bail!("Unsupported block type for i8mm");
153}
154
155#[derive(Debug)]
158pub struct F8Q8Linear {
159 data: Vec<BlockF8Q8>,
160 shape: Shape,
161 bias: Option<Tensor>,
162}
163
164impl F8Q8Linear {
165 pub fn from_weight(weight: &Tensor, bias: Option<Tensor>) -> Result<Self> {
166 let shape = weight.shape().clone();
167 let weight_f32 = weight.to_dtype(DType::F32)?.flatten_all()?;
168 let mut weight_data: Vec<f32> = weight_f32.to_vec1()?;
169
170 let elem_count = weight_data.len();
172 let padded_count = elem_count.div_ceil(QK8_0) * QK8_0;
173 weight_data.resize(padded_count, 0.0);
174
175 let num_blocks = padded_count / QK8_0;
176 let mut blocks = vec![BlockF8Q8::zeros(); num_blocks];
177 from_float(&weight_data, &mut blocks)?;
178
179 Ok(Self {
180 data: blocks,
181 shape,
182 bias,
183 })
184 }
185
186 fn dequantize(&self, dtype: DType) -> Result<Tensor> {
187 let num_blocks = self.data.len();
188 let total_floats = num_blocks * QK8_0;
189 let mut output = vec![0f32; total_floats];
190 to_float(&self.data, &mut output)?;
191
192 let n = self.shape.elem_count();
194 let output = &output[..n];
195 Tensor::from_slice(output, &self.shape, &Device::Cpu)?.to_dtype(dtype)
196 }
197}
198
199impl QuantMethod for F8Q8Linear {
200 fn new(method: QuantMethodConfig) -> Result<Self>
201 where
202 Self: Sized,
203 {
204 let _ = method;
205 hanzo_ml::bail!("F8Q8Linear should be constructed via from_weight")
206 }
207
208 fn dequantize_w(&self) -> Result<Tensor> {
209 self.dequantize(DType::F32)
210 }
211
212 fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
213 let dequant_w = self.dequantize(a.dtype())?;
214 let lin = Linear::new(dequant_w, self.bias.clone());
215 lin.forward(a)
216 }
217
218 fn quantized_act_type(&self) -> Option<DType> {
219 None
220 }
221
222 fn dtype_and_device(&self) -> (DType, Device) {
223 (DType::F32, Device::Cpu)
224 }
225
226 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
227 let dequant = self.dequantize(delta.dtype())?;
228 let new_w = (dequant + delta)?;
229 Ok(Arc::new(Self::from_weight(&new_w, self.bias.clone())?))
230 }
231
232 fn apply_isq(
233 self: Arc<Self>,
234 dtype: Option<IsqType>,
235 device: Device,
236 n_quantized: &AtomicUsize,
237 _imatrix_weight: Option<Vec<f32>>,
238 guard: QuantizeOntoGuard,
239 ) -> Result<Arc<dyn QuantMethod>> {
240 match dtype {
241 Some(IsqType::F8Q8) | None => {
242 Ok(self)
244 }
245 Some(other) => {
246 let w = self.dequantize(DType::F32)?;
248 let b = self.bias.clone();
249 let unquant =
250 crate::UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(w, b)))?;
251 Arc::new(unquant).apply_isq(Some(other), device, n_quantized, None, guard)
252 }
253 }
254 }
255}
256
257impl QuantizedSerde for F8Q8Linear {
266 fn name(&self) -> &'static str {
267 "f8q8-linear"
268 }
269
270 fn isq_serde_supported(&self) -> bool {
271 true
272 }
273
274 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
275 self.serialize_with_bias(self.bias.clone())
276 }
277
278 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
279 let mut buffer = Vec::new();
280
281 buffer.extend(&UQFF_VERSION.to_le_bytes());
283
284 buffer.push(QuantizedSerdeType::F8Q8 as u8);
286
287 buffer.push(bias.is_some() as u8);
289
290 buffer.extend(&(self.data.len() as u32).to_le_bytes());
292
293 let dims = self.shape.dims();
295 buffer.extend(&(dims.len() as u32).to_le_bytes());
296 for &dim in dims {
297 buffer.extend(&(dim as u32).to_le_bytes());
298 }
299
300 let block_bytes: &[u8] = unsafe {
302 std::slice::from_raw_parts(
303 self.data.as_ptr() as *const u8,
304 self.data.len() * std::mem::size_of::<BlockF8Q8>(),
305 )
306 };
307 buffer.extend(block_bytes);
308
309 if let Some(ref b) = bias {
311 serialize_tensor(&mut buffer, b)?;
312 }
313
314 Ok(Cow::from(buffer))
315 }
316
317 fn deserialize(
318 data: Cow<[u8]>,
319 device: &Device,
320 _comm: &Arc<crate::Comm>,
321 guard: QuantizeOntoGuard,
322 ) -> Result<Arc<dyn QuantMethod>>
323 where
324 Self: Sized,
325 {
326 let mut buffer = Cursor::new(data.to_vec());
327
328 let version = buffer.read_u32::<LittleEndian>()?;
329 if let Err(e) = version_is_compatible(version) {
330 return Err(hanzo_ml::Error::wrap(e));
331 }
332
333 let isq_type = buffer.read_u8()? as usize;
334 if isq_type != QuantizedSerdeType::F8Q8 as usize {
335 hanzo_ml::bail!(
336 "ISQ type ({isq_type}) doesn't match expected type {}",
337 QuantizedSerdeType::F8Q8 as usize
338 );
339 }
340
341 let has_bias = buffer.read_u8()? != 0;
342
343 let num_blocks = buffer.read_u32::<LittleEndian>()? as usize;
344
345 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
347 let mut dims = Vec::with_capacity(n_dims);
348 for _ in 0..n_dims {
349 dims.push(buffer.read_u32::<LittleEndian>()? as usize);
350 }
351 let shape = Shape::from_dims(&dims);
352
353 let block_byte_count = num_blocks * std::mem::size_of::<BlockF8Q8>();
355 let mut raw_data = vec![0u8; block_byte_count];
356 std::io::Read::read_exact(&mut buffer, &mut raw_data)?;
357
358 let blocks: Vec<BlockF8Q8> = unsafe {
360 let mut blocks = Vec::with_capacity(num_blocks);
361 std::ptr::copy_nonoverlapping(
362 raw_data.as_ptr(),
363 blocks.as_mut_ptr() as *mut u8,
364 block_byte_count,
365 );
366 blocks.set_len(num_blocks);
367 blocks
368 };
369
370 let _acquired_load_guard = guard.acquire(device);
371
372 let bias = if has_bias {
373 Some(deserialize_tensor(&mut buffer, device)?)
374 } else {
375 None
376 };
377
378 Ok(Arc::new(F8Q8Linear {
379 data: blocks,
380 shape,
381 bias,
382 }))
383 }
384
385 fn deserialize_ext_bias(
386 data: Cow<[u8]>,
387 device: &Device,
388 guard: QuantizeOntoGuard,
389 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
390 where
391 Self: Sized,
392 {
393 let mut buffer = Cursor::new(data.to_vec());
394
395 let version = buffer.read_u32::<LittleEndian>()?;
396 if let Err(e) = version_is_compatible(version) {
397 return Err(hanzo_ml::Error::wrap(e));
398 }
399
400 let isq_type = buffer.read_u8()? as usize;
401 if isq_type != QuantizedSerdeType::F8Q8 as usize {
402 hanzo_ml::bail!(
403 "ISQ type ({isq_type}) doesn't match expected type {}",
404 QuantizedSerdeType::F8Q8 as usize
405 );
406 }
407
408 let has_bias = buffer.read_u8()? != 0;
409
410 let num_blocks = buffer.read_u32::<LittleEndian>()? as usize;
411
412 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
414 let mut dims = Vec::with_capacity(n_dims);
415 for _ in 0..n_dims {
416 dims.push(buffer.read_u32::<LittleEndian>()? as usize);
417 }
418 let shape = Shape::from_dims(&dims);
419
420 let block_byte_count = num_blocks * std::mem::size_of::<BlockF8Q8>();
422 let mut raw_data = vec![0u8; block_byte_count];
423 std::io::Read::read_exact(&mut buffer, &mut raw_data)?;
424
425 let blocks: Vec<BlockF8Q8> = unsafe {
426 let mut blocks = Vec::with_capacity(num_blocks);
427 std::ptr::copy_nonoverlapping(
428 raw_data.as_ptr(),
429 blocks.as_mut_ptr() as *mut u8,
430 block_byte_count,
431 );
432 blocks.set_len(num_blocks);
433 blocks
434 };
435
436 let _acquired_load_guard = guard.acquire(device);
437
438 let bias = if has_bias {
439 Some(deserialize_tensor(&mut buffer, device)?)
440 } else {
441 None
442 };
443
444 Ok((
445 Arc::new(F8Q8Linear {
446 data: blocks,
447 shape,
448 bias: None,
449 }),
450 bias,
451 ))
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
460 fn test_f8q8_roundtrip() {
461 let data: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 128.0).collect();
462 let weight = Tensor::from_slice(&data, (16, 16), &Device::Cpu).unwrap();
463 let linear = F8Q8Linear::from_weight(&weight, None).unwrap();
464 let dequant = linear.dequantize(DType::F32).unwrap();
465 let dequant_data: Vec<f32> = dequant.flatten_all().unwrap().to_vec1().unwrap();
466
467 let mut max_err = 0f32;
468 for (a, b) in data.iter().zip(dequant_data.iter()) {
469 max_err = max_err.max((a - b).abs());
470 }
471 assert!(
472 max_err < 0.1,
473 "F8Q8 roundtrip max error {max_err} exceeds threshold"
474 );
475 }
476
477 #[test]
478 fn test_f8q8_non_divisible_shape() {
479 let data: Vec<f32> = (0..10000).map(|i| (i as f32 - 5000.0) / 5000.0).collect();
480 let weight = Tensor::from_slice(&data, (100, 100), &Device::Cpu).unwrap();
481 let linear = F8Q8Linear::from_weight(&weight, None).unwrap();
482 let dequant = linear.dequantize(DType::F32).unwrap();
483 assert_eq!(dequant.dims(), &[100, 100]);
484
485 let dequant_data: Vec<f32> = dequant.flatten_all().unwrap().to_vec1().unwrap();
486 let mut max_err = 0f32;
487 for (a, b) in data.iter().zip(dequant_data.iter()) {
488 max_err = max_err.max((a - b).abs());
489 }
490 assert!(
491 max_err < 0.1,
492 "F8Q8 non-divisible shape roundtrip max error {max_err} exceeds threshold"
493 );
494 }
495
496 #[test]
497 fn test_f8q8_block_size() {
498 assert_eq!(std::mem::size_of::<BlockF8Q8>(), 33);
499 assert_eq!(std::mem::size_of::<BlockQ8_0>(), 34);
500 }
501}