1use crate::{
2 backend::BackendStorage, CpuStorage, DType, Device, Result, Shape, Storage, Tensor, D,
3};
4use k_quants::*;
5use std::borrow::Cow;
6
7#[cfg(target_feature = "avx2")]
8pub mod avx;
9mod dummy_cuda;
10mod dummy_metal;
11pub mod ggml_file;
12pub mod gguf_file;
13pub mod imatrix_file;
14pub mod k_quants;
15#[cfg(feature = "metal")]
16pub mod metal;
17#[cfg(not(target_arch = "wasm32"))]
18pub mod tokenizer;
19#[cfg(not(feature = "metal"))]
20mod metal {
21 pub use super::dummy_metal::*;
22}
23#[cfg(feature = "cuda")]
24pub mod cuda;
25#[cfg(not(feature = "cuda"))]
26mod cuda {
27 pub use super::dummy_cuda::*;
28}
29
30#[cfg(target_feature = "neon")]
31pub mod neon;
32#[cfg(target_feature = "simd128")]
33pub mod simd128;
34pub mod utils;
35use half::{bf16, f16};
36
37pub use k_quants::GgmlType;
38
39fn as_t_slice<T>(data: Cow<'_, [u8]>) -> &[T] {
40 let size = std::mem::size_of::<T>();
41 assert_eq!(
42 data.len() % size,
43 0,
44 "Data length must be a multiple of T's size"
45 );
46 let ptr = data.as_ptr();
47 assert_eq!(
48 (ptr as usize) % std::mem::align_of::<T>(),
49 0,
50 "Data pointer must be aligned to T's alignment"
51 );
52 unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) }
53}
54
55pub struct QTensor {
56 storage: QStorage,
57 shape: Shape,
58}
59
60impl Device {
61 fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
62 match self {
63 Device::Cpu => {
64 let storage = dtype.cpu_zeros(elem_count);
65 Ok(QStorage::Cpu(storage))
66 }
67 Device::Metal(metal) => {
68 let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
69 Ok(QStorage::Metal(storage))
70 }
71 Device::Cuda(cuda) => {
72 let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
73 Ok(QStorage::Cuda(storage))
74 }
75 }
76 }
77}
78
79pub enum QStorage {
80 Cpu(Box<dyn QuantizedType>),
81 Metal(metal::QMetalStorage),
82 Cuda(cuda::QCudaStorage),
83}
84
85impl QStorage {
86 pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result<Self> {
87 match device {
88 Device::Cpu => Ok(Self::Cpu(dtype.from_data(data))),
89 Device::Metal(d) => match dtype {
90 GgmlDType::F32 => metal::load_quantized(d, as_t_slice::<f32>(data)),
91 GgmlDType::F16 => metal::load_quantized(d, as_t_slice::<f16>(data)),
92 GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
93 GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
94 GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
95 GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
96 GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
97 GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
98 GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
99 GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
100 GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
101 GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
102 GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
103 GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
104 GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::<bf16>(data)),
105 },
106 Device::Cuda(d) => match dtype {
107 GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::<f32>(data)),
108 GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::<f16>(data)),
109 GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
110 GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
111 GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
112 GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
113 GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
114 GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
115 GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
116 GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
117 GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
118 GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
119 GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
120 GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
121 GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::<bf16>(data)),
122 },
123 }
124 }
125
126 fn block_size(&self) -> usize {
127 match self {
128 QStorage::Cpu(storage) => storage.block_size(),
129 QStorage::Metal(storage) => storage.dtype().block_size(),
130 QStorage::Cuda(storage) => storage.dtype().block_size(),
131 }
132 }
133
134 fn dtype(&self) -> GgmlDType {
135 match self {
136 QStorage::Cpu(storage) => storage.dtype(),
137 QStorage::Metal(storage) => storage.dtype(),
138 QStorage::Cuda(storage) => storage.dtype(),
139 }
140 }
141
142 fn device(&self) -> Device {
143 match self {
144 QStorage::Cpu(_storage) => Device::Cpu,
145 QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
146 QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
147 }
148 }
149
150 fn size_in_bytes(&self) -> usize {
151 match self {
152 QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
153 QStorage::Metal(storage) => storage.storage_size_in_bytes(),
154 QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
155 }
156 }
157
158 fn quantize(&mut self, src: &Storage) -> Result<()> {
159 match (self, src) {
160 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
161 storage.from_float(src.as_slice::<f32>()?);
162 }
163 (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
164 (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
165 _ => crate::bail!("Invalid quantize storage locations do not match"),
166 }
167 Ok(())
168 }
169
170 fn quantize_imatrix(
171 &mut self,
172 src: &Storage,
173 imatrix_weights: &[f32],
174 n_per_row: usize,
175 ) -> Result<()> {
176 match (self, src) {
177 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
178 storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
179 }
180 (QStorage::Metal(storage), Storage::Metal(src)) => {
181 storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
182 }
183 (QStorage::Cuda(storage), Storage::Cuda(src)) => {
184 storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
185 }
186 _ => crate::bail!("Invalid quantize storage locations do not match"),
187 }
188 Ok(())
189 }
190
191 fn quantize_onto(&mut self, src: &Storage) -> Result<()> {
192 match (self, src) {
193 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
194 storage.from_float(src.as_slice::<f32>()?);
195 }
196 (QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
197 (QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
198 _ => crate::bail!("Invalid quantize source storage locations: not on cpu"),
199 }
200 Ok(())
201 }
202
203 fn quantize_imatrix_onto(
204 &mut self,
205 src: &Storage,
206 imatrix_weights: &[f32],
207 n_per_row: usize,
208 ) -> Result<()> {
209 match (self, src) {
210 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
211 storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
212 }
213 (QStorage::Metal(storage), Storage::Cpu(src)) => {
214 storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
215 }
216 (QStorage::Cuda(storage), Storage::Cpu(src)) => {
217 storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
218 }
219 _ => crate::bail!("Invalid quantize storage locations do not match"),
220 }
221 Ok(())
222 }
223
224 fn dequantize(&self, elem_count: usize) -> Result<Storage> {
225 match self {
226 QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
227 QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
228 QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
229 }
230 }
231
232 fn data(&self) -> Result<Cow<'_, [u8]>> {
233 match self {
234 QStorage::Cpu(storage) => {
235 let data_ptr = storage.as_ptr();
236 let size_in_bytes = storage.storage_size_in_bytes();
237 let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
238 Ok(Cow::from(data))
239 }
240 QStorage::Cuda(storage) => Ok(Cow::from(storage.data()?)),
241 QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)),
242 }
243 }
244
245 pub fn device_ptr(&self) -> Result<*const u8> {
246 match self {
247 QStorage::Cuda(storage) => storage.device_ptr(),
248 QStorage::Metal(_) | QStorage::Cpu(_) => {
249 crate::bail!("not implemented");
250 }
251 }
252 }
253}
254
255#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
256pub enum GgmlDType {
257 F32,
258 F16,
259 BF16,
260 Q4_0,
261 Q4_1,
262 Q5_0,
263 Q5_1,
264 Q8_0,
265 Q8_1,
266 Q2K,
267 Q3K,
268 Q4K,
269 Q5K,
270 Q6K,
271 Q8K,
272}
273
274impl GgmlDType {
275 pub(crate) fn from_u32(u: u32) -> Result<Self> {
276 let dtype = match u {
277 0 => Self::F32,
278 1 => Self::F16,
279 2 => Self::Q4_0,
280 3 => Self::Q4_1,
281 6 => Self::Q5_0,
282 7 => Self::Q5_1,
283 8 => Self::Q8_0,
284 9 => Self::Q8_1,
285 10 => Self::Q2K,
286 11 => Self::Q3K,
287 12 => Self::Q4K,
288 13 => Self::Q5K,
289 14 => Self::Q6K,
290 15 => Self::Q8K,
291 30 => Self::BF16,
293 _ => crate::bail!("unknown dtype for tensor {u}"),
294 };
295 Ok(dtype)
296 }
297
298 pub(crate) fn to_u32(self) -> u32 {
299 match self {
300 Self::F32 => 0,
301 Self::F16 => 1,
302 Self::Q4_0 => 2,
303 Self::Q4_1 => 3,
304 Self::Q5_0 => 6,
305 Self::Q5_1 => 7,
306 Self::Q8_0 => 8,
307 Self::Q8_1 => 9,
308 Self::Q2K => 10,
309 Self::Q3K => 11,
310 Self::Q4K => 12,
311 Self::Q5K => 13,
312 Self::Q6K => 14,
313 Self::Q8K => 15,
314 Self::BF16 => 30,
316 }
317 }
318
319 pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
321 match self {
322 Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
323 Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
324 Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
325 Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
326 Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
327 Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
328 Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
329 Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
330 Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
331 Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
332 Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
333 Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
334 Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
335 Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
336 Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
337 }
338 }
339
340 pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box<dyn QuantizedType> {
341 match self {
342 Self::F32 => Box::new(as_t_slice::<f32>(data).to_vec()),
343 Self::F16 => Box::new(as_t_slice::<f16>(data).to_vec()),
344 Self::Q4_0 => Box::new(as_t_slice::<BlockQ4_0>(data).to_vec()),
345 Self::Q4_1 => Box::new(as_t_slice::<BlockQ4_1>(data).to_vec()),
346 Self::Q5_0 => Box::new(as_t_slice::<BlockQ5_0>(data).to_vec()),
347 Self::Q5_1 => Box::new(as_t_slice::<BlockQ5_1>(data).to_vec()),
348 Self::Q8_0 => Box::new(as_t_slice::<BlockQ8_0>(data).to_vec()),
349 Self::Q8_1 => Box::new(as_t_slice::<BlockQ8_1>(data).to_vec()),
350 Self::Q2K => Box::new(as_t_slice::<BlockQ2K>(data).to_vec()),
351 Self::Q3K => Box::new(as_t_slice::<BlockQ3K>(data).to_vec()),
352 Self::Q4K => Box::new(as_t_slice::<BlockQ4K>(data).to_vec()),
353 Self::Q5K => Box::new(as_t_slice::<BlockQ5K>(data).to_vec()),
354 Self::Q6K => Box::new(as_t_slice::<BlockQ6K>(data).to_vec()),
355 Self::Q8K => Box::new(as_t_slice::<BlockQ8K>(data).to_vec()),
356 Self::BF16 => Box::new(as_t_slice::<bf16>(data).to_vec()),
357 }
358 }
359
360 pub fn type_size(&self) -> usize {
362 use k_quants::*;
363 match self {
364 Self::F32 => 4,
365 Self::F16 | Self::BF16 => 2,
366 Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
367 Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
368 Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
369 Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
370 Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
372 Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
373 Self::Q2K => std::mem::size_of::<BlockQ2K>(),
374 Self::Q3K => std::mem::size_of::<BlockQ3K>(),
375 Self::Q4K => std::mem::size_of::<BlockQ4K>(),
376 Self::Q5K => std::mem::size_of::<BlockQ5K>(),
377 Self::Q6K => std::mem::size_of::<BlockQ6K>(),
378 Self::Q8K => std::mem::size_of::<BlockQ8K>(),
379 }
380 }
381
382 pub fn block_size(&self) -> usize {
384 match self {
385 Self::F32 => 1,
386 Self::F16 | Self::BF16 => 1,
387 Self::Q4_0 => k_quants::QK4_0,
388 Self::Q4_1 => k_quants::QK4_1,
389 Self::Q5_0 => k_quants::QK5_0,
390 Self::Q5_1 => k_quants::QK5_1,
391 Self::Q8_0 => k_quants::QK8_0,
392 Self::Q8_1 => k_quants::QK8_1,
393 Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
394 }
395 }
396}
397
398pub trait QuantizedType: Send + Sync {
400 fn dtype(&self) -> GgmlDType;
401 fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
402 fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()>;
403 fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
404 fn storage_size_in_bytes(&self) -> usize;
405 fn as_ptr(&self) -> *const u8;
406 fn block_size(&self) -> usize;
407 #[allow(clippy::wrong_self_convention)]
408 fn from_float(&mut self, xs: &[f32]);
409 #[allow(clippy::wrong_self_convention)]
410 fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize);
411 fn size(&self) -> usize;
412}
413
414impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
415 fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
416 k_quants::matmul(mkn, lhs, self.as_slice(), dst)
417 }
418 fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()> {
419 k_quants::matmul_f16(mkn, lhs, self.as_slice(), dst)
420 }
421
422 fn size(&self) -> usize {
423 self.len() * core::mem::size_of::<T>()
424 }
425
426 fn from_float(&mut self, xs: &[f32]) {
427 T::from_float(xs, self)
428 }
429
430 fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize) {
431 T::from_float_imatrix(xs, self, imatrix_weights, n_per_row)
432 }
433
434 fn dtype(&self) -> GgmlDType {
435 T::DTYPE
436 }
437
438 fn block_size(&self) -> usize {
439 T::BLCK_SIZE
440 }
441
442 fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
443 let mut ys = vec![0.0f32; elem_count];
444 T::to_float(self.as_slice(), &mut ys);
445 Ok(CpuStorage::F32(ys))
446 }
447
448 fn storage_size_in_bytes(&self) -> usize {
449 self.len() * std::mem::size_of::<T>()
450 }
451
452 fn as_ptr(&self) -> *const u8 {
453 self.as_ptr() as *const u8
454 }
455}
456
457impl std::fmt::Debug for QTensor {
458 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
459 write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
460 }
461}
462
463fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
464 let dims = shape.dims();
465 if dims.is_empty() {
466 crate::bail!("scalar tensor cannot be quantized {shape:?}")
467 }
468 if !dims[dims.len() - 1].is_multiple_of(block_size) {
469 crate::bail!(
470 "quantized tensor must have their last dim divisible by block size {shape:?} {}",
471 block_size
472 )
473 }
474 Ok(())
475}
476
477impl QTensor {
478 pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
479 let shape = shape.into();
480 check_shape(&shape, storage.block_size())?;
481 Ok(Self { storage, shape })
482 }
483
484 pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
485 let shape = src.shape();
486 let block_size = dtype.block_size();
487 check_shape(shape, block_size)?;
488 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
489 let elem_count = shape.elem_count();
490 if !elem_count.is_multiple_of(block_size) {
491 crate::bail!(
492 "tensor size ({shape:?}) is not divisible by block size {}",
493 block_size
494 )
495 }
496 let mut storage = src.device().qzeros(elem_count, dtype)?;
497 storage.quantize(&src.storage())?;
498 Ok(Self {
499 storage,
500 shape: shape.clone(),
501 })
502 }
503
504 pub fn quantize_imatrix(
505 src: &Tensor,
506 imatrix_weights: &[f32],
507 dtype: GgmlDType,
508 ) -> Result<Self> {
509 let n_per_row = src.dim(D::Minus1)?;
512 if imatrix_weights.len() != n_per_row {
513 crate::bail!(
514 "imatrix weights must have the same length {} as the last dim of src {}",
515 imatrix_weights.len(),
516 src.dim(D::Minus1)?
517 );
518 }
519
520 let shape = src.shape();
521 let block_size = dtype.block_size();
522 check_shape(shape, block_size)?;
523 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
524 let elem_count = shape.elem_count();
525 if !elem_count.is_multiple_of(block_size) {
526 crate::bail!(
527 "tensor size ({shape:?}) is not divisible by block size {}",
528 block_size
529 );
530 }
531 let mut storage = src.device().qzeros(elem_count, dtype)?;
532 storage.quantize_imatrix(&src.storage(), imatrix_weights, n_per_row)?;
533 Ok(Self {
534 storage,
535 shape: shape.clone(),
536 })
537 }
538
539 pub fn quantize_imatrix_onto(
541 src: &Tensor,
542 imatrix_weights: &[f32],
543 dtype: GgmlDType,
544 dev: &Device,
545 ) -> Result<Self> {
546 if !src.device().is_cpu() {
547 crate::bail!(
548 "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
549 src.device()
550 )
551 }
552 let n_per_row = src.dim(D::Minus1)?;
555 if imatrix_weights.len() != n_per_row {
556 crate::bail!(
557 "imatrix weights must have the same length {} as the last dim of src {}",
558 imatrix_weights.len(),
559 src.dim(D::Minus1)?
560 );
561 }
562 let shape = src.shape();
563 let block_size = dtype.block_size();
564 check_shape(shape, block_size)?;
565 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
566 let elem_count = shape.elem_count();
567 if !elem_count.is_multiple_of(block_size) {
568 crate::bail!(
569 "tensor size ({shape:?}) is not divisible by block size {}",
570 block_size
571 )
572 }
573 let mut storage = dev.qzeros(elem_count, dtype)?;
575 storage.quantize_imatrix_onto(&src.storage(), imatrix_weights, n_per_row)?;
576 Ok(Self {
577 storage,
578 shape: shape.clone(),
579 })
580 }
581
582 pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result<Self> {
584 if !src.device().is_cpu() {
585 crate::bail!(
586 "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
587 src.device()
588 )
589 }
590 let shape = src.shape();
591 let block_size = dtype.block_size();
592 check_shape(shape, block_size)?;
593 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
594 let elem_count = shape.elem_count();
595 if !elem_count.is_multiple_of(block_size) {
596 crate::bail!(
597 "tensor size ({shape:?}) is not divisible by block size {}",
598 block_size
599 )
600 }
601 let mut storage = dev.qzeros(elem_count, dtype)?;
603 storage.quantize_onto(&src.storage())?;
604 Ok(Self {
605 storage,
606 shape: shape.clone(),
607 })
608 }
609
610 pub fn dtype(&self) -> GgmlDType {
611 self.storage.dtype()
612 }
613
614 pub fn device(&self) -> Device {
615 self.storage.device()
616 }
617
618 pub fn rank(&self) -> usize {
619 self.shape.rank()
620 }
621
622 pub fn shape(&self) -> &Shape {
623 &self.shape
624 }
625
626 pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
627 let storage = self.storage.dequantize(self.shape.elem_count())?;
628 let none = crate::op::BackpropOp::none();
629 crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
630 }
631
632 pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
633 match &self.storage {
636 QStorage::Cuda(s) => {
637 let s = s.dequantize_f16(self.shape.elem_count())?;
638 let none = crate::op::BackpropOp::none();
639 crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
640 .to_device(device)
641 }
642 _ => {
643 let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
644 Ok(s)
645 }
646 }
647 }
648
649 pub fn storage_size_in_bytes(&self) -> usize {
650 self.storage.size_in_bytes()
651 }
652
653 pub fn data(&self) -> Result<Cow<'_, [u8]>> {
654 self.storage.data()
655 }
656
657 pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
658 match &self.storage {
659 QStorage::Cuda(s) => match (&*x.storage(), &*ids.storage()) {
660 (Storage::Cuda(x_storage), Storage::Cuda(ids_storage)) => {
661 let (storage, out_shape) = s.indexed_moe_forward(
662 self.shape(),
663 x_storage,
664 x.layout(),
665 ids_storage,
666 ids.layout(),
667 )?;
668 Ok(crate::tensor::from_storage(
669 Storage::Cuda(storage),
670 out_shape,
671 crate::op::BackpropOp::none(),
672 false,
673 ))
674 }
675 _ => {
676 panic!("Non-cuda indexed_moe_forward is not implemented!");
677 }
678 },
679 _ => {
680 panic!("indexed_moe_forward is not implemented in this platform!");
681 }
682 }
683 }
684
685 pub fn device_ptr(&self) -> Result<*const u8> {
686 match &self.storage {
687 QStorage::Cuda(storage) => storage.device_ptr(),
688 QStorage::Metal(_) | QStorage::Cpu(_) => {
689 crate::bail!("not implemented");
690 }
691 }
692 }
693}
694
695#[derive(Clone, Debug)]
696pub enum QMatMul {
697 QTensor(std::sync::Arc<QTensor>),
698 Tensor(Tensor),
699 TensorF16(Tensor),
700}
701
702thread_local! {
703 static DEQUANTIZE_ALL: bool = {
704 match std::env::var("CANDLE_DEQUANTIZE_ALL") {
705 Ok(s) => {
706 !s.is_empty() && s != "0"
707 },
708 Err(_) => false,
709 }
710 }
711}
712
713thread_local! {
714 static DEQUANTIZE_ALL_F16: bool = {
715 match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
716 Ok(s) => {
717 !s.is_empty() && s != "0"
718 },
719 Err(_) => false,
720 }
721 }
722}
723
724impl QMatMul {
725 pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
726 let dequantize = match qtensor.dtype() {
727 GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true,
728 _ => DEQUANTIZE_ALL.with(|b| *b),
729 };
730 let t = if dequantize {
731 let tensor = qtensor.dequantize(&qtensor.device())?;
732 Self::Tensor(tensor)
733 } else if DEQUANTIZE_ALL_F16.with(|b| *b) {
734 let tensor = qtensor.dequantize_f16(&qtensor.device())?;
735 Self::TensorF16(tensor)
736 } else {
737 Self::QTensor(qtensor)
738 };
739 Ok(t)
740 }
741
742 pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
743 Self::from_arc(std::sync::Arc::new(qtensor))
744 }
745
746 pub fn dequantize_f16(&self) -> Result<Tensor> {
747 match self {
748 Self::QTensor(t) => t.dequantize_f16(&t.device()),
749 Self::Tensor(t) => t.to_dtype(DType::F16),
750 Self::TensorF16(t) => Ok(t.clone()),
751 }
752 }
753
754 pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
755 let w = self.dequantize_f16()?;
756 let in_dtype = xs.dtype();
757 let w = match *xs.dims() {
758 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
759 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
760 _ => w.t()?,
761 };
762 xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
763 }
764
765 pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
766 match self {
767 Self::QTensor(t) => t.indexed_moe_forward(x, ids),
768 _ => {
769 panic!("Not implemented!")
770 }
771 }
772 }
773}
774
775impl crate::CustomOp1 for QTensor {
776 fn name(&self) -> &'static str {
777 "qmatmul"
778 }
779
780 fn cpu_fwd(
781 &self,
782 storage: &crate::CpuStorage,
783 layout: &crate::Layout,
784 ) -> Result<(crate::CpuStorage, Shape)> {
785 if !layout.is_contiguous() {
786 crate::bail!("input tensor is not contiguous {layout:?}")
787 }
788 let src_shape = layout.shape();
789 let (n, k) = self.shape.dims2()?;
791 if src_shape.rank() < 2 {
792 crate::bail!("input tensor has only one dimension {layout:?}")
793 }
794 let mut dst_shape = src_shape.dims().to_vec();
795 let last_k = dst_shape.pop().unwrap();
796 if last_k != k {
797 crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
798 }
799 dst_shape.push(n);
800 let dst_shape = Shape::from(dst_shape);
801 #[allow(clippy::infallible_destructuring_match)]
802 let self_storage = match &self.storage {
803 QStorage::Cpu(storage) => storage,
804 QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
805 };
806 match storage.dtype() {
807 DType::F32 => {
808 let slice = storage.as_slice::<f32>()?;
809 let slice =
810 &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
811 let mut dst_storage = vec![0f32; dst_shape.elem_count()];
812 self_storage.matmul_t(
813 (dst_shape.elem_count() / n, k, n),
814 slice,
815 &mut dst_storage,
816 )?;
817 Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
818 }
819 DType::F16 => {
820 let slice = storage.as_slice::<f16>()?;
821 let slice =
822 &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
823 let mut dst_storage = vec![f16::ZERO; dst_shape.elem_count()];
824 self_storage.matmul_t_f16(
825 (dst_shape.elem_count() / n, k, n),
826 slice,
827 &mut dst_storage,
828 )?;
829 Ok((crate::CpuStorage::F16(dst_storage), dst_shape))
830 }
831 _ => crate::bail!("Expected f32/f16"),
832 }
833 }
834
835 fn metal_fwd(
836 &self,
837 storage: &crate::MetalStorage,
838 layout: &crate::Layout,
839 ) -> Result<(crate::MetalStorage, Shape)> {
840 let self_storage = match &self.storage {
841 QStorage::Metal(metal) => metal,
842 _ => unreachable!("Cannot call metal matmul on non metal QTensor"),
843 };
844 self_storage.fwd(&self.shape, storage, layout)
845 }
846
847 fn cuda_fwd(
848 &self,
849 storage: &crate::CudaStorage,
850 layout: &crate::Layout,
851 ) -> Result<(crate::CudaStorage, Shape)> {
852 let self_storage = match &self.storage {
853 QStorage::Cuda(cuda) => cuda,
854 _ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
855 };
856 self_storage.fwd(&self.shape, storage, layout)
857 }
858}
859
860impl crate::Module for QMatMul {
861 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
862 match self {
863 Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
864 Self::Tensor(w) => {
865 let w = match *xs.dims() {
866 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
867 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
868 _ => w.t()?,
869 };
870 xs.matmul(&w)
871 }
872 Self::TensorF16(w) => {
873 let in_dtype = xs.dtype();
874 let w = match *xs.dims() {
875 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
876 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
877 _ => w.t()?,
878 };
879 xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
880 }
881 }
882 }
883}