Skip to main content

candle_core/quantized/
mod.rs

1use crate::{
2    backend::BackendStorage, CpuStorage, DType, Device, Result, Shape, Storage, Tensor, D,
3};
4use k_quants::*;
5use std::{borrow::Cow, sync::OnceLock};
6
7#[cfg(target_feature = "avx2")]
8pub mod avx;
9mod dummy_cuda;
10mod dummy_metal;
11pub mod ggml_file;
12pub mod gguf_file;
13pub mod imatrix_file;
14pub mod k_quants;
15#[cfg(feature = "metal")]
16pub mod metal;
17#[cfg(not(target_arch = "wasm32"))]
18pub mod tokenizer;
19#[cfg(not(feature = "metal"))]
20mod metal {
21    pub use super::dummy_metal::*;
22}
23#[cfg(feature = "cuda")]
24pub mod cuda;
25#[cfg(feature = "cuda")]
26pub mod fast_mmq;
27#[cfg(feature = "cuda")]
28pub mod fast_mmvq;
29#[cfg(not(feature = "cuda"))]
30mod cuda {
31    pub use super::dummy_cuda::*;
32}
33
34#[cfg(target_feature = "neon")]
35pub mod neon;
36#[cfg(target_feature = "simd128")]
37pub mod simd128;
38pub mod utils;
39use half::{bf16, f16};
40
41pub use k_quants::GgmlType;
42
43fn as_t_slice<T>(data: &[u8]) -> &[T] {
44    let size = std::mem::size_of::<T>();
45    assert_eq!(
46        data.len() % size,
47        0,
48        "Data length must be a multiple of T's size"
49    );
50    let ptr = data.as_ptr();
51    assert_eq!(
52        (ptr as usize) % std::mem::align_of::<T>(),
53        0,
54        "Data pointer must be aligned to T's alignment"
55    );
56    unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) }
57}
58
59pub struct QTensor {
60    storage: QStorage,
61    shape: Shape,
62    /// Lazily initialized storage for repacked quantized data. Currently raw bits, could be `QStorage` in the future.
63    /// Not always used.
64    #[allow(dead_code)]
65    repacked_qs: OnceLock<Option<Vec<u8>>>,
66}
67
68impl Device {
69    fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
70        match self {
71            Device::Cpu => {
72                let storage = dtype.cpu_zeros(elem_count);
73                Ok(QStorage::Cpu(storage))
74            }
75            Device::Metal(metal) => {
76                let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
77                Ok(QStorage::Metal(storage))
78            }
79            Device::Cuda(cuda) => {
80                let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
81                Ok(QStorage::Cuda(storage))
82            }
83        }
84    }
85}
86
87pub enum QStorage {
88    Cpu(Box<dyn QuantizedType>),
89    Metal(metal::QMetalStorage),
90    Cuda(cuda::QCudaStorage),
91}
92
93impl QStorage {
94    pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result<Self> {
95        let data: &[u8] = &data;
96        match device {
97            Device::Cpu => Ok(Self::Cpu(dtype.from_data(Cow::Borrowed(data)))),
98            Device::Metal(d) => match dtype {
99                GgmlDType::F32 => metal::load_quantized(d, as_t_slice::<f32>(data)),
100                GgmlDType::F16 => metal::load_quantized(d, as_t_slice::<f16>(data)),
101                GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
102                GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
103                GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
104                GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
105                GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
106                GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
107                GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
108                GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
109                GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
110                GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
111                GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
112                GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
113                GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::<bf16>(data)),
114            },
115            Device::Cuda(d) => match dtype {
116                GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::<f32>(data)),
117                GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::<f16>(data)),
118                GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
119                GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
120                GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
121                GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
122                GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
123                GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
124                GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
125                GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
126                GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
127                GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
128                GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
129                GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
130                GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::<bf16>(data)),
131            },
132        }
133    }
134
135    fn block_size(&self) -> usize {
136        match self {
137            QStorage::Cpu(storage) => storage.block_size(),
138            QStorage::Metal(storage) => storage.dtype().block_size(),
139            QStorage::Cuda(storage) => storage.dtype().block_size(),
140        }
141    }
142
143    fn dtype(&self) -> GgmlDType {
144        match self {
145            QStorage::Cpu(storage) => storage.dtype(),
146            QStorage::Metal(storage) => storage.dtype(),
147            QStorage::Cuda(storage) => storage.dtype(),
148        }
149    }
150
151    fn device(&self) -> Device {
152        match self {
153            QStorage::Cpu(_storage) => Device::Cpu,
154            QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
155            QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
156        }
157    }
158
159    fn size_in_bytes(&self) -> usize {
160        match self {
161            QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
162            QStorage::Metal(storage) => storage.storage_size_in_bytes(),
163            QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
164        }
165    }
166
167    fn quantize(&mut self, src: &Storage) -> Result<()> {
168        match (self, src) {
169            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
170                storage.from_float(src.as_slice::<f32>()?);
171            }
172            (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
173            (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
174            _ => crate::bail!("Invalid quantize storage locations do not match"),
175        }
176        Ok(())
177    }
178
179    fn quantize_imatrix(
180        &mut self,
181        src: &Storage,
182        imatrix_weights: &[f32],
183        n_per_row: usize,
184    ) -> Result<()> {
185        match (self, src) {
186            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
187                storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
188            }
189            (QStorage::Metal(storage), Storage::Metal(src)) => {
190                storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
191            }
192            (QStorage::Cuda(storage), Storage::Cuda(src)) => {
193                storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
194            }
195            _ => crate::bail!("Invalid quantize storage locations do not match"),
196        }
197        Ok(())
198    }
199
200    fn quantize_onto(&mut self, src: &Storage) -> Result<()> {
201        match (self, src) {
202            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
203                storage.from_float(src.as_slice::<f32>()?);
204            }
205            (QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
206            (QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
207            _ => crate::bail!("Invalid quantize source storage locations: not on cpu"),
208        }
209        Ok(())
210    }
211
212    fn quantize_imatrix_onto(
213        &mut self,
214        src: &Storage,
215        imatrix_weights: &[f32],
216        n_per_row: usize,
217    ) -> Result<()> {
218        match (self, src) {
219            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
220                storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
221            }
222            (QStorage::Metal(storage), Storage::Cpu(src)) => {
223                storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
224            }
225            (QStorage::Cuda(storage), Storage::Cpu(src)) => {
226                storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
227            }
228            _ => crate::bail!("Invalid quantize storage locations do not match"),
229        }
230        Ok(())
231    }
232
233    fn dequantize(&self, elem_count: usize) -> Result<Storage> {
234        match self {
235            QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
236            QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
237            QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
238        }
239    }
240
241    fn data(&self) -> Result<Cow<'_, [u8]>> {
242        match self {
243            QStorage::Cpu(storage) => {
244                let data_ptr = storage.as_ptr();
245                let size_in_bytes = storage.storage_size_in_bytes();
246                let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
247                Ok(Cow::from(data))
248            }
249            QStorage::Cuda(storage) => Ok(Cow::from(storage.data()?)),
250            QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)),
251        }
252    }
253
254    pub fn device_ptr(&self) -> Result<*const u8> {
255        match self {
256            QStorage::Cuda(storage) => storage.device_ptr(),
257            QStorage::Metal(_) | QStorage::Cpu(_) => {
258                crate::bail!("not implemented");
259            }
260        }
261    }
262
263    #[cfg(feature = "cuda")]
264    pub fn device_ptr_with_guard<'a>(
265        &'a self,
266        stream: &'a crate::cuda_backend::cudarc::driver::CudaStream,
267    ) -> Result<(
268        *const u8,
269        crate::cuda_backend::cudarc::driver::SyncOnDrop<'a>,
270    )> {
271        match self {
272            QStorage::Cuda(storage) => storage.device_ptr_with_guard(stream),
273            QStorage::Metal(_) | QStorage::Cpu(_) => {
274                crate::bail!("not implemented");
275            }
276        }
277    }
278}
279
280#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
281pub enum GgmlDType {
282    F32,
283    F16,
284    BF16,
285    Q4_0,
286    Q4_1,
287    Q5_0,
288    Q5_1,
289    Q8_0,
290    Q8_1,
291    Q2K,
292    Q3K,
293    Q4K,
294    Q5K,
295    Q6K,
296    Q8K,
297}
298
299impl GgmlDType {
300    pub(crate) fn from_u32(u: u32) -> Result<Self> {
301        let dtype = match u {
302            0 => Self::F32,
303            1 => Self::F16,
304            2 => Self::Q4_0,
305            3 => Self::Q4_1,
306            6 => Self::Q5_0,
307            7 => Self::Q5_1,
308            8 => Self::Q8_0,
309            9 => Self::Q8_1,
310            10 => Self::Q2K,
311            11 => Self::Q3K,
312            12 => Self::Q4K,
313            13 => Self::Q5K,
314            14 => Self::Q6K,
315            15 => Self::Q8K,
316            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
317            30 => Self::BF16,
318            _ => crate::bail!("unknown dtype for tensor {u}"),
319        };
320        Ok(dtype)
321    }
322
323    pub(crate) fn to_u32(self) -> u32 {
324        match self {
325            Self::F32 => 0,
326            Self::F16 => 1,
327            Self::Q4_0 => 2,
328            Self::Q4_1 => 3,
329            Self::Q5_0 => 6,
330            Self::Q5_1 => 7,
331            Self::Q8_0 => 8,
332            Self::Q8_1 => 9,
333            Self::Q2K => 10,
334            Self::Q3K => 11,
335            Self::Q4K => 12,
336            Self::Q5K => 13,
337            Self::Q6K => 14,
338            Self::Q8K => 15,
339            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
340            Self::BF16 => 30,
341        }
342    }
343
344    /// The block dtype
345    pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
346        match self {
347            Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
348            Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
349            Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
350            Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
351            Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
352            Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
353            Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
354            Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
355            Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
356            Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
357            Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
358            Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
359            Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
360            Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
361            Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
362        }
363    }
364
365    pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box<dyn QuantizedType> {
366        let data: &[u8] = &data;
367        match self {
368            Self::F32 => Box::new(as_t_slice::<f32>(data).to_vec()),
369            Self::F16 => Box::new(as_t_slice::<f16>(data).to_vec()),
370            Self::Q4_0 => Box::new(as_t_slice::<BlockQ4_0>(data).to_vec()),
371            Self::Q4_1 => Box::new(as_t_slice::<BlockQ4_1>(data).to_vec()),
372            Self::Q5_0 => Box::new(as_t_slice::<BlockQ5_0>(data).to_vec()),
373            Self::Q5_1 => Box::new(as_t_slice::<BlockQ5_1>(data).to_vec()),
374            Self::Q8_0 => Box::new(as_t_slice::<BlockQ8_0>(data).to_vec()),
375            Self::Q8_1 => Box::new(as_t_slice::<BlockQ8_1>(data).to_vec()),
376            Self::Q2K => Box::new(as_t_slice::<BlockQ2K>(data).to_vec()),
377            Self::Q3K => Box::new(as_t_slice::<BlockQ3K>(data).to_vec()),
378            Self::Q4K => Box::new(as_t_slice::<BlockQ4K>(data).to_vec()),
379            Self::Q5K => Box::new(as_t_slice::<BlockQ5K>(data).to_vec()),
380            Self::Q6K => Box::new(as_t_slice::<BlockQ6K>(data).to_vec()),
381            Self::Q8K => Box::new(as_t_slice::<BlockQ8K>(data).to_vec()),
382            Self::BF16 => Box::new(as_t_slice::<bf16>(data).to_vec()),
383        }
384    }
385
386    /// The type size for blocks in bytes.
387    pub fn type_size(&self) -> usize {
388        use k_quants::*;
389        match self {
390            Self::F32 => 4,
391            Self::F16 | Self::BF16 => 2,
392            Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
393            Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
394            Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
395            Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
396            // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
397            Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
398            Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
399            Self::Q2K => std::mem::size_of::<BlockQ2K>(),
400            Self::Q3K => std::mem::size_of::<BlockQ3K>(),
401            Self::Q4K => std::mem::size_of::<BlockQ4K>(),
402            Self::Q5K => std::mem::size_of::<BlockQ5K>(),
403            Self::Q6K => std::mem::size_of::<BlockQ6K>(),
404            Self::Q8K => std::mem::size_of::<BlockQ8K>(),
405        }
406    }
407
408    /// The block size, i.e. the number of elements stored in each block.
409    pub fn block_size(&self) -> usize {
410        match self {
411            Self::F32 => 1,
412            Self::F16 | Self::BF16 => 1,
413            Self::Q4_0 => k_quants::QK4_0,
414            Self::Q4_1 => k_quants::QK4_1,
415            Self::Q5_0 => k_quants::QK5_0,
416            Self::Q5_1 => k_quants::QK5_1,
417            Self::Q8_0 => k_quants::QK8_0,
418            Self::Q8_1 => k_quants::QK8_1,
419            Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
420        }
421    }
422}
423
424// A version of GgmlType without `vec_dot` so that it can be dyn boxed.
425pub trait QuantizedType: Send + Sync {
426    fn dtype(&self) -> GgmlDType;
427    fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
428    fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()>;
429    fn embedding(&self, ids: &[u32], rows: usize, hidden: usize) -> Result<CpuStorage>;
430    fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
431    fn storage_size_in_bytes(&self) -> usize;
432    fn as_ptr(&self) -> *const u8;
433    fn block_size(&self) -> usize;
434    #[allow(clippy::wrong_self_convention)]
435    fn from_float(&mut self, xs: &[f32]);
436    #[allow(clippy::wrong_self_convention)]
437    fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize);
438    fn size(&self) -> usize;
439}
440
441impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
442    fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
443        k_quants::matmul(mkn, lhs, self.as_slice(), dst)
444    }
445    fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()> {
446        k_quants::matmul_f16(mkn, lhs, self.as_slice(), dst)
447    }
448
449    fn embedding(&self, ids: &[u32], rows: usize, hidden: usize) -> Result<CpuStorage> {
450        if !hidden.is_multiple_of(T::BLCK_SIZE) {
451            crate::bail!(
452                "quantized embedding hidden size {hidden} is not divisible by block size {}",
453                T::BLCK_SIZE
454            )
455        }
456        let row_blocks = hidden / T::BLCK_SIZE;
457        if self.len() != rows * row_blocks {
458            crate::bail!(
459                "quantized tensor has {} blocks, expected {}",
460                self.len(),
461                rows * row_blocks
462            )
463        }
464        let mut out = vec![0f32; ids.len() * hidden];
465        for (out_row, &row_id) in ids.iter().enumerate() {
466            let row = row_id as usize;
467            if row >= rows {
468                crate::bail!("embedding id {row} is out of range for {rows} rows")
469            }
470            let src = &self[row * row_blocks..(row + 1) * row_blocks];
471            let dst = &mut out[out_row * hidden..(out_row + 1) * hidden];
472            T::to_float(src, dst);
473        }
474        Ok(CpuStorage::F32(out))
475    }
476
477    fn size(&self) -> usize {
478        self.len() * core::mem::size_of::<T>()
479    }
480
481    fn from_float(&mut self, xs: &[f32]) {
482        T::from_float(xs, self)
483    }
484
485    fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize) {
486        T::from_float_imatrix(xs, self, imatrix_weights, n_per_row)
487    }
488
489    fn dtype(&self) -> GgmlDType {
490        T::DTYPE
491    }
492
493    fn block_size(&self) -> usize {
494        T::BLCK_SIZE
495    }
496
497    fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
498        let mut ys = vec![0.0f32; elem_count];
499        T::to_float(self.as_slice(), &mut ys);
500        Ok(CpuStorage::F32(ys))
501    }
502
503    fn storage_size_in_bytes(&self) -> usize {
504        self.len() * std::mem::size_of::<T>()
505    }
506
507    fn as_ptr(&self) -> *const u8 {
508        self.as_ptr() as *const u8
509    }
510}
511
512impl std::fmt::Debug for QTensor {
513    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
514        write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
515    }
516}
517
518fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
519    let dims = shape.dims();
520    if dims.is_empty() {
521        crate::bail!("scalar tensor cannot be quantized {shape:?}")
522    }
523    if !dims[dims.len() - 1].is_multiple_of(block_size) {
524        crate::bail!(
525            "quantized tensor must have their last dim divisible by block size {shape:?} {}",
526            block_size
527        )
528    }
529    Ok(())
530}
531
532impl QTensor {
533    pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
534        let shape = shape.into();
535        check_shape(&shape, storage.block_size())?;
536        Ok(Self {
537            storage,
538            shape,
539            repacked_qs: OnceLock::new(),
540        })
541    }
542
543    pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
544        let shape = src.shape();
545        let block_size = dtype.block_size();
546        check_shape(shape, block_size)?;
547        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
548        let elem_count = shape.elem_count();
549        if !elem_count.is_multiple_of(block_size) {
550            crate::bail!(
551                "tensor size ({shape:?}) is not divisible by block size {}",
552                block_size
553            )
554        }
555        let mut storage = src.device().qzeros(elem_count, dtype)?;
556        storage.quantize(&src.storage())?;
557        Ok(Self {
558            storage,
559            shape: shape.clone(),
560            repacked_qs: OnceLock::new(),
561        })
562    }
563
564    pub fn quantize_imatrix(
565        src: &Tensor,
566        imatrix_weights: &[f32],
567        dtype: GgmlDType,
568    ) -> Result<Self> {
569        // (n_per_row/QK_K-1)*QK_K+(QK_K/32-1)*32+32=n_per_row
570        // Size of imatrix == last dim of tensor
571        let n_per_row = src.dim(D::Minus1)?;
572        if imatrix_weights.len() != n_per_row {
573            crate::bail!(
574                "imatrix weights must have the same length {} as the last dim of src {}",
575                imatrix_weights.len(),
576                src.dim(D::Minus1)?
577            );
578        }
579
580        let shape = src.shape();
581        let block_size = dtype.block_size();
582        check_shape(shape, block_size)?;
583        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
584        let elem_count = shape.elem_count();
585        if !elem_count.is_multiple_of(block_size) {
586            crate::bail!(
587                "tensor size ({shape:?}) is not divisible by block size {}",
588                block_size
589            );
590        }
591        let mut storage = src.device().qzeros(elem_count, dtype)?;
592        storage.quantize_imatrix(&src.storage(), imatrix_weights, n_per_row)?;
593        Ok(Self {
594            storage,
595            shape: shape.clone(),
596            repacked_qs: OnceLock::new(),
597        })
598    }
599
600    /// Quantize `src` (currently on the CPU) to a QTensor on `dev`
601    pub fn quantize_imatrix_onto(
602        src: &Tensor,
603        imatrix_weights: &[f32],
604        dtype: GgmlDType,
605        dev: &Device,
606    ) -> Result<Self> {
607        if !src.device().is_cpu() {
608            crate::bail!(
609                "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
610                src.device()
611            )
612        }
613        // (n_per_row/QK_K-1)*QK_K+(QK_K/32-1)*32+32=n_per_row
614        // Size of imatrix == last dim of tensor
615        let n_per_row = src.dim(D::Minus1)?;
616        if imatrix_weights.len() != n_per_row {
617            crate::bail!(
618                "imatrix weights must have the same length {} as the last dim of src {}",
619                imatrix_weights.len(),
620                src.dim(D::Minus1)?
621            );
622        }
623        let shape = src.shape();
624        let block_size = dtype.block_size();
625        check_shape(shape, block_size)?;
626        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
627        let elem_count = shape.elem_count();
628        if !elem_count.is_multiple_of(block_size) {
629            crate::bail!(
630                "tensor size ({shape:?}) is not divisible by block size {}",
631                block_size
632            )
633        }
634        // storage is on the `dev`, src is on `cpu`
635        let mut storage = dev.qzeros(elem_count, dtype)?;
636        storage.quantize_imatrix_onto(&src.storage(), imatrix_weights, n_per_row)?;
637        Ok(Self {
638            storage,
639            shape: shape.clone(),
640            repacked_qs: OnceLock::new(),
641        })
642    }
643
644    /// Quantize `src` (currently on the CPU) to a QTensor on `dev`
645    pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result<Self> {
646        if !src.device().is_cpu() {
647            crate::bail!(
648                "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
649                src.device()
650            )
651        }
652        let shape = src.shape();
653        let block_size = dtype.block_size();
654        check_shape(shape, block_size)?;
655        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
656        let elem_count = shape.elem_count();
657        if !elem_count.is_multiple_of(block_size) {
658            crate::bail!(
659                "tensor size ({shape:?}) is not divisible by block size {}",
660                block_size
661            )
662        }
663        // storage is on the `dev`, src is on `cpu`
664        let mut storage = dev.qzeros(elem_count, dtype)?;
665        storage.quantize_onto(&src.storage())?;
666        Ok(Self {
667            storage,
668            shape: shape.clone(),
669            repacked_qs: OnceLock::new(),
670        })
671    }
672
673    pub fn dtype(&self) -> GgmlDType {
674        self.storage.dtype()
675    }
676
677    pub fn device(&self) -> Device {
678        self.storage.device()
679    }
680
681    pub fn rank(&self) -> usize {
682        self.shape.rank()
683    }
684
685    pub fn shape(&self) -> &Shape {
686        &self.shape
687    }
688
689    pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
690        let storage = self.storage.dequantize(self.shape.elem_count())?;
691        let none = crate::op::BackpropOp::none();
692        crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
693    }
694
695    pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
696        // In the CUDA case, we have a specialized kernel as this can be useful for volta
697        // architectures. https://github.com/huggingface/candle/issues/2136
698        match &self.storage {
699            QStorage::Cuda(s) => {
700                let s = s.dequantize_f16(self.shape.elem_count())?;
701                let none = crate::op::BackpropOp::none();
702                crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
703                    .to_device(device)
704            }
705            _ => {
706                let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
707                Ok(s)
708            }
709        }
710    }
711
712    pub fn embedding(&self, ids: &Tensor) -> Result<Tensor> {
713        let (rows, hidden) = self.shape.dims2()?;
714        if !hidden.is_multiple_of(self.dtype().block_size()) {
715            crate::bail!(
716                "quantized embedding hidden size {hidden} is not divisible by block size {}",
717                self.dtype().block_size()
718            )
719        }
720        let mut out_shape = ids.dims().to_vec();
721        out_shape.push(hidden);
722        let device = self.device();
723        let ids = ids
724            .to_device(&device)?
725            .to_dtype(DType::U32)?
726            .flatten_all()?
727            .contiguous()?;
728        let storage = match &self.storage {
729            QStorage::Cpu(storage) => {
730                let ids = ids.to_vec1::<u32>()?;
731                Storage::Cpu(storage.embedding(&ids, rows, hidden)?)
732            }
733            QStorage::Metal(storage) => match &*ids.storage() {
734                Storage::Metal(ids_storage) => {
735                    Storage::Metal(storage.embedding(rows, hidden, ids_storage, ids.layout())?)
736                }
737                _ => unreachable!("ids were moved to the QTensor device"),
738            },
739            QStorage::Cuda(storage) => match &*ids.storage() {
740                Storage::Cuda(ids_storage) => {
741                    Storage::Cuda(storage.embedding(rows, hidden, ids_storage, ids.layout())?)
742                }
743                _ => unreachable!("ids were moved to the QTensor device"),
744            },
745        };
746        let none = crate::op::BackpropOp::none();
747        Ok(crate::tensor::from_storage(storage, out_shape, none, false))
748    }
749
750    pub fn storage_size_in_bytes(&self) -> usize {
751        self.storage.size_in_bytes()
752    }
753
754    pub fn data(&self) -> Result<Cow<'_, [u8]>> {
755        self.storage.data()
756    }
757
758    pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
759        match &self.storage {
760            QStorage::Cuda(s) => match (&*x.storage(), &*ids.storage()) {
761                (Storage::Cuda(x_storage), Storage::Cuda(ids_storage)) => {
762                    let (storage, out_shape) = s.indexed_moe_forward(
763                        self.shape(),
764                        x_storage,
765                        x.layout(),
766                        ids_storage,
767                        ids.layout(),
768                    )?;
769                    Ok(crate::tensor::from_storage(
770                        Storage::Cuda(storage),
771                        out_shape,
772                        crate::op::BackpropOp::none(),
773                        false,
774                    ))
775                }
776                _ => {
777                    panic!("Non-cuda indexed_moe_forward is not implemented!");
778                }
779            },
780            _ => {
781                panic!("indexed_moe_forward is not implemented in this platform!");
782            }
783        }
784    }
785
786    pub fn device_ptr(&self) -> Result<*const u8> {
787        match &self.storage {
788            QStorage::Cuda(storage) => storage.device_ptr(),
789            QStorage::Metal(_) | QStorage::Cpu(_) => {
790                crate::bail!("not implemented");
791            }
792        }
793    }
794
795    #[cfg(feature = "cuda")]
796    pub fn device_ptr_with_guard<'a>(
797        &'a self,
798        stream: &'a crate::cuda_backend::cudarc::driver::CudaStream,
799    ) -> Result<(
800        *const u8,
801        crate::cuda_backend::cudarc::driver::SyncOnDrop<'a>,
802    )> {
803        self.storage.device_ptr_with_guard(stream)
804    }
805}
806
807#[derive(Clone, Debug)]
808pub enum QMatMul {
809    QTensor(std::sync::Arc<QTensor>),
810    Tensor(Tensor),
811    TensorF16(Tensor),
812}
813
814thread_local! {
815    static DEQUANTIZE_ALL: bool = {
816        match std::env::var("CANDLE_DEQUANTIZE_ALL") {
817            Ok(s) => {
818                !s.is_empty() && s != "0"
819            },
820            Err(_) => false,
821        }
822    }
823}
824
825thread_local! {
826    static DEQUANTIZE_ALL_F16: bool = {
827        match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
828            Ok(s) => {
829                !s.is_empty() && s != "0"
830            },
831            Err(_) => false,
832        }
833    }
834}
835
836impl QMatMul {
837    pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
838        let dequantize = match qtensor.dtype() {
839            GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true,
840            _ => DEQUANTIZE_ALL.with(|b| *b),
841        };
842        let t = if dequantize {
843            let tensor = qtensor.dequantize(&qtensor.device())?;
844            Self::Tensor(tensor)
845        } else if DEQUANTIZE_ALL_F16.with(|b| *b) {
846            let tensor = qtensor.dequantize_f16(&qtensor.device())?;
847            Self::TensorF16(tensor)
848        } else {
849            Self::QTensor(qtensor)
850        };
851        Ok(t)
852    }
853
854    pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
855        Self::from_arc(std::sync::Arc::new(qtensor))
856    }
857
858    pub fn dequantize_f16(&self) -> Result<Tensor> {
859        match self {
860            Self::QTensor(t) => t.dequantize_f16(&t.device()),
861            Self::Tensor(t) => t.to_dtype(DType::F16),
862            Self::TensorF16(t) => Ok(t.clone()),
863        }
864    }
865
866    pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
867        let w = self.dequantize_f16()?;
868        let in_dtype = xs.dtype();
869        let w = match *xs.dims() {
870            [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
871            [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
872            _ => w.t()?,
873        };
874        xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
875    }
876
877    pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
878        match self {
879            Self::QTensor(t) => t.indexed_moe_forward(x, ids),
880            _ => {
881                panic!("Not implemented!")
882            }
883        }
884    }
885
886    pub fn embedding(&self, ids: &Tensor) -> Result<Tensor> {
887        match self {
888            Self::QTensor(t) => t.embedding(ids),
889            Self::Tensor(w) | Self::TensorF16(w) => {
890                let mut final_dims = ids.dims().to_vec();
891                final_dims.push(w.dim(D::Minus1)?);
892                let ids = ids.to_device(w.device())?.flatten_all()?;
893                w.index_select(&ids, 0)?.reshape(final_dims)
894            }
895        }
896    }
897}
898
899impl crate::CustomOp1 for QTensor {
900    fn name(&self) -> &'static str {
901        "qmatmul"
902    }
903
904    fn cpu_fwd(
905        &self,
906        storage: &crate::CpuStorage,
907        layout: &crate::Layout,
908    ) -> Result<(crate::CpuStorage, Shape)> {
909        if !layout.is_contiguous() {
910            crate::bail!("input tensor is not contiguous {layout:?}")
911        }
912        let src_shape = layout.shape();
913        // self is transposed so n is first then k.
914        let (n, k) = self.shape.dims2()?;
915        if src_shape.rank() < 2 {
916            crate::bail!("input tensor has only one dimension {layout:?}")
917        }
918        let mut dst_shape = src_shape.dims().to_vec();
919        let last_k = dst_shape.pop().unwrap();
920        if last_k != k {
921            crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
922        }
923        dst_shape.push(n);
924        let dst_shape = Shape::from(dst_shape);
925        #[allow(clippy::infallible_destructuring_match)]
926        let self_storage = match &self.storage {
927            QStorage::Cpu(storage) => storage,
928            QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
929        };
930        match storage.dtype() {
931            DType::F32 => {
932                let slice = storage.as_slice::<f32>()?;
933                let slice =
934                    &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
935                let mut dst_storage = vec![0f32; dst_shape.elem_count()];
936
937                // Try the 8-column BlockQ4Kx8 repacked path.
938                #[cfg(all(target_arch = "aarch64", target_feature = "dotprod"))]
939                if self_storage.dtype() == GgmlDType::Q4K && n.is_multiple_of(8) {
940                    use zerocopy::{FromBytes, IntoBytes};
941
942                    let total_blocks =
943                        self_storage.storage_size_in_bytes() / std::mem::size_of::<BlockQ4K>();
944                    let repacked = self.repacked_qs.get_or_init(|| {
945                        let blocks = unsafe {
946                            std::slice::from_raw_parts(
947                                self_storage.as_ptr() as *const BlockQ4K,
948                                total_blocks,
949                            )
950                        };
951                        let packed = k_quants::pack_to_q4kx8(blocks, n);
952                        Some(packed.as_bytes().to_vec())
953                    });
954                    if let Some(repacked_bytes) = repacked {
955                        let block_x8: &[BlockQ4Kx8] =
956                            <[BlockQ4Kx8]>::ref_from_bytes(repacked_bytes).map_err(|_| {
957                                crate::Error::Msg(
958                                    "repacked_qs alignment invariant violated".to_string(),
959                                )
960                            })?;
961
962                        k_quants::matmul_q4k_x8(
963                            (dst_shape.elem_count() / n, k, n),
964                            slice,
965                            block_x8,
966                            &mut dst_storage,
967                        )?;
968                        return Ok((crate::CpuStorage::F32(dst_storage), dst_shape));
969                    }
970                }
971
972                self_storage.matmul_t(
973                    (dst_shape.elem_count() / n, k, n),
974                    slice,
975                    &mut dst_storage,
976                )?;
977                Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
978            }
979            DType::F16 => {
980                let slice = storage.as_slice::<f16>()?;
981                let slice =
982                    &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
983                let mut dst_storage = vec![f16::ZERO; dst_shape.elem_count()];
984                self_storage.matmul_t_f16(
985                    (dst_shape.elem_count() / n, k, n),
986                    slice,
987                    &mut dst_storage,
988                )?;
989                Ok((crate::CpuStorage::F16(dst_storage), dst_shape))
990            }
991            _ => crate::bail!("Expected f32/f16"),
992        }
993    }
994
995    fn metal_fwd(
996        &self,
997        storage: &crate::MetalStorage,
998        layout: &crate::Layout,
999    ) -> Result<(crate::MetalStorage, Shape)> {
1000        let self_storage = match &self.storage {
1001            QStorage::Metal(metal) => metal,
1002            _ => unreachable!("Cannot call metal matmul on non metal QTensor"),
1003        };
1004        self_storage.fwd(&self.shape, storage, layout)
1005    }
1006
1007    fn cuda_fwd(
1008        &self,
1009        storage: &crate::CudaStorage,
1010        layout: &crate::Layout,
1011    ) -> Result<(crate::CudaStorage, Shape)> {
1012        let self_storage = match &self.storage {
1013            QStorage::Cuda(cuda) => cuda,
1014            _ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
1015        };
1016        self_storage.fwd(&self.shape, storage, layout)
1017    }
1018}
1019
1020impl crate::Module for QMatMul {
1021    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1022        match self {
1023            Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
1024            Self::Tensor(w) => {
1025                let w = match *xs.dims() {
1026                    [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
1027                    [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
1028                    _ => w.t()?,
1029                };
1030                xs.matmul(&w)
1031            }
1032            Self::TensorF16(w) => {
1033                let in_dtype = xs.dtype();
1034                let w = match *xs.dims() {
1035                    [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
1036                    [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
1037                    _ => w.t()?,
1038                };
1039                xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
1040            }
1041        }
1042    }
1043}