1use crate::{
2 backend::BackendStorage, CpuStorage, DType, Device, Result, Shape, Storage, Tensor, D,
3};
4use k_quants::*;
5use std::{borrow::Cow, sync::OnceLock};
6
7#[cfg(target_feature = "avx2")]
8pub mod avx;
9mod dummy_cuda;
10mod dummy_metal;
11pub mod ggml_file;
12pub mod gguf_file;
13pub mod imatrix_file;
14pub mod k_quants;
15#[cfg(feature = "metal")]
16pub mod metal;
17#[cfg(not(target_arch = "wasm32"))]
18pub mod tokenizer;
19#[cfg(not(feature = "metal"))]
20mod metal {
21 pub use super::dummy_metal::*;
22}
23#[cfg(feature = "cuda")]
24pub mod cuda;
25#[cfg(feature = "cuda")]
26pub mod fast_mmq;
27#[cfg(feature = "cuda")]
28pub mod fast_mmvq;
29#[cfg(not(feature = "cuda"))]
30mod cuda {
31 pub use super::dummy_cuda::*;
32}
33
34#[cfg(target_feature = "neon")]
35pub mod neon;
36#[cfg(target_feature = "simd128")]
37pub mod simd128;
38pub mod utils;
39use half::{bf16, f16};
40
41pub use k_quants::GgmlType;
42
43fn as_t_slice<T>(data: &[u8]) -> &[T] {
44 let size = std::mem::size_of::<T>();
45 assert_eq!(
46 data.len() % size,
47 0,
48 "Data length must be a multiple of T's size"
49 );
50 let ptr = data.as_ptr();
51 assert_eq!(
52 (ptr as usize) % std::mem::align_of::<T>(),
53 0,
54 "Data pointer must be aligned to T's alignment"
55 );
56 unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) }
57}
58
59pub struct QTensor {
60 storage: QStorage,
61 shape: Shape,
62 #[allow(dead_code)]
65 repacked_qs: OnceLock<Option<Vec<u8>>>,
66}
67
68impl Device {
69 fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
70 match self {
71 Device::Cpu => {
72 let storage = dtype.cpu_zeros(elem_count);
73 Ok(QStorage::Cpu(storage))
74 }
75 Device::Metal(metal) => {
76 let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
77 Ok(QStorage::Metal(storage))
78 }
79 Device::Cuda(cuda) => {
80 let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
81 Ok(QStorage::Cuda(storage))
82 }
83 }
84 }
85}
86
87pub enum QStorage {
88 Cpu(Box<dyn QuantizedType>),
89 Metal(metal::QMetalStorage),
90 Cuda(cuda::QCudaStorage),
91}
92
93impl QStorage {
94 pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result<Self> {
95 let data: &[u8] = &data;
96 match device {
97 Device::Cpu => Ok(Self::Cpu(dtype.from_data(Cow::Borrowed(data)))),
98 Device::Metal(d) => match dtype {
99 GgmlDType::F32 => metal::load_quantized(d, as_t_slice::<f32>(data)),
100 GgmlDType::F16 => metal::load_quantized(d, as_t_slice::<f16>(data)),
101 GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
102 GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
103 GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
104 GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
105 GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
106 GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
107 GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
108 GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
109 GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
110 GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
111 GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
112 GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
113 GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::<bf16>(data)),
114 },
115 Device::Cuda(d) => match dtype {
116 GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::<f32>(data)),
117 GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::<f16>(data)),
118 GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
119 GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
120 GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
121 GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
122 GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
123 GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
124 GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
125 GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
126 GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
127 GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
128 GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
129 GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
130 GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::<bf16>(data)),
131 },
132 }
133 }
134
135 fn block_size(&self) -> usize {
136 match self {
137 QStorage::Cpu(storage) => storage.block_size(),
138 QStorage::Metal(storage) => storage.dtype().block_size(),
139 QStorage::Cuda(storage) => storage.dtype().block_size(),
140 }
141 }
142
143 fn dtype(&self) -> GgmlDType {
144 match self {
145 QStorage::Cpu(storage) => storage.dtype(),
146 QStorage::Metal(storage) => storage.dtype(),
147 QStorage::Cuda(storage) => storage.dtype(),
148 }
149 }
150
151 fn device(&self) -> Device {
152 match self {
153 QStorage::Cpu(_storage) => Device::Cpu,
154 QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
155 QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
156 }
157 }
158
159 fn size_in_bytes(&self) -> usize {
160 match self {
161 QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
162 QStorage::Metal(storage) => storage.storage_size_in_bytes(),
163 QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
164 }
165 }
166
167 fn quantize(&mut self, src: &Storage) -> Result<()> {
168 match (self, src) {
169 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
170 storage.from_float(src.as_slice::<f32>()?);
171 }
172 (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
173 (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
174 _ => crate::bail!("Invalid quantize storage locations do not match"),
175 }
176 Ok(())
177 }
178
179 fn quantize_imatrix(
180 &mut self,
181 src: &Storage,
182 imatrix_weights: &[f32],
183 n_per_row: usize,
184 ) -> Result<()> {
185 match (self, src) {
186 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
187 storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
188 }
189 (QStorage::Metal(storage), Storage::Metal(src)) => {
190 storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
191 }
192 (QStorage::Cuda(storage), Storage::Cuda(src)) => {
193 storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
194 }
195 _ => crate::bail!("Invalid quantize storage locations do not match"),
196 }
197 Ok(())
198 }
199
200 fn quantize_onto(&mut self, src: &Storage) -> Result<()> {
201 match (self, src) {
202 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
203 storage.from_float(src.as_slice::<f32>()?);
204 }
205 (QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
206 (QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
207 _ => crate::bail!("Invalid quantize source storage locations: not on cpu"),
208 }
209 Ok(())
210 }
211
212 fn quantize_imatrix_onto(
213 &mut self,
214 src: &Storage,
215 imatrix_weights: &[f32],
216 n_per_row: usize,
217 ) -> Result<()> {
218 match (self, src) {
219 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
220 storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
221 }
222 (QStorage::Metal(storage), Storage::Cpu(src)) => {
223 storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
224 }
225 (QStorage::Cuda(storage), Storage::Cpu(src)) => {
226 storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
227 }
228 _ => crate::bail!("Invalid quantize storage locations do not match"),
229 }
230 Ok(())
231 }
232
233 fn dequantize(&self, elem_count: usize) -> Result<Storage> {
234 match self {
235 QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
236 QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
237 QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
238 }
239 }
240
241 fn data(&self) -> Result<Cow<'_, [u8]>> {
242 match self {
243 QStorage::Cpu(storage) => {
244 let data_ptr = storage.as_ptr();
245 let size_in_bytes = storage.storage_size_in_bytes();
246 let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
247 Ok(Cow::from(data))
248 }
249 QStorage::Cuda(storage) => Ok(Cow::from(storage.data()?)),
250 QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)),
251 }
252 }
253
254 pub fn device_ptr(&self) -> Result<*const u8> {
255 match self {
256 QStorage::Cuda(storage) => storage.device_ptr(),
257 QStorage::Metal(_) | QStorage::Cpu(_) => {
258 crate::bail!("not implemented");
259 }
260 }
261 }
262
263 #[cfg(feature = "cuda")]
264 pub fn device_ptr_with_guard<'a>(
265 &'a self,
266 stream: &'a crate::cuda_backend::cudarc::driver::CudaStream,
267 ) -> Result<(
268 *const u8,
269 crate::cuda_backend::cudarc::driver::SyncOnDrop<'a>,
270 )> {
271 match self {
272 QStorage::Cuda(storage) => storage.device_ptr_with_guard(stream),
273 QStorage::Metal(_) | QStorage::Cpu(_) => {
274 crate::bail!("not implemented");
275 }
276 }
277 }
278}
279
280#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
281pub enum GgmlDType {
282 F32,
283 F16,
284 BF16,
285 Q4_0,
286 Q4_1,
287 Q5_0,
288 Q5_1,
289 Q8_0,
290 Q8_1,
291 Q2K,
292 Q3K,
293 Q4K,
294 Q5K,
295 Q6K,
296 Q8K,
297}
298
299impl GgmlDType {
300 pub(crate) fn from_u32(u: u32) -> Result<Self> {
301 let dtype = match u {
302 0 => Self::F32,
303 1 => Self::F16,
304 2 => Self::Q4_0,
305 3 => Self::Q4_1,
306 6 => Self::Q5_0,
307 7 => Self::Q5_1,
308 8 => Self::Q8_0,
309 9 => Self::Q8_1,
310 10 => Self::Q2K,
311 11 => Self::Q3K,
312 12 => Self::Q4K,
313 13 => Self::Q5K,
314 14 => Self::Q6K,
315 15 => Self::Q8K,
316 30 => Self::BF16,
318 _ => crate::bail!("unknown dtype for tensor {u}"),
319 };
320 Ok(dtype)
321 }
322
323 pub(crate) fn to_u32(self) -> u32 {
324 match self {
325 Self::F32 => 0,
326 Self::F16 => 1,
327 Self::Q4_0 => 2,
328 Self::Q4_1 => 3,
329 Self::Q5_0 => 6,
330 Self::Q5_1 => 7,
331 Self::Q8_0 => 8,
332 Self::Q8_1 => 9,
333 Self::Q2K => 10,
334 Self::Q3K => 11,
335 Self::Q4K => 12,
336 Self::Q5K => 13,
337 Self::Q6K => 14,
338 Self::Q8K => 15,
339 Self::BF16 => 30,
341 }
342 }
343
344 pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
346 match self {
347 Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
348 Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
349 Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
350 Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
351 Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
352 Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
353 Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
354 Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
355 Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
356 Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
357 Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
358 Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
359 Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
360 Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
361 Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
362 }
363 }
364
365 pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box<dyn QuantizedType> {
366 let data: &[u8] = &data;
367 match self {
368 Self::F32 => Box::new(as_t_slice::<f32>(data).to_vec()),
369 Self::F16 => Box::new(as_t_slice::<f16>(data).to_vec()),
370 Self::Q4_0 => Box::new(as_t_slice::<BlockQ4_0>(data).to_vec()),
371 Self::Q4_1 => Box::new(as_t_slice::<BlockQ4_1>(data).to_vec()),
372 Self::Q5_0 => Box::new(as_t_slice::<BlockQ5_0>(data).to_vec()),
373 Self::Q5_1 => Box::new(as_t_slice::<BlockQ5_1>(data).to_vec()),
374 Self::Q8_0 => Box::new(as_t_slice::<BlockQ8_0>(data).to_vec()),
375 Self::Q8_1 => Box::new(as_t_slice::<BlockQ8_1>(data).to_vec()),
376 Self::Q2K => Box::new(as_t_slice::<BlockQ2K>(data).to_vec()),
377 Self::Q3K => Box::new(as_t_slice::<BlockQ3K>(data).to_vec()),
378 Self::Q4K => Box::new(as_t_slice::<BlockQ4K>(data).to_vec()),
379 Self::Q5K => Box::new(as_t_slice::<BlockQ5K>(data).to_vec()),
380 Self::Q6K => Box::new(as_t_slice::<BlockQ6K>(data).to_vec()),
381 Self::Q8K => Box::new(as_t_slice::<BlockQ8K>(data).to_vec()),
382 Self::BF16 => Box::new(as_t_slice::<bf16>(data).to_vec()),
383 }
384 }
385
386 pub fn type_size(&self) -> usize {
388 use k_quants::*;
389 match self {
390 Self::F32 => 4,
391 Self::F16 | Self::BF16 => 2,
392 Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
393 Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
394 Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
395 Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
396 Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
398 Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
399 Self::Q2K => std::mem::size_of::<BlockQ2K>(),
400 Self::Q3K => std::mem::size_of::<BlockQ3K>(),
401 Self::Q4K => std::mem::size_of::<BlockQ4K>(),
402 Self::Q5K => std::mem::size_of::<BlockQ5K>(),
403 Self::Q6K => std::mem::size_of::<BlockQ6K>(),
404 Self::Q8K => std::mem::size_of::<BlockQ8K>(),
405 }
406 }
407
408 pub fn block_size(&self) -> usize {
410 match self {
411 Self::F32 => 1,
412 Self::F16 | Self::BF16 => 1,
413 Self::Q4_0 => k_quants::QK4_0,
414 Self::Q4_1 => k_quants::QK4_1,
415 Self::Q5_0 => k_quants::QK5_0,
416 Self::Q5_1 => k_quants::QK5_1,
417 Self::Q8_0 => k_quants::QK8_0,
418 Self::Q8_1 => k_quants::QK8_1,
419 Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
420 }
421 }
422}
423
424pub trait QuantizedType: Send + Sync {
426 fn dtype(&self) -> GgmlDType;
427 fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
428 fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()>;
429 fn embedding(&self, ids: &[u32], rows: usize, hidden: usize) -> Result<CpuStorage>;
430 fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
431 fn storage_size_in_bytes(&self) -> usize;
432 fn as_ptr(&self) -> *const u8;
433 fn block_size(&self) -> usize;
434 #[allow(clippy::wrong_self_convention)]
435 fn from_float(&mut self, xs: &[f32]);
436 #[allow(clippy::wrong_self_convention)]
437 fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize);
438 fn size(&self) -> usize;
439}
440
441impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
442 fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
443 k_quants::matmul(mkn, lhs, self.as_slice(), dst)
444 }
445 fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()> {
446 k_quants::matmul_f16(mkn, lhs, self.as_slice(), dst)
447 }
448
449 fn embedding(&self, ids: &[u32], rows: usize, hidden: usize) -> Result<CpuStorage> {
450 if !hidden.is_multiple_of(T::BLCK_SIZE) {
451 crate::bail!(
452 "quantized embedding hidden size {hidden} is not divisible by block size {}",
453 T::BLCK_SIZE
454 )
455 }
456 let row_blocks = hidden / T::BLCK_SIZE;
457 if self.len() != rows * row_blocks {
458 crate::bail!(
459 "quantized tensor has {} blocks, expected {}",
460 self.len(),
461 rows * row_blocks
462 )
463 }
464 let mut out = vec![0f32; ids.len() * hidden];
465 for (out_row, &row_id) in ids.iter().enumerate() {
466 let row = row_id as usize;
467 if row >= rows {
468 crate::bail!("embedding id {row} is out of range for {rows} rows")
469 }
470 let src = &self[row * row_blocks..(row + 1) * row_blocks];
471 let dst = &mut out[out_row * hidden..(out_row + 1) * hidden];
472 T::to_float(src, dst);
473 }
474 Ok(CpuStorage::F32(out))
475 }
476
477 fn size(&self) -> usize {
478 self.len() * core::mem::size_of::<T>()
479 }
480
481 fn from_float(&mut self, xs: &[f32]) {
482 T::from_float(xs, self)
483 }
484
485 fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize) {
486 T::from_float_imatrix(xs, self, imatrix_weights, n_per_row)
487 }
488
489 fn dtype(&self) -> GgmlDType {
490 T::DTYPE
491 }
492
493 fn block_size(&self) -> usize {
494 T::BLCK_SIZE
495 }
496
497 fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
498 let mut ys = vec![0.0f32; elem_count];
499 T::to_float(self.as_slice(), &mut ys);
500 Ok(CpuStorage::F32(ys))
501 }
502
503 fn storage_size_in_bytes(&self) -> usize {
504 self.len() * std::mem::size_of::<T>()
505 }
506
507 fn as_ptr(&self) -> *const u8 {
508 self.as_ptr() as *const u8
509 }
510}
511
512impl std::fmt::Debug for QTensor {
513 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
514 write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
515 }
516}
517
518fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
519 let dims = shape.dims();
520 if dims.is_empty() {
521 crate::bail!("scalar tensor cannot be quantized {shape:?}")
522 }
523 if !dims[dims.len() - 1].is_multiple_of(block_size) {
524 crate::bail!(
525 "quantized tensor must have their last dim divisible by block size {shape:?} {}",
526 block_size
527 )
528 }
529 Ok(())
530}
531
532impl QTensor {
533 pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
534 let shape = shape.into();
535 check_shape(&shape, storage.block_size())?;
536 Ok(Self {
537 storage,
538 shape,
539 repacked_qs: OnceLock::new(),
540 })
541 }
542
543 pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
544 let shape = src.shape();
545 let block_size = dtype.block_size();
546 check_shape(shape, block_size)?;
547 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
548 let elem_count = shape.elem_count();
549 if !elem_count.is_multiple_of(block_size) {
550 crate::bail!(
551 "tensor size ({shape:?}) is not divisible by block size {}",
552 block_size
553 )
554 }
555 let mut storage = src.device().qzeros(elem_count, dtype)?;
556 storage.quantize(&src.storage())?;
557 Ok(Self {
558 storage,
559 shape: shape.clone(),
560 repacked_qs: OnceLock::new(),
561 })
562 }
563
564 pub fn quantize_imatrix(
565 src: &Tensor,
566 imatrix_weights: &[f32],
567 dtype: GgmlDType,
568 ) -> Result<Self> {
569 let n_per_row = src.dim(D::Minus1)?;
572 if imatrix_weights.len() != n_per_row {
573 crate::bail!(
574 "imatrix weights must have the same length {} as the last dim of src {}",
575 imatrix_weights.len(),
576 src.dim(D::Minus1)?
577 );
578 }
579
580 let shape = src.shape();
581 let block_size = dtype.block_size();
582 check_shape(shape, block_size)?;
583 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
584 let elem_count = shape.elem_count();
585 if !elem_count.is_multiple_of(block_size) {
586 crate::bail!(
587 "tensor size ({shape:?}) is not divisible by block size {}",
588 block_size
589 );
590 }
591 let mut storage = src.device().qzeros(elem_count, dtype)?;
592 storage.quantize_imatrix(&src.storage(), imatrix_weights, n_per_row)?;
593 Ok(Self {
594 storage,
595 shape: shape.clone(),
596 repacked_qs: OnceLock::new(),
597 })
598 }
599
600 pub fn quantize_imatrix_onto(
602 src: &Tensor,
603 imatrix_weights: &[f32],
604 dtype: GgmlDType,
605 dev: &Device,
606 ) -> Result<Self> {
607 if !src.device().is_cpu() {
608 crate::bail!(
609 "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
610 src.device()
611 )
612 }
613 let n_per_row = src.dim(D::Minus1)?;
616 if imatrix_weights.len() != n_per_row {
617 crate::bail!(
618 "imatrix weights must have the same length {} as the last dim of src {}",
619 imatrix_weights.len(),
620 src.dim(D::Minus1)?
621 );
622 }
623 let shape = src.shape();
624 let block_size = dtype.block_size();
625 check_shape(shape, block_size)?;
626 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
627 let elem_count = shape.elem_count();
628 if !elem_count.is_multiple_of(block_size) {
629 crate::bail!(
630 "tensor size ({shape:?}) is not divisible by block size {}",
631 block_size
632 )
633 }
634 let mut storage = dev.qzeros(elem_count, dtype)?;
636 storage.quantize_imatrix_onto(&src.storage(), imatrix_weights, n_per_row)?;
637 Ok(Self {
638 storage,
639 shape: shape.clone(),
640 repacked_qs: OnceLock::new(),
641 })
642 }
643
644 pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result<Self> {
646 if !src.device().is_cpu() {
647 crate::bail!(
648 "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
649 src.device()
650 )
651 }
652 let shape = src.shape();
653 let block_size = dtype.block_size();
654 check_shape(shape, block_size)?;
655 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
656 let elem_count = shape.elem_count();
657 if !elem_count.is_multiple_of(block_size) {
658 crate::bail!(
659 "tensor size ({shape:?}) is not divisible by block size {}",
660 block_size
661 )
662 }
663 let mut storage = dev.qzeros(elem_count, dtype)?;
665 storage.quantize_onto(&src.storage())?;
666 Ok(Self {
667 storage,
668 shape: shape.clone(),
669 repacked_qs: OnceLock::new(),
670 })
671 }
672
673 pub fn dtype(&self) -> GgmlDType {
674 self.storage.dtype()
675 }
676
677 pub fn device(&self) -> Device {
678 self.storage.device()
679 }
680
681 pub fn rank(&self) -> usize {
682 self.shape.rank()
683 }
684
685 pub fn shape(&self) -> &Shape {
686 &self.shape
687 }
688
689 pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
690 let storage = self.storage.dequantize(self.shape.elem_count())?;
691 let none = crate::op::BackpropOp::none();
692 crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
693 }
694
695 pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
696 match &self.storage {
699 QStorage::Cuda(s) => {
700 let s = s.dequantize_f16(self.shape.elem_count())?;
701 let none = crate::op::BackpropOp::none();
702 crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
703 .to_device(device)
704 }
705 _ => {
706 let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
707 Ok(s)
708 }
709 }
710 }
711
712 pub fn embedding(&self, ids: &Tensor) -> Result<Tensor> {
713 let (rows, hidden) = self.shape.dims2()?;
714 if !hidden.is_multiple_of(self.dtype().block_size()) {
715 crate::bail!(
716 "quantized embedding hidden size {hidden} is not divisible by block size {}",
717 self.dtype().block_size()
718 )
719 }
720 let mut out_shape = ids.dims().to_vec();
721 out_shape.push(hidden);
722 let device = self.device();
723 let ids = ids
724 .to_device(&device)?
725 .to_dtype(DType::U32)?
726 .flatten_all()?
727 .contiguous()?;
728 let storage = match &self.storage {
729 QStorage::Cpu(storage) => {
730 let ids = ids.to_vec1::<u32>()?;
731 Storage::Cpu(storage.embedding(&ids, rows, hidden)?)
732 }
733 QStorage::Metal(storage) => match &*ids.storage() {
734 Storage::Metal(ids_storage) => {
735 Storage::Metal(storage.embedding(rows, hidden, ids_storage, ids.layout())?)
736 }
737 _ => unreachable!("ids were moved to the QTensor device"),
738 },
739 QStorage::Cuda(storage) => match &*ids.storage() {
740 Storage::Cuda(ids_storage) => {
741 Storage::Cuda(storage.embedding(rows, hidden, ids_storage, ids.layout())?)
742 }
743 _ => unreachable!("ids were moved to the QTensor device"),
744 },
745 };
746 let none = crate::op::BackpropOp::none();
747 Ok(crate::tensor::from_storage(storage, out_shape, none, false))
748 }
749
750 pub fn storage_size_in_bytes(&self) -> usize {
751 self.storage.size_in_bytes()
752 }
753
754 pub fn data(&self) -> Result<Cow<'_, [u8]>> {
755 self.storage.data()
756 }
757
758 pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
759 match &self.storage {
760 QStorage::Cuda(s) => match (&*x.storage(), &*ids.storage()) {
761 (Storage::Cuda(x_storage), Storage::Cuda(ids_storage)) => {
762 let (storage, out_shape) = s.indexed_moe_forward(
763 self.shape(),
764 x_storage,
765 x.layout(),
766 ids_storage,
767 ids.layout(),
768 )?;
769 Ok(crate::tensor::from_storage(
770 Storage::Cuda(storage),
771 out_shape,
772 crate::op::BackpropOp::none(),
773 false,
774 ))
775 }
776 _ => {
777 panic!("Non-cuda indexed_moe_forward is not implemented!");
778 }
779 },
780 _ => {
781 panic!("indexed_moe_forward is not implemented in this platform!");
782 }
783 }
784 }
785
786 pub fn device_ptr(&self) -> Result<*const u8> {
787 match &self.storage {
788 QStorage::Cuda(storage) => storage.device_ptr(),
789 QStorage::Metal(_) | QStorage::Cpu(_) => {
790 crate::bail!("not implemented");
791 }
792 }
793 }
794
795 #[cfg(feature = "cuda")]
796 pub fn device_ptr_with_guard<'a>(
797 &'a self,
798 stream: &'a crate::cuda_backend::cudarc::driver::CudaStream,
799 ) -> Result<(
800 *const u8,
801 crate::cuda_backend::cudarc::driver::SyncOnDrop<'a>,
802 )> {
803 self.storage.device_ptr_with_guard(stream)
804 }
805}
806
807#[derive(Clone, Debug)]
808pub enum QMatMul {
809 QTensor(std::sync::Arc<QTensor>),
810 Tensor(Tensor),
811 TensorF16(Tensor),
812}
813
814thread_local! {
815 static DEQUANTIZE_ALL: bool = {
816 match std::env::var("CANDLE_DEQUANTIZE_ALL") {
817 Ok(s) => {
818 !s.is_empty() && s != "0"
819 },
820 Err(_) => false,
821 }
822 }
823}
824
825thread_local! {
826 static DEQUANTIZE_ALL_F16: bool = {
827 match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
828 Ok(s) => {
829 !s.is_empty() && s != "0"
830 },
831 Err(_) => false,
832 }
833 }
834}
835
836impl QMatMul {
837 pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
838 let dequantize = match qtensor.dtype() {
839 GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true,
840 _ => DEQUANTIZE_ALL.with(|b| *b),
841 };
842 let t = if dequantize {
843 let tensor = qtensor.dequantize(&qtensor.device())?;
844 Self::Tensor(tensor)
845 } else if DEQUANTIZE_ALL_F16.with(|b| *b) {
846 let tensor = qtensor.dequantize_f16(&qtensor.device())?;
847 Self::TensorF16(tensor)
848 } else {
849 Self::QTensor(qtensor)
850 };
851 Ok(t)
852 }
853
854 pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
855 Self::from_arc(std::sync::Arc::new(qtensor))
856 }
857
858 pub fn dequantize_f16(&self) -> Result<Tensor> {
859 match self {
860 Self::QTensor(t) => t.dequantize_f16(&t.device()),
861 Self::Tensor(t) => t.to_dtype(DType::F16),
862 Self::TensorF16(t) => Ok(t.clone()),
863 }
864 }
865
866 pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
867 let w = self.dequantize_f16()?;
868 let in_dtype = xs.dtype();
869 let w = match *xs.dims() {
870 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
871 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
872 _ => w.t()?,
873 };
874 xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
875 }
876
877 pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
878 match self {
879 Self::QTensor(t) => t.indexed_moe_forward(x, ids),
880 _ => {
881 panic!("Not implemented!")
882 }
883 }
884 }
885
886 pub fn embedding(&self, ids: &Tensor) -> Result<Tensor> {
887 match self {
888 Self::QTensor(t) => t.embedding(ids),
889 Self::Tensor(w) | Self::TensorF16(w) => {
890 let mut final_dims = ids.dims().to_vec();
891 final_dims.push(w.dim(D::Minus1)?);
892 let ids = ids.to_device(w.device())?.flatten_all()?;
893 w.index_select(&ids, 0)?.reshape(final_dims)
894 }
895 }
896 }
897}
898
899impl crate::CustomOp1 for QTensor {
900 fn name(&self) -> &'static str {
901 "qmatmul"
902 }
903
904 fn cpu_fwd(
905 &self,
906 storage: &crate::CpuStorage,
907 layout: &crate::Layout,
908 ) -> Result<(crate::CpuStorage, Shape)> {
909 if !layout.is_contiguous() {
910 crate::bail!("input tensor is not contiguous {layout:?}")
911 }
912 let src_shape = layout.shape();
913 let (n, k) = self.shape.dims2()?;
915 if src_shape.rank() < 2 {
916 crate::bail!("input tensor has only one dimension {layout:?}")
917 }
918 let mut dst_shape = src_shape.dims().to_vec();
919 let last_k = dst_shape.pop().unwrap();
920 if last_k != k {
921 crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
922 }
923 dst_shape.push(n);
924 let dst_shape = Shape::from(dst_shape);
925 #[allow(clippy::infallible_destructuring_match)]
926 let self_storage = match &self.storage {
927 QStorage::Cpu(storage) => storage,
928 QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
929 };
930 match storage.dtype() {
931 DType::F32 => {
932 let slice = storage.as_slice::<f32>()?;
933 let slice =
934 &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
935 let mut dst_storage = vec![0f32; dst_shape.elem_count()];
936
937 #[cfg(all(target_arch = "aarch64", target_feature = "dotprod"))]
939 if self_storage.dtype() == GgmlDType::Q4K && n.is_multiple_of(8) {
940 use zerocopy::{FromBytes, IntoBytes};
941
942 let total_blocks =
943 self_storage.storage_size_in_bytes() / std::mem::size_of::<BlockQ4K>();
944 let repacked = self.repacked_qs.get_or_init(|| {
945 let blocks = unsafe {
946 std::slice::from_raw_parts(
947 self_storage.as_ptr() as *const BlockQ4K,
948 total_blocks,
949 )
950 };
951 let packed = k_quants::pack_to_q4kx8(blocks, n);
952 Some(packed.as_bytes().to_vec())
953 });
954 if let Some(repacked_bytes) = repacked {
955 let block_x8: &[BlockQ4Kx8] =
956 <[BlockQ4Kx8]>::ref_from_bytes(repacked_bytes).map_err(|_| {
957 crate::Error::Msg(
958 "repacked_qs alignment invariant violated".to_string(),
959 )
960 })?;
961
962 k_quants::matmul_q4k_x8(
963 (dst_shape.elem_count() / n, k, n),
964 slice,
965 block_x8,
966 &mut dst_storage,
967 )?;
968 return Ok((crate::CpuStorage::F32(dst_storage), dst_shape));
969 }
970 }
971
972 self_storage.matmul_t(
973 (dst_shape.elem_count() / n, k, n),
974 slice,
975 &mut dst_storage,
976 )?;
977 Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
978 }
979 DType::F16 => {
980 let slice = storage.as_slice::<f16>()?;
981 let slice =
982 &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
983 let mut dst_storage = vec![f16::ZERO; dst_shape.elem_count()];
984 self_storage.matmul_t_f16(
985 (dst_shape.elem_count() / n, k, n),
986 slice,
987 &mut dst_storage,
988 )?;
989 Ok((crate::CpuStorage::F16(dst_storage), dst_shape))
990 }
991 _ => crate::bail!("Expected f32/f16"),
992 }
993 }
994
995 fn metal_fwd(
996 &self,
997 storage: &crate::MetalStorage,
998 layout: &crate::Layout,
999 ) -> Result<(crate::MetalStorage, Shape)> {
1000 let self_storage = match &self.storage {
1001 QStorage::Metal(metal) => metal,
1002 _ => unreachable!("Cannot call metal matmul on non metal QTensor"),
1003 };
1004 self_storage.fwd(&self.shape, storage, layout)
1005 }
1006
1007 fn cuda_fwd(
1008 &self,
1009 storage: &crate::CudaStorage,
1010 layout: &crate::Layout,
1011 ) -> Result<(crate::CudaStorage, Shape)> {
1012 let self_storage = match &self.storage {
1013 QStorage::Cuda(cuda) => cuda,
1014 _ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
1015 };
1016 self_storage.fwd(&self.shape, storage, layout)
1017 }
1018}
1019
1020impl crate::Module for QMatMul {
1021 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1022 match self {
1023 Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
1024 Self::Tensor(w) => {
1025 let w = match *xs.dims() {
1026 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
1027 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
1028 _ => w.t()?,
1029 };
1030 xs.matmul(&w)
1031 }
1032 Self::TensorF16(w) => {
1033 let in_dtype = xs.dtype();
1034 let w = match *xs.dims() {
1035 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
1036 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
1037 _ => w.t()?,
1038 };
1039 xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
1040 }
1041 }
1042 }
1043}