1use crate::{
2 backend::BackendStorage, CpuStorage, DType, Device, Result, Shape, Storage, Tensor, D,
3};
4use k_quants::*;
5use std::borrow::Cow;
6
7#[cfg(target_feature = "avx2")]
8pub mod avx;
9mod dummy_cuda;
10mod dummy_metal;
11pub mod ggml_file;
12pub mod gguf_file;
13pub mod imatrix_file;
14pub mod k_quants;
15#[cfg(feature = "metal")]
16pub mod metal;
17#[cfg(not(feature = "metal"))]
18mod metal {
19 pub use super::dummy_metal::*;
20}
21#[cfg(feature = "cuda")]
22pub mod cuda;
23#[cfg(not(feature = "cuda"))]
24mod cuda {
25 pub use super::dummy_cuda::*;
26}
27
28#[cfg(target_feature = "neon")]
29pub mod neon;
30#[cfg(target_feature = "simd128")]
31pub mod simd128;
32pub mod utils;
33use half::{bf16, f16};
34
35pub use k_quants::GgmlType;
36
37fn as_t_slice<T>(data: Cow<'_, [u8]>) -> &[T] {
38 let size = std::mem::size_of::<T>();
39 assert_eq!(
40 data.len() % size,
41 0,
42 "Data length must be a multiple of T's size"
43 );
44 let ptr = data.as_ptr();
45 assert_eq!(
46 (ptr as usize) % std::mem::align_of::<T>(),
47 0,
48 "Data pointer must be aligned to T's alignment"
49 );
50 unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) }
51}
52
53pub struct QTensor {
54 storage: QStorage,
55 shape: Shape,
56}
57
58impl Device {
59 fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
60 match self {
61 Device::Cpu => {
62 let storage = dtype.cpu_zeros(elem_count);
63 Ok(QStorage::Cpu(storage))
64 }
65 Device::Metal(metal) => {
66 let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
67 Ok(QStorage::Metal(storage))
68 }
69 Device::Cuda(cuda) => {
70 let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
71 Ok(QStorage::Cuda(storage))
72 }
73 }
74 }
75}
76
77pub enum QStorage {
78 Cpu(Box<dyn QuantizedType>),
79 Metal(metal::QMetalStorage),
80 Cuda(cuda::QCudaStorage),
81}
82
83impl QStorage {
84 pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result<Self> {
85 match device {
86 Device::Cpu => Ok(Self::Cpu(dtype.from_data(data))),
87 Device::Metal(d) => match dtype {
88 GgmlDType::F32 => metal::load_quantized(d, as_t_slice::<f32>(data)),
89 GgmlDType::F16 => metal::load_quantized(d, as_t_slice::<f16>(data)),
90 GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
91 GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
92 GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
93 GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
94 GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
95 GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
96 GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
97 GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
98 GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
99 GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
100 GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
101 GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
102 GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::<bf16>(data)),
103 },
104 Device::Cuda(d) => match dtype {
105 GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::<f32>(data)),
106 GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::<f16>(data)),
107 GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::<BlockQ4_0>(data)),
108 GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::<BlockQ4_1>(data)),
109 GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::<BlockQ5_0>(data)),
110 GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::<BlockQ5_1>(data)),
111 GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::<BlockQ8_0>(data)),
112 GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::<BlockQ8_1>(data)),
113 GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::<BlockQ2K>(data)),
114 GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::<BlockQ3K>(data)),
115 GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::<BlockQ4K>(data)),
116 GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::<BlockQ5K>(data)),
117 GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::<BlockQ6K>(data)),
118 GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::<BlockQ8K>(data)),
119 GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::<bf16>(data)),
120 },
121 }
122 }
123
124 fn block_size(&self) -> usize {
125 match self {
126 QStorage::Cpu(storage) => storage.block_size(),
127 QStorage::Metal(storage) => storage.dtype().block_size(),
128 QStorage::Cuda(storage) => storage.dtype().block_size(),
129 }
130 }
131
132 fn dtype(&self) -> GgmlDType {
133 match self {
134 QStorage::Cpu(storage) => storage.dtype(),
135 QStorage::Metal(storage) => storage.dtype(),
136 QStorage::Cuda(storage) => storage.dtype(),
137 }
138 }
139
140 fn device(&self) -> Device {
141 match self {
142 QStorage::Cpu(_storage) => Device::Cpu,
143 QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
144 QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
145 }
146 }
147
148 fn size_in_bytes(&self) -> usize {
149 match self {
150 QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
151 QStorage::Metal(storage) => storage.storage_size_in_bytes(),
152 QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
153 }
154 }
155
156 fn quantize(&mut self, src: &Storage) -> Result<()> {
157 match (self, src) {
158 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
159 storage.from_float(src.as_slice::<f32>()?);
160 }
161 (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
162 (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
163 _ => crate::bail!("Invalid quantize storage locations do not match"),
164 }
165 Ok(())
166 }
167
168 fn quantize_imatrix(
169 &mut self,
170 src: &Storage,
171 imatrix_weights: &[f32],
172 n_per_row: usize,
173 ) -> Result<()> {
174 match (self, src) {
175 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
176 storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
177 }
178 (QStorage::Metal(storage), Storage::Metal(src)) => {
179 storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
180 }
181 (QStorage::Cuda(storage), Storage::Cuda(src)) => {
182 storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
183 }
184 _ => crate::bail!("Invalid quantize storage locations do not match"),
185 }
186 Ok(())
187 }
188
189 fn quantize_onto(&mut self, src: &Storage) -> Result<()> {
190 match (self, src) {
191 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
192 storage.from_float(src.as_slice::<f32>()?);
193 }
194 (QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
195 (QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
196 _ => crate::bail!("Invalid quantize source storage locations: not on cpu"),
197 }
198 Ok(())
199 }
200
201 fn quantize_imatrix_onto(
202 &mut self,
203 src: &Storage,
204 imatrix_weights: &[f32],
205 n_per_row: usize,
206 ) -> Result<()> {
207 match (self, src) {
208 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
209 storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
210 }
211 (QStorage::Metal(storage), Storage::Cpu(src)) => {
212 storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
213 }
214 (QStorage::Cuda(storage), Storage::Cpu(src)) => {
215 storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
216 }
217 _ => crate::bail!("Invalid quantize storage locations do not match"),
218 }
219 Ok(())
220 }
221
222 fn dequantize(&self, elem_count: usize) -> Result<Storage> {
223 match self {
224 QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
225 QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
226 QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
227 }
228 }
229
230 fn data(&self) -> Result<Cow<'_, [u8]>> {
231 match self {
232 QStorage::Cpu(storage) => {
233 let data_ptr = storage.as_ptr();
234 let size_in_bytes = storage.storage_size_in_bytes();
235 let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
236 Ok(Cow::from(data))
237 }
238 QStorage::Cuda(storage) => Ok(Cow::from(storage.data()?)),
239 QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)),
240 }
241 }
242
243 pub fn device_ptr(&self) -> Result<*const u8> {
244 match self {
245 QStorage::Cuda(storage) => storage.device_ptr(),
246 QStorage::Metal(_) | QStorage::Cpu(_) => {
247 crate::bail!("not implemented");
248 }
249 }
250 }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
254pub enum GgmlDType {
255 F32,
256 F16,
257 BF16,
258 Q4_0,
259 Q4_1,
260 Q5_0,
261 Q5_1,
262 Q8_0,
263 Q8_1,
264 Q2K,
265 Q3K,
266 Q4K,
267 Q5K,
268 Q6K,
269 Q8K,
270}
271
272impl GgmlDType {
273 pub(crate) fn from_u32(u: u32) -> Result<Self> {
274 let dtype = match u {
275 0 => Self::F32,
276 1 => Self::F16,
277 2 => Self::Q4_0,
278 3 => Self::Q4_1,
279 6 => Self::Q5_0,
280 7 => Self::Q5_1,
281 8 => Self::Q8_0,
282 9 => Self::Q8_1,
283 10 => Self::Q2K,
284 11 => Self::Q3K,
285 12 => Self::Q4K,
286 13 => Self::Q5K,
287 14 => Self::Q6K,
288 15 => Self::Q8K,
289 30 => Self::BF16,
291 _ => crate::bail!("unknown dtype for tensor {u}"),
292 };
293 Ok(dtype)
294 }
295
296 pub(crate) fn to_u32(self) -> u32 {
297 match self {
298 Self::F32 => 0,
299 Self::F16 => 1,
300 Self::Q4_0 => 2,
301 Self::Q4_1 => 3,
302 Self::Q5_0 => 6,
303 Self::Q5_1 => 7,
304 Self::Q8_0 => 8,
305 Self::Q8_1 => 9,
306 Self::Q2K => 10,
307 Self::Q3K => 11,
308 Self::Q4K => 12,
309 Self::Q5K => 13,
310 Self::Q6K => 14,
311 Self::Q8K => 15,
312 Self::BF16 => 30,
314 }
315 }
316
317 pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
319 match self {
320 Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
321 Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
322 Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
323 Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
324 Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
325 Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
326 Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
327 Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
328 Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
329 Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
330 Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
331 Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
332 Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
333 Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
334 Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
335 }
336 }
337
338 pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box<dyn QuantizedType> {
339 match self {
340 Self::F32 => Box::new(as_t_slice::<f32>(data).to_vec()),
341 Self::F16 => Box::new(as_t_slice::<f16>(data).to_vec()),
342 Self::Q4_0 => Box::new(as_t_slice::<BlockQ4_0>(data).to_vec()),
343 Self::Q4_1 => Box::new(as_t_slice::<BlockQ4_1>(data).to_vec()),
344 Self::Q5_0 => Box::new(as_t_slice::<BlockQ5_0>(data).to_vec()),
345 Self::Q5_1 => Box::new(as_t_slice::<BlockQ5_1>(data).to_vec()),
346 Self::Q8_0 => Box::new(as_t_slice::<BlockQ8_0>(data).to_vec()),
347 Self::Q8_1 => Box::new(as_t_slice::<BlockQ8_1>(data).to_vec()),
348 Self::Q2K => Box::new(as_t_slice::<BlockQ2K>(data).to_vec()),
349 Self::Q3K => Box::new(as_t_slice::<BlockQ3K>(data).to_vec()),
350 Self::Q4K => Box::new(as_t_slice::<BlockQ4K>(data).to_vec()),
351 Self::Q5K => Box::new(as_t_slice::<BlockQ5K>(data).to_vec()),
352 Self::Q6K => Box::new(as_t_slice::<BlockQ6K>(data).to_vec()),
353 Self::Q8K => Box::new(as_t_slice::<BlockQ8K>(data).to_vec()),
354 Self::BF16 => Box::new(as_t_slice::<bf16>(data).to_vec()),
355 }
356 }
357
358 pub fn type_size(&self) -> usize {
360 use k_quants::*;
361 match self {
362 Self::F32 => 4,
363 Self::F16 | Self::BF16 => 2,
364 Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
365 Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
366 Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
367 Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
368 Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
370 Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
371 Self::Q2K => std::mem::size_of::<BlockQ2K>(),
372 Self::Q3K => std::mem::size_of::<BlockQ3K>(),
373 Self::Q4K => std::mem::size_of::<BlockQ4K>(),
374 Self::Q5K => std::mem::size_of::<BlockQ5K>(),
375 Self::Q6K => std::mem::size_of::<BlockQ6K>(),
376 Self::Q8K => std::mem::size_of::<BlockQ8K>(),
377 }
378 }
379
380 pub fn block_size(&self) -> usize {
382 match self {
383 Self::F32 => 1,
384 Self::F16 | Self::BF16 => 1,
385 Self::Q4_0 => k_quants::QK4_0,
386 Self::Q4_1 => k_quants::QK4_1,
387 Self::Q5_0 => k_quants::QK5_0,
388 Self::Q5_1 => k_quants::QK5_1,
389 Self::Q8_0 => k_quants::QK8_0,
390 Self::Q8_1 => k_quants::QK8_1,
391 Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
392 }
393 }
394}
395
396pub trait QuantizedType: Send + Sync {
398 fn dtype(&self) -> GgmlDType;
399 fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
400 fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()>;
401 fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
402 fn storage_size_in_bytes(&self) -> usize;
403 fn as_ptr(&self) -> *const u8;
404 fn block_size(&self) -> usize;
405 #[allow(clippy::wrong_self_convention)]
406 fn from_float(&mut self, xs: &[f32]);
407 #[allow(clippy::wrong_self_convention)]
408 fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize);
409 fn size(&self) -> usize;
410}
411
412impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
413 fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
414 k_quants::matmul(mkn, lhs, self.as_slice(), dst)
415 }
416 fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()> {
417 k_quants::matmul_f16(mkn, lhs, self.as_slice(), dst)
418 }
419
420 fn size(&self) -> usize {
421 self.len() * core::mem::size_of::<T>()
422 }
423
424 fn from_float(&mut self, xs: &[f32]) {
425 T::from_float(xs, self)
426 }
427
428 fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize) {
429 T::from_float_imatrix(xs, self, imatrix_weights, n_per_row)
430 }
431
432 fn dtype(&self) -> GgmlDType {
433 T::DTYPE
434 }
435
436 fn block_size(&self) -> usize {
437 T::BLCK_SIZE
438 }
439
440 fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
441 let mut ys = vec![0.0f32; elem_count];
442 T::to_float(self.as_slice(), &mut ys);
443 Ok(CpuStorage::F32(ys))
444 }
445
446 fn storage_size_in_bytes(&self) -> usize {
447 self.len() * std::mem::size_of::<T>()
448 }
449
450 fn as_ptr(&self) -> *const u8 {
451 self.as_ptr() as *const u8
452 }
453}
454
455impl std::fmt::Debug for QTensor {
456 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
457 write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
458 }
459}
460
461fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
462 let dims = shape.dims();
463 if dims.is_empty() {
464 crate::bail!("scalar tensor cannot be quantized {shape:?}")
465 }
466 if !dims[dims.len() - 1].is_multiple_of(block_size) {
467 crate::bail!(
468 "quantized tensor must have their last dim divisible by block size {shape:?} {}",
469 block_size
470 )
471 }
472 Ok(())
473}
474
475impl QTensor {
476 pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
477 let shape = shape.into();
478 check_shape(&shape, storage.block_size())?;
479 Ok(Self { storage, shape })
480 }
481
482 pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
483 let shape = src.shape();
484 let block_size = dtype.block_size();
485 check_shape(shape, block_size)?;
486 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
487 let elem_count = shape.elem_count();
488 if !elem_count.is_multiple_of(block_size) {
489 crate::bail!(
490 "tensor size ({shape:?}) is not divisible by block size {}",
491 block_size
492 )
493 }
494 let mut storage = src.device().qzeros(elem_count, dtype)?;
495 storage.quantize(&src.storage())?;
496 Ok(Self {
497 storage,
498 shape: shape.clone(),
499 })
500 }
501
502 pub fn quantize_imatrix(
503 src: &Tensor,
504 imatrix_weights: &[f32],
505 dtype: GgmlDType,
506 ) -> Result<Self> {
507 let n_per_row = src.dim(D::Minus1)?;
510 if imatrix_weights.len() != n_per_row {
511 crate::bail!(
512 "imatrix weights must have the same length {} as the last dim of src {}",
513 imatrix_weights.len(),
514 src.dim(D::Minus1)?
515 );
516 }
517
518 let shape = src.shape();
519 let block_size = dtype.block_size();
520 check_shape(shape, block_size)?;
521 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
522 let elem_count = shape.elem_count();
523 if !elem_count.is_multiple_of(block_size) {
524 crate::bail!(
525 "tensor size ({shape:?}) is not divisible by block size {}",
526 block_size
527 );
528 }
529 let mut storage = src.device().qzeros(elem_count, dtype)?;
530 storage.quantize_imatrix(&src.storage(), imatrix_weights, n_per_row)?;
531 Ok(Self {
532 storage,
533 shape: shape.clone(),
534 })
535 }
536
537 pub fn quantize_imatrix_onto(
539 src: &Tensor,
540 imatrix_weights: &[f32],
541 dtype: GgmlDType,
542 dev: &Device,
543 ) -> Result<Self> {
544 if !src.device().is_cpu() {
545 crate::bail!(
546 "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
547 src.device()
548 )
549 }
550 let n_per_row = src.dim(D::Minus1)?;
553 if imatrix_weights.len() != n_per_row {
554 crate::bail!(
555 "imatrix weights must have the same length {} as the last dim of src {}",
556 imatrix_weights.len(),
557 src.dim(D::Minus1)?
558 );
559 }
560 let shape = src.shape();
561 let block_size = dtype.block_size();
562 check_shape(shape, block_size)?;
563 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
564 let elem_count = shape.elem_count();
565 if !elem_count.is_multiple_of(block_size) {
566 crate::bail!(
567 "tensor size ({shape:?}) is not divisible by block size {}",
568 block_size
569 )
570 }
571 let mut storage = dev.qzeros(elem_count, dtype)?;
573 storage.quantize_imatrix_onto(&src.storage(), imatrix_weights, n_per_row)?;
574 Ok(Self {
575 storage,
576 shape: shape.clone(),
577 })
578 }
579
580 pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result<Self> {
582 if !src.device().is_cpu() {
583 crate::bail!(
584 "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
585 src.device()
586 )
587 }
588 let shape = src.shape();
589 let block_size = dtype.block_size();
590 check_shape(shape, block_size)?;
591 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
592 let elem_count = shape.elem_count();
593 if !elem_count.is_multiple_of(block_size) {
594 crate::bail!(
595 "tensor size ({shape:?}) is not divisible by block size {}",
596 block_size
597 )
598 }
599 let mut storage = dev.qzeros(elem_count, dtype)?;
601 storage.quantize_onto(&src.storage())?;
602 Ok(Self {
603 storage,
604 shape: shape.clone(),
605 })
606 }
607
608 pub fn dtype(&self) -> GgmlDType {
609 self.storage.dtype()
610 }
611
612 pub fn device(&self) -> Device {
613 self.storage.device()
614 }
615
616 pub fn rank(&self) -> usize {
617 self.shape.rank()
618 }
619
620 pub fn shape(&self) -> &Shape {
621 &self.shape
622 }
623
624 pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
625 let storage = self.storage.dequantize(self.shape.elem_count())?;
626 let none = crate::op::BackpropOp::none();
627 crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
628 }
629
630 pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
631 match &self.storage {
634 QStorage::Cuda(s) => {
635 let s = s.dequantize_f16(self.shape.elem_count())?;
636 let none = crate::op::BackpropOp::none();
637 crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
638 .to_device(device)
639 }
640 _ => {
641 let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
642 Ok(s)
643 }
644 }
645 }
646
647 pub fn storage_size_in_bytes(&self) -> usize {
648 self.storage.size_in_bytes()
649 }
650
651 pub fn data(&self) -> Result<Cow<'_, [u8]>> {
652 self.storage.data()
653 }
654
655 pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
656 match &self.storage {
657 QStorage::Cuda(s) => match (&*x.storage(), &*ids.storage()) {
658 (Storage::Cuda(x_storage), Storage::Cuda(ids_storage)) => {
659 let (storage, out_shape) = s.indexed_moe_forward(
660 self.shape(),
661 x_storage,
662 x.layout(),
663 ids_storage,
664 ids.layout(),
665 )?;
666 Ok(crate::tensor::from_storage(
667 Storage::Cuda(storage),
668 out_shape,
669 crate::op::BackpropOp::none(),
670 false,
671 ))
672 }
673 _ => {
674 panic!("Non-cuda indexed_moe_forward is not implemented!");
675 }
676 },
677 _ => {
678 panic!("indexed_moe_forward is not implemented in this platform!");
679 }
680 }
681 }
682
683 pub fn device_ptr(&self) -> Result<*const u8> {
684 match &self.storage {
685 QStorage::Cuda(storage) => storage.device_ptr(),
686 QStorage::Metal(_) | QStorage::Cpu(_) => {
687 crate::bail!("not implemented");
688 }
689 }
690 }
691}
692
693#[derive(Clone, Debug)]
694pub enum QMatMul {
695 QTensor(std::sync::Arc<QTensor>),
696 Tensor(Tensor),
697 TensorF16(Tensor),
698}
699
700thread_local! {
701 static DEQUANTIZE_ALL: bool = {
702 match std::env::var("CANDLE_DEQUANTIZE_ALL") {
703 Ok(s) => {
704 !s.is_empty() && s != "0"
705 },
706 Err(_) => false,
707 }
708 }
709}
710
711thread_local! {
712 static DEQUANTIZE_ALL_F16: bool = {
713 match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
714 Ok(s) => {
715 !s.is_empty() && s != "0"
716 },
717 Err(_) => false,
718 }
719 }
720}
721
722impl QMatMul {
723 pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
724 let dequantize = match qtensor.dtype() {
725 GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true,
726 _ => DEQUANTIZE_ALL.with(|b| *b),
727 };
728 let t = if dequantize {
729 let tensor = qtensor.dequantize(&qtensor.device())?;
730 Self::Tensor(tensor)
731 } else if DEQUANTIZE_ALL_F16.with(|b| *b) {
732 let tensor = qtensor.dequantize_f16(&qtensor.device())?;
733 Self::TensorF16(tensor)
734 } else {
735 Self::QTensor(qtensor)
736 };
737 Ok(t)
738 }
739
740 pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
741 Self::from_arc(std::sync::Arc::new(qtensor))
742 }
743
744 pub fn dequantize_f16(&self) -> Result<Tensor> {
745 match self {
746 Self::QTensor(t) => t.dequantize_f16(&t.device()),
747 Self::Tensor(t) => t.to_dtype(DType::F16),
748 Self::TensorF16(t) => Ok(t.clone()),
749 }
750 }
751
752 pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
753 let w = self.dequantize_f16()?;
754 let in_dtype = xs.dtype();
755 let w = match *xs.dims() {
756 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
757 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
758 _ => w.t()?,
759 };
760 xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
761 }
762
763 pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
764 match self {
765 Self::QTensor(t) => t.indexed_moe_forward(x, ids),
766 _ => {
767 panic!("Not implemented!")
768 }
769 }
770 }
771}
772
773impl crate::CustomOp1 for QTensor {
774 fn name(&self) -> &'static str {
775 "qmatmul"
776 }
777
778 fn cpu_fwd(
779 &self,
780 storage: &crate::CpuStorage,
781 layout: &crate::Layout,
782 ) -> Result<(crate::CpuStorage, Shape)> {
783 if !layout.is_contiguous() {
784 crate::bail!("input tensor is not contiguous {layout:?}")
785 }
786 let src_shape = layout.shape();
787 let (n, k) = self.shape.dims2()?;
789 if src_shape.rank() < 2 {
790 crate::bail!("input tensor has only one dimension {layout:?}")
791 }
792 let mut dst_shape = src_shape.dims().to_vec();
793 let last_k = dst_shape.pop().unwrap();
794 if last_k != k {
795 crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
796 }
797 dst_shape.push(n);
798 let dst_shape = Shape::from(dst_shape);
799 #[allow(clippy::infallible_destructuring_match)]
800 let self_storage = match &self.storage {
801 QStorage::Cpu(storage) => storage,
802 QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
803 };
804 match storage.dtype() {
805 DType::F32 => {
806 let slice = storage.as_slice::<f32>()?;
807 let slice =
808 &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
809 let mut dst_storage = vec![0f32; dst_shape.elem_count()];
810 self_storage.matmul_t(
811 (dst_shape.elem_count() / n, k, n),
812 slice,
813 &mut dst_storage,
814 )?;
815 Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
816 }
817 DType::F16 => {
818 let slice = storage.as_slice::<f16>()?;
819 let slice =
820 &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
821 let mut dst_storage = vec![f16::ZERO; dst_shape.elem_count()];
822 self_storage.matmul_t_f16(
823 (dst_shape.elem_count() / n, k, n),
824 slice,
825 &mut dst_storage,
826 )?;
827 Ok((crate::CpuStorage::F16(dst_storage), dst_shape))
828 }
829 _ => crate::bail!("Expected f32/f16"),
830 }
831 }
832
833 fn metal_fwd(
834 &self,
835 storage: &crate::MetalStorage,
836 layout: &crate::Layout,
837 ) -> Result<(crate::MetalStorage, Shape)> {
838 let self_storage = match &self.storage {
839 QStorage::Metal(metal) => metal,
840 _ => unreachable!("Cannot call metal matmul on non metal QTensor"),
841 };
842 self_storage.fwd(&self.shape, storage, layout)
843 }
844
845 fn cuda_fwd(
846 &self,
847 storage: &crate::CudaStorage,
848 layout: &crate::Layout,
849 ) -> Result<(crate::CudaStorage, Shape)> {
850 let self_storage = match &self.storage {
851 QStorage::Cuda(cuda) => cuda,
852 _ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
853 };
854 self_storage.fwd(&self.shape, storage, layout)
855 }
856}
857
858impl crate::Module for QMatMul {
859 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
860 match self {
861 Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
862 Self::Tensor(w) => {
863 let w = match *xs.dims() {
864 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
865 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
866 _ => w.t()?,
867 };
868 xs.matmul(&w)
869 }
870 Self::TensorF16(w) => {
871 let in_dtype = xs.dtype();
872 let w = match *xs.dims() {
873 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
874 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
875 _ => w.t()?,
876 };
877 xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
878 }
879 }
880 }
881}