1use std::{
2 borrow::Cow,
3 fmt::Debug,
4 num::NonZeroUsize,
5 sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10 quantized::{GgmlDType, QMatMul, QTensor},
11 DType, Device, Result, Tensor,
12};
13use pertensor_fp8::pertensor_fp8_linear_b;
14
15#[cfg(feature = "metal")]
16mod metal_kernels;
17
18mod afq;
19mod bitsandbytes;
20mod blockwise_fp8;
21pub mod cublaslt;
22pub mod distributed;
23mod dummy;
24mod fp8;
25pub mod gemv;
26mod gguf;
27mod gptq;
28mod hqq;
29mod imatrix;
30mod lora;
31mod mxfp4;
32mod pertensor_fp8;
33pub mod rotary;
34pub mod safetensors;
35mod scalar_fp8;
36mod unquantized;
37mod utils;
38mod vector_fp8;
39
40use gptq::gptq_linear;
41use lora::merge_lora_weights;
42use regex::Regex;
43pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
44
45pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
46pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
47pub use blockwise_fp8::{
48 blockwise_fp8_moe, fp8_blockwise_dequantize, fp8_blockwise_quantize, BlockwiseFP8Linear,
49};
50pub use distributed::{
51 layers::{
52 compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
53 ReplicatedLayer, RowParallelLayer,
54 },
55 socket::{Client, Server},
56 BarrierLike, Comm, Id, RingConfig, SumAllReduce,
57};
58pub use dummy::DummyLayer;
59pub use fp8::FP8Linear;
60#[cfg(feature = "cuda")]
61pub use gemv::gemv;
62pub use gemv::{should_use_gemv, GEMV_CONTROLLER};
63pub use gguf::GgufMatMul;
64pub use gptq::GptqLayer;
65pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
66pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
67pub use lora::{
68 clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
69 LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
70};
71pub use mxfp4::MXFP4Layer;
72pub use pertensor_fp8::PerTensorFP8Linear;
73pub use unquantized::UnquantLinear;
74#[cfg(feature = "cuda")]
75pub use utils::gptoss_swiglu_fused;
76#[cfg(feature = "cuda")]
77pub use utils::gptoss_swiglu_interleaved;
78pub use utils::isq::apply_immediate_isq;
79#[cfg(feature = "cuda")]
80pub use utils::softmax_with_sinks;
81pub use utils::{fused_glu, GluActivationType};
82pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
83pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
84
85use candle_nn::{Conv1d, Conv2d, Linear, Module};
86use serde::{Deserialize, Deserializer, Serialize};
87
88#[derive(Clone, Debug)]
89pub struct ImmediateIsqParams {
90 pub guard: QuantizeOntoGuard,
91 pub ty: Option<IsqType>,
92 pub predicates: Vec<Regex>,
93 pub overrides: Vec<ImmediateIsqOverride>,
94}
95
96#[derive(Clone, Debug)]
97pub struct ImmediateIsqOverride {
98 pub predicate: Regex,
99 pub ty: Option<IsqType>,
100 pub device: Option<Device>,
101}
102
103#[derive(Clone, Debug)]
104pub struct ImmediateIsqMatch {
105 pub ty: IsqType,
106 pub device: Option<Device>,
107}
108
109thread_local! {
110 static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
111}
112
113pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
114 set_immediate_isq_with_overrides(isq, predicates, Vec::new());
115}
116
117pub fn set_immediate_isq_with_overrides(
118 isq: Option<IsqType>,
119 predicates: Vec<Regex>,
120 overrides: Vec<ImmediateIsqOverride>,
121) {
122 ENGINE_IMMEDIATE_ISQ.with(|cell| {
123 *cell.borrow_mut() = Some(ImmediateIsqParams {
124 guard: QuantizeOntoGuard::new(),
125 ty: isq,
126 predicates,
127 overrides,
128 });
129 });
130}
131
132pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
133 ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
134}
135
136pub fn clear_immediate_isq() {
137 ENGINE_IMMEDIATE_ISQ.with(|cell| {
138 *cell.borrow_mut() = None;
139 });
140}
141
142pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
143 immediate_isq_match(vb).is_some()
144}
145
146pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
147 let immediate_isq = get_immediate_isq()?;
148 let prefix = format!("{}.weight", vb.prefix());
150 resolve_immediate_isq(&immediate_isq, &prefix)
151}
152
153fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
154 if let Some(override_hit) = params
155 .overrides
156 .iter()
157 .find(|override_pred| override_pred.predicate.is_match(prefix))
158 {
159 if let Some(ty) = override_hit.ty.or(params.ty) {
160 return Some(ImmediateIsqMatch {
161 ty,
162 device: override_hit.device.clone(),
163 });
164 }
165 return None;
166 }
167
168 if let Some(ty) = params.ty {
169 if params
170 .predicates
171 .iter()
172 .any(|predicate| predicate.is_match(prefix))
173 {
174 return Some(ImmediateIsqMatch { ty, device: None });
175 }
176 }
177
178 None
179}
180
181#[derive(Debug, Clone, Serialize)]
182#[serde(tag = "quant_method", rename_all = "lowercase")]
183pub enum QuantizedConfig {
184 GptqAwq {
185 bits: usize,
186 group_size: usize,
187 checkpoint_format: Option<String>,
188 is_awq: bool,
189 },
190 Fp8 {
191 weight_block_size: Option<Vec<usize>>,
192 },
193 Bitsandbytes {
194 bnb_4bit_quant_type: Option<String>,
195 },
196 Afq {
197 bits: usize,
198 group_size: usize,
199 },
200 MXFP4 {},
201}
202
203#[derive(Deserialize)]
205struct RawConfig {
206 quant_method: Option<String>,
207 bits: Option<usize>,
208 group_size: Option<usize>,
209 checkpoint_format: Option<String>,
210 weight_block_size: Option<Vec<usize>>,
211 bnb_4bit_quant_type: Option<String>,
212}
213
214impl<'de> Deserialize<'de> for QuantizedConfig {
216 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
217 where
218 D: Deserializer<'de>,
219 {
220 let raw = RawConfig::deserialize(deserializer)?;
221
222 match &raw.quant_method {
223 Some(m) if m == "gptq" || m == "awq" => {
224 let bits = raw
225 .bits
226 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
227 let group_size = raw
228 .group_size
229 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
230 Ok(QuantizedConfig::GptqAwq {
231 bits,
232 group_size,
233 checkpoint_format: raw.checkpoint_format,
234 is_awq: m == "awq",
235 })
236 }
237 Some(m) if m == "fp8" => {
238 Ok(QuantizedConfig::Fp8 {
240 weight_block_size: raw.weight_block_size,
241 })
242 }
243 Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
244 bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
245 }),
246 Some(m) if m == "afq" => {
247 let bits = raw
248 .bits
249 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
250 let group_size = raw
251 .group_size
252 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
253 Ok(QuantizedConfig::Afq { bits, group_size })
254 }
255 Some(m) if m == "mxfp4" => {
256 Ok(QuantizedConfig::MXFP4 { })
257 }
258 None => {
259 let bits = raw
260 .bits
261 .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
262 let group_size = raw
263 .group_size
264 .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
265 Ok(QuantizedConfig::Afq { bits, group_size })
266 }
267 Some(unknown_method) => {
268 Err(serde::de::Error::custom(format!(
269 "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
270 )))
271 },
272 }
273 }
274}
275
276impl QuantizedConfig {
277 pub fn name(&self) -> &'static str {
278 match self {
279 Self::GptqAwq { .. } => "gptq",
280 Self::Fp8 { .. } => "fp8",
281 Self::Bitsandbytes { .. } => "bitsandbytes",
282 Self::Afq { .. } => "afq",
283 Self::MXFP4 { .. } => "mxfp4",
284 }
285 }
286
287 pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
288 match self {
289 Self::GptqAwq { bits, .. } => format!("{bits} bits"),
290 Self::Fp8 { .. } => "8 bits".to_string(),
291 Self::Bitsandbytes {
292 bnb_4bit_quant_type: Some(_),
293 } => "4 bits".to_string(),
294 Self::Bitsandbytes {
295 bnb_4bit_quant_type: None,
296 } => "8 bits".to_string(),
297 Self::Afq { bits, .. } => format!("{bits} bits"),
298 Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
299 }
300 }
301
302 pub fn pack_factor(&self, dtype: DType) -> usize {
303 match self {
304 Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
305 2 => IsqType::Q2K.pack_factor(dtype),
306 3 => IsqType::Q3K.pack_factor(dtype),
307 4 => IsqType::Q4K.pack_factor(dtype),
308 5 => IsqType::Q5K.pack_factor(dtype),
309 6 => IsqType::Q6K.pack_factor(dtype),
310 8 => IsqType::Q8_0.pack_factor(dtype),
311 40 => 4, other => panic!("Unexpected bits in `pack_factor` {other}"),
313 },
314 Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
315 Self::Bitsandbytes {
316 bnb_4bit_quant_type: Some(_),
317 }
318 | Self::Bitsandbytes {
319 bnb_4bit_quant_type: None,
320 } => IsqType::Q4K.pack_factor(dtype),
321 Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
322 }
323 }
324}
325
326#[derive(Debug, Clone)]
327pub enum QuantMethodConfig {
328 GptqAwq {
329 bits: i32,
330 use_exllama: bool,
331 q_weight: Tensor,
332 qzeros: Option<Tensor>,
333 scales: Tensor,
334 g_idx: Option<Tensor>,
335 bias: Option<Tensor>,
336 workspace: Option<Tensor>,
337 is_marlin: bool,
338 is_awq: bool,
339 },
340 Gguf {
341 q_weight: Arc<QTensor>,
342 b: Option<Tensor>,
343 },
344 Unquantized(Linear),
345 Hqq {
346 tensor: Tensor,
347 bits: HqqBits,
348 group_size: NonZeroUsize,
349 axis: HqqAxis,
350 optimization_steps: Option<usize>,
351 round_zeros: Option<bool>,
352 channel_wise: Option<bool>,
353 bias: Option<Tensor>,
354 },
355 Dummy,
356 FP8 {
357 lin: Linear,
358 dtype: DType,
359 },
360 Bnb {
361 weight: Tensor,
362 bias: Option<Tensor>,
363 params: BnbQuantParams,
364 quant_ty: BnbQuantType,
365 },
366 BlockwiseFP8 {
367 weight: Tensor,
368 weight_scale_inv: Tensor,
369 bias: Option<Tensor>,
370 dequant_dtype: DType,
371 weight_block_size: Vec<usize>,
372 },
373 PerTensorFP8 {
374 weight: Tensor,
375 weight_scale_inv: Tensor,
376 activation_scale: Option<Tensor>,
377 bias: Option<Tensor>,
378 dequant_dtype: DType,
379 },
380 Afq {
381 weight: Tensor,
382 bias: Option<Tensor>,
383 bits: AfqBits,
384 group_size: AfqGroupSize,
385 },
386 MXFP4 {
387 blocks: Tensor,
388 scales: Tensor,
389 bias: Option<Tensor>,
390 },
391}
392
393pub struct MatMul;
396
397impl MatMul {
398 pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
400 #[cfg(feature = "accelerate")]
401 {
402 let original_dtype = a.dtype();
403 a.to_dtype(DType::F32)?
404 .matmul(&b.to_dtype(DType::F32)?)?
405 .to_dtype(original_dtype)
406 }
407 #[cfg(not(feature = "accelerate"))]
408 {
409 if a.device().is_cpu() {
410 let original_dtype = a.dtype();
411 a.to_dtype(DType::F16)?
412 .matmul(&b.to_dtype(DType::F16)?)?
413 .to_dtype(original_dtype)
414 } else {
415 a.matmul(b)
416 }
417 }
418 }
419
420 pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
423 self.matmul(a, b)? / scale
425 }
426
427 pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
430 self.matmul(a, b)? * scale
432 }
433
434 pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
436 matmul.forward(x)
437 }
438
439 pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
441 matmul.forward(x)
442 }
443}
444
445pub struct Convolution;
448
449impl Convolution {
450 pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
451 if x.device().is_cpu() {
452 let original_dtype = x.dtype();
453 Conv1d::new(
454 layer.weight().to_dtype(DType::F32)?,
455 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
456 *layer.config(),
457 )
458 .forward(&x.to_dtype(DType::F32)?)?
459 .to_dtype(original_dtype)
460 } else {
461 layer.forward(x)
462 }
463 }
464
465 pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
466 if x.device().is_cpu() {
467 let original_dtype = x.dtype();
468 Conv2d::new(
469 layer.weight().to_dtype(DType::F32)?,
470 layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
471 *layer.config(),
472 )
473 .forward(&x.to_dtype(DType::F32)?)?
474 .to_dtype(original_dtype)
475 } else {
476 layer.forward(x)
477 }
478 }
479}
480
481#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
482pub enum IsqType {
483 Q4_0,
484 Q4_1,
485 Q5_0,
486 Q5_1,
487 Q8_0,
488 Q8_1,
489 Q2K,
490 Q3K,
491 Q4K,
492 Q5K,
493 Q6K,
494 Q8K,
495 HQQ8,
496 HQQ4,
497 F8E4M3,
501 AFQ8,
502 AFQ6,
503 AFQ4,
504 AFQ3,
505 AFQ2,
506}
507
508impl IsqType {
509 pub fn pack_factor(&self, dtype: DType) -> usize {
512 match self {
513 Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
514 .div_ceil(GgmlDType::Q4_0.type_size()),
515 Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
516 .div_ceil(GgmlDType::Q4_1.type_size()),
517 Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
518 .div_ceil(GgmlDType::Q5_0.type_size()),
519 Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
520 .div_ceil(GgmlDType::Q5_1.type_size()),
521 Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
522 .div_ceil(GgmlDType::Q8_0.type_size()),
523 Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
524 .div_ceil(GgmlDType::Q8_1.type_size()),
525 Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
526 .div_ceil(GgmlDType::Q2K.type_size()),
527 Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
528 .div_ceil(GgmlDType::Q3K.type_size()),
529 Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
530 .div_ceil(GgmlDType::Q4K.type_size()),
531 Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
532 .div_ceil(GgmlDType::Q5K.type_size()),
533 Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
534 .div_ceil(GgmlDType::Q6K.type_size()),
535 Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
536 .div_ceil(GgmlDType::Q8K.type_size()),
537 Self::HQQ4 => 4,
539 Self::HQQ8 => 2,
540 Self::F8E4M3 => 2,
541 }
542 }
543
544 pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
545 match self {
546 IsqType::HQQ4
548 | IsqType::HQQ8
549 | IsqType::AFQ2
550 | IsqType::AFQ3
551 | IsqType::AFQ4
552 | IsqType::AFQ6
553 | IsqType::AFQ8 => {
554 Some(1.try_into().unwrap())
556 }
557 IsqType::F8E4M3 => None,
558 IsqType::Q2K
559 | IsqType::Q3K
560 | IsqType::Q4K
561 | IsqType::Q4_0
562 | IsqType::Q4_1
563 | IsqType::Q5K
564 | IsqType::Q5_0
565 | IsqType::Q5_1
566 | IsqType::Q6K
567 | IsqType::Q8K
568 | IsqType::Q8_0
569 | IsqType::Q8_1 => None,
570 }
571 }
572}
573
574impl TryFrom<IsqType> for GgmlDType {
575 type Error = candle_core::Error;
576
577 fn try_from(value: IsqType) -> Result<Self> {
578 let tp = match value {
579 IsqType::Q2K => Self::Q2K,
580 IsqType::Q3K => Self::Q3K,
581 IsqType::Q4K => Self::Q4K,
582 IsqType::Q4_0 => Self::Q4_0,
583 IsqType::Q4_1 => Self::Q4_1,
584 IsqType::Q5K => Self::Q5K,
585 IsqType::Q5_0 => Self::Q5_0,
586 IsqType::Q5_1 => Self::Q5_1,
587 IsqType::Q6K => Self::Q6K,
588 IsqType::Q8K => Self::Q8K,
589 IsqType::Q8_0 => Self::Q8_0,
590 IsqType::Q8_1 => Self::Q8_1,
591 _ => candle_core::bail!("Expected valid GGML ISQ type."),
592 };
593 #[cfg(feature = "cuda")]
594 {
595 if !matches!(
596 tp,
597 GgmlDType::Q4_0
598 | GgmlDType::Q4_1
599 | GgmlDType::Q5_0
600 | GgmlDType::Q5_1
601 | GgmlDType::Q8_0
602 | GgmlDType::Q2K
603 | GgmlDType::Q3K
604 | GgmlDType::Q4K
605 | GgmlDType::Q5K
606 | GgmlDType::Q6K
607 ) {
608 candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
609 }
610 }
611 Ok(tp)
612 }
613}
614
615impl TryFrom<GgmlDType> for IsqType {
616 type Error = candle_core::Error;
617
618 fn try_from(value: GgmlDType) -> Result<Self> {
619 match value {
620 GgmlDType::Q2K => Ok(Self::Q2K),
621 GgmlDType::Q3K => Ok(Self::Q3K),
622 GgmlDType::Q4K => Ok(Self::Q4K),
623 GgmlDType::Q5K => Ok(Self::Q5K),
624 GgmlDType::Q6K => Ok(Self::Q6K),
625 GgmlDType::Q4_0 => Ok(Self::Q4_0),
626 GgmlDType::Q4_1 => Ok(Self::Q4_1),
627 GgmlDType::Q5_0 => Ok(Self::Q5_0),
628 GgmlDType::Q5_1 => Ok(Self::Q5_1),
629 GgmlDType::Q8_0 => Ok(Self::Q8_0),
630 GgmlDType::Q8_1 => Ok(Self::Q8_1),
631 GgmlDType::Q8K => Ok(Self::Q8K),
632 GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
633 candle_core::bail!("Expected valid GGML ISQ type.")
634 }
635 }
636 }
637}
638
639#[derive(Debug, Clone, Copy)]
640pub enum QuantizedSerdeType {
641 Gguf = 0,
642 Unquant = 1,
643 Hqq = 2,
644 Fp8 = 3,
645 Afq = 4,
646}
647
648impl TryFrom<usize> for QuantizedSerdeType {
649 type Error = candle_core::Error;
650 fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
651 match value {
652 0 => Ok(Self::Gguf),
653 1 => Ok(Self::Unquant),
654 2 => Ok(Self::Hqq),
655 3 => Ok(Self::Fp8),
656 4 => Ok(Self::Afq),
657 other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
658 }
659 }
660}
661
662pub trait QuantizedSerde {
663 fn name(&self) -> &'static str;
664 fn isq_serde_supported(&self) -> bool {
665 false
666 }
667 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
668 candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
669 }
670 fn deserialize(
671 _data: Cow<[u8]>,
672 _device: &Device,
673 _comm: &Arc<crate::Comm>,
674 _guard: QuantizeOntoGuard,
675 ) -> Result<Arc<dyn QuantMethod>>
676 where
677 Self: Sized,
678 {
679 candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
680 }
681 fn deserialize_ext_bias(
682 _data: Cow<[u8]>,
683 _device: &Device,
684 _guard: QuantizeOntoGuard,
685 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
686 where
687 Self: Sized,
688 {
689 candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
690 }
691 fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
693 candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
694 }
695}
696
697#[derive(Clone, Debug)]
699#[allow(unused)]
700pub struct QuantizeOntoGuard {
701 pub inner: Arc<Mutex<()>>,
702}
703
704pub enum QuantizeOntoDropGuard<'a> {
706 Real(MutexGuard<'a, ()>),
707 Fake,
708}
709
710impl Default for QuantizeOntoGuard {
711 fn default() -> Self {
712 Self::new()
713 }
714}
715
716impl QuantizeOntoGuard {
717 pub fn new() -> Self {
718 QuantizeOntoGuard {
719 inner: Arc::new(Mutex::new(())),
720 }
721 }
722
723 pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
727 #[cfg(feature = "cuda")]
728 {
729 let _ = device;
730 QuantizeOntoDropGuard::Fake
731 }
732
733 #[cfg(not(feature = "cuda"))]
734 {
735 #[cfg(feature = "metal")]
736 if let Device::Metal(dev) = device {
737 dev.wait_until_completed()
739 .expect("Failed to flush command buffer.");
740 }
741 #[cfg(not(feature = "metal"))]
742 let _ = device;
743
744 QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
745 }
746 }
747}
748
749pub enum DistributedKind {
750 ColumnParallel,
751 RowParallel,
752 Replicated,
753}
754
755pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
757 fn new(method: QuantMethodConfig) -> Result<Self>
758 where
759 Self: Sized;
760
761 fn dequantize_w(&self) -> Result<Tensor>;
762
763 fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
766 let original_ty = a.dtype();
767 let a = if let Some(t) = self.quantized_act_type() {
768 a.to_dtype(t)?
769 } else {
770 a.clone()
771 };
772 self.forward(&a)?.to_dtype(original_ty)
773 }
774
775 fn forward(&self, a: &Tensor) -> Result<Tensor>;
777
778 fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
784 let original_ty = a.dtype();
785 let a = if let Some(t) = self.quantized_act_type() {
786 a.to_dtype(t)?
787 } else {
788 a.clone()
789 };
790 self.gather_forward(&a, indices)?.to_dtype(original_ty)
791 }
792
793 fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
798 candle_core::bail!(
799 "{} does not support `gather_forward`. Please raise an issue.",
800 self.name()
801 )
802 }
803
804 fn quantized_act_type(&self) -> Option<DType>;
806
807 fn dtype_and_device(&self) -> (DType, Device);
809
810 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
812
813 fn apply_isq(
815 self: Arc<Self>,
816 dtype: Option<IsqType>,
817 device: Device,
818 n_quantized: &AtomicUsize,
819 imatrix_weight: Option<Vec<f32>>,
820 guard: QuantizeOntoGuard,
821 ) -> Result<Arc<dyn QuantMethod>>;
822
823 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
824 None
825 }
826
827 fn begin_track_stats(&mut self) -> Result<()> {
829 candle_core::bail!("`{}` does not support tracking stats.", self.name())
830 }
831
832 fn end_track_stats(&self) -> Result<Tensor> {
834 candle_core::bail!("`{}` does not support tracking stats.", self.name())
835 }
836
837 fn is_distributed(&self) -> Option<DistributedKind> {
838 None
839 }
840}
841
842impl Module for dyn QuantMethod {
843 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
844 Self::forward(self, xs)
845 }
846}
847
848pub fn linear_no_bias(
849 in_dim: usize,
850 out_dim: usize,
851 config: &Option<QuantizedConfig>,
852 vb: ShardedVarBuilder,
853) -> Result<Arc<dyn QuantMethod>> {
854 let base_vb = vb.clone();
855 let vb = if should_apply_immediate_isq(&vb) {
856 vb.set_device(Device::Cpu)
857 } else {
858 vb
859 };
860
861 let layer = if let Some(quant_conf) = &config {
862 match quant_conf {
863 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
864 QuantizedConfig::Fp8 { weight_block_size } => {
865 if weight_block_size.is_some() {
866 blockwise_fp8_linear_b(
867 in_dim,
868 out_dim,
869 quant_conf,
870 false,
871 Default::default(),
872 vb,
873 )?
874 } else {
875 pertensor_fp8_linear_b(
876 in_dim,
877 out_dim,
878 quant_conf,
879 false,
880 Default::default(),
881 vb,
882 )?
883 }
884 }
885 QuantizedConfig::Bitsandbytes { .. } => {
886 Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
887 }
888 QuantizedConfig::Afq { .. } => {
889 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
890 }
891 QuantizedConfig::MXFP4 {} => {
892 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
893 }
894 }
895 } else {
896 if !vb.contains_tensor("weight") {
898 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
899 Arc::new(layer) as Arc<dyn QuantMethod>
900 } else {
901 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
902 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
903
904 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
905 Linear::new(weight, None),
906 ))?;
907 Arc::new(layer) as Arc<dyn QuantMethod>
908 }
909 };
910 apply_immediate_isq(layer, base_vb)
911}
912
913pub fn linear(
914 in_dim: usize,
915 out_dim: usize,
916 config: &Option<QuantizedConfig>,
917 vb: ShardedVarBuilder,
918) -> Result<Arc<dyn QuantMethod>> {
919 let base_vb = vb.clone();
920 let vb = if should_apply_immediate_isq(&vb) {
921 vb.set_device(Device::Cpu)
922 } else {
923 vb
924 };
925
926 let layer = if let Some(quant_conf) = &config {
927 match quant_conf {
928 QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
929 QuantizedConfig::Fp8 { weight_block_size } => {
930 if weight_block_size.is_some() {
931 blockwise_fp8_linear_b(
932 in_dim,
933 out_dim,
934 quant_conf,
935 true,
936 Default::default(),
937 vb,
938 )?
939 } else {
940 pertensor_fp8_linear_b(
941 in_dim,
942 out_dim,
943 quant_conf,
944 true,
945 Default::default(),
946 vb,
947 )?
948 }
949 }
950 QuantizedConfig::Bitsandbytes { .. } => {
951 Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
952 }
953 QuantizedConfig::Afq { .. } => {
954 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
955 }
956 QuantizedConfig::MXFP4 {} => {
957 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
958 }
959 }
960 } else {
961 if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
963 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
964 Arc::new(layer) as Arc<dyn QuantMethod>
965 } else {
966 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
967 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
968 let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
969
970 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
971 Linear::new(weight, Some(bias)),
972 ))?;
973 Arc::new(layer) as Arc<dyn QuantMethod>
974 }
975 };
976 apply_immediate_isq(layer, base_vb)
977}
978
979pub fn linear_b(
980 in_dim: usize,
981 out_dim: usize,
982 bias: bool,
983 config: &Option<QuantizedConfig>,
984 vb: ShardedVarBuilder,
985) -> Result<Arc<dyn QuantMethod>> {
986 if bias {
987 linear(in_dim, out_dim, config, vb)
988 } else {
989 linear_no_bias(in_dim, out_dim, config, vb)
990 }
991}