candle_core_temp/quantized/
mod.rs

1use crate::{Device, Result, Shape, Tensor};
2
3#[cfg(target_feature = "avx")]
4pub mod avx;
5pub mod ggml_file;
6pub mod gguf_file;
7pub mod k_quants;
8#[cfg(target_feature = "neon")]
9pub mod neon;
10#[cfg(target_feature = "simd128")]
11pub mod simd128;
12pub mod utils;
13
14pub use k_quants::GgmlType;
15
16pub struct QTensor {
17    data: Box<dyn QuantizedType>,
18    shape: Shape,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
22pub enum GgmlDType {
23    F32,
24    F16,
25    Q4_0,
26    Q4_1,
27    Q5_0,
28    Q5_1,
29    Q8_0,
30    Q8_1,
31    Q2K,
32    Q3K,
33    Q4K,
34    Q5K,
35    Q6K,
36    Q8K,
37}
38
39impl GgmlDType {
40    pub(crate) fn from_u32(u: u32) -> Result<Self> {
41        let dtype = match u {
42            0 => Self::F32,
43            1 => Self::F16,
44            2 => Self::Q4_0,
45            3 => Self::Q4_1,
46            6 => Self::Q5_0,
47            7 => Self::Q5_1,
48            8 => Self::Q8_0,
49            9 => Self::Q8_1,
50            10 => Self::Q2K,
51            11 => Self::Q3K,
52            12 => Self::Q4K,
53            13 => Self::Q5K,
54            14 => Self::Q6K,
55            15 => Self::Q8K,
56            _ => crate::bail!("unknown dtype for tensor {u}"),
57        };
58        Ok(dtype)
59    }
60
61    pub(crate) fn to_u32(self) -> u32 {
62        match self {
63            Self::F32 => 0,
64            Self::F16 => 1,
65            Self::Q4_0 => 2,
66            Self::Q4_1 => 3,
67            Self::Q5_0 => 6,
68            Self::Q5_1 => 7,
69            Self::Q8_0 => 8,
70            Self::Q8_1 => 9,
71            Self::Q2K => 10,
72            Self::Q3K => 11,
73            Self::Q4K => 12,
74            Self::Q5K => 13,
75            Self::Q6K => 14,
76            Self::Q8K => 15,
77        }
78    }
79
80    /// The type size for blocks in bytes.
81    pub fn type_size(&self) -> usize {
82        use k_quants::*;
83        match self {
84            Self::F32 => 4,
85            Self::F16 => 2,
86            Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
87            Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
88            Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
89            Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
90            // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L932
91            Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
92            Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
93            Self::Q2K => std::mem::size_of::<BlockQ2K>(),
94            Self::Q3K => std::mem::size_of::<BlockQ3K>(),
95            Self::Q4K => std::mem::size_of::<BlockQ4K>(),
96            Self::Q5K => std::mem::size_of::<BlockQ5K>(),
97            Self::Q6K => std::mem::size_of::<BlockQ6K>(),
98            Self::Q8K => std::mem::size_of::<BlockQ8K>(),
99        }
100    }
101
102    /// The block size, i.e. the number of elements stored in each block.
103    pub fn blck_size(&self) -> usize {
104        match self {
105            Self::F32 => 1,
106            Self::F16 => 1,
107            Self::Q4_0 => k_quants::QK4_0,
108            Self::Q4_1 => k_quants::QK4_1,
109            Self::Q5_0 => k_quants::QK5_0,
110            Self::Q5_1 => k_quants::QK5_1,
111            Self::Q8_0 => k_quants::QK8_0,
112            Self::Q8_1 => k_quants::QK8_1,
113            Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
114        }
115    }
116}
117
118// A version of GgmlType without `vec_dot` so that it can be dyn boxed.
119pub trait QuantizedType: Send + Sync {
120    fn dtype(&self) -> GgmlDType;
121    fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
122    fn to_float(&self, ys: &mut [f32]) -> Result<()>;
123    fn storage_size_in_bytes(&self) -> usize;
124    fn as_ptr(&self) -> *const u8;
125}
126
127impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
128    fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
129        k_quants::matmul(mkn, lhs, self.as_slice(), dst)
130    }
131
132    fn dtype(&self) -> GgmlDType {
133        T::DTYPE
134    }
135
136    fn to_float(&self, ys: &mut [f32]) -> Result<()> {
137        T::to_float(self.as_slice(), ys)
138    }
139
140    fn storage_size_in_bytes(&self) -> usize {
141        self.len() * std::mem::size_of::<T>()
142    }
143
144    fn as_ptr(&self) -> *const u8 {
145        self.as_ptr() as *const u8
146    }
147}
148
149impl std::fmt::Debug for QTensor {
150    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
151        write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
152    }
153}
154
155fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
156    let dims = shape.dims();
157    if dims.is_empty() {
158        crate::bail!("scalar tensor cannot be quantized {shape:?}")
159    }
160    if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
161        crate::bail!(
162            "quantized tensor must have their last dim divisible by block size {shape:?} {}",
163            T::BLCK_SIZE
164        )
165    }
166    Ok(())
167}
168
169impl QTensor {
170    pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
171        data: Vec<T>,
172        shape: S,
173    ) -> Result<Self> {
174        let shape = shape.into();
175        check_shape::<T>(&shape)?;
176        Ok(Self {
177            data: Box::new(data),
178            shape,
179        })
180    }
181
182    pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
183        let shape = src.shape();
184        check_shape::<T>(shape)?;
185        let src = src
186            .to_dtype(crate::DType::F32)?
187            .flatten_all()?
188            .to_vec1::<f32>()?;
189        if src.len() % T::BLCK_SIZE != 0 {
190            crate::bail!(
191                "tensor size ({shape:?}) is not divisible by block size {}",
192                T::BLCK_SIZE
193            )
194        }
195        let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
196        T::from_float(&src, &mut data)?;
197        Ok(Self {
198            data: Box::new(data),
199            shape: shape.clone(),
200        })
201    }
202
203    pub fn dtype(&self) -> GgmlDType {
204        self.data.dtype()
205    }
206
207    pub fn rank(&self) -> usize {
208        self.shape.rank()
209    }
210
211    pub fn shape(&self) -> &Shape {
212        &self.shape
213    }
214
215    pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
216        let mut f32_data = vec![0f32; self.shape.elem_count()];
217        self.data.to_float(&mut f32_data)?;
218        Tensor::from_vec(f32_data, &self.shape, device)
219    }
220
221    pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
222        self.data.matmul_t(mkn, lhs, dst)
223    }
224
225    pub fn storage_size_in_bytes(&self) -> usize {
226        self.data.storage_size_in_bytes()
227    }
228
229    pub fn as_ptr(&self) -> *const u8 {
230        self.data.as_ptr()
231    }
232}
233
234#[derive(Clone, Debug)]
235pub enum QMatMul {
236    QTensor(std::sync::Arc<QTensor>),
237    Tensor(Tensor),
238}
239
240thread_local! {
241    static DEQUANTIZE_ALL: bool = {
242        match std::env::var("CANDLE_DEQUANTIZE_ALL") {
243            Ok(s) => {
244                !s.is_empty() && s != "0"
245            },
246            Err(_) => false,
247        }
248    }
249}
250
251impl QMatMul {
252    pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
253        let dequantize = match qtensor.dtype() {
254            GgmlDType::F32 | GgmlDType::F16 => true,
255            _ => DEQUANTIZE_ALL.with(|b| *b),
256        };
257        let t = if dequantize {
258            let tensor = qtensor.dequantize(&Device::Cpu)?;
259            Self::Tensor(tensor)
260        } else {
261            Self::QTensor(qtensor)
262        };
263        Ok(t)
264    }
265
266    pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
267        Self::from_arc(std::sync::Arc::new(qtensor))
268    }
269}
270
271impl crate::CustomOp1 for QTensor {
272    fn name(&self) -> &'static str {
273        "qmatmul"
274    }
275
276    fn cpu_fwd(
277        &self,
278        storage: &crate::CpuStorage,
279        layout: &crate::Layout,
280    ) -> Result<(crate::CpuStorage, Shape)> {
281        if !layout.is_contiguous() {
282            crate::bail!("input tensor is not contiguous {layout:?}")
283        }
284        let src_shape = layout.shape();
285        // self is transposed so n is first then k.
286        let (n, k) = self.shape.dims2()?;
287        if src_shape.rank() < 2 {
288            crate::bail!("input tensor has only one dimension {layout:?}")
289        }
290        let mut dst_shape = src_shape.dims().to_vec();
291        let last_k = dst_shape.pop().unwrap();
292        if last_k != k {
293            crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
294        }
295        dst_shape.push(n);
296        let dst_shape = Shape::from(dst_shape);
297        let storage = storage.as_slice::<f32>()?;
298        let storage =
299            &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
300        let mut dst_storage = vec![0f32; dst_shape.elem_count()];
301        self.matmul_t(
302            (dst_shape.elem_count() / n, k, n),
303            storage,
304            &mut dst_storage,
305        )?;
306        Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
307    }
308}
309
310impl QMatMul {
311    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
312        match self {
313            Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
314            Self::Tensor(w) => {
315                let w = match *xs.dims() {
316                    [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
317                    [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
318                    _ => w.t()?,
319                };
320                xs.matmul(&w)
321            }
322        }
323    }
324}