1use crate::{Context, CpuStorage, DType, Device, Result, Shape, Storage, Tensor};
3use k_quants::*;
4use std::borrow::Cow;
5
6#[cfg(target_feature = "avx")]
7pub mod avx;
8mod dummy_cuda;
9mod dummy_metal;
10pub mod ggml_file;
11pub mod gguf_file;
12pub mod k_quants;
13#[cfg(feature = "metal")]
14pub mod metal;
15#[cfg(not(feature = "metal"))]
16mod metal {
17 pub use super::dummy_metal::*;
18}
19#[cfg(feature = "cuda")]
20pub mod cuda;
21#[cfg(not(feature = "cuda"))]
22mod cuda {
23 pub use super::dummy_cuda::*;
24}
25
26#[cfg(target_feature = "neon")]
27pub mod neon;
28#[cfg(target_feature = "simd128")]
29pub mod simd128;
30pub mod utils;
31use half::f16;
32
33pub use k_quants::GgmlType;
34
35pub struct QTensor {
36 storage: QStorage,
37 shape: Shape,
38}
39
40impl Device {
41 fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
42 match self {
43 Device::Cpu => {
44 let storage = dtype.cpu_zeros(elem_count);
45 Ok(QStorage::Cpu(storage))
46 }
47 Device::Metal(metal) => {
48 let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
49 Ok(QStorage::Metal(storage))
50 }
51 Device::Cuda(cuda) => {
52 let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
53 Ok(QStorage::Cuda(storage))
54 }
55 }
56 }
57}
58
59pub enum QStorage {
60 Cpu(Box<dyn QuantizedType>),
61 Metal(metal::QMetalStorage),
62 Cuda(cuda::QCudaStorage),
63}
64
65impl QStorage {
66 fn block_size(&self) -> usize {
67 match self {
68 QStorage::Cpu(storage) => storage.block_size(),
69 QStorage::Metal(storage) => storage.dtype().block_size(),
70 QStorage::Cuda(storage) => storage.dtype().block_size(),
71 }
72 }
73
74 fn dtype(&self) -> GgmlDType {
75 match self {
76 QStorage::Cpu(storage) => storage.dtype(),
77 QStorage::Metal(storage) => storage.dtype(),
78 QStorage::Cuda(storage) => storage.dtype(),
79 }
80 }
81
82 fn device(&self) -> Device {
83 match self {
84 QStorage::Cpu(_storage) => Device::Cpu,
85 QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
86 QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
87 }
88 }
89
90 fn size_in_bytes(&self) -> usize {
91 match self {
92 QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
93 QStorage::Metal(storage) => storage.storage_size_in_bytes(),
94 QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
95 }
96 }
97
98 fn quantize(&mut self, src: &Storage) -> Result<()> {
99 match (self, src) {
100 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
101 storage.from_float(src.as_slice::<f32>()?)?;
102 }
103 (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
104 (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
105 _ => crate::bail!("Invalid dequantize storage locations do not match"),
106 }
107 Ok(())
108 }
109
110 fn dequantize(&self, elem_count: usize) -> Result<Storage> {
111 match self {
112 QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
113 QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
114 QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
115 }
116 }
117
118 fn data(&self) -> Result<Cow<[u8]>> {
119 match self {
120 QStorage::Cpu(storage) => {
121 let data_ptr = storage.as_ptr();
122 let size_in_bytes = storage.storage_size_in_bytes();
123 let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
124 Ok(Cow::from(data))
125 }
126 QStorage::Metal(_) | QStorage::Cuda(_) => {
127 crate::bail!("not implemented");
128 }
129 }
130 }
131}
132
133#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
134pub enum GgmlDType {
135 F32,
136 F16,
137 Q4_0,
138 Q4_1,
139 Q5_0,
140 Q5_1,
141 Q8_0,
142 Q8_1,
143 Q2K,
144 Q3K,
145 Q4K,
146 Q5K,
147 Q6K,
148 Q8K,
149}
150
151impl GgmlDType {
152 pub(crate) fn from_u32(u: u32) -> Result<Self> {
153 let dtype = match u {
154 0 => Self::F32,
155 1 => Self::F16,
156 2 => Self::Q4_0,
157 3 => Self::Q4_1,
158 6 => Self::Q5_0,
159 7 => Self::Q5_1,
160 8 => Self::Q8_0,
161 9 => Self::Q8_1,
162 10 => Self::Q2K,
163 11 => Self::Q3K,
164 12 => Self::Q4K,
165 13 => Self::Q5K,
166 14 => Self::Q6K,
167 15 => Self::Q8K,
168 _ => crate::bail!("unknown dtype for tensor {u}"),
169 };
170 Ok(dtype)
171 }
172
173 pub(crate) fn to_u32(self) -> u32 {
174 match self {
175 Self::F32 => 0,
176 Self::F16 => 1,
177 Self::Q4_0 => 2,
178 Self::Q4_1 => 3,
179 Self::Q5_0 => 6,
180 Self::Q5_1 => 7,
181 Self::Q8_0 => 8,
182 Self::Q8_1 => 9,
183 Self::Q2K => 10,
184 Self::Q3K => 11,
185 Self::Q4K => 12,
186 Self::Q5K => 13,
187 Self::Q6K => 14,
188 Self::Q8K => 15,
189 }
190 }
191
192 pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
194 match self {
195 Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
196 Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
197 Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
198 Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
199 Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
200 Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
201 Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
202 Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
203 Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
204 Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
205 Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
206 Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
207 Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
208 Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
209 }
210 }
211 pub fn type_size(&self) -> usize {
213 use k_quants::*;
214 match self {
215 Self::F32 => 4,
216 Self::F16 => 2,
217 Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
218 Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
219 Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
220 Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
221 Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
223 Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
224 Self::Q2K => std::mem::size_of::<BlockQ2K>(),
225 Self::Q3K => std::mem::size_of::<BlockQ3K>(),
226 Self::Q4K => std::mem::size_of::<BlockQ4K>(),
227 Self::Q5K => std::mem::size_of::<BlockQ5K>(),
228 Self::Q6K => std::mem::size_of::<BlockQ6K>(),
229 Self::Q8K => std::mem::size_of::<BlockQ8K>(),
230 }
231 }
232
233 pub fn block_size(&self) -> usize {
235 match self {
236 Self::F32 => 1,
237 Self::F16 => 1,
238 Self::Q4_0 => k_quants::QK4_0,
239 Self::Q4_1 => k_quants::QK4_1,
240 Self::Q5_0 => k_quants::QK5_0,
241 Self::Q5_1 => k_quants::QK5_1,
242 Self::Q8_0 => k_quants::QK8_0,
243 Self::Q8_1 => k_quants::QK8_1,
244 Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
245 }
246 }
247}
248
249pub trait QuantizedType: Send + Sync {
251 fn dtype(&self) -> GgmlDType;
252 fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
253 fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
254 fn storage_size_in_bytes(&self) -> usize;
255 fn as_ptr(&self) -> *const u8;
256 fn block_size(&self) -> usize;
257 #[allow(clippy::wrong_self_convention)]
258 fn from_float(&mut self, xs: &[f32]) -> Result<()>;
259 fn size(&self) -> usize;
260}
261
262impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
263 fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
264 k_quants::matmul(mkn, lhs, self.as_slice(), dst)
265 }
266
267 fn size(&self) -> usize {
268 self.len() * core::mem::size_of::<T>()
269 }
270
271 fn from_float(&mut self, xs: &[f32]) -> Result<()> {
272 T::from_float(xs, self)
273 }
274
275 fn dtype(&self) -> GgmlDType {
276 T::DTYPE
277 }
278
279 fn block_size(&self) -> usize {
280 T::BLCK_SIZE
281 }
282
283 fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
284 let mut ys = vec![0.0f32; elem_count];
285 T::to_float(self.as_slice(), &mut ys)?;
286 Ok(CpuStorage::F32(ys))
287 }
288
289 fn storage_size_in_bytes(&self) -> usize {
290 self.len() * std::mem::size_of::<T>()
291 }
292
293 fn as_ptr(&self) -> *const u8 {
294 self.as_ptr() as *const u8
295 }
296}
297
298impl std::fmt::Debug for QTensor {
299 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
300 write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
301 }
302}
303
304fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
305 let dims = shape.dims();
306 if dims.is_empty() {
307 crate::bail!("scalar tensor cannot be quantized {shape:?}")
308 }
309 if dims[dims.len() - 1] % block_size != 0 {
310 crate::bail!(
311 "quantized tensor must have their last dim divisible by block size {shape:?} {}",
312 block_size
313 )
314 }
315 Ok(())
316}
317
318impl QTensor {
319 pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
320 let shape = shape.into();
321 check_shape(&shape, storage.block_size())?;
322 Ok(Self { storage, shape })
323 }
324
325 pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
326 let shape = src.shape();
327 let block_size = dtype.block_size();
328 check_shape(shape, block_size)?;
329 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
330 let elem_count = shape.elem_count();
331 if elem_count % block_size != 0 {
332 crate::bail!(
333 "tensor size ({shape:?}) is not divisible by block size {}",
334 block_size
335 )
336 }
337 let mut storage = src.device().qzeros(elem_count, dtype)?;
338 storage.quantize(&src.storage())?;
339 Ok(Self {
340 storage,
341 shape: shape.clone(),
342 })
343 }
344
345 pub fn dtype(&self) -> GgmlDType {
346 self.storage.dtype()
347 }
348
349 pub fn device(&self) -> Device {
350 self.storage.device()
351 }
352
353 pub fn rank(&self) -> usize {
354 self.shape.rank()
355 }
356
357 pub fn shape(&self) -> &Shape {
358 &self.shape
359 }
360
361 pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
362 let storage = self.storage.dequantize(self.shape.elem_count())?;
363 let none = crate::op::BackpropOp::none();
364 crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
365 }
366
367 pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
368 match &self.storage {
371 QStorage::Cuda(s) => {
372 let s = s.dequantize_f16(self.shape.elem_count())?;
373 let none = crate::op::BackpropOp::none();
374 crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
375 .to_device(device)
376 }
377 _ => {
378 let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
379 Ok(s)
380 }
381 }
382 }
383
384 pub fn storage_size_in_bytes(&self) -> usize {
385 self.storage.size_in_bytes()
386 }
387
388 pub fn data(&self) -> Result<Cow<'_, [u8]>> {
389 self.storage.data()
390 }
391}
392
393#[derive(Clone, Debug)]
394pub enum QMatMul {
395 QTensor(std::sync::Arc<QTensor>),
396 Tensor(Tensor),
397 TensorF16(Tensor),
398}
399
400thread_local! {
401 static DEQUANTIZE_ALL: bool = {
402 match std::env::var("CANDLE_DEQUANTIZE_ALL") {
403 Ok(s) => {
404 !s.is_empty() && s != "0"
405 },
406 Err(_) => false,
407 }
408 }
409}
410
411thread_local! {
412 static DEQUANTIZE_ALL_F16: bool = {
413 match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
414 Ok(s) => {
415 !s.is_empty() && s != "0"
416 },
417 Err(_) => false,
418 }
419 }
420}
421
422impl QMatMul {
423 pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
424 let dequantize = match qtensor.dtype() {
425 GgmlDType::F32 | GgmlDType::F16 => true,
426 _ => DEQUANTIZE_ALL.with(|b| *b),
427 };
428 let t = if dequantize {
429 let tensor = qtensor.dequantize(&qtensor.device())?;
430 Self::Tensor(tensor)
431 } else if DEQUANTIZE_ALL_F16.with(|b| *b) {
432 let tensor = qtensor.dequantize_f16(&qtensor.device())?;
433 Self::TensorF16(tensor)
434 } else {
435 Self::QTensor(qtensor)
436 };
437 Ok(t)
438 }
439
440 pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
441 Self::from_arc(std::sync::Arc::new(qtensor))
442 }
443
444 pub fn dequantize_f16(&self) -> Result<Tensor> {
445 match self {
446 Self::QTensor(t) => t.dequantize_f16(&t.device()),
447 Self::Tensor(t) => t.to_dtype(DType::F16),
448 Self::TensorF16(t) => Ok(t.clone()),
449 }
450 }
451
452 pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
453 let w = self.dequantize_f16()?;
454 let in_dtype = xs.dtype();
455 let w = match *xs.dims() {
456 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
457 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
458 _ => w.t()?,
459 };
460 xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
461 }
462}
463
464impl crate::CustomOp1 for QTensor {
465 fn name(&self) -> &'static str {
466 "qmatmul"
467 }
468
469 fn cpu_fwd(
470 &self,
471 storage: &crate::CpuStorage,
472 layout: &crate::Layout,
473 ) -> Result<(crate::CpuStorage, Shape)> {
474 if !layout.is_contiguous() {
475 crate::bail!("input tensor is not contiguous {layout:?}")
476 }
477 let src_shape = layout.shape();
478 let (n, k) = self.shape.dims2()?;
480 if src_shape.rank() < 2 {
481 crate::bail!("input tensor has only one dimension {layout:?}")
482 }
483 let mut dst_shape = src_shape.dims().to_vec();
484 let last_k = dst_shape.pop().context("empty dst_shape")?;
485 if last_k != k {
486 crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
487 }
488 dst_shape.push(n);
489 let dst_shape = Shape::from(dst_shape);
490 #[allow(clippy::infallible_destructuring_match)]
491 let self_storage = match &self.storage {
492 QStorage::Cpu(storage) => storage,
493 QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
494 };
495 let slice = storage.as_slice::<f32>()?;
496 let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
497 let mut dst_storage = vec![0f32; dst_shape.elem_count()];
498 self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?;
499 Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
500 }
501
502 fn metal_fwd(
503 &self,
504 storage: &crate::MetalStorage,
505 layout: &crate::Layout,
506 ) -> Result<(crate::MetalStorage, Shape)> {
507 let self_storage = match &self.storage {
508 QStorage::Metal(metal) => metal,
509 _ => unreachable!("Cannot call metal matmul on non metal QTensor"),
510 };
511 self_storage.fwd(&self.shape, storage, layout)
512 }
513
514 fn cuda_fwd(
515 &self,
516 storage: &crate::CudaStorage,
517 layout: &crate::Layout,
518 ) -> Result<(crate::CudaStorage, Shape)> {
519 let self_storage = match &self.storage {
520 QStorage::Cuda(cuda) => cuda,
521 _ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
522 };
523 self_storage.fwd(&self.shape, storage, layout)
524 }
525}
526
527impl crate::Module for QMatMul {
528 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
529 match self {
530 Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
531 Self::Tensor(w) => {
532 let w = match *xs.dims() {
533 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
534 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
535 _ => w.t()?,
536 };
537 xs.matmul(&w)
538 }
539 Self::TensorF16(w) => {
540 let in_dtype = xs.dtype();
541 let w = match *xs.dims() {
542 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
543 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
544 _ => w.t()?,
545 };
546 xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
547 }
548 }
549 }
550}