1use std::{
2 borrow::Cow,
3 fmt::Debug,
4 num::NonZeroUsize,
5 sync::{atomic::AtomicUsize, Arc, Condvar, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10 quantized::{GgmlDType, QMatMul, QTensor},
11 DType, Device, Result, Tensor,
12};
13use pertensor_fp8::pertensor_fp8_linear_b;
14
15#[cfg(feature = "metal")]
16mod metal_kernels;
17
18mod afq;
19mod bitsandbytes;
20mod blockwise_fp8;
21pub mod cublaslt;
22pub mod distributed;
23mod dummy;
24pub mod f8q8;
25mod fp8;
26pub mod gemv;
27mod gguf;
28mod gptq;
29mod hqq;
30mod imatrix;
31mod lora;
32mod mxfp4;
33mod pending_layer;
34mod pertensor_fp8;
35pub mod rotary;
36pub mod safetensors;
37mod scalar_fp8;
38mod unquantized;
39mod utils;
40mod vector_fp8;
41
42use gptq::gptq_linear;
43use lora::merge_lora_weights;
44use regex::Regex;
45pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
46
47pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
48pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
49pub use blockwise_fp8::{
50 blockwise_fp8_moe, fp8_blockwise_dequantize, fp8_blockwise_quantize, BlockwiseFP8Linear,
51};
52pub use distributed::{
53 layers::{
54 compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
55 ReplicatedLayer, RowParallelLayer,
56 },
57 socket::{Client, Server},
58 BarrierLike, Comm, Id, RingConfig, SumAllReduce,
59};
60pub use dummy::DummyLayer;
61pub use f8q8::F8Q8Linear;
62pub use fp8::FP8Linear;
63#[cfg(feature = "cuda")]
64pub use gemv::gemv;
65pub use gemv::{should_use_gemv, GEMV_CONTROLLER};
66pub use gguf::GgufMatMul;
67pub use gptq::GptqLayer;
68pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
69pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
70pub use lora::{
71 clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
72 LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
73};
74pub use mxfp4::MXFP4Layer;
75pub use pending_layer::PendingIsqLayer;
76pub use pertensor_fp8::PerTensorFP8Linear;
77pub use unquantized::UnquantLinear;
78pub use utils::flash_attn_sinks_metal;
79pub use utils::flash_attn_sinks_varlen_metal;
80#[cfg(feature = "cuda")]
81pub use utils::gptoss_swiglu_fused;
82#[cfg(feature = "cuda")]
83pub use utils::gptoss_swiglu_interleaved;
84pub use utils::isq::apply_immediate_isq;
85pub use utils::softmax_with_sinks;
86pub use utils::{fused_glu, GluActivationType};
87pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
88pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
89
90use candle_nn::{Conv1d, Conv2d, Linear, Module};
91use serde::{Deserialize, Deserializer, Serialize};
92
93pub struct IsqBackpressure {
99 count: Mutex<usize>,
100 cvar: Condvar,
101 max: usize,
102}
103
104impl IsqBackpressure {
105 pub fn new(max: usize) -> Self {
106 Self {
107 count: Mutex::new(0),
108 cvar: Condvar::new(),
109 max,
110 }
111 }
112
113 pub fn acquire(&self) {
115 let mut count = self.count.lock().expect("ISQ backpressure lock poisoned");
116 while *count >= self.max {
117 count = self
118 .cvar
119 .wait(count)
120 .expect("ISQ backpressure lock poisoned");
121 }
122 *count += 1;
123 }
124
125 pub fn release(&self) {
127 let mut count = self.count.lock().expect("ISQ backpressure lock poisoned");
128 *count = count.saturating_sub(1);
129 self.cvar.notify_one();
130 }
131}
132
133impl Debug for IsqBackpressure {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 let count = self.count.lock().map(|c| *c).unwrap_or(0);
136 f.debug_struct("IsqBackpressure")
137 .field("outstanding", &count)
138 .field("max", &self.max)
139 .finish()
140 }
141}
142
143#[derive(Clone, Debug)]
144pub struct ImmediateIsqParams {
145 pub guard: QuantizeOntoGuard,
146 pub ty: Option<IsqType>,
147 pub predicates: Vec<Regex>,
148 pub overrides: Vec<ImmediateIsqOverride>,
149 pub pool: Option<Arc<rayon::ThreadPool>>,
153 pub backpressure: Arc<IsqBackpressure>,
155}
156
157#[derive(Clone, Debug)]
158pub struct ImmediateIsqOverride {
159 pub predicate: Regex,
160 pub ty: Option<IsqType>,
161 pub device: Option<Device>,
162}
163
164#[derive(Clone, Debug)]
165pub struct ImmediateIsqMatch {
166 pub ty: IsqType,
167 pub device: Option<Device>,
168}
169
170thread_local! {
171 static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
172}
173
174pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
175 let (pool, _) = create_isq_thread_pool(isq);
176 set_immediate_isq_with_pool(isq, predicates, Vec::new(), pool);
177}
178
179pub fn set_immediate_isq_with_pool(
180 isq: Option<IsqType>,
181 predicates: Vec<Regex>,
182 overrides: Vec<ImmediateIsqOverride>,
183 pool: rayon::ThreadPool,
184) {
185 let max_outstanding = pool.current_num_threads() + 1;
188 ENGINE_IMMEDIATE_ISQ.with(|cell| {
189 *cell.borrow_mut() = Some(ImmediateIsqParams {
190 guard: QuantizeOntoGuard::new(),
191 ty: isq,
192 predicates,
193 overrides,
194 backpressure: Arc::new(IsqBackpressure::new(max_outstanding)),
195 pool: Some(Arc::new(pool)),
196 });
197 });
198}
199
200pub fn create_isq_thread_pool(ty: Option<IsqType>) -> (rayon::ThreadPool, usize) {
207 let num_threads = if std::env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
208 1
209 } else if let Some(ty) = ty {
210 ty.get_max_isq_cpu_threads()
211 .map(usize::from)
212 .unwrap_or_else(rayon::current_num_threads)
213 } else {
214 rayon::current_num_threads()
215 };
216
217 let pool = rayon::ThreadPoolBuilder::new()
218 .num_threads(num_threads)
219 .build()
220 .expect("Failed to create ISQ thread pool");
221 (pool, num_threads)
222}
223
224pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
225 ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
226}
227
228pub fn clear_immediate_isq() {
229 ENGINE_IMMEDIATE_ISQ.with(|cell| {
230 *cell.borrow_mut() = None;
231 });
232}
233
234pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
235 immediate_isq_match(vb).is_some()
236}
237
238pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
239 let immediate_isq = get_immediate_isq()?;
240 let prefix = format!("{}.weight", vb.prefix());
242 resolve_immediate_isq(&immediate_isq, &prefix)
243}
244
245fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
246 if let Some(override_hit) = params
247 .overrides
248 .iter()
249 .find(|override_pred| override_pred.predicate.is_match(prefix))
250 {
251 if let Some(ty) = override_hit.ty.or(params.ty) {
252 return Some(ImmediateIsqMatch {
253 ty,
254 device: override_hit.device.clone(),
255 });
256 }
257 return None;
258 }
259
260 if let Some(ty) = params.ty {
261 if params
262 .predicates
263 .iter()
264 .any(|predicate| predicate.is_match(prefix))
265 {
266 return Some(ImmediateIsqMatch { ty, device: None });
267 }
268 }
269
270 None
271}
272
273#[derive(Debug, Clone, Serialize)]
274#[serde(tag = "quant_method", rename_all = "lowercase")]
275pub enum QuantizedConfig {
276 GptqAwq {
277 bits: usize,
278 group_size: usize,
279 checkpoint_format: Option<String>,
280 is_awq: bool,
281 },
282 Fp8 {
283 weight_block_size: Option<Vec<usize>>,
284 },
285 Bitsandbytes {
286 bnb_4bit_quant_type: Option<String>,
287 },
288 Afq {
289 bits: usize,
290 group_size: usize,
291 },
292 MXFP4 {},
293}
294
295#[derive(Deserialize)]
297struct RawConfig {
298 quant_method: Option<String>,
299 bits: Option<usize>,
300 group_size: Option<usize>,
301 checkpoint_format: Option<String>,
302 weight_block_size: Option<Vec<usize>>,
303 bnb_4bit_quant_type: Option<String>,
304}
305
306impl<'de> Deserialize<'de> for QuantizedConfig {
308 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
309 where
310 D: Deserializer<'de>,
311 {
312 let raw = RawConfig::deserialize(deserializer)?;
313
314 match &raw.quant_method {
315 Some(m) if m == "gptq" || m == "awq" => {
316 let bits = raw
317 .bits
318 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
319 let group_size = raw
320 .group_size
321 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
322 Ok(QuantizedConfig::GptqAwq {
323 bits,
324 group_size,
325 checkpoint_format: raw.checkpoint_format,
326 is_awq: m == "awq",
327 })
328 }
329 Some(m) if m == "fp8" => {
330 Ok(QuantizedConfig::Fp8 {
332 weight_block_size: raw.weight_block_size,
333 })
334 }
335 Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
336 bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
337 }),
338 Some(m) if m == "afq" => {
339 let bits = raw
340 .bits
341 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
342 let group_size = raw
343 .group_size
344 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
345 Ok(QuantizedConfig::Afq { bits, group_size })
346 }
347 Some(m) if m == "mxfp4" => {
348 Ok(QuantizedConfig::MXFP4 { })
349 }
350 None => {
351 let bits = raw
352 .bits
353 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
354 let group_size = raw
355 .group_size
356 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
357 Ok(QuantizedConfig::Afq { bits, group_size })
358 }
359 Some(unknown_method) => {
360 Err(serde::de::Error::custom(format!(
361 "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
362 )))
363 },
364 }
365 }
366}
367
368impl QuantizedConfig {
369 pub fn name(&self) -> &'static str {
370 match self {
371 Self::GptqAwq { .. } => "gptq",
372 Self::Fp8 { .. } => "fp8",
373 Self::Bitsandbytes { .. } => "bitsandbytes",
374 Self::Afq { .. } => "afq",
375 Self::MXFP4 { .. } => "mxfp4",
376 }
377 }
378
379 pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
380 match self {
381 Self::GptqAwq { bits, .. } => format!("{bits} bits"),
382 Self::Fp8 { .. } => "8 bits".to_string(),
383 Self::Bitsandbytes {
384 bnb_4bit_quant_type: Some(_),
385 } => "4 bits".to_string(),
386 Self::Bitsandbytes {
387 bnb_4bit_quant_type: None,
388 } => "8 bits".to_string(),
389 Self::Afq { bits, .. } => format!("{bits} bits"),
390 Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
391 }
392 }
393
394 pub fn pack_factor(&self, dtype: DType) -> usize {
395 match self {
396 Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
397 2 => IsqType::Q2K.pack_factor(dtype),
398 3 => IsqType::Q3K.pack_factor(dtype),
399 4 => IsqType::Q4K.pack_factor(dtype),
400 5 => IsqType::Q5K.pack_factor(dtype),
401 6 => IsqType::Q6K.pack_factor(dtype),
402 8 => IsqType::Q8_0.pack_factor(dtype),
403 40 => 4, other => panic!("Unexpected bits in `pack_factor` {other}"),
405 },
406 Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
407 Self::Bitsandbytes {
408 bnb_4bit_quant_type: Some(_),
409 }
410 | Self::Bitsandbytes {
411 bnb_4bit_quant_type: None,
412 } => IsqType::Q4K.pack_factor(dtype),
413 Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
414 }
415 }
416}
417
418#[derive(Debug, Clone)]
419pub enum QuantMethodConfig {
420 GptqAwq {
421 bits: i32,
422 use_exllama: bool,
423 q_weight: Tensor,
424 qzeros: Option<Tensor>,
425 scales: Tensor,
426 g_idx: Option<Tensor>,
427 bias: Option<Tensor>,
428 workspace: Option<Tensor>,
429 is_marlin: bool,
430 is_awq: bool,
431 },
432 Gguf {
433 q_weight: Arc<QTensor>,
434 b: Option<Tensor>,
435 },
436 Unquantized(Linear),
437 Hqq {
438 tensor: Tensor,
439 bits: HqqBits,
440 group_size: NonZeroUsize,
441 axis: HqqAxis,
442 optimization_steps: Option<usize>,
443 round_zeros: Option<bool>,
444 channel_wise: Option<bool>,
445 bias: Option<Tensor>,
446 },
447 Dummy,
448 FP8 {
449 lin: Linear,
450 dtype: DType,
451 },
452 Bnb {
453 weight: Tensor,
454 bias: Option<Tensor>,
455 params: BnbQuantParams,
456 quant_ty: BnbQuantType,
457 },
458 BlockwiseFP8 {
459 weight: Tensor,
460 weight_scale_inv: Tensor,
461 bias: Option<Tensor>,
462 dequant_dtype: DType,
463 weight_block_size: Vec<usize>,
464 },
465 PerTensorFP8 {
466 weight: Tensor,
467 weight_scale_inv: Tensor,
468 activation_scale: Option<Tensor>,
469 bias: Option<Tensor>,
470 dequant_dtype: DType,
471 },
472 Afq {
473 weight: Tensor,
474 bias: Option<Tensor>,
475 bits: AfqBits,
476 group_size: AfqGroupSize,
477 },
478 MXFP4 {
479 blocks: Tensor,
480 scales: Tensor,
481 bias: Option<Tensor>,
482 },
483}
484
485pub struct MatMul;
488
489impl MatMul {
490 pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
492 #[cfg(feature = "accelerate")]
493 {
494 let original_dtype = a.dtype();
495 a.to_dtype(DType::F32)?
496 .matmul(&b.to_dtype(DType::F32)?)?
497 .to_dtype(original_dtype)
498 }
499 #[cfg(not(feature = "accelerate"))]
500 {
501 if a.device().is_cpu() {
502 let original_dtype = a.dtype();
503 a.to_dtype(DType::F16)?
504 .matmul(&b.to_dtype(DType::F16)?)?
505 .to_dtype(original_dtype)
506 } else {
507 a.matmul(b)
508 }
509 }
510 }
511
512 pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
515 self.matmul(a, b)? / scale
517 }
518
519 pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
522 self.matmul(a, b)? * scale
524 }
525
526 pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
528 matmul.forward(x)
529 }
530
531 pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
533 matmul.forward(x)
534 }
535}
536
537pub struct Convolution;
540
541impl Convolution {
542 pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
543 if x.device().is_cpu() {
544 let original_dtype = x.dtype();
545 Conv1d::new(
546 layer.weight().to_dtype(DType::F32)?,
547 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
548 *layer.config(),
549 )
550 .forward(&x.to_dtype(DType::F32)?)?
551 .to_dtype(original_dtype)
552 } else {
553 layer.forward(x)
554 }
555 }
556
557 pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
558 if x.device().is_cpu() {
559 let original_dtype = x.dtype();
560 Conv2d::new(
561 layer.weight().to_dtype(DType::F32)?,
562 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
563 *layer.config(),
564 )
565 .forward(&x.to_dtype(DType::F32)?)?
566 .to_dtype(original_dtype)
567 } else {
568 layer.forward(x)
569 }
570 }
571}
572
573#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
575pub enum IsqType {
576 Q4_0,
577 Q4_1,
578 Q5_0,
579 Q5_1,
580 Q8_0,
581 Q8_1,
582 Q2K,
583 Q3K,
584 Q4K,
585 Q5K,
586 Q6K,
587 Q8K,
588 HQQ8,
589 HQQ4,
590 F8E4M3,
594 AFQ8,
595 AFQ6,
596 AFQ4,
597 AFQ3,
598 AFQ2,
599 F8Q8,
600 MXFP4,
601}
602
603#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
607pub enum IsqBits {
608 Two,
610 Three,
612 Four,
614 Five,
616 Six,
618 Eight,
620}
621
622impl IsqBits {
623 pub fn resolve(self, device: &Device) -> IsqType {
625 match (self, device.is_metal()) {
626 (Self::Two, true) => IsqType::AFQ2,
627 (Self::Two, false) => IsqType::Q2K,
628 (Self::Three, true) => IsqType::AFQ3,
629 (Self::Three, false) => IsqType::Q3K,
630 (Self::Four, true) => IsqType::AFQ4,
631 (Self::Four, false) => IsqType::Q4K,
632 (Self::Five, _) => IsqType::Q5K,
633 (Self::Six, true) => IsqType::AFQ6,
634 (Self::Six, false) => IsqType::Q6K,
635 (Self::Eight, true) => IsqType::AFQ8,
636 (Self::Eight, false) => IsqType::Q8_0,
637 }
638 }
639
640 pub fn expand(self) -> Vec<IsqType> {
643 #[cfg(feature = "metal")]
644 match self {
645 Self::Two => vec![IsqType::AFQ2, IsqType::Q2K],
646 Self::Three => vec![IsqType::AFQ3, IsqType::Q3K],
647 Self::Four => vec![IsqType::AFQ4, IsqType::Q4K],
648 Self::Five => vec![IsqType::Q5K],
649 Self::Six => vec![IsqType::AFQ6, IsqType::Q6K],
650 Self::Eight => vec![IsqType::AFQ8, IsqType::Q8_0],
651 }
652 #[cfg(not(feature = "metal"))]
653 match self {
654 Self::Two => vec![IsqType::Q2K, IsqType::AFQ2],
655 Self::Three => vec![IsqType::Q3K, IsqType::AFQ3],
656 Self::Four => vec![IsqType::Q4K, IsqType::AFQ4],
657 Self::Five => vec![IsqType::Q5K],
658 Self::Six => vec![IsqType::Q6K, IsqType::AFQ6],
659 Self::Eight => vec![IsqType::Q8_0, IsqType::AFQ8],
660 }
661 }
662}
663
664impl TryFrom<&str> for IsqBits {
665 type Error = ();
666 fn try_from(s: &str) -> std::result::Result<Self, ()> {
667 match s {
668 "2" => Ok(Self::Two),
669 "3" => Ok(Self::Three),
670 "4" => Ok(Self::Four),
671 "5" => Ok(Self::Five),
672 "6" => Ok(Self::Six),
673 "8" => Ok(Self::Eight),
674 _ => Err(()),
675 }
676 }
677}
678
679impl std::fmt::Display for IsqType {
680 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
681 match self {
682 Self::Q4_0 => write!(f, "q4_0"),
683 Self::Q4_1 => write!(f, "q4_1"),
684 Self::Q5_0 => write!(f, "q5_0"),
685 Self::Q5_1 => write!(f, "q5_1"),
686 Self::Q8_0 => write!(f, "q8_0"),
687 Self::Q8_1 => write!(f, "q8_1"),
688 Self::Q2K => write!(f, "q2k"),
689 Self::Q3K => write!(f, "q3k"),
690 Self::Q4K => write!(f, "q4k"),
691 Self::Q5K => write!(f, "q5k"),
692 Self::Q6K => write!(f, "q6k"),
693 Self::Q8K => write!(f, "q8k"),
694 Self::HQQ8 => write!(f, "hqq8"),
695 Self::HQQ4 => write!(f, "hqq4"),
696 Self::F8E4M3 => write!(f, "fp8"),
697 Self::AFQ8 => write!(f, "afq8"),
698 Self::AFQ6 => write!(f, "afq6"),
699 Self::AFQ4 => write!(f, "afq4"),
700 Self::AFQ3 => write!(f, "afq3"),
701 Self::AFQ2 => write!(f, "afq2"),
702 Self::F8Q8 => write!(f, "f8q8"),
703 Self::MXFP4 => write!(f, "mxfp4"),
704 }
705 }
706}
707
708impl IsqType {
709 pub fn pack_factor(&self, dtype: DType) -> usize {
712 match self {
713 Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
714 .div_ceil(GgmlDType::Q4_0.type_size()),
715 Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
716 .div_ceil(GgmlDType::Q4_1.type_size()),
717 Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
718 .div_ceil(GgmlDType::Q5_0.type_size()),
719 Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
720 .div_ceil(GgmlDType::Q5_1.type_size()),
721 Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
722 .div_ceil(GgmlDType::Q8_0.type_size()),
723 Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
724 .div_ceil(GgmlDType::Q8_1.type_size()),
725 Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
726 .div_ceil(GgmlDType::Q2K.type_size()),
727 Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
728 .div_ceil(GgmlDType::Q3K.type_size()),
729 Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
730 .div_ceil(GgmlDType::Q4K.type_size()),
731 Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
732 .div_ceil(GgmlDType::Q5K.type_size()),
733 Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
734 .div_ceil(GgmlDType::Q6K.type_size()),
735 Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
736 .div_ceil(GgmlDType::Q8K.type_size()),
737 Self::F8Q8 => (dtype.size_in_bytes() * 32).div_ceil(33),
739 Self::HQQ4 => 4,
741 Self::HQQ8 => 2,
742 Self::F8E4M3 => 2,
743 Self::MXFP4 => 3,
746 }
747 }
748
749 pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
750 match self {
751 IsqType::HQQ4
753 | IsqType::HQQ8
754 | IsqType::AFQ2
755 | IsqType::AFQ3
756 | IsqType::AFQ4
757 | IsqType::AFQ6
758 | IsqType::AFQ8
759 | IsqType::MXFP4 => {
760 Some(1.try_into().unwrap())
762 }
763 IsqType::F8E4M3 | IsqType::F8Q8 => None,
764 IsqType::Q2K
765 | IsqType::Q3K
766 | IsqType::Q4K
767 | IsqType::Q4_0
768 | IsqType::Q4_1
769 | IsqType::Q5K
770 | IsqType::Q5_0
771 | IsqType::Q5_1
772 | IsqType::Q6K
773 | IsqType::Q8K
774 | IsqType::Q8_0
775 | IsqType::Q8_1 => None,
776 }
777 }
778}
779
780impl TryFrom<IsqType> for GgmlDType {
781 type Error = candle_core::Error;
782
783 fn try_from(value: IsqType) -> Result<Self> {
784 let tp = match value {
785 IsqType::Q2K => Self::Q2K,
786 IsqType::Q3K => Self::Q3K,
787 IsqType::Q4K => Self::Q4K,
788 IsqType::Q4_0 => Self::Q4_0,
789 IsqType::Q4_1 => Self::Q4_1,
790 IsqType::Q5K => Self::Q5K,
791 IsqType::Q5_0 => Self::Q5_0,
792 IsqType::Q5_1 => Self::Q5_1,
793 IsqType::Q6K => Self::Q6K,
794 IsqType::Q8K => Self::Q8K,
795 IsqType::Q8_0 => Self::Q8_0,
796 IsqType::Q8_1 => Self::Q8_1,
797 _ => candle_core::bail!("Expected valid GGML ISQ type."),
798 };
799 #[cfg(feature = "cuda")]
800 {
801 if !matches!(
802 tp,
803 GgmlDType::Q4_0
804 | GgmlDType::Q4_1
805 | GgmlDType::Q5_0
806 | GgmlDType::Q5_1
807 | GgmlDType::Q8_0
808 | GgmlDType::Q2K
809 | GgmlDType::Q3K
810 | GgmlDType::Q4K
811 | GgmlDType::Q5K
812 | GgmlDType::Q6K
813 ) {
814 candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
815 }
816 }
817 Ok(tp)
818 }
819}
820
821impl TryFrom<GgmlDType> for IsqType {
822 type Error = candle_core::Error;
823
824 fn try_from(value: GgmlDType) -> Result<Self> {
825 match value {
826 GgmlDType::Q2K => Ok(Self::Q2K),
827 GgmlDType::Q3K => Ok(Self::Q3K),
828 GgmlDType::Q4K => Ok(Self::Q4K),
829 GgmlDType::Q5K => Ok(Self::Q5K),
830 GgmlDType::Q6K => Ok(Self::Q6K),
831 GgmlDType::Q4_0 => Ok(Self::Q4_0),
832 GgmlDType::Q4_1 => Ok(Self::Q4_1),
833 GgmlDType::Q5_0 => Ok(Self::Q5_0),
834 GgmlDType::Q5_1 => Ok(Self::Q5_1),
835 GgmlDType::Q8_0 => Ok(Self::Q8_0),
836 GgmlDType::Q8_1 => Ok(Self::Q8_1),
837 GgmlDType::Q8K => Ok(Self::Q8K),
838 GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
839 candle_core::bail!("Expected valid GGML ISQ type.")
840 }
841 }
842 }
843}
844
845#[derive(Debug, Clone, Copy)]
846pub enum QuantizedSerdeType {
847 Gguf = 0,
848 Unquant = 1,
849 Hqq = 2,
850 Fp8 = 3,
851 Afq = 4,
852 F8Q8 = 5,
853 Mxfp4 = 6,
854}
855
856impl TryFrom<usize> for QuantizedSerdeType {
857 type Error = candle_core::Error;
858 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
859 match value {
860 0 => Ok(Self::Gguf),
861 1 => Ok(Self::Unquant),
862 2 => Ok(Self::Hqq),
863 3 => Ok(Self::Fp8),
864 4 => Ok(Self::Afq),
865 5 => Ok(Self::F8Q8),
866 6 => Ok(Self::Mxfp4),
867 other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
868 }
869 }
870}
871
872pub trait QuantizedSerde {
873 fn name(&self) -> &'static str;
874 fn isq_serde_supported(&self) -> bool {
875 false
876 }
877 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
878 candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
879 }
880 fn deserialize(
881 _data: Cow<[u8]>,
882 _device: &Device,
883 _comm: &Arc<crate::Comm>,
884 _guard: QuantizeOntoGuard,
885 ) -> Result<Arc<dyn QuantMethod>>
886 where
887 Self: Sized,
888 {
889 candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
890 }
891 fn deserialize_ext_bias(
892 _data: Cow<[u8]>,
893 _device: &Device,
894 _guard: QuantizeOntoGuard,
895 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
896 where
897 Self: Sized,
898 {
899 candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
900 }
901 fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
903 candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
904 }
905}
906
907#[derive(Clone, Debug)]
909#[allow(unused)]
910pub struct QuantizeOntoGuard {
911 pub inner: Arc<Mutex<()>>,
912}
913
914pub enum QuantizeOntoDropGuard<'a> {
916 Real(MutexGuard<'a, ()>),
917 Fake,
918}
919
920impl Default for QuantizeOntoGuard {
921 fn default() -> Self {
922 Self::new()
923 }
924}
925
926impl QuantizeOntoGuard {
927 pub fn new() -> Self {
928 QuantizeOntoGuard {
929 inner: Arc::new(Mutex::new(())),
930 }
931 }
932
933 pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
937 #[cfg(feature = "cuda")]
938 {
939 let _ = device;
940 QuantizeOntoDropGuard::Fake
941 }
942
943 #[cfg(not(feature = "cuda"))]
944 {
945 #[cfg(feature = "metal")]
946 if let Device::Metal(dev) = device {
947 dev.wait_until_completed()
949 .expect("Failed to flush command buffer.");
950 }
951 #[cfg(not(feature = "metal"))]
952 let _ = device;
953
954 QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
955 }
956 }
957}
958
959pub enum DistributedKind {
960 ColumnParallel,
961 RowParallel,
962 Replicated,
963}
964
965pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
967 fn new(method: QuantMethodConfig) -> Result<Self>
968 where
969 Self: Sized;
970
971 fn dequantize_w(&self) -> Result<Tensor>;
972
973 fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
976 let original_ty = a.dtype();
977 let a = if let Some(t) = self.quantized_act_type() {
978 a.to_dtype(t)?
979 } else {
980 a.clone()
981 };
982 self.forward(&a)?.to_dtype(original_ty)
983 }
984
985 fn forward(&self, a: &Tensor) -> Result<Tensor>;
987
988 fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
994 let original_ty = a.dtype();
995 let a = if let Some(t) = self.quantized_act_type() {
996 a.to_dtype(t)?
997 } else {
998 a.clone()
999 };
1000 self.gather_forward(&a, indices)?.to_dtype(original_ty)
1001 }
1002
1003 fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
1008 candle_core::bail!(
1009 "{} does not support `gather_forward`. Please raise an issue.",
1010 self.name()
1011 )
1012 }
1013
1014 fn quantized_act_type(&self) -> Option<DType>;
1016
1017 fn dtype_and_device(&self) -> (DType, Device);
1019
1020 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
1022
1023 fn apply_isq(
1025 self: Arc<Self>,
1026 dtype: Option<IsqType>,
1027 device: Device,
1028 n_quantized: &AtomicUsize,
1029 imatrix_weight: Option<Vec<f32>>,
1030 guard: QuantizeOntoGuard,
1031 ) -> Result<Arc<dyn QuantMethod>>;
1032
1033 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
1034 None
1035 }
1036
1037 fn begin_track_stats(&mut self) -> Result<()> {
1039 candle_core::bail!("`{}` does not support tracking stats.", self.name())
1040 }
1041
1042 fn end_track_stats(&self) -> Result<Tensor> {
1044 candle_core::bail!("`{}` does not support tracking stats.", self.name())
1045 }
1046
1047 fn is_distributed(&self) -> Option<DistributedKind> {
1048 None
1049 }
1050}
1051
1052impl Module for dyn QuantMethod {
1053 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1054 Self::forward(self, xs)
1055 }
1056}
1057
1058pub fn linear_no_bias(
1059 in_dim: usize,
1060 out_dim: usize,
1061 config: &Option<QuantizedConfig>,
1062 vb: ShardedVarBuilder,
1063) -> Result<Arc<dyn QuantMethod>> {
1064 let base_vb = vb.clone();
1065 let vb = if should_apply_immediate_isq(&vb) {
1066 vb.set_device(Device::Cpu)
1067 } else {
1068 vb
1069 };
1070
1071 let layer = if let Some(quant_conf) = &config {
1072 match quant_conf {
1073 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
1074 QuantizedConfig::Fp8 { weight_block_size } => {
1075 if weight_block_size.is_some() {
1076 blockwise_fp8_linear_b(
1077 in_dim,
1078 out_dim,
1079 quant_conf,
1080 false,
1081 Default::default(),
1082 vb,
1083 )?
1084 } else {
1085 pertensor_fp8_linear_b(
1086 in_dim,
1087 out_dim,
1088 quant_conf,
1089 false,
1090 Default::default(),
1091 vb,
1092 )?
1093 }
1094 }
1095 QuantizedConfig::Bitsandbytes { .. } => {
1096 Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
1097 }
1098 QuantizedConfig::Afq { .. } => {
1099 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
1100 }
1101 QuantizedConfig::MXFP4 {} => {
1102 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
1103 }
1104 }
1105 } else {
1106 if !vb.contains_tensor("weight") {
1108 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
1109 Arc::new(layer) as Arc<dyn QuantMethod>
1110 } else {
1111 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
1112 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
1113
1114 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
1115 Linear::new(weight, None),
1116 ))?;
1117 Arc::new(layer) as Arc<dyn QuantMethod>
1118 }
1119 };
1120 apply_immediate_isq(layer, base_vb)
1121}
1122
1123pub fn linear(
1124 in_dim: usize,
1125 out_dim: usize,
1126 config: &Option<QuantizedConfig>,
1127 vb: ShardedVarBuilder,
1128) -> Result<Arc<dyn QuantMethod>> {
1129 let base_vb = vb.clone();
1130 let vb = if should_apply_immediate_isq(&vb) {
1131 vb.set_device(Device::Cpu)
1132 } else {
1133 vb
1134 };
1135
1136 let layer = if let Some(quant_conf) = &config {
1137 match quant_conf {
1138 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
1139 QuantizedConfig::Fp8 { weight_block_size } => {
1140 if weight_block_size.is_some() {
1141 blockwise_fp8_linear_b(
1142 in_dim,
1143 out_dim,
1144 quant_conf,
1145 true,
1146 Default::default(),
1147 vb,
1148 )?
1149 } else {
1150 pertensor_fp8_linear_b(
1151 in_dim,
1152 out_dim,
1153 quant_conf,
1154 true,
1155 Default::default(),
1156 vb,
1157 )?
1158 }
1159 }
1160 QuantizedConfig::Bitsandbytes { .. } => {
1161 Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
1162 }
1163 QuantizedConfig::Afq { .. } => {
1164 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
1165 }
1166 QuantizedConfig::MXFP4 {} => {
1167 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
1168 }
1169 }
1170 } else {
1171 if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
1173 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
1174 Arc::new(layer) as Arc<dyn QuantMethod>
1175 } else {
1176 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
1177 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
1178 let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
1179
1180 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
1181 Linear::new(weight, Some(bias)),
1182 ))?;
1183 Arc::new(layer) as Arc<dyn QuantMethod>
1184 }
1185 };
1186 apply_immediate_isq(layer, base_vb)
1187}
1188
1189pub fn linear_b(
1190 in_dim: usize,
1191 out_dim: usize,
1192 bias: bool,
1193 config: &Option<QuantizedConfig>,
1194 vb: ShardedVarBuilder,
1195) -> Result<Arc<dyn QuantMethod>> {
1196 if bias {
1197 linear(in_dim, out_dim, config, vb)
1198 } else {
1199 linear_no_bias(in_dim, out_dim, config, vb)
1200 }
1201}