candle_core/quantized/
mod.rs

1use crate::{
2    backend::BackendStorage, CpuStorage, DType, Device, Result, Shape, Storage, Tensor, D,
3};
4use k_quants::*;
5use std::borrow::Cow;
6
7#[cfg(target_feature = "avx2")]
8pub mod avx;
9mod dummy_cuda;
10mod dummy_metal;
11pub mod ggml_file;
12pub mod gguf_file;
13pub mod imatrix_file;
14pub mod k_quants;
15#[cfg(feature = "metal")]
16pub mod metal;
17#[cfg(not(feature = "metal"))]
18mod metal {
19    pub use super::dummy_metal::*;
20}
21#[cfg(feature = "cuda")]
22pub mod cuda;
23#[cfg(not(feature = "cuda"))]
24mod cuda {
25    pub use super::dummy_cuda::*;
26}
27
28#[cfg(target_feature = "neon")]
29pub mod neon;
30#[cfg(target_feature = "simd128")]
31pub mod simd128;
32pub mod utils;
33use half::{bf16, f16};
34
35pub use k_quants::GgmlType;
36
37fn as_t_slice<T>(data: Cow<'_, [u8]>) -> &[T] {
38    let size = std::mem::size_of::<T>();
39    assert_eq!(
40        data.len() % size,
41        0,
42        "Data length must be a multiple of T's size"
43    );
44    let ptr = data.as_ptr();
45    assert_eq!(
46        (ptr as usize) % std::mem::align_of::<T>(),
47        0,
48        "Data pointer must be aligned to T's alignment"
49    );
50    unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) }
51}
52
53pub struct QTensor {
54    storage: QStorage,
55    shape: Shape,
56}
57
58impl Device {
59    fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
60        match self {
61            Device::Cpu => {
62                let storage = dtype.cpu_zeros(elem_count);
63                Ok(QStorage::Cpu(storage))
64            }
65            Device::Metal(metal) => {
66                let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
67                Ok(QStorage::Metal(storage))
68            }
69            Device::Cuda(cuda) => {
70                let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
71                Ok(QStorage::Cuda(storage))
72            }
73        }
74    }
75}
76
77pub enum QStorage {
78    Cpu(Box<dyn QuantizedType>),
79    Metal(metal::QMetalStorage),
80    Cuda(cuda::QCudaStorage),
81}
82
83impl QStorage {
84    pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result<Self> {
85        match device {
86            Device::Cpu => Ok(Self::Cpu(dtype.from_data(data))),
87            Device::Metal(d) => match dtype {
88                GgmlDType::F32 => metal::load_quantized(d, as_t_slice::<f32>(data)),
89                GgmlDType::F16 => metal::load_quantized(d, as_t_slice::<f16>(data)),
90                GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
91                GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
92                GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
93                GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
94                GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
95                GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
96                GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
97                GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
98                GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
99                GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
100                GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
101                GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
102                GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::<bf16>(data)),
103            },
104            Device::Cuda(d) => match dtype {
105                GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::<f32>(data)),
106                GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::<f16>(data)),
107                GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
108                GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
109                GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
110                GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
111                GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
112                GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
113                GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
114                GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
115                GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
116                GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
117                GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
118                GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
119                GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::<bf16>(data)),
120            },
121        }
122    }
123
124    fn block_size(&self) -> usize {
125        match self {
126            QStorage::Cpu(storage) => storage.block_size(),
127            QStorage::Metal(storage) => storage.dtype().block_size(),
128            QStorage::Cuda(storage) => storage.dtype().block_size(),
129        }
130    }
131
132    fn dtype(&self) -> GgmlDType {
133        match self {
134            QStorage::Cpu(storage) => storage.dtype(),
135            QStorage::Metal(storage) => storage.dtype(),
136            QStorage::Cuda(storage) => storage.dtype(),
137        }
138    }
139
140    fn device(&self) -> Device {
141        match self {
142            QStorage::Cpu(_storage) => Device::Cpu,
143            QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
144            QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
145        }
146    }
147
148    fn size_in_bytes(&self) -> usize {
149        match self {
150            QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
151            QStorage::Metal(storage) => storage.storage_size_in_bytes(),
152            QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
153        }
154    }
155
156    fn quantize(&mut self, src: &Storage) -> Result<()> {
157        match (self, src) {
158            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
159                storage.from_float(src.as_slice::<f32>()?);
160            }
161            (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
162            (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
163            _ => crate::bail!("Invalid quantize storage locations do not match"),
164        }
165        Ok(())
166    }
167
168    fn quantize_imatrix(
169        &mut self,
170        src: &Storage,
171        imatrix_weights: &[f32],
172        n_per_row: usize,
173    ) -> Result<()> {
174        match (self, src) {
175            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
176                storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
177            }
178            (QStorage::Metal(storage), Storage::Metal(src)) => {
179                storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
180            }
181            (QStorage::Cuda(storage), Storage::Cuda(src)) => {
182                storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
183            }
184            _ => crate::bail!("Invalid quantize storage locations do not match"),
185        }
186        Ok(())
187    }
188
189    fn quantize_onto(&mut self, src: &Storage) -> Result<()> {
190        match (self, src) {
191            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
192                storage.from_float(src.as_slice::<f32>()?);
193            }
194            (QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
195            (QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
196            _ => crate::bail!("Invalid quantize source storage locations: not on cpu"),
197        }
198        Ok(())
199    }
200
201    fn quantize_imatrix_onto(
202        &mut self,
203        src: &Storage,
204        imatrix_weights: &[f32],
205        n_per_row: usize,
206    ) -> Result<()> {
207        match (self, src) {
208            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
209                storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
210            }
211            (QStorage::Metal(storage), Storage::Cpu(src)) => {
212                storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
213            }
214            (QStorage::Cuda(storage), Storage::Cpu(src)) => {
215                storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
216            }
217            _ => crate::bail!("Invalid quantize storage locations do not match"),
218        }
219        Ok(())
220    }
221
222    fn dequantize(&self, elem_count: usize) -> Result<Storage> {
223        match self {
224            QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
225            QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
226            QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
227        }
228    }
229
230    fn data(&self) -> Result<Cow<'_, [u8]>> {
231        match self {
232            QStorage::Cpu(storage) => {
233                let data_ptr = storage.as_ptr();
234                let size_in_bytes = storage.storage_size_in_bytes();
235                let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
236                Ok(Cow::from(data))
237            }
238            QStorage::Cuda(storage) => Ok(Cow::from(storage.data()?)),
239            QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)),
240        }
241    }
242
243    pub fn device_ptr(&self) -> Result<*const u8> {
244        match self {
245            QStorage::Cuda(storage) => storage.device_ptr(),
246            QStorage::Metal(_) | QStorage::Cpu(_) => {
247                crate::bail!("not implemented");
248            }
249        }
250    }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
254pub enum GgmlDType {
255    F32,
256    F16,
257    BF16,
258    Q4_0,
259    Q4_1,
260    Q5_0,
261    Q5_1,
262    Q8_0,
263    Q8_1,
264    Q2K,
265    Q3K,
266    Q4K,
267    Q5K,
268    Q6K,
269    Q8K,
270}
271
272impl GgmlDType {
273    pub(crate) fn from_u32(u: u32) -> Result<Self> {
274        let dtype = match u {
275            0 => Self::F32,
276            1 => Self::F16,
277            2 => Self::Q4_0,
278            3 => Self::Q4_1,
279            6 => Self::Q5_0,
280            7 => Self::Q5_1,
281            8 => Self::Q8_0,
282            9 => Self::Q8_1,
283            10 => Self::Q2K,
284            11 => Self::Q3K,
285            12 => Self::Q4K,
286            13 => Self::Q5K,
287            14 => Self::Q6K,
288            15 => Self::Q8K,
289            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
290            30 => Self::BF16,
291            _ => crate::bail!("unknown dtype for tensor {u}"),
292        };
293        Ok(dtype)
294    }
295
296    pub(crate) fn to_u32(self) -> u32 {
297        match self {
298            Self::F32 => 0,
299            Self::F16 => 1,
300            Self::Q4_0 => 2,
301            Self::Q4_1 => 3,
302            Self::Q5_0 => 6,
303            Self::Q5_1 => 7,
304            Self::Q8_0 => 8,
305            Self::Q8_1 => 9,
306            Self::Q2K => 10,
307            Self::Q3K => 11,
308            Self::Q4K => 12,
309            Self::Q5K => 13,
310            Self::Q6K => 14,
311            Self::Q8K => 15,
312            // https://github.com/ggerganov/ggml/blob/29d87fc6676e7ed0cdfdec0804b06001d9c2bb44/include/ggml.h#L389
313            Self::BF16 => 30,
314        }
315    }
316
317    /// The block dtype
318    pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
319        match self {
320            Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
321            Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
322            Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
323            Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
324            Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
325            Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
326            Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
327            Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
328            Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
329            Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
330            Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
331            Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
332            Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
333            Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
334            Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
335        }
336    }
337
338    pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box<dyn QuantizedType> {
339        match self {
340            Self::F32 => Box::new(as_t_slice::<f32>(data).to_vec()),
341            Self::F16 => Box::new(as_t_slice::<f16>(data).to_vec()),
342            Self::Q4_0 => Box::new(as_t_slice::<BlockQ4_0>(data).to_vec()),
343            Self::Q4_1 => Box::new(as_t_slice::<BlockQ4_1>(data).to_vec()),
344            Self::Q5_0 => Box::new(as_t_slice::<BlockQ5_0>(data).to_vec()),
345            Self::Q5_1 => Box::new(as_t_slice::<BlockQ5_1>(data).to_vec()),
346            Self::Q8_0 => Box::new(as_t_slice::<BlockQ8_0>(data).to_vec()),
347            Self::Q8_1 => Box::new(as_t_slice::<BlockQ8_1>(data).to_vec()),
348            Self::Q2K => Box::new(as_t_slice::<BlockQ2K>(data).to_vec()),
349            Self::Q3K => Box::new(as_t_slice::<BlockQ3K>(data).to_vec()),
350            Self::Q4K => Box::new(as_t_slice::<BlockQ4K>(data).to_vec()),
351            Self::Q5K => Box::new(as_t_slice::<BlockQ5K>(data).to_vec()),
352            Self::Q6K => Box::new(as_t_slice::<BlockQ6K>(data).to_vec()),
353            Self::Q8K => Box::new(as_t_slice::<BlockQ8K>(data).to_vec()),
354            Self::BF16 => Box::new(as_t_slice::<bf16>(data).to_vec()),
355        }
356    }
357
358    /// The type size for blocks in bytes.
359    pub fn type_size(&self) -> usize {
360        use k_quants::*;
361        match self {
362            Self::F32 => 4,
363            Self::F16 | Self::BF16 => 2,
364            Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
365            Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
366            Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
367            Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
368            // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
369            Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
370            Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
371            Self::Q2K => std::mem::size_of::<BlockQ2K>(),
372            Self::Q3K => std::mem::size_of::<BlockQ3K>(),
373            Self::Q4K => std::mem::size_of::<BlockQ4K>(),
374            Self::Q5K => std::mem::size_of::<BlockQ5K>(),
375            Self::Q6K => std::mem::size_of::<BlockQ6K>(),
376            Self::Q8K => std::mem::size_of::<BlockQ8K>(),
377        }
378    }
379
380    /// The block size, i.e. the number of elements stored in each block.
381    pub fn block_size(&self) -> usize {
382        match self {
383            Self::F32 => 1,
384            Self::F16 | Self::BF16 => 1,
385            Self::Q4_0 => k_quants::QK4_0,
386            Self::Q4_1 => k_quants::QK4_1,
387            Self::Q5_0 => k_quants::QK5_0,
388            Self::Q5_1 => k_quants::QK5_1,
389            Self::Q8_0 => k_quants::QK8_0,
390            Self::Q8_1 => k_quants::QK8_1,
391            Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
392        }
393    }
394}
395
396// A version of GgmlType without `vec_dot` so that it can be dyn boxed.
397pub trait QuantizedType: Send + Sync {
398    fn dtype(&self) -> GgmlDType;
399    fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
400    fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()>;
401    fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
402    fn storage_size_in_bytes(&self) -> usize;
403    fn as_ptr(&self) -> *const u8;
404    fn block_size(&self) -> usize;
405    #[allow(clippy::wrong_self_convention)]
406    fn from_float(&mut self, xs: &[f32]);
407    #[allow(clippy::wrong_self_convention)]
408    fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize);
409    fn size(&self) -> usize;
410}
411
412impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
413    fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
414        k_quants::matmul(mkn, lhs, self.as_slice(), dst)
415    }
416    fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()> {
417        k_quants::matmul_f16(mkn, lhs, self.as_slice(), dst)
418    }
419
420    fn size(&self) -> usize {
421        self.len() * core::mem::size_of::<T>()
422    }
423
424    fn from_float(&mut self, xs: &[f32]) {
425        T::from_float(xs, self)
426    }
427
428    fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize) {
429        T::from_float_imatrix(xs, self, imatrix_weights, n_per_row)
430    }
431
432    fn dtype(&self) -> GgmlDType {
433        T::DTYPE
434    }
435
436    fn block_size(&self) -> usize {
437        T::BLCK_SIZE
438    }
439
440    fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
441        let mut ys = vec![0.0f32; elem_count];
442        T::to_float(self.as_slice(), &mut ys);
443        Ok(CpuStorage::F32(ys))
444    }
445
446    fn storage_size_in_bytes(&self) -> usize {
447        self.len() * std::mem::size_of::<T>()
448    }
449
450    fn as_ptr(&self) -> *const u8 {
451        self.as_ptr() as *const u8
452    }
453}
454
455impl std::fmt::Debug for QTensor {
456    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
457        write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
458    }
459}
460
461fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
462    let dims = shape.dims();
463    if dims.is_empty() {
464        crate::bail!("scalar tensor cannot be quantized {shape:?}")
465    }
466    if !dims[dims.len() - 1].is_multiple_of(block_size) {
467        crate::bail!(
468            "quantized tensor must have their last dim divisible by block size {shape:?} {}",
469            block_size
470        )
471    }
472    Ok(())
473}
474
475impl QTensor {
476    pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
477        let shape = shape.into();
478        check_shape(&shape, storage.block_size())?;
479        Ok(Self { storage, shape })
480    }
481
482    pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
483        let shape = src.shape();
484        let block_size = dtype.block_size();
485        check_shape(shape, block_size)?;
486        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
487        let elem_count = shape.elem_count();
488        if !elem_count.is_multiple_of(block_size) {
489            crate::bail!(
490                "tensor size ({shape:?}) is not divisible by block size {}",
491                block_size
492            )
493        }
494        let mut storage = src.device().qzeros(elem_count, dtype)?;
495        storage.quantize(&src.storage())?;
496        Ok(Self {
497            storage,
498            shape: shape.clone(),
499        })
500    }
501
502    pub fn quantize_imatrix(
503        src: &Tensor,
504        imatrix_weights: &[f32],
505        dtype: GgmlDType,
506    ) -> Result<Self> {
507        // (n_per_row/QK_K-1)*QK_K+(QK_K/32-1)*32+32=n_per_row
508        // Size of imatrix == last dim of tensor
509        let n_per_row = src.dim(D::Minus1)?;
510        if imatrix_weights.len() != n_per_row {
511            crate::bail!(
512                "imatrix weights must have the same length {} as the last dim of src {}",
513                imatrix_weights.len(),
514                src.dim(D::Minus1)?
515            );
516        }
517
518        let shape = src.shape();
519        let block_size = dtype.block_size();
520        check_shape(shape, block_size)?;
521        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
522        let elem_count = shape.elem_count();
523        if !elem_count.is_multiple_of(block_size) {
524            crate::bail!(
525                "tensor size ({shape:?}) is not divisible by block size {}",
526                block_size
527            );
528        }
529        let mut storage = src.device().qzeros(elem_count, dtype)?;
530        storage.quantize_imatrix(&src.storage(), imatrix_weights, n_per_row)?;
531        Ok(Self {
532            storage,
533            shape: shape.clone(),
534        })
535    }
536
537    /// Quantize `src` (currently on the CPU) to a QTensor on `dev`
538    pub fn quantize_imatrix_onto(
539        src: &Tensor,
540        imatrix_weights: &[f32],
541        dtype: GgmlDType,
542        dev: &Device,
543    ) -> Result<Self> {
544        if !src.device().is_cpu() {
545            crate::bail!(
546                "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
547                src.device()
548            )
549        }
550        // (n_per_row/QK_K-1)*QK_K+(QK_K/32-1)*32+32=n_per_row
551        // Size of imatrix == last dim of tensor
552        let n_per_row = src.dim(D::Minus1)?;
553        if imatrix_weights.len() != n_per_row {
554            crate::bail!(
555                "imatrix weights must have the same length {} as the last dim of src {}",
556                imatrix_weights.len(),
557                src.dim(D::Minus1)?
558            );
559        }
560        let shape = src.shape();
561        let block_size = dtype.block_size();
562        check_shape(shape, block_size)?;
563        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
564        let elem_count = shape.elem_count();
565        if !elem_count.is_multiple_of(block_size) {
566            crate::bail!(
567                "tensor size ({shape:?}) is not divisible by block size {}",
568                block_size
569            )
570        }
571        // storage is on the `dev`, src is on `cpu`
572        let mut storage = dev.qzeros(elem_count, dtype)?;
573        storage.quantize_imatrix_onto(&src.storage(), imatrix_weights, n_per_row)?;
574        Ok(Self {
575            storage,
576            shape: shape.clone(),
577        })
578    }
579
580    /// Quantize `src` (currently on the CPU) to a QTensor on `dev`
581    pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result<Self> {
582        if !src.device().is_cpu() {
583            crate::bail!(
584                "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
585                src.device()
586            )
587        }
588        let shape = src.shape();
589        let block_size = dtype.block_size();
590        check_shape(shape, block_size)?;
591        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
592        let elem_count = shape.elem_count();
593        if !elem_count.is_multiple_of(block_size) {
594            crate::bail!(
595                "tensor size ({shape:?}) is not divisible by block size {}",
596                block_size
597            )
598        }
599        // storage is on the `dev`, src is on `cpu`
600        let mut storage = dev.qzeros(elem_count, dtype)?;
601        storage.quantize_onto(&src.storage())?;
602        Ok(Self {
603            storage,
604            shape: shape.clone(),
605        })
606    }
607
608    pub fn dtype(&self) -> GgmlDType {
609        self.storage.dtype()
610    }
611
612    pub fn device(&self) -> Device {
613        self.storage.device()
614    }
615
616    pub fn rank(&self) -> usize {
617        self.shape.rank()
618    }
619
620    pub fn shape(&self) -> &Shape {
621        &self.shape
622    }
623
624    pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
625        let storage = self.storage.dequantize(self.shape.elem_count())?;
626        let none = crate::op::BackpropOp::none();
627        crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
628    }
629
630    pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
631        // In the CUDA case, we have a specialized kernel as this can be useful for volta
632        // architectures. https://github.com/huggingface/candle/issues/2136
633        match &self.storage {
634            QStorage::Cuda(s) => {
635                let s = s.dequantize_f16(self.shape.elem_count())?;
636                let none = crate::op::BackpropOp::none();
637                crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
638                    .to_device(device)
639            }
640            _ => {
641                let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
642                Ok(s)
643            }
644        }
645    }
646
647    pub fn storage_size_in_bytes(&self) -> usize {
648        self.storage.size_in_bytes()
649    }
650
651    pub fn data(&self) -> Result<Cow<'_, [u8]>> {
652        self.storage.data()
653    }
654
655    pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
656        match &self.storage {
657            QStorage::Cuda(s) => match (&*x.storage(), &*ids.storage()) {
658                (Storage::Cuda(x_storage), Storage::Cuda(ids_storage)) => {
659                    let (storage, out_shape) = s.indexed_moe_forward(
660                        self.shape(),
661                        x_storage,
662                        x.layout(),
663                        ids_storage,
664                        ids.layout(),
665                    )?;
666                    Ok(crate::tensor::from_storage(
667                        Storage::Cuda(storage),
668                        out_shape,
669                        crate::op::BackpropOp::none(),
670                        false,
671                    ))
672                }
673                _ => {
674                    panic!("Non-cuda indexed_moe_forward is not implemented!");
675                }
676            },
677            _ => {
678                panic!("indexed_moe_forward is not implemented in this platform!");
679            }
680        }
681    }
682
683    pub fn device_ptr(&self) -> Result<*const u8> {
684        match &self.storage {
685            QStorage::Cuda(storage) => storage.device_ptr(),
686            QStorage::Metal(_) | QStorage::Cpu(_) => {
687                crate::bail!("not implemented");
688            }
689        }
690    }
691}
692
693#[derive(Clone, Debug)]
694pub enum QMatMul {
695    QTensor(std::sync::Arc<QTensor>),
696    Tensor(Tensor),
697    TensorF16(Tensor),
698}
699
700thread_local! {
701    static DEQUANTIZE_ALL: bool = {
702        match std::env::var("CANDLE_DEQUANTIZE_ALL") {
703            Ok(s) => {
704                !s.is_empty() && s != "0"
705            },
706            Err(_) => false,
707        }
708    }
709}
710
711thread_local! {
712    static DEQUANTIZE_ALL_F16: bool = {
713        match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
714            Ok(s) => {
715                !s.is_empty() && s != "0"
716            },
717            Err(_) => false,
718        }
719    }
720}
721
722impl QMatMul {
723    pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
724        let dequantize = match qtensor.dtype() {
725            GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true,
726            _ => DEQUANTIZE_ALL.with(|b| *b),
727        };
728        let t = if dequantize {
729            let tensor = qtensor.dequantize(&qtensor.device())?;
730            Self::Tensor(tensor)
731        } else if DEQUANTIZE_ALL_F16.with(|b| *b) {
732            let tensor = qtensor.dequantize_f16(&qtensor.device())?;
733            Self::TensorF16(tensor)
734        } else {
735            Self::QTensor(qtensor)
736        };
737        Ok(t)
738    }
739
740    pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
741        Self::from_arc(std::sync::Arc::new(qtensor))
742    }
743
744    pub fn dequantize_f16(&self) -> Result<Tensor> {
745        match self {
746            Self::QTensor(t) => t.dequantize_f16(&t.device()),
747            Self::Tensor(t) => t.to_dtype(DType::F16),
748            Self::TensorF16(t) => Ok(t.clone()),
749        }
750    }
751
752    pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
753        let w = self.dequantize_f16()?;
754        let in_dtype = xs.dtype();
755        let w = match *xs.dims() {
756            [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
757            [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
758            _ => w.t()?,
759        };
760        xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
761    }
762
763    pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
764        match self {
765            Self::QTensor(t) => t.indexed_moe_forward(x, ids),
766            _ => {
767                panic!("Not implemented!")
768            }
769        }
770    }
771}
772
773impl crate::CustomOp1 for QTensor {
774    fn name(&self) -> &'static str {
775        "qmatmul"
776    }
777
778    fn cpu_fwd(
779        &self,
780        storage: &crate::CpuStorage,
781        layout: &crate::Layout,
782    ) -> Result<(crate::CpuStorage, Shape)> {
783        if !layout.is_contiguous() {
784            crate::bail!("input tensor is not contiguous {layout:?}")
785        }
786        let src_shape = layout.shape();
787        // self is transposed so n is first then k.
788        let (n, k) = self.shape.dims2()?;
789        if src_shape.rank() < 2 {
790            crate::bail!("input tensor has only one dimension {layout:?}")
791        }
792        let mut dst_shape = src_shape.dims().to_vec();
793        let last_k = dst_shape.pop().unwrap();
794        if last_k != k {
795            crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
796        }
797        dst_shape.push(n);
798        let dst_shape = Shape::from(dst_shape);
799        #[allow(clippy::infallible_destructuring_match)]
800        let self_storage = match &self.storage {
801            QStorage::Cpu(storage) => storage,
802            QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
803        };
804        match storage.dtype() {
805            DType::F32 => {
806                let slice = storage.as_slice::<f32>()?;
807                let slice =
808                    &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
809                let mut dst_storage = vec![0f32; dst_shape.elem_count()];
810                self_storage.matmul_t(
811                    (dst_shape.elem_count() / n, k, n),
812                    slice,
813                    &mut dst_storage,
814                )?;
815                Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
816            }
817            DType::F16 => {
818                let slice = storage.as_slice::<f16>()?;
819                let slice =
820                    &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
821                let mut dst_storage = vec![f16::ZERO; dst_shape.elem_count()];
822                self_storage.matmul_t_f16(
823                    (dst_shape.elem_count() / n, k, n),
824                    slice,
825                    &mut dst_storage,
826                )?;
827                Ok((crate::CpuStorage::F16(dst_storage), dst_shape))
828            }
829            _ => crate::bail!("Expected f32/f16"),
830        }
831    }
832
833    fn metal_fwd(
834        &self,
835        storage: &crate::MetalStorage,
836        layout: &crate::Layout,
837    ) -> Result<(crate::MetalStorage, Shape)> {
838        let self_storage = match &self.storage {
839            QStorage::Metal(metal) => metal,
840            _ => unreachable!("Cannot call metal matmul on non metal QTensor"),
841        };
842        self_storage.fwd(&self.shape, storage, layout)
843    }
844
845    fn cuda_fwd(
846        &self,
847        storage: &crate::CudaStorage,
848        layout: &crate::Layout,
849    ) -> Result<(crate::CudaStorage, Shape)> {
850        let self_storage = match &self.storage {
851            QStorage::Cuda(cuda) => cuda,
852            _ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
853        };
854        self_storage.fwd(&self.shape, storage, layout)
855    }
856}
857
858impl crate::Module for QMatMul {
859    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
860        match self {
861            Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
862            Self::Tensor(w) => {
863                let w = match *xs.dims() {
864                    [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
865                    [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
866                    _ => w.t()?,
867                };
868                xs.matmul(&w)
869            }
870            Self::TensorF16(w) => {
871                let in_dtype = xs.dtype();
872                let w = match *xs.dims() {
873                    [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
874                    [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
875                    _ => w.t()?,
876                };
877                xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
878            }
879        }
880    }
881}