Skip to main content

hanzo_quant/gguf/
mod.rs

1#[cfg(not(feature = "cuda"))]
2mod cpu;
3#[cfg(feature = "cuda")]
4pub(crate) mod cuda;
5#[cfg(feature = "cuda")]
6pub mod fast_mmq;
7#[cfg(feature = "cuda")]
8pub mod fast_mmvq;
9#[cfg(feature = "cuda")]
10mod ffi;
11
12use std::{
13    borrow::Cow,
14    io::{Cursor, Read},
15    sync::{atomic::AtomicUsize, Arc},
16};
17
18use byteorder::{LittleEndian, ReadBytesExt};
19use hanzo_ml::{
20    quantized::{ggml_file::qtensor_from_ggml, GgmlDType, QMatMul, QTensor},
21    DType, Device, Result, Tensor,
22};
23use hanzo_nn::Module;
24
25use crate::{
26    generate_isq, generate_isq_imatrix,
27    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
28    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
29};
30
31#[derive(Debug)]
32pub struct GgufMatMul {
33    pub(crate) w: QMatMul,
34    pub(crate) b: Option<Tensor>,
35}
36
37impl GgufMatMul {
38    fn add_bias(&self, x: Tensor) -> Result<Tensor> {
39        if let Some(ref b) = self.b {
40            x.broadcast_add(b)
41        } else {
42            Ok(x)
43        }
44    }
45
46    #[cfg(feature = "cuda")]
47    fn uses_fast_mmvq(&self) -> bool {
48        matches!(
49            &self.w,
50            QMatMul::QTensor(q) if q.device().is_cuda() && fast_mmvq::supports(q.dtype())
51        )
52    }
53
54    #[cfg(feature = "cuda")]
55    fn try_fast_forward(&self, a: &Tensor) -> Result<Option<Tensor>> {
56        if !self.uses_fast_mmvq() || !matches!(a.dtype(), DType::BF16 | DType::F16 | DType::F32) {
57            return Ok(None);
58        }
59
60        let flat_batch = a.dims()[..a.dims().len().saturating_sub(1)]
61            .iter()
62            .product::<usize>();
63
64        let QMatMul::QTensor(q) = &self.w else {
65            unreachable!("uses_fast_mmvq() requires QTensor weights")
66        };
67
68        // Batch 1-8: use MMVQ (decode kernel)
69        if (1..=fast_mmvq::MMVQ_MAX_BATCH).contains(&flat_batch) {
70            return Ok(Some(fast_mmvq::plain(q, a)?));
71        }
72
73        // Batch > 8: use MMQ (prompt kernel)
74        if flat_batch > fast_mmvq::MMVQ_MAX_BATCH {
75            return Ok(Some(fast_mmq::plain(q, a)?));
76        }
77
78        Ok(None)
79    }
80}
81
82impl QuantMethod for GgufMatMul {
83    fn new(method: QuantMethodConfig) -> Result<Self>
84    where
85        Self: Sized,
86    {
87        match method {
88            QuantMethodConfig::Gguf { q_weight, b } => Ok(Self {
89                w: QMatMul::from_arc(q_weight)?,
90                b,
91            }),
92            QuantMethodConfig::GptqAwq { .. }
93            | QuantMethodConfig::Unquantized(_)
94            | QuantMethodConfig::Hqq { .. }
95            | QuantMethodConfig::Dummy
96            | QuantMethodConfig::FP8 { .. }
97            | QuantMethodConfig::Bnb { .. }
98            | QuantMethodConfig::BlockwiseFP8 { .. }
99            | QuantMethodConfig::PerTensorFP8 { .. }
100            | QuantMethodConfig::Afq { .. }
101            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
102        }
103    }
104
105    fn dequantize_w(&self) -> Result<Tensor> {
106        self.w.dequantize_f16()?.to_dtype(DType::F32)
107    }
108
109    fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
110        #[cfg(feature = "cuda")]
111        {
112            if let Some(out) = self.try_fast_forward(a)? {
113                return self.add_bias(out);
114            }
115        }
116
117        // Fallback: Hanzo QMatMul requires F32
118        let original_dtype = a.dtype();
119        let a_f32 = if original_dtype == DType::F32 {
120            a.clone()
121        } else {
122            a.to_dtype(DType::F32)?
123        };
124        let x = self.w.forward(&a_f32)?;
125        let x = if original_dtype == DType::F32 {
126            x
127        } else {
128            x.to_dtype(original_dtype)?
129        };
130        self.add_bias(x)
131    }
132
133    /// Compute matmul of `self` and `a`. `self` should contain the weights.
134    ///
135    /// If `a` is (n_tokens, 1, cols), `self` weights are (n_experts, rows, cols),
136    /// then the indices are (n_tokens, n_experts_per_tok).
137    fn gather_forward_raw(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
138        // Use indexed_moe_forward for efficient indexed matmul
139        // Expected shapes:
140        // - x: (n_tokens, 1, hidden_dim) or (n_tokens, n_experts_per_tok, hidden_dim)
141        // - indices: (n_tokens, n_experts_per_tok)
142        // - weights (self): (n_experts, out_features, in_features)
143        #[cfg(feature = "cuda")]
144        let res = cuda::qmatmul_indexed_moe_forward(&self.w, x, indices)?;
145
146        // For CPU and Metal: use dequantize-then-matmul approach
147        #[cfg(not(feature = "cuda"))]
148        let res = cpu::cpu_indexed_moe_forward(&self.w, x, indices)?;
149
150        if let Some(ref b) = self.b {
151            res.broadcast_add(b)
152        } else {
153            Ok(res)
154        }
155    }
156
157    #[cfg(feature = "cuda")]
158    fn get_qtensor(&self) -> Option<&hanzo_ml::quantized::QTensor> {
159        match &self.w {
160            hanzo_ml::quantized::QMatMul::QTensor(qt) => Some(qt),
161            _ => None,
162        }
163    }
164
165    fn quantized_act_type(&self) -> Option<DType> {
166        #[cfg(feature = "cuda")]
167        {
168            if self.uses_fast_mmvq() {
169                return None;
170            }
171        }
172        Some(DType::F32)
173    }
174
175    fn has_bias(&self) -> bool {
176        self.b.is_some()
177    }
178
179    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
180        match self {
181            Self {
182                w: QMatMul::Tensor(w),
183                b,
184            } => Ok(Arc::new(Self {
185                w: QMatMul::Tensor((w + delta)?),
186                b: b.clone(),
187            })),
188            Self {
189                w: QMatMul::TensorF16(w),
190                b,
191            } => Ok(Arc::new(Self {
192                w: QMatMul::TensorF16((w + delta)?),
193                b: b.clone(),
194            })),
195            Self {
196                w: QMatMul::QTensor(w),
197                b,
198            } => {
199                let (w, dtype) = (w.dequantize(&w.device())?, w.dtype());
200                let w = QMatMul::QTensor(std::sync::Arc::new(
201                    hanzo_ml::quantized::QTensor::quantize(&(w + delta)?, dtype)?,
202                ));
203                Ok(Arc::new(Self { w, b: b.clone() }))
204            }
205            #[cfg(feature = "vulkan")]
206            Self {
207                w: QMatMul::VulkanQuant { qtensor, .. },
208                b,
209            } => {
210                let (wd, dtype) = (qtensor.dequantize(&qtensor.device())?, qtensor.dtype());
211                let w = QMatMul::from_qtensor(hanzo_ml::quantized::QTensor::quantize(
212                    &(wd + delta)?,
213                    dtype,
214                )?)?;
215                Ok(Arc::new(Self { w, b: b.clone() }))
216            }
217        }
218    }
219
220    fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
221        match &self.w {
222            QMatMul::QTensor(q) => (DType::F32, q.device()),
223            #[cfg(feature = "vulkan")]
224            QMatMul::VulkanQuant { qtensor, .. } => (DType::F32, qtensor.device()),
225            QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()),
226        }
227    }
228
229    fn apply_isq(
230        self: Arc<Self>,
231        dtype: Option<IsqType>,
232        device: Device,
233        n_quantized: &AtomicUsize,
234        imatrix_weight: Option<Vec<f32>>,
235        guard: QuantizeOntoGuard,
236    ) -> Result<Arc<dyn QuantMethod>> {
237        if let Some(dtype) = dtype {
238            // F8Q8 is not a GgmlDType, so intercept before try_into()
239            if dtype == IsqType::F8Q8 {
240                let t = match &self.w {
241                    QMatMul::QTensor(q) => q.dequantize(&q.device())?,
242                    #[cfg(feature = "vulkan")]
243                    QMatMul::VulkanQuant { qtensor, .. } => {
244                        qtensor.dequantize(&qtensor.device())?
245                    }
246                    QMatMul::TensorF16(t) | QMatMul::Tensor(t) => t.clone(),
247                };
248                let t = t.to_device(&device)?;
249                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
250                return Ok(Arc::new(crate::F8Q8Linear::from_weight(
251                    &t,
252                    self.b.clone(),
253                )?));
254            }
255            let t = match &self.w {
256                QMatMul::QTensor(q) => q.dequantize(&q.device())?,
257                #[cfg(feature = "vulkan")]
258                QMatMul::VulkanQuant { qtensor, .. } => qtensor.dequantize(&qtensor.device())?,
259                QMatMul::TensorF16(t) | QMatMul::Tensor(t) => t.clone(),
260            };
261            let dtype = dtype.try_into()?;
262            let res = if let Some(imatrix_weight) = imatrix_weight {
263                generate_isq_imatrix!(t, imatrix_weight, device, dtype, n_quantized, guard)
264            } else {
265                generate_isq!(t, device, dtype, n_quantized, guard)
266            };
267            Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
268                q_weight: res,
269                b: self.b.clone(),
270            })?))
271        } else {
272            let w = match &self.w {
273                QMatMul::QTensor(q) => QMatMul::QTensor(Arc::new(QTensor::quantize(
274                    &q.dequantize(&device)?,
275                    q.dtype(),
276                )?)),
277                #[cfg(feature = "vulkan")]
278                QMatMul::VulkanQuant { qtensor, .. } => QMatMul::from_qtensor(QTensor::quantize(
279                    &qtensor.dequantize(&device)?,
280                    qtensor.dtype(),
281                )?)?,
282                QMatMul::Tensor(t) => QMatMul::Tensor(t.to_device(&device)?),
283                QMatMul::TensorF16(t) => QMatMul::TensorF16(t.to_device(&device)?),
284            };
285            let b = if let Some(b) = &self.b {
286                Some(b.to_device(&device)?)
287            } else {
288                None
289            };
290            Ok(Arc::new(GgufMatMul { w, b }))
291        }
292    }
293}
294
295// Serialization structure:
296//
297// -----------------------
298// UQFF version, u32, little endian
299// -----------------------
300// ISQ type (0 for GGUF), u8, little endian
301// -----------------------
302// Tensor data length in bytes, u32, little endian
303// -----------------------
304// Whether bias data is included, u8 boolean
305// -----------------------
306// Quantized dtype, u32, little endian
307// -----------------------
308// Num shape dims, u32, little endian
309// -----------------------
310// ...
311// Array (in original order): quantized weight shape dims, u32, little endian
312// ...
313// -----------------------
314// ...
315// Array: quantized weight data, u8s
316// ...
317// -----------------------
318// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
319// -----------------------
320
321impl QuantizedSerde for GgufMatMul {
322    fn isq_serde_supported(&self) -> bool {
323        true
324    }
325    fn name(&self) -> &'static str {
326        "gguf"
327    }
328    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
329        self.serialize_with_bias(self.b.clone())
330    }
331    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
332        // VulkanQuant carries the same quantized QTensor; serialize it identically.
333        #[cfg(feature = "vulkan")]
334        let qw_opt = match &self.w {
335            QMatMul::QTensor(qw) => Some(qw),
336            QMatMul::VulkanQuant { qtensor, .. } => Some(qtensor),
337            _ => None,
338        };
339        #[cfg(not(feature = "vulkan"))]
340        let qw_opt = match &self.w {
341            QMatMul::QTensor(qw) => Some(qw),
342            _ => None,
343        };
344        let mut buffer = if let Some(qw) = qw_opt {
345            {
346                let w = qw.data()?.to_vec();
347                let w_shape = qw.shape().dims();
348                let dtype: u32 = match qw.dtype() {
349                    GgmlDType::F32 => 0,
350                    GgmlDType::F16 => 1,
351                    GgmlDType::Q4_0 => 2,
352                    GgmlDType::Q4_1 => 3,
353                    GgmlDType::Q5_0 => 6,
354                    GgmlDType::Q5_1 => 7,
355                    GgmlDType::Q8_0 => 8,
356                    GgmlDType::Q8_1 => 9,
357                    GgmlDType::Q2K => 10,
358                    GgmlDType::Q3K => 11,
359                    GgmlDType::Q4K => 12,
360                    GgmlDType::Q5K => 13,
361                    GgmlDType::Q6K => 14,
362                    GgmlDType::Q8K => 15,
363                    // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
364                    GgmlDType::BF16 => 30,
365                };
366
367                let mut buffer = Vec::new();
368
369                // Version is always first!
370                buffer.extend(&UQFF_VERSION.to_le_bytes());
371
372                // ISQ type for GGUF is 0
373                buffer.push(QuantizedSerdeType::Gguf as u8);
374
375                // Length
376                buffer.extend(&(w.len() as u32).to_le_bytes());
377
378                // Has bias
379                buffer.push(bias.is_some() as u8);
380
381                // Dtype (u32)
382                buffer.extend(&dtype.to_le_bytes());
383
384                // Shape
385                buffer.extend((w_shape.len() as u32).to_le_bytes());
386                for dim in w_shape {
387                    buffer.extend((*dim as u32).to_le_bytes());
388                }
389
390                // Quantized W Vec<u8> (just append it)
391                buffer.extend(&w);
392
393                buffer
394            }
395        } else {
396            hanzo_ml::bail!("Cannot serialize non-quantized")
397        };
398
399        if let Some(b) = bias.as_ref() {
400            serialize_tensor(&mut buffer, b)?;
401        }
402
403        Ok(Cow::from(buffer))
404    }
405
406    fn deserialize(
407        data: Cow<[u8]>,
408        device: &Device,
409        _comm: &Arc<crate::Comm>,
410        guard: QuantizeOntoGuard,
411    ) -> Result<Arc<dyn QuantMethod>> {
412        let mut buffer = Cursor::new(data);
413
414        let version = buffer.read_u32::<LittleEndian>()?;
415        if let Err(e) = version_is_compatible(version) {
416            return Err(hanzo_ml::Error::wrap(e));
417        }
418
419        let isq_type = buffer.read_u8()? as usize;
420        if isq_type != QuantizedSerdeType::Gguf as usize {
421            hanzo_ml::bail!(
422                "ISQ type ({isq_type}) doesn't match expected type {}",
423                QuantizedSerdeType::Gguf as usize
424            );
425        }
426
427        let data_len = buffer.read_u32::<LittleEndian>()? as usize;
428
429        let has_bias = buffer.read_u8()? != 0;
430
431        // TODO: keep this in sync with get_isq_type_from_uqff!
432        let dtype = buffer.read_u32::<LittleEndian>()?;
433        let dtype = match dtype {
434            0 => GgmlDType::F32,
435            1 => GgmlDType::F16,
436            2 => GgmlDType::Q4_0,
437            3 => GgmlDType::Q4_1,
438            6 => GgmlDType::Q5_0,
439            7 => GgmlDType::Q5_1,
440            8 => GgmlDType::Q8_0,
441            9 => GgmlDType::Q8_1,
442            10 => GgmlDType::Q2K,
443            11 => GgmlDType::Q3K,
444            12 => GgmlDType::Q4K,
445            13 => GgmlDType::Q5K,
446            14 => GgmlDType::Q6K,
447            15 => GgmlDType::Q8K,
448            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
449            30 => GgmlDType::BF16,
450            _ => hanzo_ml::bail!("unknown dtype for quantized weight tensor {dtype}"),
451        };
452
453        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
454
455        let mut dims = Vec::with_capacity(n_dims);
456        for _ in 0..n_dims {
457            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
458        }
459
460        let mut tensor_data = vec![0; data_len];
461        buffer.read_exact(&mut tensor_data)?;
462
463        let _acquired_load_guard = guard.acquire(device);
464        // If we have bias
465        let b = if has_bias {
466            Some(deserialize_tensor(&mut buffer, device)?)
467        } else {
468            None
469        };
470
471        let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
472        // `from_arc` keeps Q8_0 weights quantized in VRAM on Vulkan (native Q8 decode matvec);
473        // on every other backend it is identical to `QMatMul::QTensor` for a quantized weight.
474        Ok(Arc::new(Self {
475            w: QMatMul::from_arc(w.into())?,
476            b,
477        }))
478    }
479    fn deserialize_ext_bias(
480        data: Cow<[u8]>,
481        device: &Device,
482        guard: QuantizeOntoGuard,
483    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)> {
484        let mut buffer = Cursor::new(data);
485
486        let version = buffer.read_u32::<LittleEndian>()?;
487        if let Err(e) = version_is_compatible(version) {
488            return Err(hanzo_ml::Error::wrap(e));
489        }
490
491        let isq_type = buffer.read_u8()? as usize;
492        if isq_type != QuantizedSerdeType::Gguf as usize {
493            hanzo_ml::bail!(
494                "ISQ type ({isq_type}) doesn't match expected type {}",
495                QuantizedSerdeType::Gguf as usize
496            );
497        }
498
499        let data_len = buffer.read_u32::<LittleEndian>()? as usize;
500
501        let has_bias = buffer.read_u8()? != 0;
502
503        // TODO: keep this in sync with get_isq_type_from_uqff!
504        let dtype = buffer.read_u32::<LittleEndian>()?;
505        let dtype = match dtype {
506            0 => GgmlDType::F32,
507            1 => GgmlDType::F16,
508            2 => GgmlDType::Q4_0,
509            3 => GgmlDType::Q4_1,
510            6 => GgmlDType::Q5_0,
511            7 => GgmlDType::Q5_1,
512            8 => GgmlDType::Q8_0,
513            9 => GgmlDType::Q8_1,
514            10 => GgmlDType::Q2K,
515            11 => GgmlDType::Q3K,
516            12 => GgmlDType::Q4K,
517            13 => GgmlDType::Q5K,
518            14 => GgmlDType::Q6K,
519            15 => GgmlDType::Q8K,
520            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
521            30 => GgmlDType::BF16,
522            _ => hanzo_ml::bail!("unknown dtype for quantized weight tensor {dtype}"),
523        };
524
525        let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
526
527        let mut dims = Vec::with_capacity(n_dims);
528        for _ in 0..n_dims {
529            dims.push(buffer.read_u32::<LittleEndian>()? as usize)
530        }
531
532        let mut tensor_data = vec![0; data_len];
533        buffer.read_exact(&mut tensor_data)?;
534
535        let _acquired_load_guard = guard.acquire(device);
536        // If we have bias
537        let b = if has_bias {
538            Some(deserialize_tensor(&mut buffer, device)?)
539        } else {
540            None
541        };
542
543        let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
544        Ok((
545            Arc::new(Self {
546                w: QMatMul::from_arc(w.into())?,
547                b: None,
548            }),
549            b,
550        ))
551    }
552}
553
554impl GgufMatMul {
555    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
556        let mut buffer = Cursor::new(data);
557
558        let version = buffer.read_u32::<LittleEndian>()?;
559        if let Err(e) = version_is_compatible(version) {
560            return Err(hanzo_ml::Error::wrap(e));
561        }
562
563        let isq_type = buffer.read_u8()? as usize;
564        if isq_type != QuantizedSerdeType::Gguf as usize {
565            hanzo_ml::bail!(
566                "ISQ type ({isq_type}) doesn't match expected type {}",
567                QuantizedSerdeType::Gguf as usize
568            );
569        }
570
571        let _ = buffer.read_u32::<LittleEndian>()? as usize;
572
573        let _ = buffer.read_u8()? != 0;
574
575        let dtype = buffer.read_u32::<LittleEndian>()?;
576        let dtype = match dtype {
577            0 => GgmlDType::F32,
578            1 => GgmlDType::F16,
579            2 => GgmlDType::Q4_0,
580            3 => GgmlDType::Q4_1,
581            6 => GgmlDType::Q5_0,
582            7 => GgmlDType::Q5_1,
583            8 => GgmlDType::Q8_0,
584            9 => GgmlDType::Q8_1,
585            10 => GgmlDType::Q2K,
586            11 => GgmlDType::Q3K,
587            12 => GgmlDType::Q4K,
588            13 => GgmlDType::Q5K,
589            14 => GgmlDType::Q6K,
590            15 => GgmlDType::Q8K,
591            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
592            30 => GgmlDType::BF16,
593            _ => hanzo_ml::bail!("unknown dtype for quantized weight tensor {dtype}"),
594        };
595
596        IsqType::try_from(dtype)
597    }
598}