Skip to main content

hanzo_ml/quantized/
mod.rs

1use crate::{
2    backend::BackendStorage, CpuStorage, DType, Device, Result, Shape, Storage, Tensor, D,
3};
4use k_quants::*;
5use std::borrow::Cow;
6
7#[cfg(target_feature = "avx2")]
8pub mod avx;
9mod dummy_cuda;
10mod dummy_metal;
11pub mod ggml_file;
12pub mod gguf_file;
13pub mod imatrix_file;
14pub mod k_quants;
15#[cfg(feature = "metal")]
16pub mod metal;
17#[cfg(not(target_arch = "wasm32"))]
18pub mod tokenizer;
19#[cfg(not(feature = "metal"))]
20mod metal {
21    pub use super::dummy_metal::*;
22}
23#[cfg(feature = "cuda")]
24pub mod cuda;
25#[cfg(feature = "cuda")]
26pub mod fast_mmq;
27#[cfg(feature = "cuda")]
28pub mod fast_mmvq;
29#[cfg(not(feature = "cuda"))]
30mod cuda {
31    pub use super::dummy_cuda::*;
32}
33
34#[cfg(target_feature = "neon")]
35pub mod neon;
36#[cfg(target_feature = "simd128")]
37pub mod simd128;
38pub mod utils;
39use half::{bf16, f16};
40
41pub use k_quants::GgmlType;
42
43// Borrows `data` (does not consume it) so the returned slice stays valid for the caller's lifetime.
44// Taking `Cow` by value here was a use-after-free: the Cow dropped at return, dangling the slice for
45// Cow::Owned inputs (segfault on large quantized tensors; mmap'd Cow::Borrowed happened to survive).
46fn as_t_slice<T>(data: &[u8]) -> &[T] {
47    let size = std::mem::size_of::<T>();
48    assert_eq!(
49        data.len() % size,
50        0,
51        "Data length must be a multiple of T's size"
52    );
53    let ptr = data.as_ptr();
54    assert_eq!(
55        (ptr as usize) % std::mem::align_of::<T>(),
56        0,
57        "Data pointer must be aligned to T's alignment"
58    );
59    unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) }
60}
61
62pub struct QTensor {
63    storage: QStorage,
64    shape: Shape,
65}
66
67impl Device {
68    fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
69        match self {
70            Device::Cpu => {
71                let storage = dtype.cpu_zeros(elem_count);
72                Ok(QStorage::Cpu(storage))
73            }
74            Device::Metal(metal) => {
75                let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
76                Ok(QStorage::Metal(storage))
77            }
78            Device::Cuda(cuda) => {
79                let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
80                Ok(QStorage::Cuda(storage))
81            }
82            #[cfg(feature = "rocm")]
83            Device::Rocm(_) => crate::bail!("quantized tensors on rocm are not supported yet"),
84            #[cfg(feature = "vulkan")]
85            Device::Vulkan(d) => {
86                // Keep quantized blocks in a CPU box + the device (same as from_data); QMatMul's
87                // VulkanQuant path uploads/dequantizes them to the GPU on use.
88                let storage = dtype.cpu_zeros(elem_count);
89                Ok(QStorage::Vulkan(storage, d.clone()))
90            }
91        }
92    }
93}
94
95pub enum QStorage {
96    Cpu(Box<dyn QuantizedType>),
97    Metal(metal::QMetalStorage),
98    Cuda(cuda::QCudaStorage),
99    // Vulkan keeps the quantized blocks in a CPU box and dequantizes to an f32 Vulkan tensor on
100    // demand (QMatMul forces dequantize for Vulkan). Lets GGUF-quantized models run on the GPU;
101    // a direct quantized matmul (the Q8 kernel) is the bandwidth-win follow-up.
102    #[cfg(feature = "vulkan")]
103    Vulkan(Box<dyn QuantizedType>, crate::VulkanDevice),
104}
105
106impl QStorage {
107    pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result<Self> {
108        match device {
109            Device::Cpu => Ok(Self::Cpu(dtype.from_data(data))),
110            Device::Metal(d) => match dtype {
111                GgmlDType::F32 => metal::load_quantized(d, as_t_slice::<f32>(&data)),
112                GgmlDType::F16 => metal::load_quantized(d, as_t_slice::<f16>(&data)),
113                GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::<BlockQ4_0>(&data)),
114                GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::<BlockQ4_1>(&data)),
115                GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::<BlockQ5_0>(&data)),
116                GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::<BlockQ5_1>(&data)),
117                GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::<BlockQ8_0>(&data)),
118                GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::<BlockQ8_1>(&data)),
119                GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::<BlockQ2K>(&data)),
120                GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::<BlockQ3K>(&data)),
121                GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::<BlockQ4K>(&data)),
122                GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::<BlockQ5K>(&data)),
123                GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::<BlockQ6K>(&data)),
124                GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::<BlockQ8K>(&data)),
125                GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::<bf16>(&data)),
126            },
127            Device::Cuda(d) => match dtype {
128                GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::<f32>(&data)),
129                GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::<f16>(&data)),
130                GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::<BlockQ4_0>(&data)),
131                GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::<BlockQ4_1>(&data)),
132                GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::<BlockQ5_0>(&data)),
133                GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::<BlockQ5_1>(&data)),
134                GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::<BlockQ8_0>(&data)),
135                GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::<BlockQ8_1>(&data)),
136                GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::<BlockQ2K>(&data)),
137                GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::<BlockQ3K>(&data)),
138                GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::<BlockQ4K>(&data)),
139                GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::<BlockQ5K>(&data)),
140                GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::<BlockQ6K>(&data)),
141                GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::<BlockQ8K>(&data)),
142                GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::<bf16>(&data)),
143            },
144            #[cfg(feature = "rocm")]
145            Device::Rocm(_) => crate::bail!("quantized tensors on rocm are not supported yet"),
146            #[cfg(feature = "vulkan")]
147            Device::Vulkan(d) => Ok(Self::Vulkan(dtype.from_data(data), d.clone())),
148        }
149    }
150
151    fn block_size(&self) -> usize {
152        match self {
153            QStorage::Cpu(storage) => storage.block_size(),
154            QStorage::Metal(storage) => storage.dtype().block_size(),
155            QStorage::Cuda(storage) => storage.dtype().block_size(),
156            #[cfg(feature = "vulkan")]
157            QStorage::Vulkan(storage, _) => storage.block_size(),
158        }
159    }
160
161    fn dtype(&self) -> GgmlDType {
162        match self {
163            QStorage::Cpu(storage) => storage.dtype(),
164            QStorage::Metal(storage) => storage.dtype(),
165            QStorage::Cuda(storage) => storage.dtype(),
166            #[cfg(feature = "vulkan")]
167            QStorage::Vulkan(storage, _) => storage.dtype(),
168        }
169    }
170
171    fn device(&self) -> Device {
172        match self {
173            QStorage::Cpu(_storage) => Device::Cpu,
174            QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
175            QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
176            #[cfg(feature = "vulkan")]
177            QStorage::Vulkan(_storage, device) => Device::Vulkan(device.clone()),
178        }
179    }
180
181    fn size_in_bytes(&self) -> usize {
182        match self {
183            QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
184            QStorage::Metal(storage) => storage.storage_size_in_bytes(),
185            QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
186            #[cfg(feature = "vulkan")]
187            QStorage::Vulkan(storage, _) => storage.storage_size_in_bytes(),
188        }
189    }
190
191    fn quantize(&mut self, src: &Storage) -> Result<()> {
192        match (self, src) {
193            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
194                storage.from_float(src.as_slice::<f32>()?);
195            }
196            (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
197            (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
198            _ => crate::bail!("Invalid quantize storage locations do not match"),
199        }
200        Ok(())
201    }
202
203    fn quantize_imatrix(
204        &mut self,
205        src: &Storage,
206        imatrix_weights: &[f32],
207        n_per_row: usize,
208    ) -> Result<()> {
209        match (self, src) {
210            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
211                storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
212            }
213            (QStorage::Metal(storage), Storage::Metal(src)) => {
214                storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
215            }
216            (QStorage::Cuda(storage), Storage::Cuda(src)) => {
217                storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
218            }
219            _ => crate::bail!("Invalid quantize storage locations do not match"),
220        }
221        Ok(())
222    }
223
224    fn quantize_onto(&mut self, src: &Storage) -> Result<()> {
225        match (self, src) {
226            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
227                storage.from_float(src.as_slice::<f32>()?);
228            }
229            (QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
230            (QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
231            _ => crate::bail!("Invalid quantize source storage locations: not on cpu"),
232        }
233        Ok(())
234    }
235
236    fn quantize_imatrix_onto(
237        &mut self,
238        src: &Storage,
239        imatrix_weights: &[f32],
240        n_per_row: usize,
241    ) -> Result<()> {
242        match (self, src) {
243            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
244                storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
245            }
246            (QStorage::Metal(storage), Storage::Cpu(src)) => {
247                storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
248            }
249            (QStorage::Cuda(storage), Storage::Cpu(src)) => {
250                storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
251            }
252            _ => crate::bail!("Invalid quantize storage locations do not match"),
253        }
254        Ok(())
255    }
256
257    fn dequantize(&self, elem_count: usize) -> Result<Storage> {
258        match self {
259            QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
260            QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
261            QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
262            #[cfg(feature = "vulkan")]
263            QStorage::Vulkan(storage, device) => {
264                // Dequantize on the CPU, then upload the f32 weights to the GPU.
265                let cpu = storage.dequantize(elem_count)?;
266                Ok(Storage::Vulkan(device.upload_f32(cpu.as_slice::<f32>()?)?))
267            }
268        }
269    }
270
271    fn data(&self) -> Result<Cow<'_, [u8]>> {
272        match self {
273            QStorage::Cpu(storage) => {
274                let data_ptr = storage.as_ptr();
275                let size_in_bytes = storage.storage_size_in_bytes();
276                let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
277                Ok(Cow::from(data))
278            }
279            QStorage::Cuda(storage) => Ok(Cow::from(storage.data()?)),
280            QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)),
281            #[cfg(feature = "vulkan")]
282            QStorage::Vulkan(storage, _) => {
283                let data_ptr = storage.as_ptr();
284                let size_in_bytes = storage.storage_size_in_bytes();
285                let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
286                Ok(Cow::from(data))
287            }
288        }
289    }
290
291    pub fn device_ptr(&self) -> Result<*const u8> {
292        match self {
293            QStorage::Cuda(storage) => storage.device_ptr(),
294            #[cfg(feature = "vulkan")]
295            QStorage::Vulkan(..) => crate::bail!("not implemented"),
296            QStorage::Metal(_) | QStorage::Cpu(_) => {
297                crate::bail!("not implemented");
298            }
299        }
300    }
301
302    #[cfg(feature = "cuda")]
303    pub fn device_ptr_with_guard<'a>(
304        &'a self,
305        stream: &'a crate::cuda_backend::cudarc::driver::CudaStream,
306    ) -> Result<(
307        *const u8,
308        crate::cuda_backend::cudarc::driver::SyncOnDrop<'a>,
309    )> {
310        match self {
311            QStorage::Cuda(storage) => storage.device_ptr_with_guard(stream),
312            QStorage::Metal(_) | QStorage::Cpu(_) => {
313                crate::bail!("not implemented");
314            }
315        }
316    }
317}
318
319#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
320pub enum GgmlDType {
321    F32,
322    F16,
323    BF16,
324    Q4_0,
325    Q4_1,
326    Q5_0,
327    Q5_1,
328    Q8_0,
329    Q8_1,
330    Q2K,
331    Q3K,
332    Q4K,
333    Q5K,
334    Q6K,
335    Q8K,
336}
337
338impl GgmlDType {
339    pub(crate) fn from_u32(u: u32) -> Result<Self> {
340        let dtype = match u {
341            0 => Self::F32,
342            1 => Self::F16,
343            2 => Self::Q4_0,
344            3 => Self::Q4_1,
345            6 => Self::Q5_0,
346            7 => Self::Q5_1,
347            8 => Self::Q8_0,
348            9 => Self::Q8_1,
349            10 => Self::Q2K,
350            11 => Self::Q3K,
351            12 => Self::Q4K,
352            13 => Self::Q5K,
353            14 => Self::Q6K,
354            15 => Self::Q8K,
355            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
356            30 => Self::BF16,
357            _ => crate::bail!("unknown dtype for tensor {u}"),
358        };
359        Ok(dtype)
360    }
361
362    pub(crate) fn to_u32(self) -> u32 {
363        match self {
364            Self::F32 => 0,
365            Self::F16 => 1,
366            Self::Q4_0 => 2,
367            Self::Q4_1 => 3,
368            Self::Q5_0 => 6,
369            Self::Q5_1 => 7,
370            Self::Q8_0 => 8,
371            Self::Q8_1 => 9,
372            Self::Q2K => 10,
373            Self::Q3K => 11,
374            Self::Q4K => 12,
375            Self::Q5K => 13,
376            Self::Q6K => 14,
377            Self::Q8K => 15,
378            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
379            Self::BF16 => 30,
380        }
381    }
382
383    /// The block dtype
384    pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
385        match self {
386            Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
387            Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
388            Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
389            Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
390            Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
391            Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
392            Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
393            Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
394            Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
395            Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
396            Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
397            Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
398            Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
399            Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
400            Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
401        }
402    }
403
404    pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box<dyn QuantizedType> {
405        match self {
406            Self::F32 => Box::new(as_t_slice::<f32>(&data).to_vec()),
407            Self::F16 => Box::new(as_t_slice::<f16>(&data).to_vec()),
408            Self::Q4_0 => Box::new(as_t_slice::<BlockQ4_0>(&data).to_vec()),
409            Self::Q4_1 => Box::new(as_t_slice::<BlockQ4_1>(&data).to_vec()),
410            Self::Q5_0 => Box::new(as_t_slice::<BlockQ5_0>(&data).to_vec()),
411            Self::Q5_1 => Box::new(as_t_slice::<BlockQ5_1>(&data).to_vec()),
412            Self::Q8_0 => Box::new(as_t_slice::<BlockQ8_0>(&data).to_vec()),
413            Self::Q8_1 => Box::new(as_t_slice::<BlockQ8_1>(&data).to_vec()),
414            Self::Q2K => Box::new(as_t_slice::<BlockQ2K>(&data).to_vec()),
415            Self::Q3K => Box::new(as_t_slice::<BlockQ3K>(&data).to_vec()),
416            Self::Q4K => Box::new(as_t_slice::<BlockQ4K>(&data).to_vec()),
417            Self::Q5K => Box::new(as_t_slice::<BlockQ5K>(&data).to_vec()),
418            Self::Q6K => Box::new(as_t_slice::<BlockQ6K>(&data).to_vec()),
419            Self::Q8K => Box::new(as_t_slice::<BlockQ8K>(&data).to_vec()),
420            Self::BF16 => Box::new(as_t_slice::<bf16>(&data).to_vec()),
421        }
422    }
423
424    /// The type size for blocks in bytes.
425    pub fn type_size(&self) -> usize {
426        use k_quants::*;
427        match self {
428            Self::F32 => 4,
429            Self::F16 | Self::BF16 => 2,
430            Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
431            Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
432            Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
433            Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
434            // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
435            Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
436            Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
437            Self::Q2K => std::mem::size_of::<BlockQ2K>(),
438            Self::Q3K => std::mem::size_of::<BlockQ3K>(),
439            Self::Q4K => std::mem::size_of::<BlockQ4K>(),
440            Self::Q5K => std::mem::size_of::<BlockQ5K>(),
441            Self::Q6K => std::mem::size_of::<BlockQ6K>(),
442            Self::Q8K => std::mem::size_of::<BlockQ8K>(),
443        }
444    }
445
446    /// The block size, i.e. the number of elements stored in each block.
447    pub fn block_size(&self) -> usize {
448        match self {
449            Self::F32 => 1,
450            Self::F16 | Self::BF16 => 1,
451            Self::Q4_0 => k_quants::QK4_0,
452            Self::Q4_1 => k_quants::QK4_1,
453            Self::Q5_0 => k_quants::QK5_0,
454            Self::Q5_1 => k_quants::QK5_1,
455            Self::Q8_0 => k_quants::QK8_0,
456            Self::Q8_1 => k_quants::QK8_1,
457            Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
458        }
459    }
460}
461
462// A version of GgmlType without `vec_dot` so that it can be dyn boxed.
463pub trait QuantizedType: Send + Sync {
464    fn dtype(&self) -> GgmlDType;
465    fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
466    fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()>;
467    fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
468    fn storage_size_in_bytes(&self) -> usize;
469    fn as_ptr(&self) -> *const u8;
470    fn block_size(&self) -> usize;
471    #[allow(clippy::wrong_self_convention)]
472    fn from_float(&mut self, xs: &[f32]);
473    #[allow(clippy::wrong_self_convention)]
474    fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize);
475    fn size(&self) -> usize;
476}
477
478impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
479    fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
480        k_quants::matmul(mkn, lhs, self.as_slice(), dst)
481    }
482    fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()> {
483        k_quants::matmul_f16(mkn, lhs, self.as_slice(), dst)
484    }
485
486    fn size(&self) -> usize {
487        self.len() * core::mem::size_of::<T>()
488    }
489
490    fn from_float(&mut self, xs: &[f32]) {
491        T::from_float(xs, self)
492    }
493
494    fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize) {
495        T::from_float_imatrix(xs, self, imatrix_weights, n_per_row)
496    }
497
498    fn dtype(&self) -> GgmlDType {
499        T::DTYPE
500    }
501
502    fn block_size(&self) -> usize {
503        T::BLCK_SIZE
504    }
505
506    fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
507        let mut ys = vec![0.0f32; elem_count];
508        T::to_float(self.as_slice(), &mut ys);
509        Ok(CpuStorage::F32(ys))
510    }
511
512    fn storage_size_in_bytes(&self) -> usize {
513        self.len() * std::mem::size_of::<T>()
514    }
515
516    fn as_ptr(&self) -> *const u8 {
517        self.as_ptr() as *const u8
518    }
519}
520
521impl std::fmt::Debug for QTensor {
522    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
523        write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
524    }
525}
526
527fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
528    let dims = shape.dims();
529    if dims.is_empty() {
530        crate::bail!("scalar tensor cannot be quantized {shape:?}")
531    }
532    if !dims[dims.len() - 1].is_multiple_of(block_size) {
533        crate::bail!(
534            "quantized tensor must have their last dim divisible by block size {shape:?} {}",
535            block_size
536        )
537    }
538    Ok(())
539}
540
541impl QTensor {
542    pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
543        let shape = shape.into();
544        check_shape(&shape, storage.block_size())?;
545        Ok(Self { storage, shape })
546    }
547
548    pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
549        let shape = src.shape();
550        let block_size = dtype.block_size();
551        check_shape(shape, block_size)?;
552        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
553        let elem_count = shape.elem_count();
554        if !elem_count.is_multiple_of(block_size) {
555            crate::bail!(
556                "tensor size ({shape:?}) is not divisible by block size {}",
557                block_size
558            )
559        }
560        let mut storage = src.device().qzeros(elem_count, dtype)?;
561        storage.quantize(&src.storage())?;
562        Ok(Self {
563            storage,
564            shape: shape.clone(),
565        })
566    }
567
568    pub fn quantize_imatrix(
569        src: &Tensor,
570        imatrix_weights: &[f32],
571        dtype: GgmlDType,
572    ) -> Result<Self> {
573        // (n_per_row/QK_K-1)*QK_K+(QK_K/32-1)*32+32=n_per_row
574        // Size of imatrix == last dim of tensor
575        let n_per_row = src.dim(D::Minus1)?;
576        if imatrix_weights.len() != n_per_row {
577            crate::bail!(
578                "imatrix weights must have the same length {} as the last dim of src {}",
579                imatrix_weights.len(),
580                src.dim(D::Minus1)?
581            );
582        }
583
584        let shape = src.shape();
585        let block_size = dtype.block_size();
586        check_shape(shape, block_size)?;
587        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
588        let elem_count = shape.elem_count();
589        if !elem_count.is_multiple_of(block_size) {
590            crate::bail!(
591                "tensor size ({shape:?}) is not divisible by block size {}",
592                block_size
593            );
594        }
595        let mut storage = src.device().qzeros(elem_count, dtype)?;
596        storage.quantize_imatrix(&src.storage(), imatrix_weights, n_per_row)?;
597        Ok(Self {
598            storage,
599            shape: shape.clone(),
600        })
601    }
602
603    /// Quantize `src` (currently on the CPU) to a QTensor on `dev`
604    pub fn quantize_imatrix_onto(
605        src: &Tensor,
606        imatrix_weights: &[f32],
607        dtype: GgmlDType,
608        dev: &Device,
609    ) -> Result<Self> {
610        if !src.device().is_cpu() {
611            crate::bail!(
612                "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
613                src.device()
614            )
615        }
616        // (n_per_row/QK_K-1)*QK_K+(QK_K/32-1)*32+32=n_per_row
617        // Size of imatrix == last dim of tensor
618        let n_per_row = src.dim(D::Minus1)?;
619        if imatrix_weights.len() != n_per_row {
620            crate::bail!(
621                "imatrix weights must have the same length {} as the last dim of src {}",
622                imatrix_weights.len(),
623                src.dim(D::Minus1)?
624            );
625        }
626        let shape = src.shape();
627        let block_size = dtype.block_size();
628        check_shape(shape, block_size)?;
629        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
630        let elem_count = shape.elem_count();
631        if !elem_count.is_multiple_of(block_size) {
632            crate::bail!(
633                "tensor size ({shape:?}) is not divisible by block size {}",
634                block_size
635            )
636        }
637        // storage is on the `dev`, src is on `cpu`
638        let mut storage = dev.qzeros(elem_count, dtype)?;
639        storage.quantize_imatrix_onto(&src.storage(), imatrix_weights, n_per_row)?;
640        Ok(Self {
641            storage,
642            shape: shape.clone(),
643        })
644    }
645
646    /// Quantize `src` (currently on the CPU) to a QTensor on `dev`
647    pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result<Self> {
648        if !src.device().is_cpu() {
649            crate::bail!(
650                "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
651                src.device()
652            )
653        }
654        let shape = src.shape();
655        let block_size = dtype.block_size();
656        check_shape(shape, block_size)?;
657        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
658        let elem_count = shape.elem_count();
659        if !elem_count.is_multiple_of(block_size) {
660            crate::bail!(
661                "tensor size ({shape:?}) is not divisible by block size {}",
662                block_size
663            )
664        }
665        // storage is on the `dev`, src is on `cpu`
666        let mut storage = dev.qzeros(elem_count, dtype)?;
667        storage.quantize_onto(&src.storage())?;
668        Ok(Self {
669            storage,
670            shape: shape.clone(),
671        })
672    }
673
674    pub fn dtype(&self) -> GgmlDType {
675        self.storage.dtype()
676    }
677
678    pub fn device(&self) -> Device {
679        self.storage.device()
680    }
681
682    pub fn rank(&self) -> usize {
683        self.shape.rank()
684    }
685
686    pub fn shape(&self) -> &Shape {
687        &self.shape
688    }
689
690    pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
691        let storage = self.storage.dequantize(self.shape.elem_count())?;
692        let none = crate::op::BackpropOp::none();
693        crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
694    }
695
696    pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
697        // In the CUDA case, we have a specialized kernel as this can be useful for volta
698        // architectures. https://github.com/hanzoai/ml/issues/2136
699        match &self.storage {
700            QStorage::Cuda(s) => {
701                let s = s.dequantize_f16(self.shape.elem_count())?;
702                let none = crate::op::BackpropOp::none();
703                crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
704                    .to_device(device)
705            }
706            _ => {
707                let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
708                Ok(s)
709            }
710        }
711    }
712
713    pub fn storage_size_in_bytes(&self) -> usize {
714        self.storage.size_in_bytes()
715    }
716
717    pub fn data(&self) -> Result<Cow<'_, [u8]>> {
718        self.storage.data()
719    }
720
721    pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
722        match &self.storage {
723            QStorage::Cuda(s) => match (&*x.storage(), &*ids.storage()) {
724                (Storage::Cuda(x_storage), Storage::Cuda(ids_storage)) => {
725                    let (storage, out_shape) = s.indexed_moe_forward(
726                        self.shape(),
727                        x_storage,
728                        x.layout(),
729                        ids_storage,
730                        ids.layout(),
731                    )?;
732                    Ok(crate::tensor::from_storage(
733                        Storage::Cuda(storage),
734                        out_shape,
735                        crate::op::BackpropOp::none(),
736                        false,
737                    ))
738                }
739                _ => {
740                    panic!("Non-cuda indexed_moe_forward is not implemented!");
741                }
742            },
743            _ => {
744                // CPU / non-CUDA fallback: per-expert quantized matmul. The packed expert bank
745                // [E, n, k] is sliced into equal, contiguous per-expert quantized blocks; for
746                // each expert that is actually selected we run hanzo-ml's native quantized matmul
747                // on just the tokens routed to it. Nothing is dequantized, so quantized MoE runs
748                // on any backend (CPU, Metal, ...) at a cost proportional to the active experts.
749                use crate::Module; // brings QMatMul::forward into scope
750                use std::collections::HashMap;
751                use std::sync::Arc;
752                let device = x.device();
753                let out_dtype = x.dtype();
754                let (e_cnt, n, k) = self.shape().dims3()?;
755                let (t, topk) = ids.dims2()?;
756                let s = x.dim(1)?; // 1 (gate/up: shared input) or topk (down: per-slot)
757                let x_exp = if s == topk {
758                    x.clone()
759                } else {
760                    x.broadcast_as((t, topk, k))?
761                };
762                let x_flat = x_exp
763                    .reshape((t * topk, k))?
764                    .to_dtype(crate::DType::F32)?
765                    .contiguous()?;
766                let ids_flat = ids.reshape((t * topk,))?.to_dtype(crate::DType::U32)?;
767                let ids_vec = ids_flat.to_vec1::<u32>()?;
768                let mut groups: HashMap<u32, Vec<u32>> = HashMap::new();
769                for (slot, eid) in ids_vec.iter().enumerate() {
770                    groups.entry(*eid).or_default().push(slot as u32);
771                }
772                let dtype = self.storage.dtype();
773                let all_bytes = self.data()?;
774                let expert_bytes = all_bytes.len() / e_cnt;
775                let mut out_flat = Tensor::zeros((t * topk, n), crate::DType::F32, device)?;
776                for (eid, slots) in groups.into_iter() {
777                    let off = eid as usize * expert_bytes;
778                    let qs = QStorage::from_data(
779                        std::borrow::Cow::Borrowed(&all_bytes[off..off + expert_bytes]),
780                        device,
781                        dtype,
782                    )?;
783                    let shape: crate::Shape = (n, k).into();
784                    let w_e = QTensor { storage: qs, shape };
785                    let qm = QMatMul::from_arc(Arc::new(w_e))?;
786                    let m = slots.len();
787                    let idx = Tensor::from_vec(slots, (m,), device)?;
788                    let x_e = x_flat.index_select(&idx, 0)?; // [m, k]
789                    let y_e = qm.forward(&x_e)?.to_dtype(crate::DType::F32)?; // [m, n]
790                    out_flat = out_flat.index_add(&idx, &y_e, 0)?;
791                }
792                out_flat.reshape((t, topk, n))?.to_dtype(out_dtype)
793            }
794        }
795    }
796
797    pub fn device_ptr(&self) -> Result<*const u8> {
798        match &self.storage {
799            QStorage::Cuda(storage) => storage.device_ptr(),
800            #[cfg(feature = "vulkan")]
801            QStorage::Vulkan(..) => crate::bail!("not implemented"),
802            QStorage::Metal(_) | QStorage::Cpu(_) => {
803                crate::bail!("not implemented");
804            }
805        }
806    }
807
808    #[cfg(feature = "cuda")]
809    pub fn device_ptr_with_guard<'a>(
810        &'a self,
811        stream: &'a crate::cuda_backend::cudarc::driver::CudaStream,
812    ) -> Result<(
813        *const u8,
814        crate::cuda_backend::cudarc::driver::SyncOnDrop<'a>,
815    )> {
816        self.storage.device_ptr_with_guard(stream)
817    }
818}
819
820#[derive(Clone, Debug)]
821pub enum QMatMul {
822    QTensor(std::sync::Arc<QTensor>),
823    Tensor(Tensor),
824    TensorF16(Tensor),
825    // Native Vulkan quantized weight: Q8_0 blocks resident in VRAM (~1.125 B/elem). Decode (1 row)
826    // runs the GPU Q8 matvec directly (bandwidth-optimal); prefill dequantizes to a temporary f32.
827    // `qtensor` is the original (CPU-side) for the prefill dequant; `n`/`k` are the weight dims.
828    #[cfg(feature = "vulkan")]
829    VulkanQuant {
830        qtensor: std::sync::Arc<QTensor>,
831        wq: std::sync::Arc<crate::VulkanStorage>,
832        n: usize,
833        k: usize,
834    },
835}
836
837thread_local! {
838    static DEQUANTIZE_ALL: bool = {
839        match std::env::var("CANDLE_DEQUANTIZE_ALL") {
840            Ok(s) => {
841                !s.is_empty() && s != "0"
842            },
843            Err(_) => false,
844        }
845    }
846}
847
848thread_local! {
849    static DEQUANTIZE_ALL_F16: bool = {
850        match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
851            Ok(s) => {
852                !s.is_empty() && s != "0"
853            },
854            Err(_) => false,
855        }
856    }
857}
858
859impl QMatMul {
860    pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
861        // Native Vulkan quantized path: keep Q8_0 weights quantized in VRAM and run the GPU matvec
862        // for decode, instead of dequantizing the whole model to f32 (4x the decode bandwidth).
863        #[cfg(feature = "vulkan")]
864        if qtensor.dtype() == GgmlDType::Q8_0 {
865            if let Device::Vulkan(d) = qtensor.device() {
866                if let Ok((n, k)) = qtensor.shape().dims2() {
867                    if k % 32 == 0 {
868                        let wf = qtensor
869                            .dequantize(&Device::Cpu)?
870                            .flatten_all()?
871                            .to_vec1::<f32>()?;
872                        let wq = d.quantize_q8(&wf, n, k)?;
873                        return Ok(Self::VulkanQuant {
874                            qtensor,
875                            wq: std::sync::Arc::new(wq),
876                            n,
877                            k,
878                        });
879                    }
880                }
881            }
882        }
883        let dequantize = match qtensor.dtype() {
884            GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true,
885            // The Vulkan backend has no native quantized matmul yet, so dequantize to f32 here
886            // (once, at construction) and run the regular f32 GPU matmul.
887            _ => DEQUANTIZE_ALL.with(|b| *b) || qtensor.device().is_vulkan(),
888        };
889        let t = if dequantize {
890            let tensor = qtensor.dequantize(&qtensor.device())?;
891            Self::Tensor(tensor)
892        } else if DEQUANTIZE_ALL_F16.with(|b| *b) {
893            let tensor = qtensor.dequantize_f16(&qtensor.device())?;
894            Self::TensorF16(tensor)
895        } else {
896            Self::QTensor(qtensor)
897        };
898        Ok(t)
899    }
900
901    pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
902        Self::from_arc(std::sync::Arc::new(qtensor))
903    }
904
905    pub fn dequantize_f16(&self) -> Result<Tensor> {
906        match self {
907            Self::QTensor(t) => t.dequantize_f16(&t.device()),
908            Self::Tensor(t) => t.to_dtype(DType::F16),
909            Self::TensorF16(t) => Ok(t.clone()),
910            #[cfg(feature = "vulkan")]
911            Self::VulkanQuant { qtensor, .. } => qtensor.dequantize_f16(&qtensor.device()),
912        }
913    }
914
915    pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
916        let w = self.dequantize_f16()?;
917        let in_dtype = xs.dtype();
918        let w = match *xs.dims() {
919            [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
920            [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
921            _ => w.t()?,
922        };
923        xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
924    }
925
926    pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
927        match self {
928            Self::QTensor(t) => t.indexed_moe_forward(x, ids),
929            #[cfg(feature = "vulkan")]
930            Self::VulkanQuant { qtensor, .. } => qtensor.indexed_moe_forward(x, ids),
931            _ => {
932                panic!("Not implemented!")
933            }
934        }
935    }
936}
937
938impl crate::CustomOp1 for QTensor {
939    fn name(&self) -> &'static str {
940        "qmatmul"
941    }
942
943    fn cpu_fwd(
944        &self,
945        storage: &crate::CpuStorage,
946        layout: &crate::Layout,
947    ) -> Result<(crate::CpuStorage, Shape)> {
948        if !layout.is_contiguous() {
949            crate::bail!("input tensor is not contiguous {layout:?}")
950        }
951        let src_shape = layout.shape();
952        // self is transposed so n is first then k.
953        let (n, k) = self.shape.dims2()?;
954        if src_shape.rank() < 2 {
955            crate::bail!("input tensor has only one dimension {layout:?}")
956        }
957        let mut dst_shape = src_shape.dims().to_vec();
958        let last_k = dst_shape.pop().unwrap();
959        if last_k != k {
960            crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
961        }
962        dst_shape.push(n);
963        let dst_shape = Shape::from(dst_shape);
964        #[allow(clippy::infallible_destructuring_match)]
965        let self_storage = match &self.storage {
966            QStorage::Cpu(storage) => storage,
967            #[cfg(feature = "vulkan")]
968            QStorage::Vulkan(..) => crate::bail!("Invalid storage"),
969            QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
970        };
971        match storage.dtype() {
972            DType::F32 => {
973                let slice = storage.as_slice::<f32>()?;
974                let slice =
975                    &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
976                let mut dst_storage = vec![0f32; dst_shape.elem_count()];
977                self_storage.matmul_t(
978                    (dst_shape.elem_count() / n, k, n),
979                    slice,
980                    &mut dst_storage,
981                )?;
982                Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
983            }
984            DType::F16 => {
985                let slice = storage.as_slice::<f16>()?;
986                let slice =
987                    &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
988                let mut dst_storage = vec![f16::ZERO; dst_shape.elem_count()];
989                self_storage.matmul_t_f16(
990                    (dst_shape.elem_count() / n, k, n),
991                    slice,
992                    &mut dst_storage,
993                )?;
994                Ok((crate::CpuStorage::F16(dst_storage), dst_shape))
995            }
996            _ => crate::bail!("Expected f32/f16"),
997        }
998    }
999
1000    fn metal_fwd(
1001        &self,
1002        storage: &crate::MetalStorage,
1003        layout: &crate::Layout,
1004    ) -> Result<(crate::MetalStorage, Shape)> {
1005        let self_storage = match &self.storage {
1006            QStorage::Metal(metal) => metal,
1007            _ => unreachable!("Cannot call metal matmul on non metal QTensor"),
1008        };
1009        self_storage.fwd(&self.shape, storage, layout)
1010    }
1011
1012    fn cuda_fwd(
1013        &self,
1014        storage: &crate::CudaStorage,
1015        layout: &crate::Layout,
1016    ) -> Result<(crate::CudaStorage, Shape)> {
1017        let self_storage = match &self.storage {
1018            QStorage::Cuda(cuda) => cuda,
1019            _ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
1020        };
1021        self_storage.fwd(&self.shape, storage, layout)
1022    }
1023}
1024
1025impl crate::Module for QMatMul {
1026    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1027        match self {
1028            #[cfg(feature = "vulkan")]
1029            Self::VulkanQuant { qtensor, wq, n, k } => {
1030                let rows: usize = xs.elem_count() / *k;
1031                if rows == 1 {
1032                    // Decode: weights stay quantized in VRAM, GPU Q8 matvec (no dequant, no copy).
1033                    let xs = xs.contiguous()?;
1034                    let d = match xs.device() {
1035                        Device::Vulkan(d) => d,
1036                        _ => crate::bail!("VulkanQuant input not on vulkan"),
1037                    };
1038                    let y = {
1039                        let (store, _) = xs.storage_and_layout();
1040                        let xv = match &*store {
1041                            crate::Storage::Vulkan(v) => v,
1042                            _ => crate::bail!("VulkanQuant expected vulkan storage"),
1043                        };
1044                        d.matvec_q8_gpu(wq, xv, *n, *k)?
1045                    };
1046                    let mut dims = xs.dims().to_vec();
1047                    let last = dims.len() - 1;
1048                    dims[last] = *n;
1049                    Ok(crate::tensor::from_storage(
1050                        crate::Storage::Vulkan(y),
1051                        dims,
1052                        crate::op::BackpropOp::none(),
1053                        false,
1054                    ))
1055                } else {
1056                    // Prefill: dequantize to a temporary f32 weight (reuses the NT matmul path).
1057                    let w = qtensor.dequantize(&xs.device())?;
1058                    let w = match *xs.dims() {
1059                        [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
1060                        [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
1061                        _ => w.t()?,
1062                    };
1063                    xs.matmul(&w)
1064                }
1065            }
1066            Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
1067            Self::Tensor(w) => {
1068                let w = match *xs.dims() {
1069                    [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
1070                    [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
1071                    _ => w.t()?,
1072                };
1073                xs.matmul(&w)
1074            }
1075            Self::TensorF16(w) => {
1076                let in_dtype = xs.dtype();
1077                let w = match *xs.dims() {
1078                    [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
1079                    [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
1080                    _ => w.t()?,
1081                };
1082                xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
1083            }
1084        }
1085    }
1086}