1use crate::{
2 backend::BackendStorage, CpuStorage, DType, Device, Result, Shape, Storage, Tensor, D,
3};
4use k_quants::*;
5use std::borrow::Cow;
6
7#[cfg(target_feature = "avx2")]
8pub mod avx;
9mod dummy_cuda;
10mod dummy_metal;
11pub mod ggml_file;
12pub mod gguf_file;
13pub mod imatrix_file;
14pub mod k_quants;
15#[cfg(feature = "metal")]
16pub mod metal;
17#[cfg(not(target_arch = "wasm32"))]
18pub mod tokenizer;
19#[cfg(not(feature = "metal"))]
20mod metal {
21 pub use super::dummy_metal::*;
22}
23#[cfg(feature = "cuda")]
24pub mod cuda;
25#[cfg(feature = "cuda")]
26pub mod fast_mmq;
27#[cfg(feature = "cuda")]
28pub mod fast_mmvq;
29#[cfg(not(feature = "cuda"))]
30mod cuda {
31 pub use super::dummy_cuda::*;
32}
33
34#[cfg(target_feature = "neon")]
35pub mod neon;
36#[cfg(target_feature = "simd128")]
37pub mod simd128;
38pub mod utils;
39use half::{bf16, f16};
40
41pub use k_quants::GgmlType;
42
43fn as_t_slice<T>(data: &[u8]) -> &[T] {
47 let size = std::mem::size_of::<T>();
48 assert_eq!(
49 data.len() % size,
50 0,
51 "Data length must be a multiple of T's size"
52 );
53 let ptr = data.as_ptr();
54 assert_eq!(
55 (ptr as usize) % std::mem::align_of::<T>(),
56 0,
57 "Data pointer must be aligned to T's alignment"
58 );
59 unsafe { std::slice::from_raw_parts(ptr as *const T, data.len() / size) }
60}
61
62pub struct QTensor {
63 storage: QStorage,
64 shape: Shape,
65}
66
67impl Device {
68 fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
69 match self {
70 Device::Cpu => {
71 let storage = dtype.cpu_zeros(elem_count);
72 Ok(QStorage::Cpu(storage))
73 }
74 Device::Metal(metal) => {
75 let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
76 Ok(QStorage::Metal(storage))
77 }
78 Device::Cuda(cuda) => {
79 let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
80 Ok(QStorage::Cuda(storage))
81 }
82 #[cfg(feature = "rocm")]
83 Device::Rocm(_) => crate::bail!("quantized tensors on rocm are not supported yet"),
84 #[cfg(feature = "vulkan")]
85 Device::Vulkan(d) => {
86 let storage = dtype.cpu_zeros(elem_count);
89 Ok(QStorage::Vulkan(storage, d.clone()))
90 }
91 }
92 }
93}
94
95pub enum QStorage {
96 Cpu(Box<dyn QuantizedType>),
97 Metal(metal::QMetalStorage),
98 Cuda(cuda::QCudaStorage),
99 #[cfg(feature = "vulkan")]
103 Vulkan(Box<dyn QuantizedType>, crate::VulkanDevice),
104}
105
106impl QStorage {
107 pub fn from_data(data: Cow<'_, [u8]>, device: &Device, dtype: GgmlDType) -> Result<Self> {
108 match device {
109 Device::Cpu => Ok(Self::Cpu(dtype.from_data(data))),
110 Device::Metal(d) => match dtype {
111 GgmlDType::F32 => metal::load_quantized(d, as_t_slice::<f32>(&data)),
112 GgmlDType::F16 => metal::load_quantized(d, as_t_slice::<f16>(&data)),
113 GgmlDType::Q4_0 => metal::load_quantized(d, as_t_slice::<BlockQ4_0>(&data)),
114 GgmlDType::Q4_1 => metal::load_quantized(d, as_t_slice::<BlockQ4_1>(&data)),
115 GgmlDType::Q5_0 => metal::load_quantized(d, as_t_slice::<BlockQ5_0>(&data)),
116 GgmlDType::Q5_1 => metal::load_quantized(d, as_t_slice::<BlockQ5_1>(&data)),
117 GgmlDType::Q8_0 => metal::load_quantized(d, as_t_slice::<BlockQ8_0>(&data)),
118 GgmlDType::Q8_1 => metal::load_quantized(d, as_t_slice::<BlockQ8_1>(&data)),
119 GgmlDType::Q2K => metal::load_quantized(d, as_t_slice::<BlockQ2K>(&data)),
120 GgmlDType::Q3K => metal::load_quantized(d, as_t_slice::<BlockQ3K>(&data)),
121 GgmlDType::Q4K => metal::load_quantized(d, as_t_slice::<BlockQ4K>(&data)),
122 GgmlDType::Q5K => metal::load_quantized(d, as_t_slice::<BlockQ5K>(&data)),
123 GgmlDType::Q6K => metal::load_quantized(d, as_t_slice::<BlockQ6K>(&data)),
124 GgmlDType::Q8K => metal::load_quantized(d, as_t_slice::<BlockQ8K>(&data)),
125 GgmlDType::BF16 => metal::load_quantized(d, as_t_slice::<bf16>(&data)),
126 },
127 Device::Cuda(d) => match dtype {
128 GgmlDType::F32 => cuda::load_quantized(d, as_t_slice::<f32>(&data)),
129 GgmlDType::F16 => cuda::load_quantized(d, as_t_slice::<f16>(&data)),
130 GgmlDType::Q4_0 => cuda::load_quantized(d, as_t_slice::<BlockQ4_0>(&data)),
131 GgmlDType::Q4_1 => cuda::load_quantized(d, as_t_slice::<BlockQ4_1>(&data)),
132 GgmlDType::Q5_0 => cuda::load_quantized(d, as_t_slice::<BlockQ5_0>(&data)),
133 GgmlDType::Q5_1 => cuda::load_quantized(d, as_t_slice::<BlockQ5_1>(&data)),
134 GgmlDType::Q8_0 => cuda::load_quantized(d, as_t_slice::<BlockQ8_0>(&data)),
135 GgmlDType::Q8_1 => cuda::load_quantized(d, as_t_slice::<BlockQ8_1>(&data)),
136 GgmlDType::Q2K => cuda::load_quantized(d, as_t_slice::<BlockQ2K>(&data)),
137 GgmlDType::Q3K => cuda::load_quantized(d, as_t_slice::<BlockQ3K>(&data)),
138 GgmlDType::Q4K => cuda::load_quantized(d, as_t_slice::<BlockQ4K>(&data)),
139 GgmlDType::Q5K => cuda::load_quantized(d, as_t_slice::<BlockQ5K>(&data)),
140 GgmlDType::Q6K => cuda::load_quantized(d, as_t_slice::<BlockQ6K>(&data)),
141 GgmlDType::Q8K => cuda::load_quantized(d, as_t_slice::<BlockQ8K>(&data)),
142 GgmlDType::BF16 => cuda::load_quantized(d, as_t_slice::<bf16>(&data)),
143 },
144 #[cfg(feature = "rocm")]
145 Device::Rocm(_) => crate::bail!("quantized tensors on rocm are not supported yet"),
146 #[cfg(feature = "vulkan")]
147 Device::Vulkan(d) => Ok(Self::Vulkan(dtype.from_data(data), d.clone())),
148 }
149 }
150
151 fn block_size(&self) -> usize {
152 match self {
153 QStorage::Cpu(storage) => storage.block_size(),
154 QStorage::Metal(storage) => storage.dtype().block_size(),
155 QStorage::Cuda(storage) => storage.dtype().block_size(),
156 #[cfg(feature = "vulkan")]
157 QStorage::Vulkan(storage, _) => storage.block_size(),
158 }
159 }
160
161 fn dtype(&self) -> GgmlDType {
162 match self {
163 QStorage::Cpu(storage) => storage.dtype(),
164 QStorage::Metal(storage) => storage.dtype(),
165 QStorage::Cuda(storage) => storage.dtype(),
166 #[cfg(feature = "vulkan")]
167 QStorage::Vulkan(storage, _) => storage.dtype(),
168 }
169 }
170
171 fn device(&self) -> Device {
172 match self {
173 QStorage::Cpu(_storage) => Device::Cpu,
174 QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
175 QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
176 #[cfg(feature = "vulkan")]
177 QStorage::Vulkan(_storage, device) => Device::Vulkan(device.clone()),
178 }
179 }
180
181 fn size_in_bytes(&self) -> usize {
182 match self {
183 QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
184 QStorage::Metal(storage) => storage.storage_size_in_bytes(),
185 QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
186 #[cfg(feature = "vulkan")]
187 QStorage::Vulkan(storage, _) => storage.storage_size_in_bytes(),
188 }
189 }
190
191 fn quantize(&mut self, src: &Storage) -> Result<()> {
192 match (self, src) {
193 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
194 storage.from_float(src.as_slice::<f32>()?);
195 }
196 (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
197 (QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
198 _ => crate::bail!("Invalid quantize storage locations do not match"),
199 }
200 Ok(())
201 }
202
203 fn quantize_imatrix(
204 &mut self,
205 src: &Storage,
206 imatrix_weights: &[f32],
207 n_per_row: usize,
208 ) -> Result<()> {
209 match (self, src) {
210 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
211 storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
212 }
213 (QStorage::Metal(storage), Storage::Metal(src)) => {
214 storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
215 }
216 (QStorage::Cuda(storage), Storage::Cuda(src)) => {
217 storage.quantize_imatrix(src, imatrix_weights, n_per_row)?
218 }
219 _ => crate::bail!("Invalid quantize storage locations do not match"),
220 }
221 Ok(())
222 }
223
224 fn quantize_onto(&mut self, src: &Storage) -> Result<()> {
225 match (self, src) {
226 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
227 storage.from_float(src.as_slice::<f32>()?);
228 }
229 (QStorage::Metal(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
230 (QStorage::Cuda(storage), Storage::Cpu(src)) => storage.quantize_onto(src)?,
231 _ => crate::bail!("Invalid quantize source storage locations: not on cpu"),
232 }
233 Ok(())
234 }
235
236 fn quantize_imatrix_onto(
237 &mut self,
238 src: &Storage,
239 imatrix_weights: &[f32],
240 n_per_row: usize,
241 ) -> Result<()> {
242 match (self, src) {
243 (QStorage::Cpu(storage), Storage::Cpu(src)) => {
244 storage.from_float_imatrix(src.as_slice::<f32>()?, imatrix_weights, n_per_row);
245 }
246 (QStorage::Metal(storage), Storage::Cpu(src)) => {
247 storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
248 }
249 (QStorage::Cuda(storage), Storage::Cpu(src)) => {
250 storage.quantize_imatrix_onto(src, imatrix_weights, n_per_row)?
251 }
252 _ => crate::bail!("Invalid quantize storage locations do not match"),
253 }
254 Ok(())
255 }
256
257 fn dequantize(&self, elem_count: usize) -> Result<Storage> {
258 match self {
259 QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
260 QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
261 QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
262 #[cfg(feature = "vulkan")]
263 QStorage::Vulkan(storage, device) => {
264 let cpu = storage.dequantize(elem_count)?;
266 Ok(Storage::Vulkan(device.upload_f32(cpu.as_slice::<f32>()?)?))
267 }
268 }
269 }
270
271 fn data(&self) -> Result<Cow<'_, [u8]>> {
272 match self {
273 QStorage::Cpu(storage) => {
274 let data_ptr = storage.as_ptr();
275 let size_in_bytes = storage.storage_size_in_bytes();
276 let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
277 Ok(Cow::from(data))
278 }
279 QStorage::Cuda(storage) => Ok(Cow::from(storage.data()?)),
280 QStorage::Metal(storage) => Ok(Cow::from(storage.data()?)),
281 #[cfg(feature = "vulkan")]
282 QStorage::Vulkan(storage, _) => {
283 let data_ptr = storage.as_ptr();
284 let size_in_bytes = storage.storage_size_in_bytes();
285 let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
286 Ok(Cow::from(data))
287 }
288 }
289 }
290
291 pub fn device_ptr(&self) -> Result<*const u8> {
292 match self {
293 QStorage::Cuda(storage) => storage.device_ptr(),
294 #[cfg(feature = "vulkan")]
295 QStorage::Vulkan(..) => crate::bail!("not implemented"),
296 QStorage::Metal(_) | QStorage::Cpu(_) => {
297 crate::bail!("not implemented");
298 }
299 }
300 }
301
302 #[cfg(feature = "cuda")]
303 pub fn device_ptr_with_guard<'a>(
304 &'a self,
305 stream: &'a crate::cuda_backend::cudarc::driver::CudaStream,
306 ) -> Result<(
307 *const u8,
308 crate::cuda_backend::cudarc::driver::SyncOnDrop<'a>,
309 )> {
310 match self {
311 QStorage::Cuda(storage) => storage.device_ptr_with_guard(stream),
312 QStorage::Metal(_) | QStorage::Cpu(_) => {
313 crate::bail!("not implemented");
314 }
315 }
316 }
317}
318
319#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
320pub enum GgmlDType {
321 F32,
322 F16,
323 BF16,
324 Q4_0,
325 Q4_1,
326 Q5_0,
327 Q5_1,
328 Q8_0,
329 Q8_1,
330 Q2K,
331 Q3K,
332 Q4K,
333 Q5K,
334 Q6K,
335 Q8K,
336}
337
338impl GgmlDType {
339 pub(crate) fn from_u32(u: u32) -> Result<Self> {
340 let dtype = match u {
341 0 => Self::F32,
342 1 => Self::F16,
343 2 => Self::Q4_0,
344 3 => Self::Q4_1,
345 6 => Self::Q5_0,
346 7 => Self::Q5_1,
347 8 => Self::Q8_0,
348 9 => Self::Q8_1,
349 10 => Self::Q2K,
350 11 => Self::Q3K,
351 12 => Self::Q4K,
352 13 => Self::Q5K,
353 14 => Self::Q6K,
354 15 => Self::Q8K,
355 30 => Self::BF16,
357 _ => crate::bail!("unknown dtype for tensor {u}"),
358 };
359 Ok(dtype)
360 }
361
362 pub(crate) fn to_u32(self) -> u32 {
363 match self {
364 Self::F32 => 0,
365 Self::F16 => 1,
366 Self::Q4_0 => 2,
367 Self::Q4_1 => 3,
368 Self::Q5_0 => 6,
369 Self::Q5_1 => 7,
370 Self::Q8_0 => 8,
371 Self::Q8_1 => 9,
372 Self::Q2K => 10,
373 Self::Q3K => 11,
374 Self::Q4K => 12,
375 Self::Q5K => 13,
376 Self::Q6K => 14,
377 Self::Q8K => 15,
378 Self::BF16 => 30,
380 }
381 }
382
383 pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
385 match self {
386 Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
387 Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
388 Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
389 Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
390 Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
391 Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
392 Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
393 Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
394 Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
395 Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
396 Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
397 Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
398 Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
399 Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
400 Self::BF16 => Box::new(vec![bf16::zeros(); elem_count]),
401 }
402 }
403
404 pub fn from_data(&self, data: Cow<'_, [u8]>) -> Box<dyn QuantizedType> {
405 match self {
406 Self::F32 => Box::new(as_t_slice::<f32>(&data).to_vec()),
407 Self::F16 => Box::new(as_t_slice::<f16>(&data).to_vec()),
408 Self::Q4_0 => Box::new(as_t_slice::<BlockQ4_0>(&data).to_vec()),
409 Self::Q4_1 => Box::new(as_t_slice::<BlockQ4_1>(&data).to_vec()),
410 Self::Q5_0 => Box::new(as_t_slice::<BlockQ5_0>(&data).to_vec()),
411 Self::Q5_1 => Box::new(as_t_slice::<BlockQ5_1>(&data).to_vec()),
412 Self::Q8_0 => Box::new(as_t_slice::<BlockQ8_0>(&data).to_vec()),
413 Self::Q8_1 => Box::new(as_t_slice::<BlockQ8_1>(&data).to_vec()),
414 Self::Q2K => Box::new(as_t_slice::<BlockQ2K>(&data).to_vec()),
415 Self::Q3K => Box::new(as_t_slice::<BlockQ3K>(&data).to_vec()),
416 Self::Q4K => Box::new(as_t_slice::<BlockQ4K>(&data).to_vec()),
417 Self::Q5K => Box::new(as_t_slice::<BlockQ5K>(&data).to_vec()),
418 Self::Q6K => Box::new(as_t_slice::<BlockQ6K>(&data).to_vec()),
419 Self::Q8K => Box::new(as_t_slice::<BlockQ8K>(&data).to_vec()),
420 Self::BF16 => Box::new(as_t_slice::<bf16>(&data).to_vec()),
421 }
422 }
423
424 pub fn type_size(&self) -> usize {
426 use k_quants::*;
427 match self {
428 Self::F32 => 4,
429 Self::F16 | Self::BF16 => 2,
430 Self::Q4_0 => std::mem::size_of::<BlockQ4_0>(),
431 Self::Q4_1 => std::mem::size_of::<BlockQ4_1>(),
432 Self::Q5_0 => std::mem::size_of::<BlockQ5_0>(),
433 Self::Q5_1 => std::mem::size_of::<BlockQ5_1>(),
434 Self::Q8_0 => std::mem::size_of::<BlockQ8_0>(),
436 Self::Q8_1 => std::mem::size_of::<BlockQ8_1>(),
437 Self::Q2K => std::mem::size_of::<BlockQ2K>(),
438 Self::Q3K => std::mem::size_of::<BlockQ3K>(),
439 Self::Q4K => std::mem::size_of::<BlockQ4K>(),
440 Self::Q5K => std::mem::size_of::<BlockQ5K>(),
441 Self::Q6K => std::mem::size_of::<BlockQ6K>(),
442 Self::Q8K => std::mem::size_of::<BlockQ8K>(),
443 }
444 }
445
446 pub fn block_size(&self) -> usize {
448 match self {
449 Self::F32 => 1,
450 Self::F16 | Self::BF16 => 1,
451 Self::Q4_0 => k_quants::QK4_0,
452 Self::Q4_1 => k_quants::QK4_1,
453 Self::Q5_0 => k_quants::QK5_0,
454 Self::Q5_1 => k_quants::QK5_1,
455 Self::Q8_0 => k_quants::QK8_0,
456 Self::Q8_1 => k_quants::QK8_1,
457 Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K,
458 }
459 }
460}
461
462pub trait QuantizedType: Send + Sync {
464 fn dtype(&self) -> GgmlDType;
465 fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
466 fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()>;
467 fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>;
468 fn storage_size_in_bytes(&self) -> usize;
469 fn as_ptr(&self) -> *const u8;
470 fn block_size(&self) -> usize;
471 #[allow(clippy::wrong_self_convention)]
472 fn from_float(&mut self, xs: &[f32]);
473 #[allow(clippy::wrong_self_convention)]
474 fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize);
475 fn size(&self) -> usize;
476}
477
478impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
479 fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
480 k_quants::matmul(mkn, lhs, self.as_slice(), dst)
481 }
482 fn matmul_t_f16(&self, mkn: (usize, usize, usize), lhs: &[f16], dst: &mut [f16]) -> Result<()> {
483 k_quants::matmul_f16(mkn, lhs, self.as_slice(), dst)
484 }
485
486 fn size(&self) -> usize {
487 self.len() * core::mem::size_of::<T>()
488 }
489
490 fn from_float(&mut self, xs: &[f32]) {
491 T::from_float(xs, self)
492 }
493
494 fn from_float_imatrix(&mut self, xs: &[f32], imatrix_weights: &[f32], n_per_row: usize) {
495 T::from_float_imatrix(xs, self, imatrix_weights, n_per_row)
496 }
497
498 fn dtype(&self) -> GgmlDType {
499 T::DTYPE
500 }
501
502 fn block_size(&self) -> usize {
503 T::BLCK_SIZE
504 }
505
506 fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
507 let mut ys = vec![0.0f32; elem_count];
508 T::to_float(self.as_slice(), &mut ys);
509 Ok(CpuStorage::F32(ys))
510 }
511
512 fn storage_size_in_bytes(&self) -> usize {
513 self.len() * std::mem::size_of::<T>()
514 }
515
516 fn as_ptr(&self) -> *const u8 {
517 self.as_ptr() as *const u8
518 }
519}
520
521impl std::fmt::Debug for QTensor {
522 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
523 write!(f, "QTensor[{:?}; {:?}]", self.shape, self.dtype())
524 }
525}
526
527fn check_shape(shape: &Shape, block_size: usize) -> Result<()> {
528 let dims = shape.dims();
529 if dims.is_empty() {
530 crate::bail!("scalar tensor cannot be quantized {shape:?}")
531 }
532 if !dims[dims.len() - 1].is_multiple_of(block_size) {
533 crate::bail!(
534 "quantized tensor must have their last dim divisible by block size {shape:?} {}",
535 block_size
536 )
537 }
538 Ok(())
539}
540
541impl QTensor {
542 pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> {
543 let shape = shape.into();
544 check_shape(&shape, storage.block_size())?;
545 Ok(Self { storage, shape })
546 }
547
548 pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> {
549 let shape = src.shape();
550 let block_size = dtype.block_size();
551 check_shape(shape, block_size)?;
552 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
553 let elem_count = shape.elem_count();
554 if !elem_count.is_multiple_of(block_size) {
555 crate::bail!(
556 "tensor size ({shape:?}) is not divisible by block size {}",
557 block_size
558 )
559 }
560 let mut storage = src.device().qzeros(elem_count, dtype)?;
561 storage.quantize(&src.storage())?;
562 Ok(Self {
563 storage,
564 shape: shape.clone(),
565 })
566 }
567
568 pub fn quantize_imatrix(
569 src: &Tensor,
570 imatrix_weights: &[f32],
571 dtype: GgmlDType,
572 ) -> Result<Self> {
573 let n_per_row = src.dim(D::Minus1)?;
576 if imatrix_weights.len() != n_per_row {
577 crate::bail!(
578 "imatrix weights must have the same length {} as the last dim of src {}",
579 imatrix_weights.len(),
580 src.dim(D::Minus1)?
581 );
582 }
583
584 let shape = src.shape();
585 let block_size = dtype.block_size();
586 check_shape(shape, block_size)?;
587 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
588 let elem_count = shape.elem_count();
589 if !elem_count.is_multiple_of(block_size) {
590 crate::bail!(
591 "tensor size ({shape:?}) is not divisible by block size {}",
592 block_size
593 );
594 }
595 let mut storage = src.device().qzeros(elem_count, dtype)?;
596 storage.quantize_imatrix(&src.storage(), imatrix_weights, n_per_row)?;
597 Ok(Self {
598 storage,
599 shape: shape.clone(),
600 })
601 }
602
603 pub fn quantize_imatrix_onto(
605 src: &Tensor,
606 imatrix_weights: &[f32],
607 dtype: GgmlDType,
608 dev: &Device,
609 ) -> Result<Self> {
610 if !src.device().is_cpu() {
611 crate::bail!(
612 "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
613 src.device()
614 )
615 }
616 let n_per_row = src.dim(D::Minus1)?;
619 if imatrix_weights.len() != n_per_row {
620 crate::bail!(
621 "imatrix weights must have the same length {} as the last dim of src {}",
622 imatrix_weights.len(),
623 src.dim(D::Minus1)?
624 );
625 }
626 let shape = src.shape();
627 let block_size = dtype.block_size();
628 check_shape(shape, block_size)?;
629 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
630 let elem_count = shape.elem_count();
631 if !elem_count.is_multiple_of(block_size) {
632 crate::bail!(
633 "tensor size ({shape:?}) is not divisible by block size {}",
634 block_size
635 )
636 }
637 let mut storage = dev.qzeros(elem_count, dtype)?;
639 storage.quantize_imatrix_onto(&src.storage(), imatrix_weights, n_per_row)?;
640 Ok(Self {
641 storage,
642 shape: shape.clone(),
643 })
644 }
645
646 pub fn quantize_onto(src: &Tensor, dtype: GgmlDType, dev: &Device) -> Result<Self> {
648 if !src.device().is_cpu() {
649 crate::bail!(
650 "`quantize_onto` expects a `src` to be on the cpu, got {:?}.",
651 src.device()
652 )
653 }
654 let shape = src.shape();
655 let block_size = dtype.block_size();
656 check_shape(shape, block_size)?;
657 let src = src.to_dtype(crate::DType::F32)?.flatten_all()?;
658 let elem_count = shape.elem_count();
659 if !elem_count.is_multiple_of(block_size) {
660 crate::bail!(
661 "tensor size ({shape:?}) is not divisible by block size {}",
662 block_size
663 )
664 }
665 let mut storage = dev.qzeros(elem_count, dtype)?;
667 storage.quantize_onto(&src.storage())?;
668 Ok(Self {
669 storage,
670 shape: shape.clone(),
671 })
672 }
673
674 pub fn dtype(&self) -> GgmlDType {
675 self.storage.dtype()
676 }
677
678 pub fn device(&self) -> Device {
679 self.storage.device()
680 }
681
682 pub fn rank(&self) -> usize {
683 self.shape.rank()
684 }
685
686 pub fn shape(&self) -> &Shape {
687 &self.shape
688 }
689
690 pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
691 let storage = self.storage.dequantize(self.shape.elem_count())?;
692 let none = crate::op::BackpropOp::none();
693 crate::tensor::from_storage(storage, self.shape.clone(), none, false).to_device(device)
694 }
695
696 pub fn dequantize_f16(&self, device: &Device) -> Result<Tensor> {
697 match &self.storage {
700 QStorage::Cuda(s) => {
701 let s = s.dequantize_f16(self.shape.elem_count())?;
702 let none = crate::op::BackpropOp::none();
703 crate::tensor::from_storage(Storage::Cuda(s), self.shape.clone(), none, false)
704 .to_device(device)
705 }
706 _ => {
707 let s = self.dequantize(device)?.to_dtype(crate::DType::F16)?;
708 Ok(s)
709 }
710 }
711 }
712
713 pub fn storage_size_in_bytes(&self) -> usize {
714 self.storage.size_in_bytes()
715 }
716
717 pub fn data(&self) -> Result<Cow<'_, [u8]>> {
718 self.storage.data()
719 }
720
721 pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
722 match &self.storage {
723 QStorage::Cuda(s) => match (&*x.storage(), &*ids.storage()) {
724 (Storage::Cuda(x_storage), Storage::Cuda(ids_storage)) => {
725 let (storage, out_shape) = s.indexed_moe_forward(
726 self.shape(),
727 x_storage,
728 x.layout(),
729 ids_storage,
730 ids.layout(),
731 )?;
732 Ok(crate::tensor::from_storage(
733 Storage::Cuda(storage),
734 out_shape,
735 crate::op::BackpropOp::none(),
736 false,
737 ))
738 }
739 _ => {
740 panic!("Non-cuda indexed_moe_forward is not implemented!");
741 }
742 },
743 _ => {
744 use crate::Module; use std::collections::HashMap;
751 use std::sync::Arc;
752 let device = x.device();
753 let out_dtype = x.dtype();
754 let (e_cnt, n, k) = self.shape().dims3()?;
755 let (t, topk) = ids.dims2()?;
756 let s = x.dim(1)?; let x_exp = if s == topk {
758 x.clone()
759 } else {
760 x.broadcast_as((t, topk, k))?
761 };
762 let x_flat = x_exp
763 .reshape((t * topk, k))?
764 .to_dtype(crate::DType::F32)?
765 .contiguous()?;
766 let ids_flat = ids.reshape((t * topk,))?.to_dtype(crate::DType::U32)?;
767 let ids_vec = ids_flat.to_vec1::<u32>()?;
768 let mut groups: HashMap<u32, Vec<u32>> = HashMap::new();
769 for (slot, eid) in ids_vec.iter().enumerate() {
770 groups.entry(*eid).or_default().push(slot as u32);
771 }
772 let dtype = self.storage.dtype();
773 let all_bytes = self.data()?;
774 let expert_bytes = all_bytes.len() / e_cnt;
775 let mut out_flat = Tensor::zeros((t * topk, n), crate::DType::F32, device)?;
776 for (eid, slots) in groups.into_iter() {
777 let off = eid as usize * expert_bytes;
778 let qs = QStorage::from_data(
779 std::borrow::Cow::Borrowed(&all_bytes[off..off + expert_bytes]),
780 device,
781 dtype,
782 )?;
783 let shape: crate::Shape = (n, k).into();
784 let w_e = QTensor { storage: qs, shape };
785 let qm = QMatMul::from_arc(Arc::new(w_e))?;
786 let m = slots.len();
787 let idx = Tensor::from_vec(slots, (m,), device)?;
788 let x_e = x_flat.index_select(&idx, 0)?; let y_e = qm.forward(&x_e)?.to_dtype(crate::DType::F32)?; out_flat = out_flat.index_add(&idx, &y_e, 0)?;
791 }
792 out_flat.reshape((t, topk, n))?.to_dtype(out_dtype)
793 }
794 }
795 }
796
797 pub fn device_ptr(&self) -> Result<*const u8> {
798 match &self.storage {
799 QStorage::Cuda(storage) => storage.device_ptr(),
800 #[cfg(feature = "vulkan")]
801 QStorage::Vulkan(..) => crate::bail!("not implemented"),
802 QStorage::Metal(_) | QStorage::Cpu(_) => {
803 crate::bail!("not implemented");
804 }
805 }
806 }
807
808 #[cfg(feature = "cuda")]
809 pub fn device_ptr_with_guard<'a>(
810 &'a self,
811 stream: &'a crate::cuda_backend::cudarc::driver::CudaStream,
812 ) -> Result<(
813 *const u8,
814 crate::cuda_backend::cudarc::driver::SyncOnDrop<'a>,
815 )> {
816 self.storage.device_ptr_with_guard(stream)
817 }
818}
819
820#[derive(Clone, Debug)]
821pub enum QMatMul {
822 QTensor(std::sync::Arc<QTensor>),
823 Tensor(Tensor),
824 TensorF16(Tensor),
825 #[cfg(feature = "vulkan")]
829 VulkanQuant {
830 qtensor: std::sync::Arc<QTensor>,
831 wq: std::sync::Arc<crate::VulkanStorage>,
832 n: usize,
833 k: usize,
834 },
835}
836
837thread_local! {
838 static DEQUANTIZE_ALL: bool = {
839 match std::env::var("CANDLE_DEQUANTIZE_ALL") {
840 Ok(s) => {
841 !s.is_empty() && s != "0"
842 },
843 Err(_) => false,
844 }
845 }
846}
847
848thread_local! {
849 static DEQUANTIZE_ALL_F16: bool = {
850 match std::env::var("CANDLE_DEQUANTIZE_ALL_F16") {
851 Ok(s) => {
852 !s.is_empty() && s != "0"
853 },
854 Err(_) => false,
855 }
856 }
857}
858
859impl QMatMul {
860 pub fn from_arc(qtensor: std::sync::Arc<QTensor>) -> Result<Self> {
861 #[cfg(feature = "vulkan")]
864 if qtensor.dtype() == GgmlDType::Q8_0 {
865 if let Device::Vulkan(d) = qtensor.device() {
866 if let Ok((n, k)) = qtensor.shape().dims2() {
867 if k % 32 == 0 {
868 let wf = qtensor
869 .dequantize(&Device::Cpu)?
870 .flatten_all()?
871 .to_vec1::<f32>()?;
872 let wq = d.quantize_q8(&wf, n, k)?;
873 return Ok(Self::VulkanQuant {
874 qtensor,
875 wq: std::sync::Arc::new(wq),
876 n,
877 k,
878 });
879 }
880 }
881 }
882 }
883 let dequantize = match qtensor.dtype() {
884 GgmlDType::F32 | GgmlDType::F16 | GgmlDType::BF16 => true,
885 _ => DEQUANTIZE_ALL.with(|b| *b) || qtensor.device().is_vulkan(),
888 };
889 let t = if dequantize {
890 let tensor = qtensor.dequantize(&qtensor.device())?;
891 Self::Tensor(tensor)
892 } else if DEQUANTIZE_ALL_F16.with(|b| *b) {
893 let tensor = qtensor.dequantize_f16(&qtensor.device())?;
894 Self::TensorF16(tensor)
895 } else {
896 Self::QTensor(qtensor)
897 };
898 Ok(t)
899 }
900
901 pub fn from_qtensor(qtensor: QTensor) -> Result<Self> {
902 Self::from_arc(std::sync::Arc::new(qtensor))
903 }
904
905 pub fn dequantize_f16(&self) -> Result<Tensor> {
906 match self {
907 Self::QTensor(t) => t.dequantize_f16(&t.device()),
908 Self::Tensor(t) => t.to_dtype(DType::F16),
909 Self::TensorF16(t) => Ok(t.clone()),
910 #[cfg(feature = "vulkan")]
911 Self::VulkanQuant { qtensor, .. } => qtensor.dequantize_f16(&qtensor.device()),
912 }
913 }
914
915 pub fn forward_via_f16(&self, xs: &Tensor) -> Result<Tensor> {
916 let w = self.dequantize_f16()?;
917 let in_dtype = xs.dtype();
918 let w = match *xs.dims() {
919 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
920 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
921 _ => w.t()?,
922 };
923 xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
924 }
925
926 pub fn indexed_moe_forward(&self, x: &Tensor, ids: &Tensor) -> Result<Tensor> {
927 match self {
928 Self::QTensor(t) => t.indexed_moe_forward(x, ids),
929 #[cfg(feature = "vulkan")]
930 Self::VulkanQuant { qtensor, .. } => qtensor.indexed_moe_forward(x, ids),
931 _ => {
932 panic!("Not implemented!")
933 }
934 }
935 }
936}
937
938impl crate::CustomOp1 for QTensor {
939 fn name(&self) -> &'static str {
940 "qmatmul"
941 }
942
943 fn cpu_fwd(
944 &self,
945 storage: &crate::CpuStorage,
946 layout: &crate::Layout,
947 ) -> Result<(crate::CpuStorage, Shape)> {
948 if !layout.is_contiguous() {
949 crate::bail!("input tensor is not contiguous {layout:?}")
950 }
951 let src_shape = layout.shape();
952 let (n, k) = self.shape.dims2()?;
954 if src_shape.rank() < 2 {
955 crate::bail!("input tensor has only one dimension {layout:?}")
956 }
957 let mut dst_shape = src_shape.dims().to_vec();
958 let last_k = dst_shape.pop().unwrap();
959 if last_k != k {
960 crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
961 }
962 dst_shape.push(n);
963 let dst_shape = Shape::from(dst_shape);
964 #[allow(clippy::infallible_destructuring_match)]
965 let self_storage = match &self.storage {
966 QStorage::Cpu(storage) => storage,
967 #[cfg(feature = "vulkan")]
968 QStorage::Vulkan(..) => crate::bail!("Invalid storage"),
969 QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
970 };
971 match storage.dtype() {
972 DType::F32 => {
973 let slice = storage.as_slice::<f32>()?;
974 let slice =
975 &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
976 let mut dst_storage = vec![0f32; dst_shape.elem_count()];
977 self_storage.matmul_t(
978 (dst_shape.elem_count() / n, k, n),
979 slice,
980 &mut dst_storage,
981 )?;
982 Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
983 }
984 DType::F16 => {
985 let slice = storage.as_slice::<f16>()?;
986 let slice =
987 &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
988 let mut dst_storage = vec![f16::ZERO; dst_shape.elem_count()];
989 self_storage.matmul_t_f16(
990 (dst_shape.elem_count() / n, k, n),
991 slice,
992 &mut dst_storage,
993 )?;
994 Ok((crate::CpuStorage::F16(dst_storage), dst_shape))
995 }
996 _ => crate::bail!("Expected f32/f16"),
997 }
998 }
999
1000 fn metal_fwd(
1001 &self,
1002 storage: &crate::MetalStorage,
1003 layout: &crate::Layout,
1004 ) -> Result<(crate::MetalStorage, Shape)> {
1005 let self_storage = match &self.storage {
1006 QStorage::Metal(metal) => metal,
1007 _ => unreachable!("Cannot call metal matmul on non metal QTensor"),
1008 };
1009 self_storage.fwd(&self.shape, storage, layout)
1010 }
1011
1012 fn cuda_fwd(
1013 &self,
1014 storage: &crate::CudaStorage,
1015 layout: &crate::Layout,
1016 ) -> Result<(crate::CudaStorage, Shape)> {
1017 let self_storage = match &self.storage {
1018 QStorage::Cuda(cuda) => cuda,
1019 _ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
1020 };
1021 self_storage.fwd(&self.shape, storage, layout)
1022 }
1023}
1024
1025impl crate::Module for QMatMul {
1026 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1027 match self {
1028 #[cfg(feature = "vulkan")]
1029 Self::VulkanQuant { qtensor, wq, n, k } => {
1030 let rows: usize = xs.elem_count() / *k;
1031 if rows == 1 {
1032 let xs = xs.contiguous()?;
1034 let d = match xs.device() {
1035 Device::Vulkan(d) => d,
1036 _ => crate::bail!("VulkanQuant input not on vulkan"),
1037 };
1038 let y = {
1039 let (store, _) = xs.storage_and_layout();
1040 let xv = match &*store {
1041 crate::Storage::Vulkan(v) => v,
1042 _ => crate::bail!("VulkanQuant expected vulkan storage"),
1043 };
1044 d.matvec_q8_gpu(wq, xv, *n, *k)?
1045 };
1046 let mut dims = xs.dims().to_vec();
1047 let last = dims.len() - 1;
1048 dims[last] = *n;
1049 Ok(crate::tensor::from_storage(
1050 crate::Storage::Vulkan(y),
1051 dims,
1052 crate::op::BackpropOp::none(),
1053 false,
1054 ))
1055 } else {
1056 let w = qtensor.dequantize(&xs.device())?;
1058 let w = match *xs.dims() {
1059 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
1060 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
1061 _ => w.t()?,
1062 };
1063 xs.matmul(&w)
1064 }
1065 }
1066 Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()),
1067 Self::Tensor(w) => {
1068 let w = match *xs.dims() {
1069 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
1070 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
1071 _ => w.t()?,
1072 };
1073 xs.matmul(&w)
1074 }
1075 Self::TensorF16(w) => {
1076 let in_dtype = xs.dtype();
1077 let w = match *xs.dims() {
1078 [b1, b2, _, _] => w.broadcast_left((b1, b2))?.t()?,
1079 [bsize, _, _] => w.broadcast_left(bsize)?.t()?,
1080 _ => w.t()?,
1081 };
1082 xs.to_dtype(DType::F16)?.matmul(&w)?.to_dtype(in_dtype)
1083 }
1084 }
1085 }
1086}