candle_core/quantized/
mod.rs

1//! Code for GGML and GGUF files
2use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
3use k_quants::*;
4use std::borrow::Cow;
5
6#[cfg(target_feature = "avx")]
7pub mod avx;
8mod dummy_cuda;
9mod dummy_metal;
10pub mod ggml_file;
11pub mod gguf_file;
12pub mod k_quants;
13#[cfg(feature = "metal")]
14pub mod metal;
15#[cfg(not(feature = "metal"))]
16mod metal {
17    pub use super::dummy_metal::*;
18}
19#[cfg(feature = "cuda")]
20pub mod cuda;
21#[cfg(not(feature = "cuda"))]
22mod cuda {
23    pub use super::dummy_cuda::*;
24}
25
26#[cfg(target_feature = "neon")]
27pub mod neon;
28#[cfg(target_feature = "simd128")]
29pub mod simd128;
30pub mod utils;
31use half::f16;
32
33pub use k_quants::GgmlType;
34
35pub struct QTensor {
36    storage: QStorage,
37    shape: Shape,
38}
39
40impl Device {
41    fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
42        match self {
43            Device::Cpu => {
44                let storage = dtype.cpu_zeros(elem_count);
45                Ok(QStorage::Cpu(storage))
46            }
47            Device::Metal(metal) => {
48                let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
49                Ok(QStorage::Metal(storage))
50            }
51            Device::Cuda(cuda) => {
52                let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
53                Ok(QStorage::Cuda(storage))
54            }
55        }
56    }
57}
58
59pub enum QStorage {
60    Cpu(Box<dyn QuantizedType>),
61    Metal(metal::QMetalStorage),
62    Cuda(cuda::QCudaStorage),
63}
64
65impl QStorage {
66    fn block_size(&self) -> usize {
67        match self {
68            QStorage::Cpu(storage) => storage.block_size(),
69            QStorage::Metal(storage) => storage.dtype().block_size(),
70            QStorage::Cuda(storage) => storage.dtype().block_size(),
71        }
72    }
73
74    fn dtype(&self) -> GgmlDType {
75        match self {
76            QStorage::Cpu(storage) => storage.dtype(),
77            QStorage::Metal(storage) => storage.dtype(),
78            QStorage::Cuda(storage) => storage.dtype(),
79        }
80    }
81
82    fn device(&self) -> Device {
83        match self {
84            QStorage::Cpu(_storage) => Device::Cpu,
85            QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
86            QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
87        }
88    }
89
90    fn size_in_bytes(&self) -> usize {
91        match self {
92            QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
93            QStorage::Metal(storage) => storage.storage_size_in_bytes(),
94            QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
95        }
96    }
97
98    fn quantize(&mut self, src: &Storage) -> Result<()> {
99        match (self, src) {
100            (QStorage::Cpu(storage), Storage::Cpu(src)) => {
101                storage.from_float(src.as_slice::<f32>()?)?;
102            }
103            (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
104            (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
105            _ => crate::bail!("Invalid dequantize storage locations do not match"),
106        }
107        Ok(())
108    }
109
110    fn dequantize(&self, elem_count: usize) -> Result<Storage> {
111        match self {
112            QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
113            QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
114            QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
115        }
116    }
117
118    fn data(&self) -> Result<Cow<[u8]>> {
119        match self {
120            QStorage::Cpu(storage) => {
121                let data_ptr = storage.as_ptr();
122                let size_in_bytes = storage.storage_size_in_bytes();
123                let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
124                Ok(Cow::from(data))
125            }
126            QStorage::Metal(_) | QStorage::Cuda(_) => {
127                crate::bail!("not implemented");
128            }
129        }
130    }
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
134pub enum GgmlDType {
135    F32,
136    F16,
137    Q4_0,
138    Q4_1,
139    Q5_0,
140    Q5_1,
141    Q8_0,
142    Q8_1,
143    Q2K,
144    Q3K,
145    Q4K,
146    Q5K,
147    Q6K,
148    Q8K,
149}
150
151impl GgmlDType {
152    pub(crate) fn from_u32(u: u32) -> Result<Self> {
153        let dtype = match u {
154            0 => Self::F32,
155            1 => Self::F16,
156            2 => Self::Q4_0,
157            3 => Self::Q4_1,
158            6 => Self::Q5_0,
159            7 => Self::Q5_1,
160            8 => Self::Q8_0,
161            9 => Self::Q8_1,
162            10 => Self::Q2K,
163            11 => Self::Q3K,
164            12 => Self::Q4K,
165            13 => Self::Q5K,
166            14 => Self::Q6K,
167            15 => Self::Q8K,
168            _ => crate::bail!("unknown dtype for tensor {u}"),
169        };
170        Ok(dtype)
171    }
172
173    pub(crate) fn to_u32(self) -> u32 {
174        match self {
175            Self::F32 => 0,
176            Self::F16 => 1,
177            Self::Q4_0 => 2,
178            Self::Q4_1 => 3,
179            Self::Q5_0 => 6,
180            Self::Q5_1 => 7,
181            Self::Q8_0 => 8,
182            Self::Q8_1 => 9,
183            Self::Q2K => 10,
184            Self::Q3K => 11,
185            Self::Q4K => 12,
186            Self::Q5K => 13,
187            Self::Q6K => 14,
188            Self::Q8K => 15,
189        }
190    }
191
192    /// The block dtype
193    pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
194        match self {
195            Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
196            Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
197            Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
198            Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
199            Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
200            Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
201            Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
202            Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
203            Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
204            Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
205            Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
206            Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
207            Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
208            Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
209        }
210    }
211    /// The type size for blocks in bytes.
212    pub fn type_size(&self) -> usize {
213        use k_quants::*;
214        match self {
215            Self::F32 => 4,
216            Self::F16 => 2,
217            Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
218            Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
219            Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
220            Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
221            // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
222            Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
223            Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
224            Self::Q2K => std::mem::size_of::<BlockQ2K>(),
225            Self::Q3K => std::mem::size_of::<BlockQ3K>(),
226            Self::Q4K => std::mem::size_of::<BlockQ4K>(),
227            Self::Q5K => std::mem::size_of::<BlockQ5K>(),
228            Self::Q6K => std::mem::size_of::<BlockQ6K>(),
229            Self::Q8K => std::mem::size_of::<BlockQ8K>(),
230        }
231    }
232
233    /// The block size, i.e. the number of elements stored in each block.
234    pub fn block_size(&self) -> usize {
235        match self {
236            Self::F32 => 1,
237            Self::F16 => 1,
238            Self::Q4_0 => k_quants::QK4_0,
239            Self::Q4_1 => k_quants::QK4_1,
240            Self::Q5_0 => k_quants::QK5_0,
241            Self::Q5_1 => k_quants::QK5_1,
242            Self::Q8_0 => k_quants::QK8_0,
243            Self::Q8_1 => k_quants::QK8_1,
244            Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
245        }
246    }
247}
248
249// A version of GgmlType without `vec_dot` so that it can be dyn boxed.
250pub trait QuantizedType: Send + Sync {
251    fn dtype(&self) -> GgmlDType;
252    fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
253    fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
254    fn storage_size_in_bytes(&self) -> usize;
255    fn as_ptr(&self) -> *const u8;
256    fn block_size(&self) -> usize;
257    #[allow(clippy::wrong_self_convention)]
258    fn from_float(&mut self, xs: &[f32]) -> Result<()>;
259    fn size(&self) -> usize;
260}
261
262impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
263    fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
264        k_quants::matmul(mkn, lhs, self.as_slice(), dst)
265    }
266
267    fn size(&self) -> usize {
268        self.len() * core::mem::size_of::<T>()
269    }
270
271    fn from_float(&mut self, xs: &[f32]) -> Result<()> {
272        T::from_float(xs, self)
273    }
274
275    fn dtype(&self) -> GgmlDType {
276        T::DTYPE
277    }
278
279    fn block_size(&self) -> usize {
280        T::BLCK_SIZE
281    }
282
283    fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
284        let mut ys = vec![0.0f32; elem_count];
285        T::to_float(self.as_slice(), &mut ys)?;
286        Ok(CpuStorage::F32(ys))
287    }
288
289    fn storage_size_in_bytes(&self) -> usize {
290        self.len() * std::mem::size_of::<T>()
291    }
292
293    fn as_ptr(&self) -> *const u8 {
294        self.as_ptr() as *const u8
295    }
296}
297
298impl std::fmt::Debug for QTensor {
299    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
300        write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
301    }
302}
303
304fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
305    let dims = shape.dims();
306    if dims.is_empty() {
307        crate::bail!("scalar tensor cannot be quantized {shape:?}")
308    }
309    if dims[dims.len() - 1] % block_size != 0 {
310        crate::bail!(
311            "quantized tensor must have their last dim divisible by block size {shape:?} {}",
312            block_size
313        )
314    }
315    Ok(())
316}
317
318impl QTensor {
319    pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
320        let shape = shape.into();
321        check_shape(&shape, storage.block_size())?;
322        Ok(Self { storage, shape })
323    }
324
325    pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
326        let shape = src.shape();
327        let block_size = dtype.block_size();
328        check_shape(shape, block_size)?;
329        let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
330        let elem_count = shape.elem_count();
331        if elem_count % block_size != 0 {
332            crate::bail!(
333                "tensor size ({shape:?}) is not divisible by block size {}",
334                block_size
335            )
336        }
337        let mut storage = src.device().qzeros(elem_count, dtype)?;
338        storage.quantize(&src.storage())?;
339        Ok(Self {
340            storage,
341            shape: shape.clone(),
342        })
343    }
344
345    pub fn dtype(&self) -> GgmlDType {
346        self.storage.dtype()
347    }
348
349    pub fn device(&self) -> Device {
350        self.storage.device()
351    }
352
353    pub fn rank(&self) -> usize {
354        self.shape.rank()
355    }
356
357    pub fn shape(&self) -> &Shape {
358        &self.shape
359    }
360
361    pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
362        let storage = self.storage.dequantize(self.shape.elem_count())?;
363        let none = crate::op::BackpropOp::none();
364        crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
365    }
366
367    pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
368        // In the CUDA case, we have a specialized kernel as this can be useful for volta
369        // architectures. https://github.com/huggingface/candle/issues/2136
370        match &self.storage {
371            QStorage::Cuda(s) => {
372                let s = s.dequantize_f16(self.shape.elem_count())?;
373                let none = crate::op::BackpropOp::none();
374                crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
375                    .to_device(device)
376            }
377            _ => {
378                let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
379                Ok(s)
380            }
381        }
382    }
383
384    pub fn storage_size_in_bytes(&self) -> usize {
385        self.storage.size_in_bytes()
386    }
387
388    pub fn data(&self) -> Result<Cow<'_, [u8]>> {
389        self.storage.data()
390    }
391}
392
393#[derive(Clone, Debug)]
394pub enum QMatMul {
395    QTensor(std::sync::Arc<QTensor>),
396    Tensor(Tensor),
397    TensorF16(Tensor),
398}
399
400thread_local! {
401    static DEQUANTIZE_ALL: bool = {
402        match std::env::var("CANDLE_DEQUANTIZE_ALL") {
403            Ok(s) => {
404                !s.is_empty() && s != "0"
405            },
406            Err(_) => false,
407        }
408    }
409}
410
411thread_local! {
412    static DEQUANTIZE_ALL_F16: bool = {
413        match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
414            Ok(s) => {
415                !s.is_empty() && s != "0"
416            },
417            Err(_) => false,
418        }
419    }
420}
421
422impl QMatMul {
423    pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
424        let dequantize = match qtensor.dtype() {
425            GgmlDType::F32 | GgmlDType::F16 => true,
426            _ => DEQUANTIZE_ALL.with(|b| *b),
427        };
428        let t = if dequantize {
429            let tensor = qtensor.dequantize(&qtensor.device())?;
430            Self::Tensor(tensor)
431        } else if DEQUANTIZE_ALL_F16.with(|b| *b) {
432            let tensor = qtensor.dequantize_f16(&qtensor.device())?;
433            Self::TensorF16(tensor)
434        } else {
435            Self::QTensor(qtensor)
436        };
437        Ok(t)
438    }
439
440    pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
441        Self::from_arc(std::sync::Arc::new(qtensor))
442    }
443
444    pub fn dequantize_f16(&self) -> Result<Tensor> {
445        match self {
446            Self::QTensor(t) => t.dequantize_f16(&t.device()),
447            Self::Tensor(t) => t.to_dtype(DType::F16),
448            Self::TensorF16(t) => Ok(t.clone()),
449        }
450    }
451
452    pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
453        let w = self.dequantize_f16()?;
454        let in_dtype = xs.dtype();
455        let w = match *xs.dims() {
456            [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
457            [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
458            _ => w.t()?,
459        };
460        xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
461    }
462}
463
464impl crate::CustomOp1 for QTensor {
465    fn name(&self) -> &'static str {
466        "qmatmul"
467    }
468
469    fn cpu_fwd(
470        &self,
471        storage: &crate::CpuStorage,
472        layout: &crate::Layout,
473    ) -> Result<(crate::CpuStorage, Shape)> {
474        if !layout.is_contiguous() {
475            crate::bail!("input tensor is not contiguous {layout:?}")
476        }
477        let src_shape = layout.shape();
478        // self is transposed so n is first then k.
479        let (n, k) = self.shape.dims2()?;
480        if src_shape.rank() < 2 {
481            crate::bail!("input tensor has only one dimension {layout:?}")
482        }
483        let mut dst_shape = src_shape.dims().to_vec();
484        let last_k = dst_shape.pop().context("empty dst_shape")?;
485        if last_k != k {
486            crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
487        }
488        dst_shape.push(n);
489        let dst_shape = Shape::from(dst_shape);
490        #[allow(clippy::infallible_destructuring_match)]
491        let self_storage = match &self.storage {
492            QStorage::Cpu(storage) => storage,
493            QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
494        };
495        let slice = storage.as_slice::<f32>()?;
496        let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
497        let mut dst_storage = vec![0f32; dst_shape.elem_count()];
498        self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
499        Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
500    }
501
502    fn metal_fwd(
503        &self,
504        storage: &crate::MetalStorage,
505        layout: &crate::Layout,
506    ) -> Result<(crate::MetalStorage, Shape)> {
507        let self_storage = match &self.storage {
508            QStorage::Metal(metal) => metal,
509            _ => unreachable!("Cannot call metal matmul on non metal QTensor"),
510        };
511        self_storage.fwd(&self.shape, storage, layout)
512    }
513
514    fn cuda_fwd(
515        &self,
516        storage: &crate::CudaStorage,
517        layout: &crate::Layout,
518    ) -> Result<(crate::CudaStorage, Shape)> {
519        let self_storage = match &self.storage {
520            QStorage::Cuda(cuda) => cuda,
521            _ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
522        };
523        self_storage.fwd(&self.shape, storage, layout)
524    }
525}
526
527impl crate::Module for QMatMul {
528    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
529        match self {
530            Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
531            Self::Tensor(w) => {
532                let w = match *xs.dims() {
533                    [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
534                    [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
535                    _ => w.t()?,
536                };
537                xs.matmul(&w)
538            }
539            Self::TensorF16(w) => {
540                let in_dtype = xs.dtype();
541                let w = match *xs.dims() {
542                    [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
543                    [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
544                    _ => w.t()?,
545                };
546                xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
547            }
548        }
549    }
550}