Skip to main content

candle_core/quantized/
mod.rs

1use crate::{
2    backend::BackendStorage, CpuStorage, DType, Device, Result, Shape, Storage, Tensor, D,
3};
4use k_quants::*;
5use std::borrow::Cow;
6
7#[cfg(target_feature = "avx2")]
8pub mod avx;
9mod dummy_cuda;
10mod dummy_metal;
11pub mod ggml_file;
12pub mod gguf_file;
13pub mod imatrix_file;
14pub mod k_quants;
15#[cfg(feature = "metal")]
16pub mod metal;
17#[cfg(not(target_arch = "wasm32"))]
18pub mod tokenizer;
19#[cfg(not(feature = "metal"))]
20mod metal {
21    pub use super::dummy_metal::*;
22}
23#[cfg(feature = "cuda")]
24pub mod cuda;
25#[cfg(not(feature = "cuda"))]
26mod cuda {
27    pub use super::dummy_cuda::*;
28}
29
30#[cfg(target_feature = "neon")]
31pub mod neon;
32#[cfg(target_feature = "simd128")]
33pub mod simd128;
34pub mod utils;
35use half::{bf16, f16};
36
37pub use k_quants::GgmlType;
38
39fn as_t_slice<T>(data: Cow<'_, [u8]>) -> &[T] {
40    let size = std::mem::size_of::<T>();
41    assert_eq!(
42        data.len() % size,
43        0,
44        "Data length must be a multiple of T's size"
45    );
46    let ptr = data.as_ptr();
47    assert_eq!(
48        (ptr as usize) % std::mem::align_of::<T>(),
49        0,
50        "Data pointer must be aligned to T's alignment"
51    );
52    unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) }
53}
54
55pub struct QTensor {
56    storage: QStorage,
57    shape: Shape,
58}
59
60impl Device {
61    fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
62        match self {
63            Device::Cpu => {
64                let storage = dtype.cpu_zeros(elem_count);
65                Ok(QStorage::Cpu(storage))
66            }
67            Device::Metal(metal) => {
68                let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
69                Ok(QStorage::Metal(storage))
70            }
71            Device::Cuda(cuda) => {
72                let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
73                Ok(QStorage::Cuda(storage))
74            }
75        }
76    }
77}
78
79pub enum QStorage {
80    Cpu(Box<dyn QuantizedType>),
81    Metal(metal::QMetalStorage),
82    Cuda(cuda::QCudaStorage),
83}
84
85impl QStorage {
86    pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result<Self> {
87        match device {
88            Device::Cpu => Ok(Self::Cpu(dtype.from_data(data))),
89            Device::Metal(d) => match dtype {
90                GgmlDType::F32 => metal::load_quantized(d, as_t_slice::<f32>(data)),
91                GgmlDType::F16 => metal::load_quantized(d, as_t_slice::<f16>(data)),
92                GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
93                GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
94                GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
95                GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
96                GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
97                GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
98                GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
99                GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
100                GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
101                GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
102                GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
103                GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
104                GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::<bf16>(data)),
105            },
106            Device::Cuda(d) => match dtype {
107                GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::<f32>(data)),
108                GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::<f16>(data)),
109                GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
110                GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
111                GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
112                GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
113                GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
114                GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
115                GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
116                GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
117                GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
118                GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
119                GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
120                GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
121                GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::<bf16>(data)),
122            },
123        }
124    }
125
126    fn block_size(&self) -> usize {
127        match self {
128            QStorage::Cpu(storage) => storage.block_size(),
129            QStorage::Metal(storage) => storage.dtype().block_size(),
130            QStorage::Cuda(storage) => storage.dtype().block_size(),
131        }
132    }
133
134    fn dtype(&self) -> GgmlDType {
135        match self {
136            QStorage::Cpu(storage) => storage.dtype(),
137            QStorage::Metal(storage) => storage.dtype(),
138            QStorage::Cuda(storage) => storage.dtype(),
139        }
140    }
141
142    fn device(&self) -> Device {
143        match self {
144            QStorage::Cpu(_storage) => Device::Cpu,
145            QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
146            QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
147        }
148    }
149
150    fn size_in_bytes(&self) -> usize {
151        match self {
152            QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
153            QStorage::Metal(storage) => storage.storage_size_in_bytes(),
154            QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
155        }
156    }
157
158    fn quantize(&mut self, src: &Storage) -> Result<()> {
159        match (self, src) {
160            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
161                storage.from_float(src.as_slice::<f32>()?);
162            }
163            (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
164            (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
165            _ => crate::bail!("Invalid quantize storage locations do not match"),
166        }
167        Ok(())
168    }
169
170    fn quantize_imatrix(
171        &mut self,
172        src: &Storage,
173        imatrix_weights: &[f32],
174        n_per_row: usize,
175    ) -> Result<()> {
176        match (self, src) {
177            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
178                storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
179            }
180            (QStorage::Metal(storage), Storage::Metal(src)) => {
181                storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
182            }
183            (QStorage::Cuda(storage), Storage::Cuda(src)) => {
184                storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
185            }
186            _ => crate::bail!("Invalid quantize storage locations do not match"),
187        }
188        Ok(())
189    }
190
191    fn quantize_onto(&mut self, src: &Storage) -> Result<()> {
192        match (self, src) {
193            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
194                storage.from_float(src.as_slice::<f32>()?);
195            }
196            (QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
197            (QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
198            _ => crate::bail!("Invalid quantize source storage locations: not on cpu"),
199        }
200        Ok(())
201    }
202
203    fn quantize_imatrix_onto(
204        &mut self,
205        src: &Storage,
206        imatrix_weights: &[f32],
207        n_per_row: usize,
208    ) -> Result<()> {
209        match (self, src) {
210            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
211                storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
212            }
213            (QStorage::Metal(storage), Storage::Cpu(src)) => {
214                storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
215            }
216            (QStorage::Cuda(storage), Storage::Cpu(src)) => {
217                storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
218            }
219            _ => crate::bail!("Invalid quantize storage locations do not match"),
220        }
221        Ok(())
222    }
223
224    fn dequantize(&self, elem_count: usize) -> Result<Storage> {
225        match self {
226            QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
227            QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
228            QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
229        }
230    }
231
232    fn data(&self) -> Result<Cow<'_, [u8]>> {
233        match self {
234            QStorage::Cpu(storage) => {
235                let data_ptr = storage.as_ptr();
236                let size_in_bytes = storage.storage_size_in_bytes();
237                let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
238                Ok(Cow::from(data))
239            }
240            QStorage::Cuda(storage) => Ok(Cow::from(storage.data()?)),
241            QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)),
242        }
243    }
244
245    pub fn device_ptr(&self) -> Result<*const u8> {
246        match self {
247            QStorage::Cuda(storage) => storage.device_ptr(),
248            QStorage::Metal(_) | QStorage::Cpu(_) => {
249                crate::bail!("not implemented");
250            }
251        }
252    }
253}
254
255#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
256pub enum GgmlDType {
257    F32,
258    F16,
259    BF16,
260    Q4_0,
261    Q4_1,
262    Q5_0,
263    Q5_1,
264    Q8_0,
265    Q8_1,
266    Q2K,
267    Q3K,
268    Q4K,
269    Q5K,
270    Q6K,
271    Q8K,
272}
273
274impl GgmlDType {
275    pub(crate) fn from_u32(u: u32) -> Result<Self> {
276        let dtype = match u {
277            0 => Self::F32,
278            1 => Self::F16,
279            2 => Self::Q4_0,
280            3 => Self::Q4_1,
281            6 => Self::Q5_0,
282            7 => Self::Q5_1,
283            8 => Self::Q8_0,
284            9 => Self::Q8_1,
285            10 => Self::Q2K,
286            11 => Self::Q3K,
287            12 => Self::Q4K,
288            13 => Self::Q5K,
289            14 => Self::Q6K,
290            15 => Self::Q8K,
291            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
292            30 => Self::BF16,
293            _ => crate::bail!("unknown dtype for tensor {u}"),
294        };
295        Ok(dtype)
296    }
297
298    pub(crate) fn to_u32(self) -> u32 {
299        match self {
300            Self::F32 => 0,
301            Self::F16 => 1,
302            Self::Q4_0 => 2,
303            Self::Q4_1 => 3,
304            Self::Q5_0 => 6,
305            Self::Q5_1 => 7,
306            Self::Q8_0 => 8,
307            Self::Q8_1 => 9,
308            Self::Q2K => 10,
309            Self::Q3K => 11,
310            Self::Q4K => 12,
311            Self::Q5K => 13,
312            Self::Q6K => 14,
313            Self::Q8K => 15,
314            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
315            Self::BF16 => 30,
316        }
317    }
318
319    /// The block dtype
320    pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
321        match self {
322            Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
323            Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
324            Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
325            Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
326            Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
327            Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
328            Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
329            Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
330            Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
331            Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
332            Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
333            Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
334            Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
335            Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
336            Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
337        }
338    }
339
340    pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box<dyn QuantizedType> {
341        match self {
342            Self::F32 => Box::new(as_t_slice::<f32>(data).to_vec()),
343            Self::F16 => Box::new(as_t_slice::<f16>(data).to_vec()),
344            Self::Q4_0 => Box::new(as_t_slice::<BlockQ4_0>(data).to_vec()),
345            Self::Q4_1 => Box::new(as_t_slice::<BlockQ4_1>(data).to_vec()),
346            Self::Q5_0 => Box::new(as_t_slice::<BlockQ5_0>(data).to_vec()),
347            Self::Q5_1 => Box::new(as_t_slice::<BlockQ5_1>(data).to_vec()),
348            Self::Q8_0 => Box::new(as_t_slice::<BlockQ8_0>(data).to_vec()),
349            Self::Q8_1 => Box::new(as_t_slice::<BlockQ8_1>(data).to_vec()),
350            Self::Q2K => Box::new(as_t_slice::<BlockQ2K>(data).to_vec()),
351            Self::Q3K => Box::new(as_t_slice::<BlockQ3K>(data).to_vec()),
352            Self::Q4K => Box::new(as_t_slice::<BlockQ4K>(data).to_vec()),
353            Self::Q5K => Box::new(as_t_slice::<BlockQ5K>(data).to_vec()),
354            Self::Q6K => Box::new(as_t_slice::<BlockQ6K>(data).to_vec()),
355            Self::Q8K => Box::new(as_t_slice::<BlockQ8K>(data).to_vec()),
356            Self::BF16 => Box::new(as_t_slice::<bf16>(data).to_vec()),
357        }
358    }
359
360    /// The type size for blocks in bytes.
361    pub fn type_size(&self) -> usize {
362        use k_quants::*;
363        match self {
364            Self::F32 => 4,
365            Self::F16 | Self::BF16 => 2,
366            Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
367            Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
368            Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
369            Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
370            // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
371            Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
372            Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
373            Self::Q2K => std::mem::size_of::<BlockQ2K>(),
374            Self::Q3K => std::mem::size_of::<BlockQ3K>(),
375            Self::Q4K => std::mem::size_of::<BlockQ4K>(),
376            Self::Q5K => std::mem::size_of::<BlockQ5K>(),
377            Self::Q6K => std::mem::size_of::<BlockQ6K>(),
378            Self::Q8K => std::mem::size_of::<BlockQ8K>(),
379        }
380    }
381
382    /// The block size, i.e. the number of elements stored in each block.
383    pub fn block_size(&self) -> usize {
384        match self {
385            Self::F32 => 1,
386            Self::F16 | Self::BF16 => 1,
387            Self::Q4_0 => k_quants::QK4_0,
388            Self::Q4_1 => k_quants::QK4_1,
389            Self::Q5_0 => k_quants::QK5_0,
390            Self::Q5_1 => k_quants::QK5_1,
391            Self::Q8_0 => k_quants::QK8_0,
392            Self::Q8_1 => k_quants::QK8_1,
393            Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
394        }
395    }
396}
397
398// A version of GgmlType without `vec_dot` so that it can be dyn boxed.
399pub trait QuantizedType: Send + Sync {
400    fn dtype(&self) -> GgmlDType;
401    fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
402    fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()>;
403    fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
404    fn storage_size_in_bytes(&self) -> usize;
405    fn as_ptr(&self) -> *const u8;
406    fn block_size(&self) -> usize;
407    #[allow(clippy::wrong_self_convention)]
408    fn from_float(&mut self, xs: &[f32]);
409    #[allow(clippy::wrong_self_convention)]
410    fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize);
411    fn size(&self) -> usize;
412}
413
414impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
415    fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
416        k_quants::matmul(mkn, lhs, self.as_slice(), dst)
417    }
418    fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()> {
419        k_quants::matmul_f16(mkn, lhs, self.as_slice(), dst)
420    }
421
422    fn size(&self) -> usize {
423        self.len() * core::mem::size_of::<T>()
424    }
425
426    fn from_float(&mut self, xs: &[f32]) {
427        T::from_float(xs, self)
428    }
429
430    fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize) {
431        T::from_float_imatrix(xs, self, imatrix_weights, n_per_row)
432    }
433
434    fn dtype(&self) -> GgmlDType {
435        T::DTYPE
436    }
437
438    fn block_size(&self) -> usize {
439        T::BLCK_SIZE
440    }
441
442    fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
443        let mut ys = vec![0.0f32; elem_count];
444        T::to_float(self.as_slice(), &mut ys);
445        Ok(CpuStorage::F32(ys))
446    }
447
448    fn storage_size_in_bytes(&self) -> usize {
449        self.len() * std::mem::size_of::<T>()
450    }
451
452    fn as_ptr(&self) -> *const u8 {
453        self.as_ptr() as *const u8
454    }
455}
456
457impl std::fmt::Debug for QTensor {
458    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
459        write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
460    }
461}
462
463fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
464    let dims = shape.dims();
465    if dims.is_empty() {
466        crate::bail!("scalar tensor cannot be quantized {shape:?}")
467    }
468    if !dims[dims.len() - 1].is_multiple_of(block_size) {
469        crate::bail!(
470            "quantized tensor must have their last dim divisible by block size {shape:?} {}",
471            block_size
472        )
473    }
474    Ok(())
475}
476
477impl QTensor {
478    pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
479        let shape = shape.into();
480        check_shape(&shape, storage.block_size())?;
481        Ok(Self { storage, shape })
482    }
483
484    pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
485        let shape = src.shape();
486        let block_size = dtype.block_size();
487        check_shape(shape, block_size)?;
488        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
489        let elem_count = shape.elem_count();
490        if !elem_count.is_multiple_of(block_size) {
491            crate::bail!(
492                "tensor size ({shape:?}) is not divisible by block size {}",
493                block_size
494            )
495        }
496        let mut storage = src.device().qzeros(elem_count, dtype)?;
497        storage.quantize(&src.storage())?;
498        Ok(Self {
499            storage,
500            shape: shape.clone(),
501        })
502    }
503
504    pub fn quantize_imatrix(
505        src: &Tensor,
506        imatrix_weights: &[f32],
507        dtype: GgmlDType,
508    ) -> Result<Self> {
509        // (n_per_row/QK_K-1)*QK_K+(QK_K/32-1)*32+32=n_per_row
510        // Size of imatrix == last dim of tensor
511        let n_per_row = src.dim(D::Minus1)?;
512        if imatrix_weights.len() != n_per_row {
513            crate::bail!(
514                "imatrix weights must have the same length {} as the last dim of src {}",
515                imatrix_weights.len(),
516                src.dim(D::Minus1)?
517            );
518        }
519
520        let shape = src.shape();
521        let block_size = dtype.block_size();
522        check_shape(shape, block_size)?;
523        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
524        let elem_count = shape.elem_count();
525        if !elem_count.is_multiple_of(block_size) {
526            crate::bail!(
527                "tensor size ({shape:?}) is not divisible by block size {}",
528                block_size
529            );
530        }
531        let mut storage = src.device().qzeros(elem_count, dtype)?;
532        storage.quantize_imatrix(&src.storage(), imatrix_weights, n_per_row)?;
533        Ok(Self {
534            storage,
535            shape: shape.clone(),
536        })
537    }
538
539    /// Quantize `src` (currently on the CPU) to a QTensor on `dev`
540    pub fn quantize_imatrix_onto(
541        src: &Tensor,
542        imatrix_weights: &[f32],
543        dtype: GgmlDType,
544        dev: &Device,
545    ) -> Result<Self> {
546        if !src.device().is_cpu() {
547            crate::bail!(
548                "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
549                src.device()
550            )
551        }
552        // (n_per_row/QK_K-1)*QK_K+(QK_K/32-1)*32+32=n_per_row
553        // Size of imatrix == last dim of tensor
554        let n_per_row = src.dim(D::Minus1)?;
555        if imatrix_weights.len() != n_per_row {
556            crate::bail!(
557                "imatrix weights must have the same length {} as the last dim of src {}",
558                imatrix_weights.len(),
559                src.dim(D::Minus1)?
560            );
561        }
562        let shape = src.shape();
563        let block_size = dtype.block_size();
564        check_shape(shape, block_size)?;
565        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
566        let elem_count = shape.elem_count();
567        if !elem_count.is_multiple_of(block_size) {
568            crate::bail!(
569                "tensor size ({shape:?}) is not divisible by block size {}",
570                block_size
571            )
572        }
573        // storage is on the `dev`, src is on `cpu`
574        let mut storage = dev.qzeros(elem_count, dtype)?;
575        storage.quantize_imatrix_onto(&src.storage(), imatrix_weights, n_per_row)?;
576        Ok(Self {
577            storage,
578            shape: shape.clone(),
579        })
580    }
581
582    /// Quantize `src` (currently on the CPU) to a QTensor on `dev`
583    pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result<Self> {
584        if !src.device().is_cpu() {
585            crate::bail!(
586                "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
587                src.device()
588            )
589        }
590        let shape = src.shape();
591        let block_size = dtype.block_size();
592        check_shape(shape, block_size)?;
593        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
594        let elem_count = shape.elem_count();
595        if !elem_count.is_multiple_of(block_size) {
596            crate::bail!(
597                "tensor size ({shape:?}) is not divisible by block size {}",
598                block_size
599            )
600        }
601        // storage is on the `dev`, src is on `cpu`
602        let mut storage = dev.qzeros(elem_count, dtype)?;
603        storage.quantize_onto(&src.storage())?;
604        Ok(Self {
605            storage,
606            shape: shape.clone(),
607        })
608    }
609
610    pub fn dtype(&self) -> GgmlDType {
611        self.storage.dtype()
612    }
613
614    pub fn device(&self) -> Device {
615        self.storage.device()
616    }
617
618    pub fn rank(&self) -> usize {
619        self.shape.rank()
620    }
621
622    pub fn shape(&self) -> &Shape {
623        &self.shape
624    }
625
626    pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
627        let storage = self.storage.dequantize(self.shape.elem_count())?;
628        let none = crate::op::BackpropOp::none();
629        crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
630    }
631
632    pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
633        // In the CUDA case, we have a specialized kernel as this can be useful for volta
634        // architectures. https://github.com/huggingface/candle/issues/2136
635        match &self.storage {
636            QStorage::Cuda(s) => {
637                let s = s.dequantize_f16(self.shape.elem_count())?;
638                let none = crate::op::BackpropOp::none();
639                crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
640                    .to_device(device)
641            }
642            _ => {
643                let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
644                Ok(s)
645            }
646        }
647    }
648
649    pub fn storage_size_in_bytes(&self) -> usize {
650        self.storage.size_in_bytes()
651    }
652
653    pub fn data(&self) -> Result<Cow<'_, [u8]>> {
654        self.storage.data()
655    }
656
657    pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
658        match &self.storage {
659            QStorage::Cuda(s) => match (&*x.storage(), &*ids.storage()) {
660                (Storage::Cuda(x_storage), Storage::Cuda(ids_storage)) => {
661                    let (storage, out_shape) = s.indexed_moe_forward(
662                        self.shape(),
663                        x_storage,
664                        x.layout(),
665                        ids_storage,
666                        ids.layout(),
667                    )?;
668                    Ok(crate::tensor::from_storage(
669                        Storage::Cuda(storage),
670                        out_shape,
671                        crate::op::BackpropOp::none(),
672                        false,
673                    ))
674                }
675                _ => {
676                    panic!("Non-cuda indexed_moe_forward is not implemented!");
677                }
678            },
679            _ => {
680                panic!("indexed_moe_forward is not implemented in this platform!");
681            }
682        }
683    }
684
685    pub fn device_ptr(&self) -> Result<*const u8> {
686        match &self.storage {
687            QStorage::Cuda(storage) => storage.device_ptr(),
688            QStorage::Metal(_) | QStorage::Cpu(_) => {
689                crate::bail!("not implemented");
690            }
691        }
692    }
693}
694
695#[derive(Clone, Debug)]
696pub enum QMatMul {
697    QTensor(std::sync::Arc<QTensor>),
698    Tensor(Tensor),
699    TensorF16(Tensor),
700}
701
702thread_local! {
703    static DEQUANTIZE_ALL: bool = {
704        match std::env::var("CANDLE_DEQUANTIZE_ALL") {
705            Ok(s) => {
706                !s.is_empty() && s != "0"
707            },
708            Err(_) => false,
709        }
710    }
711}
712
713thread_local! {
714    static DEQUANTIZE_ALL_F16: bool = {
715        match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
716            Ok(s) => {
717                !s.is_empty() && s != "0"
718            },
719            Err(_) => false,
720        }
721    }
722}
723
724impl QMatMul {
725    pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
726        let dequantize = match qtensor.dtype() {
727            GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true,
728            _ => DEQUANTIZE_ALL.with(|b| *b),
729        };
730        let t = if dequantize {
731            let tensor = qtensor.dequantize(&qtensor.device())?;
732            Self::Tensor(tensor)
733        } else if DEQUANTIZE_ALL_F16.with(|b| *b) {
734            let tensor = qtensor.dequantize_f16(&qtensor.device())?;
735            Self::TensorF16(tensor)
736        } else {
737            Self::QTensor(qtensor)
738        };
739        Ok(t)
740    }
741
742    pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
743        Self::from_arc(std::sync::Arc::new(qtensor))
744    }
745
746    pub fn dequantize_f16(&self) -> Result<Tensor> {
747        match self {
748            Self::QTensor(t) => t.dequantize_f16(&t.device()),
749            Self::Tensor(t) => t.to_dtype(DType::F16),
750            Self::TensorF16(t) => Ok(t.clone()),
751        }
752    }
753
754    pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
755        let w = self.dequantize_f16()?;
756        let in_dtype = xs.dtype();
757        let w = match *xs.dims() {
758            [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
759            [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
760            _ => w.t()?,
761        };
762        xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
763    }
764
765    pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
766        match self {
767            Self::QTensor(t) => t.indexed_moe_forward(x, ids),
768            _ => {
769                panic!("Not implemented!")
770            }
771        }
772    }
773}
774
775impl crate::CustomOp1 for QTensor {
776    fn name(&self) -> &'static str {
777        "qmatmul"
778    }
779
780    fn cpu_fwd(
781        &self,
782        storage: &crate::CpuStorage,
783        layout: &crate::Layout,
784    ) -> Result<(crate::CpuStorage, Shape)> {
785        if !layout.is_contiguous() {
786            crate::bail!("input tensor is not contiguous {layout:?}")
787        }
788        let src_shape = layout.shape();
789        // self is transposed so n is first then k.
790        let (n, k) = self.shape.dims2()?;
791        if src_shape.rank() < 2 {
792            crate::bail!("input tensor has only one dimension {layout:?}")
793        }
794        let mut dst_shape = src_shape.dims().to_vec();
795        let last_k = dst_shape.pop().unwrap();
796        if last_k != k {
797            crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
798        }
799        dst_shape.push(n);
800        let dst_shape = Shape::from(dst_shape);
801        #[allow(clippy::infallible_destructuring_match)]
802        let self_storage = match &self.storage {
803            QStorage::Cpu(storage) => storage,
804            QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
805        };
806        match storage.dtype() {
807            DType::F32 => {
808                let slice = storage.as_slice::<f32>()?;
809                let slice =
810                    &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
811                let mut dst_storage = vec![0f32; dst_shape.elem_count()];
812                self_storage.matmul_t(
813                    (dst_shape.elem_count() / n, k, n),
814                    slice,
815                    &mut dst_storage,
816                )?;
817                Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
818            }
819            DType::F16 => {
820                let slice = storage.as_slice::<f16>()?;
821                let slice =
822                    &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
823                let mut dst_storage = vec![f16::ZERO; dst_shape.elem_count()];
824                self_storage.matmul_t_f16(
825                    (dst_shape.elem_count() / n, k, n),
826                    slice,
827                    &mut dst_storage,
828                )?;
829                Ok((crate::CpuStorage::F16(dst_storage), dst_shape))
830            }
831            _ => crate::bail!("Expected f32/f16"),
832        }
833    }
834
835    fn metal_fwd(
836        &self,
837        storage: &crate::MetalStorage,
838        layout: &crate::Layout,
839    ) -> Result<(crate::MetalStorage, Shape)> {
840        let self_storage = match &self.storage {
841            QStorage::Metal(metal) => metal,
842            _ => unreachable!("Cannot call metal matmul on non metal QTensor"),
843        };
844        self_storage.fwd(&self.shape, storage, layout)
845    }
846
847    fn cuda_fwd(
848        &self,
849        storage: &crate::CudaStorage,
850        layout: &crate::Layout,
851    ) -> Result<(crate::CudaStorage, Shape)> {
852        let self_storage = match &self.storage {
853            QStorage::Cuda(cuda) => cuda,
854            _ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
855        };
856        self_storage.fwd(&self.shape, storage, layout)
857    }
858}
859
860impl crate::Module for QMatMul {
861    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
862        match self {
863            Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
864            Self::Tensor(w) => {
865                let w = match *xs.dims() {
866                    [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
867                    [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
868                    _ => w.t()?,
869                };
870                xs.matmul(&w)
871            }
872            Self::TensorF16(w) => {
873                let in_dtype = xs.dtype();
874                let w = match *xs.dims() {
875                    [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
876                    [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
877                    _ => w.t()?,
878                };
879                xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
880            }
881        }
882    }
883}