1#[cfg(not(feature = "cuda"))]
2mod cpu;
3#[cfg(feature = "cuda")]
4pub(crate) mod cuda;
5#[cfg(feature = "cuda")]
6pub mod fast_mmq;
7#[cfg(feature = "cuda")]
8pub mod fast_mmvq;
9#[cfg(feature = "cuda")]
10mod ffi;
11
12use std::{
13 borrow::Cow,
14 io::{Cursor, Read},
15 sync::{atomic::AtomicUsize, Arc},
16};
17
18use byteorder::{LittleEndian, ReadBytesExt};
19use hanzo_ml::{
20 quantized::{ggml_file::qtensor_from_ggml, GgmlDType, QMatMul, QTensor},
21 DType, Device, Result, Tensor,
22};
23use hanzo_nn::Module;
24
25use crate::{
26 generate_isq, generate_isq_imatrix,
27 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
28 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
29};
30
31#[derive(Debug)]
32pub struct GgufMatMul {
33 pub(crate) w: QMatMul,
34 pub(crate) b: Option<Tensor>,
35}
36
37impl GgufMatMul {
38 fn add_bias(&self, x: Tensor) -> Result<Tensor> {
39 if let Some(ref b) = self.b {
40 x.broadcast_add(b)
41 } else {
42 Ok(x)
43 }
44 }
45
46 #[cfg(feature = "cuda")]
47 fn uses_fast_mmvq(&self) -> bool {
48 matches!(
49 &self.w,
50 QMatMul::QTensor(q) if q.device().is_cuda() && fast_mmvq::supports(q.dtype())
51 )
52 }
53
54 #[cfg(feature = "cuda")]
55 fn try_fast_forward(&self, a: &Tensor) -> Result<Option<Tensor>> {
56 if !self.uses_fast_mmvq() || !matches!(a.dtype(), DType::BF16 | DType::F16 | DType::F32) {
57 return Ok(None);
58 }
59
60 let flat_batch = a.dims()[..a.dims().len().saturating_sub(1)]
61 .iter()
62 .product::<usize>();
63
64 let QMatMul::QTensor(q) = &self.w else {
65 unreachable!("uses_fast_mmvq() requires QTensor weights")
66 };
67
68 if (1..=fast_mmvq::MMVQ_MAX_BATCH).contains(&flat_batch) {
70 return Ok(Some(fast_mmvq::plain(q, a)?));
71 }
72
73 if flat_batch > fast_mmvq::MMVQ_MAX_BATCH {
75 return Ok(Some(fast_mmq::plain(q, a)?));
76 }
77
78 Ok(None)
79 }
80}
81
82impl QuantMethod for GgufMatMul {
83 fn new(method: QuantMethodConfig) -> Result<Self>
84 where
85 Self: Sized,
86 {
87 match method {
88 QuantMethodConfig::Gguf { q_weight, b } => Ok(Self {
89 w: QMatMul::from_arc(q_weight)?,
90 b,
91 }),
92 QuantMethodConfig::GptqAwq { .. }
93 | QuantMethodConfig::Unquantized(_)
94 | QuantMethodConfig::Hqq { .. }
95 | QuantMethodConfig::Dummy
96 | QuantMethodConfig::FP8 { .. }
97 | QuantMethodConfig::Bnb { .. }
98 | QuantMethodConfig::BlockwiseFP8 { .. }
99 | QuantMethodConfig::PerTensorFP8 { .. }
100 | QuantMethodConfig::Afq { .. }
101 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
102 }
103 }
104
105 fn dequantize_w(&self) -> Result<Tensor> {
106 self.w.dequantize_f16()?.to_dtype(DType::F32)
107 }
108
109 fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
110 #[cfg(feature = "cuda")]
111 {
112 if let Some(out) = self.try_fast_forward(a)? {
113 return self.add_bias(out);
114 }
115 }
116
117 let original_dtype = a.dtype();
119 let a_f32 = if original_dtype == DType::F32 {
120 a.clone()
121 } else {
122 a.to_dtype(DType::F32)?
123 };
124 let x = self.w.forward(&a_f32)?;
125 let x = if original_dtype == DType::F32 {
126 x
127 } else {
128 x.to_dtype(original_dtype)?
129 };
130 self.add_bias(x)
131 }
132
133 fn gather_forward_raw(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
138 #[cfg(feature = "cuda")]
144 let res = cuda::qmatmul_indexed_moe_forward(&self.w, x, indices)?;
145
146 #[cfg(not(feature = "cuda"))]
148 let res = cpu::cpu_indexed_moe_forward(&self.w, x, indices)?;
149
150 if let Some(ref b) = self.b {
151 res.broadcast_add(b)
152 } else {
153 Ok(res)
154 }
155 }
156
157 #[cfg(feature = "cuda")]
158 fn get_qtensor(&self) -> Option<&hanzo_ml::quantized::QTensor> {
159 match &self.w {
160 hanzo_ml::quantized::QMatMul::QTensor(qt) => Some(qt),
161 _ => None,
162 }
163 }
164
165 fn quantized_act_type(&self) -> Option<DType> {
166 #[cfg(feature = "cuda")]
167 {
168 if self.uses_fast_mmvq() {
169 return None;
170 }
171 }
172 Some(DType::F32)
173 }
174
175 fn has_bias(&self) -> bool {
176 self.b.is_some()
177 }
178
179 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
180 match self {
181 Self {
182 w: QMatMul::Tensor(w),
183 b,
184 } => Ok(Arc::new(Self {
185 w: QMatMul::Tensor((w + delta)?),
186 b: b.clone(),
187 })),
188 Self {
189 w: QMatMul::TensorF16(w),
190 b,
191 } => Ok(Arc::new(Self {
192 w: QMatMul::TensorF16((w + delta)?),
193 b: b.clone(),
194 })),
195 Self {
196 w: QMatMul::QTensor(w),
197 b,
198 } => {
199 let (w, dtype) = (w.dequantize(&w.device())?, w.dtype());
200 let w = QMatMul::QTensor(std::sync::Arc::new(
201 hanzo_ml::quantized::QTensor::quantize(&(w + delta)?, dtype)?,
202 ));
203 Ok(Arc::new(Self { w, b: b.clone() }))
204 }
205 #[cfg(feature = "vulkan")]
206 Self {
207 w: QMatMul::VulkanQuant { qtensor, .. },
208 b,
209 } => {
210 let (wd, dtype) = (qtensor.dequantize(&qtensor.device())?, qtensor.dtype());
211 let w = QMatMul::from_qtensor(hanzo_ml::quantized::QTensor::quantize(
212 &(wd + delta)?,
213 dtype,
214 )?)?;
215 Ok(Arc::new(Self { w, b: b.clone() }))
216 }
217 }
218 }
219
220 fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
221 match &self.w {
222 QMatMul::QTensor(q) => (DType::F32, q.device()),
223 #[cfg(feature = "vulkan")]
224 QMatMul::VulkanQuant { qtensor, .. } => (DType::F32, qtensor.device()),
225 QMatMul::Tensor(t) | QMatMul::TensorF16(t) => (t.dtype(), t.device().clone()),
226 }
227 }
228
229 fn apply_isq(
230 self: Arc<Self>,
231 dtype: Option<IsqType>,
232 device: Device,
233 n_quantized: &AtomicUsize,
234 imatrix_weight: Option<Vec<f32>>,
235 guard: QuantizeOntoGuard,
236 ) -> Result<Arc<dyn QuantMethod>> {
237 if let Some(dtype) = dtype {
238 if dtype == IsqType::F8Q8 {
240 let t = match &self.w {
241 QMatMul::QTensor(q) => q.dequantize(&q.device())?,
242 #[cfg(feature = "vulkan")]
243 QMatMul::VulkanQuant { qtensor, .. } => {
244 qtensor.dequantize(&qtensor.device())?
245 }
246 QMatMul::TensorF16(t) | QMatMul::Tensor(t) => t.clone(),
247 };
248 let t = t.to_device(&device)?;
249 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
250 return Ok(Arc::new(crate::F8Q8Linear::from_weight(
251 &t,
252 self.b.clone(),
253 )?));
254 }
255 let t = match &self.w {
256 QMatMul::QTensor(q) => q.dequantize(&q.device())?,
257 #[cfg(feature = "vulkan")]
258 QMatMul::VulkanQuant { qtensor, .. } => qtensor.dequantize(&qtensor.device())?,
259 QMatMul::TensorF16(t) | QMatMul::Tensor(t) => t.clone(),
260 };
261 let dtype = dtype.try_into()?;
262 let res = if let Some(imatrix_weight) = imatrix_weight {
263 generate_isq_imatrix!(t, imatrix_weight, device, dtype, n_quantized, guard)
264 } else {
265 generate_isq!(t, device, dtype, n_quantized, guard)
266 };
267 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
268 q_weight: res,
269 b: self.b.clone(),
270 })?))
271 } else {
272 let w = match &self.w {
273 QMatMul::QTensor(q) => QMatMul::QTensor(Arc::new(QTensor::quantize(
274 &q.dequantize(&device)?,
275 q.dtype(),
276 )?)),
277 #[cfg(feature = "vulkan")]
278 QMatMul::VulkanQuant { qtensor, .. } => QMatMul::from_qtensor(QTensor::quantize(
279 &qtensor.dequantize(&device)?,
280 qtensor.dtype(),
281 )?)?,
282 QMatMul::Tensor(t) => QMatMul::Tensor(t.to_device(&device)?),
283 QMatMul::TensorF16(t) => QMatMul::TensorF16(t.to_device(&device)?),
284 };
285 let b = if let Some(b) = &self.b {
286 Some(b.to_device(&device)?)
287 } else {
288 None
289 };
290 Ok(Arc::new(GgufMatMul { w, b }))
291 }
292 }
293}
294
295impl QuantizedSerde for GgufMatMul {
322 fn isq_serde_supported(&self) -> bool {
323 true
324 }
325 fn name(&self) -> &'static str {
326 "gguf"
327 }
328 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
329 self.serialize_with_bias(self.b.clone())
330 }
331 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
332 #[cfg(feature = "vulkan")]
334 let qw_opt = match &self.w {
335 QMatMul::QTensor(qw) => Some(qw),
336 QMatMul::VulkanQuant { qtensor, .. } => Some(qtensor),
337 _ => None,
338 };
339 #[cfg(not(feature = "vulkan"))]
340 let qw_opt = match &self.w {
341 QMatMul::QTensor(qw) => Some(qw),
342 _ => None,
343 };
344 let mut buffer = if let Some(qw) = qw_opt {
345 {
346 let w = qw.data()?.to_vec();
347 let w_shape = qw.shape().dims();
348 let dtype: u32 = match qw.dtype() {
349 GgmlDType::F32 => 0,
350 GgmlDType::F16 => 1,
351 GgmlDType::Q4_0 => 2,
352 GgmlDType::Q4_1 => 3,
353 GgmlDType::Q5_0 => 6,
354 GgmlDType::Q5_1 => 7,
355 GgmlDType::Q8_0 => 8,
356 GgmlDType::Q8_1 => 9,
357 GgmlDType::Q2K => 10,
358 GgmlDType::Q3K => 11,
359 GgmlDType::Q4K => 12,
360 GgmlDType::Q5K => 13,
361 GgmlDType::Q6K => 14,
362 GgmlDType::Q8K => 15,
363 GgmlDType::BF16 => 30,
365 };
366
367 let mut buffer = Vec::new();
368
369 buffer.extend(&UQFF_VERSION.to_le_bytes());
371
372 buffer.push(QuantizedSerdeType::Gguf as u8);
374
375 buffer.extend(&(w.len() as u32).to_le_bytes());
377
378 buffer.push(bias.is_some() as u8);
380
381 buffer.extend(&dtype.to_le_bytes());
383
384 buffer.extend((w_shape.len() as u32).to_le_bytes());
386 for dim in w_shape {
387 buffer.extend((*dim as u32).to_le_bytes());
388 }
389
390 buffer.extend(&w);
392
393 buffer
394 }
395 } else {
396 hanzo_ml::bail!("Cannot serialize non-quantized")
397 };
398
399 if let Some(b) = bias.as_ref() {
400 serialize_tensor(&mut buffer, b)?;
401 }
402
403 Ok(Cow::from(buffer))
404 }
405
406 fn deserialize(
407 data: Cow<[u8]>,
408 device: &Device,
409 _comm: &Arc<crate::Comm>,
410 guard: QuantizeOntoGuard,
411 ) -> Result<Arc<dyn QuantMethod>> {
412 let mut buffer = Cursor::new(data);
413
414 let version = buffer.read_u32::<LittleEndian>()?;
415 if let Err(e) = version_is_compatible(version) {
416 return Err(hanzo_ml::Error::wrap(e));
417 }
418
419 let isq_type = buffer.read_u8()? as usize;
420 if isq_type != QuantizedSerdeType::Gguf as usize {
421 hanzo_ml::bail!(
422 "ISQ type ({isq_type}) doesn't match expected type {}",
423 QuantizedSerdeType::Gguf as usize
424 );
425 }
426
427 let data_len = buffer.read_u32::<LittleEndian>()? as usize;
428
429 let has_bias = buffer.read_u8()? != 0;
430
431 let dtype = buffer.read_u32::<LittleEndian>()?;
433 let dtype = match dtype {
434 0 => GgmlDType::F32,
435 1 => GgmlDType::F16,
436 2 => GgmlDType::Q4_0,
437 3 => GgmlDType::Q4_1,
438 6 => GgmlDType::Q5_0,
439 7 => GgmlDType::Q5_1,
440 8 => GgmlDType::Q8_0,
441 9 => GgmlDType::Q8_1,
442 10 => GgmlDType::Q2K,
443 11 => GgmlDType::Q3K,
444 12 => GgmlDType::Q4K,
445 13 => GgmlDType::Q5K,
446 14 => GgmlDType::Q6K,
447 15 => GgmlDType::Q8K,
448 30 => GgmlDType::BF16,
450 _ => hanzo_ml::bail!("unknown dtype for quantized weight tensor {dtype}"),
451 };
452
453 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
454
455 let mut dims = Vec::with_capacity(n_dims);
456 for _ in 0..n_dims {
457 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
458 }
459
460 let mut tensor_data = vec![0; data_len];
461 buffer.read_exact(&mut tensor_data)?;
462
463 let _acquired_load_guard = guard.acquire(device);
464 let b = if has_bias {
466 Some(deserialize_tensor(&mut buffer, device)?)
467 } else {
468 None
469 };
470
471 let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
472 Ok(Arc::new(Self {
475 w: QMatMul::from_arc(w.into())?,
476 b,
477 }))
478 }
479 fn deserialize_ext_bias(
480 data: Cow<[u8]>,
481 device: &Device,
482 guard: QuantizeOntoGuard,
483 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)> {
484 let mut buffer = Cursor::new(data);
485
486 let version = buffer.read_u32::<LittleEndian>()?;
487 if let Err(e) = version_is_compatible(version) {
488 return Err(hanzo_ml::Error::wrap(e));
489 }
490
491 let isq_type = buffer.read_u8()? as usize;
492 if isq_type != QuantizedSerdeType::Gguf as usize {
493 hanzo_ml::bail!(
494 "ISQ type ({isq_type}) doesn't match expected type {}",
495 QuantizedSerdeType::Gguf as usize
496 );
497 }
498
499 let data_len = buffer.read_u32::<LittleEndian>()? as usize;
500
501 let has_bias = buffer.read_u8()? != 0;
502
503 let dtype = buffer.read_u32::<LittleEndian>()?;
505 let dtype = match dtype {
506 0 => GgmlDType::F32,
507 1 => GgmlDType::F16,
508 2 => GgmlDType::Q4_0,
509 3 => GgmlDType::Q4_1,
510 6 => GgmlDType::Q5_0,
511 7 => GgmlDType::Q5_1,
512 8 => GgmlDType::Q8_0,
513 9 => GgmlDType::Q8_1,
514 10 => GgmlDType::Q2K,
515 11 => GgmlDType::Q3K,
516 12 => GgmlDType::Q4K,
517 13 => GgmlDType::Q5K,
518 14 => GgmlDType::Q6K,
519 15 => GgmlDType::Q8K,
520 30 => GgmlDType::BF16,
522 _ => hanzo_ml::bail!("unknown dtype for quantized weight tensor {dtype}"),
523 };
524
525 let n_dims = buffer.read_u32::<LittleEndian>()? as usize;
526
527 let mut dims = Vec::with_capacity(n_dims);
528 for _ in 0..n_dims {
529 dims.push(buffer.read_u32::<LittleEndian>()? as usize)
530 }
531
532 let mut tensor_data = vec![0; data_len];
533 buffer.read_exact(&mut tensor_data)?;
534
535 let _acquired_load_guard = guard.acquire(device);
536 let b = if has_bias {
538 Some(deserialize_tensor(&mut buffer, device)?)
539 } else {
540 None
541 };
542
543 let w = qtensor_from_ggml(dtype, &tensor_data, dims, device)?;
544 Ok((
545 Arc::new(Self {
546 w: QMatMul::from_arc(w.into())?,
547 b: None,
548 }),
549 b,
550 ))
551 }
552}
553
554impl GgufMatMul {
555 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
556 let mut buffer = Cursor::new(data);
557
558 let version = buffer.read_u32::<LittleEndian>()?;
559 if let Err(e) = version_is_compatible(version) {
560 return Err(hanzo_ml::Error::wrap(e));
561 }
562
563 let isq_type = buffer.read_u8()? as usize;
564 if isq_type != QuantizedSerdeType::Gguf as usize {
565 hanzo_ml::bail!(
566 "ISQ type ({isq_type}) doesn't match expected type {}",
567 QuantizedSerdeType::Gguf as usize
568 );
569 }
570
571 let _ = buffer.read_u32::<LittleEndian>()? as usize;
572
573 let _ = buffer.read_u8()? != 0;
574
575 let dtype = buffer.read_u32::<LittleEndian>()?;
576 let dtype = match dtype {
577 0 => GgmlDType::F32,
578 1 => GgmlDType::F16,
579 2 => GgmlDType::Q4_0,
580 3 => GgmlDType::Q4_1,
581 6 => GgmlDType::Q5_0,
582 7 => GgmlDType::Q5_1,
583 8 => GgmlDType::Q8_0,
584 9 => GgmlDType::Q8_1,
585 10 => GgmlDType::Q2K,
586 11 => GgmlDType::Q3K,
587 12 => GgmlDType::Q4K,
588 13 => GgmlDType::Q5K,
589 14 => GgmlDType::Q6K,
590 15 => GgmlDType::Q8K,
591 30 => GgmlDType::BF16,
593 _ => hanzo_ml::bail!("unknown dtype for quantized weight tensor {dtype}"),
594 };
595
596 IsqType::try_from(dtype)
597 }
598}