Skip to main content

mistralrs_quant/
lib.rs

1use std::{
2    borrow::Cow,
3    fmt::Debug,
4    num::NonZeroUsize,
5    sync::{atomic::AtomicUsize, Arc, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10    quantized::{GgmlDType, QMatMul, QTensor},
11    DType, Device, Result, Tensor,
12};
13use pertensor_fp8::pertensor_fp8_linear_b;
14
15#[cfg(feature = "metal")]
16mod metal_kernels;
17
18mod afq;
19mod bitsandbytes;
20mod blockwise_fp8;
21pub mod cublaslt;
22pub mod distributed;
23mod dummy;
24mod fp8;
25pub mod gemv;
26mod gguf;
27mod gptq;
28mod hqq;
29mod imatrix;
30mod lora;
31mod mxfp4;
32mod pertensor_fp8;
33pub mod rotary;
34pub mod safetensors;
35mod scalar_fp8;
36mod unquantized;
37mod utils;
38mod vector_fp8;
39
40use gptq::gptq_linear;
41use lora::merge_lora_weights;
42use regex::Regex;
43pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
44
45pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
46pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
47pub use blockwise_fp8::{
48    blockwise_fp8_moe, fp8_blockwise_dequantize, fp8_blockwise_quantize, BlockwiseFP8Linear,
49};
50pub use distributed::{
51    layers::{
52        compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
53        ReplicatedLayer, RowParallelLayer,
54    },
55    socket::{Client, Server},
56    BarrierLike, Comm, Id, RingConfig, SumAllReduce,
57};
58pub use dummy::DummyLayer;
59pub use fp8::FP8Linear;
60#[cfg(feature = "cuda")]
61pub use gemv::gemv;
62pub use gemv::{should_use_gemv, GEMV_CONTROLLER};
63pub use gguf::GgufMatMul;
64pub use gptq::GptqLayer;
65pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
66pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
67pub use lora::{
68    clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
69    LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
70};
71pub use mxfp4::MXFP4Layer;
72pub use pertensor_fp8::PerTensorFP8Linear;
73pub use unquantized::UnquantLinear;
74#[cfg(feature = "cuda")]
75pub use utils::gptoss_swiglu_fused;
76#[cfg(feature = "cuda")]
77pub use utils::gptoss_swiglu_interleaved;
78pub use utils::isq::apply_immediate_isq;
79#[cfg(feature = "cuda")]
80pub use utils::softmax_with_sinks;
81pub use utils::{fused_glu, GluActivationType};
82pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
83pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
84
85use candle_nn::{Conv1d, Conv2d, Linear, Module};
86use serde::{Deserialize, Deserializer, Serialize};
87
88#[derive(Clone, Debug)]
89pub struct ImmediateIsqParams {
90    pub guard: QuantizeOntoGuard,
91    pub ty: Option<IsqType>,
92    pub predicates: Vec<Regex>,
93    pub overrides: Vec<ImmediateIsqOverride>,
94}
95
96#[derive(Clone, Debug)]
97pub struct ImmediateIsqOverride {
98    pub predicate: Regex,
99    pub ty: Option<IsqType>,
100    pub device: Option<Device>,
101}
102
103#[derive(Clone, Debug)]
104pub struct ImmediateIsqMatch {
105    pub ty: IsqType,
106    pub device: Option<Device>,
107}
108
109thread_local! {
110    static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
111}
112
113pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
114    set_immediate_isq_with_overrides(isq, predicates, Vec::new());
115}
116
117pub fn set_immediate_isq_with_overrides(
118    isq: Option<IsqType>,
119    predicates: Vec<Regex>,
120    overrides: Vec<ImmediateIsqOverride>,
121) {
122    ENGINE_IMMEDIATE_ISQ.with(|cell| {
123        *cell.borrow_mut() = Some(ImmediateIsqParams {
124            guard: QuantizeOntoGuard::new(),
125            ty: isq,
126            predicates,
127            overrides,
128        });
129    });
130}
131
132pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
133    ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
134}
135
136pub fn clear_immediate_isq() {
137    ENGINE_IMMEDIATE_ISQ.with(|cell| {
138        *cell.borrow_mut() = None;
139    });
140}
141
142pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
143    immediate_isq_match(vb).is_some()
144}
145
146pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
147    let immediate_isq = get_immediate_isq()?;
148    // Add a .weight to match the ISQ regexes!
149    let prefix = format!("{}.weight", vb.prefix());
150    resolve_immediate_isq(&immediate_isq, &prefix)
151}
152
153fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
154    if let Some(override_hit) = params
155        .overrides
156        .iter()
157        .find(|override_pred| override_pred.predicate.is_match(prefix))
158    {
159        if let Some(ty) = override_hit.ty.or(params.ty) {
160            return Some(ImmediateIsqMatch {
161                ty,
162                device: override_hit.device.clone(),
163            });
164        }
165        return None;
166    }
167
168    if let Some(ty) = params.ty {
169        if params
170            .predicates
171            .iter()
172            .any(|predicate| predicate.is_match(prefix))
173        {
174            return Some(ImmediateIsqMatch { ty, device: None });
175        }
176    }
177
178    None
179}
180
181#[derive(Debug, Clone, Serialize)]
182#[serde(tag = "quant_method", rename_all = "lowercase")]
183pub enum QuantizedConfig {
184    GptqAwq {
185        bits: usize,
186        group_size: usize,
187        checkpoint_format: Option<String>,
188        is_awq: bool,
189    },
190    Fp8 {
191        weight_block_size: Option<Vec<usize>>,
192    },
193    Bitsandbytes {
194        bnb_4bit_quant_type: Option<String>,
195    },
196    Afq {
197        bits: usize,
198        group_size: usize,
199    },
200    MXFP4 {},
201}
202
203// Common fields for all variants
204#[derive(Deserialize)]
205struct RawConfig {
206    quant_method: Option<String>,
207    bits: Option<usize>,
208    group_size: Option<usize>,
209    checkpoint_format: Option<String>,
210    weight_block_size: Option<Vec<usize>>,
211    bnb_4bit_quant_type: Option<String>,
212}
213
214// Custom deserializer implementation
215impl<'de> Deserialize<'de> for QuantizedConfig {
216    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
217    where
218        D: Deserializer<'de>,
219    {
220        let raw = RawConfig::deserialize(deserializer)?;
221
222        match &raw.quant_method {
223            Some(m) if m == "gptq" || m == "awq" => {
224                let bits = raw
225                    .bits
226                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
227                let group_size = raw
228                    .group_size
229                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
230                Ok(QuantizedConfig::GptqAwq {
231                    bits,
232                    group_size,
233                    checkpoint_format: raw.checkpoint_format,
234                    is_awq: m == "awq",
235                })
236            }
237            Some(m) if m == "fp8" => {
238                // weight_block_size is optional - None means per-tensor quantization
239                Ok(QuantizedConfig::Fp8 {
240                    weight_block_size: raw.weight_block_size,
241                })
242            }
243            Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
244                bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
245            }),
246            Some(m) if m == "afq" => {
247                let bits = raw
248                    .bits
249                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
250                let group_size = raw
251                    .group_size
252                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
253                Ok(QuantizedConfig::Afq { bits, group_size })
254            }
255            Some(m) if m == "mxfp4" => {
256                Ok(QuantizedConfig::MXFP4 {  })
257            }
258            None => {
259                let bits = raw
260                    .bits
261                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
262                let group_size = raw
263                    .group_size
264                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
265                Ok(QuantizedConfig::Afq { bits, group_size })
266            }
267            Some(unknown_method) => {
268                Err(serde::de::Error::custom(format!(
269                    "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
270                )))
271            },
272        }
273    }
274}
275
276impl QuantizedConfig {
277    pub fn name(&self) -> &'static str {
278        match self {
279            Self::GptqAwq { .. } => "gptq",
280            Self::Fp8 { .. } => "fp8",
281            Self::Bitsandbytes { .. } => "bitsandbytes",
282            Self::Afq { .. } => "afq",
283            Self::MXFP4 { .. } => "mxfp4",
284        }
285    }
286
287    pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
288        match self {
289            Self::GptqAwq { bits, .. } => format!("{bits} bits"),
290            Self::Fp8 { .. } => "8 bits".to_string(),
291            Self::Bitsandbytes {
292                bnb_4bit_quant_type: Some(_),
293            } => "4 bits".to_string(),
294            Self::Bitsandbytes {
295                bnb_4bit_quant_type: None,
296            } => "8 bits".to_string(),
297            Self::Afq { bits, .. } => format!("{bits} bits"),
298            Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
299        }
300    }
301
302    pub fn pack_factor(&self, dtype: DType) -> usize {
303        match self {
304            Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
305                2 => IsqType::Q2K.pack_factor(dtype),
306                3 => IsqType::Q3K.pack_factor(dtype),
307                4 => IsqType::Q4K.pack_factor(dtype),
308                5 => IsqType::Q5K.pack_factor(dtype),
309                6 => IsqType::Q6K.pack_factor(dtype),
310                8 => IsqType::Q8_0.pack_factor(dtype),
311                40 => 4, // mxfp4: 2 FP4 values per byte = factor of 4
312                other => panic!("Unexpected bits in `pack_factor` {other}"),
313            },
314            Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
315            Self::Bitsandbytes {
316                bnb_4bit_quant_type: Some(_),
317            }
318            | Self::Bitsandbytes {
319                bnb_4bit_quant_type: None,
320            } => IsqType::Q4K.pack_factor(dtype),
321            Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
322        }
323    }
324}
325
326#[derive(Debug, Clone)]
327pub enum QuantMethodConfig {
328    GptqAwq {
329        bits: i32,
330        use_exllama: bool,
331        q_weight: Tensor,
332        qzeros: Option<Tensor>,
333        scales: Tensor,
334        g_idx: Option<Tensor>,
335        bias: Option<Tensor>,
336        workspace: Option<Tensor>,
337        is_marlin: bool,
338        is_awq: bool,
339    },
340    Gguf {
341        q_weight: Arc<QTensor>,
342        b: Option<Tensor>,
343    },
344    Unquantized(Linear),
345    Hqq {
346        tensor: Tensor,
347        bits: HqqBits,
348        group_size: NonZeroUsize,
349        axis: HqqAxis,
350        optimization_steps: Option<usize>,
351        round_zeros: Option<bool>,
352        channel_wise: Option<bool>,
353        bias: Option<Tensor>,
354    },
355    Dummy,
356    FP8 {
357        lin: Linear,
358        dtype: DType,
359    },
360    Bnb {
361        weight: Tensor,
362        bias: Option<Tensor>,
363        params: BnbQuantParams,
364        quant_ty: BnbQuantType,
365    },
366    BlockwiseFP8 {
367        weight: Tensor,
368        weight_scale_inv: Tensor,
369        bias: Option<Tensor>,
370        dequant_dtype: DType,
371        weight_block_size: Vec<usize>,
372    },
373    PerTensorFP8 {
374        weight: Tensor,
375        weight_scale_inv: Tensor,
376        activation_scale: Option<Tensor>,
377        bias: Option<Tensor>,
378        dequant_dtype: DType,
379    },
380    Afq {
381        weight: Tensor,
382        bias: Option<Tensor>,
383        bits: AfqBits,
384        group_size: AfqGroupSize,
385    },
386    MXFP4 {
387        blocks: Tensor,
388        scales: Tensor,
389        bias: Option<Tensor>,
390    },
391}
392
393/// Device/configurable intelligent matrix multiplication
394/// - Handles limitation of `accelerate` which requires f32
395pub struct MatMul;
396
397impl MatMul {
398    /// Compute matrix-matrix product.
399    pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
400        #[cfg(feature = "accelerate")]
401        {
402            let original_dtype = a.dtype();
403            a.to_dtype(DType::F32)?
404                .matmul(&b.to_dtype(DType::F32)?)?
405                .to_dtype(original_dtype)
406        }
407        #[cfg(not(feature = "accelerate"))]
408        {
409            if a.device().is_cpu() {
410                let original_dtype = a.dtype();
411                a.to_dtype(DType::F16)?
412                    .matmul(&b.to_dtype(DType::F16)?)?
413                    .to_dtype(original_dtype)
414            } else {
415                a.matmul(b)
416            }
417        }
418    }
419
420    /// Compute matrix-matrix product.
421    /// The result will be divided by the `scale` parameter in an affine division.
422    pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
423        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
424        self.matmul(a, b)? / scale
425    }
426
427    /// Compute matrix-matrix product.
428    /// The result will be divided by the `scale` parameter in an affine multiplication.
429    pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
430        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
431        self.matmul(a, b)? * scale
432    }
433
434    /// Compute quantized matrix-matrix product.
435    pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
436        matmul.forward(x)
437    }
438
439    /// Compute quantized matrix-matrix product.
440    pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
441        matmul.forward(x)
442    }
443}
444
445/// Device/configurable intelligent convolution
446/// - Handles limitation of cpu which requires f32
447pub struct Convolution;
448
449impl Convolution {
450    pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
451        if x.device().is_cpu() {
452            let original_dtype = x.dtype();
453            Conv1d::new(
454                layer.weight().to_dtype(DType::F32)?,
455                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
456                *layer.config(),
457            )
458            .forward(&x.to_dtype(DType::F32)?)?
459            .to_dtype(original_dtype)
460        } else {
461            layer.forward(x)
462        }
463    }
464
465    pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
466        if x.device().is_cpu() {
467            let original_dtype = x.dtype();
468            Conv2d::new(
469                layer.weight().to_dtype(DType::F32)?,
470                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
471                *layer.config(),
472            )
473            .forward(&x.to_dtype(DType::F32)?)?
474            .to_dtype(original_dtype)
475        } else {
476            layer.forward(x)
477        }
478    }
479}
480
481#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
482pub enum IsqType {
483    Q4_0,
484    Q4_1,
485    Q5_0,
486    Q5_1,
487    Q8_0,
488    Q8_1,
489    Q2K,
490    Q3K,
491    Q4K,
492    Q5K,
493    Q6K,
494    Q8K,
495    HQQ8,
496    HQQ4,
497    // HQQ3,
498    // HQQ2,
499    // HQQ1,
500    F8E4M3,
501    AFQ8,
502    AFQ6,
503    AFQ4,
504    AFQ3,
505    AFQ2,
506}
507
508impl IsqType {
509    /// Factor by which the weight size is reduced over the given dtype.
510    /// original size / pack factor = quantized size
511    pub fn pack_factor(&self, dtype: DType) -> usize {
512        match self {
513            Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
514                .div_ceil(GgmlDType::Q4_0.type_size()),
515            Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
516                .div_ceil(GgmlDType::Q4_1.type_size()),
517            Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
518                .div_ceil(GgmlDType::Q5_0.type_size()),
519            Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
520                .div_ceil(GgmlDType::Q5_1.type_size()),
521            Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
522                .div_ceil(GgmlDType::Q8_0.type_size()),
523            Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
524                .div_ceil(GgmlDType::Q8_1.type_size()),
525            Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
526                .div_ceil(GgmlDType::Q2K.type_size()),
527            Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
528                .div_ceil(GgmlDType::Q3K.type_size()),
529            Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
530                .div_ceil(GgmlDType::Q4K.type_size()),
531            Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
532                .div_ceil(GgmlDType::Q5K.type_size()),
533            Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
534                .div_ceil(GgmlDType::Q6K.type_size()),
535            Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
536                .div_ceil(GgmlDType::Q8K.type_size()),
537            // Estimates
538            Self::HQQ4 => 4,
539            Self::HQQ8 => 2,
540            Self::F8E4M3 => 2,
541        }
542    }
543
544    pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
545        match self {
546            /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
547            IsqType::HQQ4
548            | IsqType::HQQ8
549            | IsqType::AFQ2
550            | IsqType::AFQ3
551            | IsqType::AFQ4
552            | IsqType::AFQ6
553            | IsqType::AFQ8 => {
554                // Use 1 because our HQQ quantizes on the GPU
555                Some(1.try_into().unwrap())
556            }
557            IsqType::F8E4M3 => None,
558            IsqType::Q2K
559            | IsqType::Q3K
560            | IsqType::Q4K
561            | IsqType::Q4_0
562            | IsqType::Q4_1
563            | IsqType::Q5K
564            | IsqType::Q5_0
565            | IsqType::Q5_1
566            | IsqType::Q6K
567            | IsqType::Q8K
568            | IsqType::Q8_0
569            | IsqType::Q8_1 => None,
570        }
571    }
572}
573
574impl TryFrom<IsqType> for GgmlDType {
575    type Error = candle_core::Error;
576
577    fn try_from(value: IsqType) -> Result<Self> {
578        let tp = match value {
579            IsqType::Q2K => Self::Q2K,
580            IsqType::Q3K => Self::Q3K,
581            IsqType::Q4K => Self::Q4K,
582            IsqType::Q4_0 => Self::Q4_0,
583            IsqType::Q4_1 => Self::Q4_1,
584            IsqType::Q5K => Self::Q5K,
585            IsqType::Q5_0 => Self::Q5_0,
586            IsqType::Q5_1 => Self::Q5_1,
587            IsqType::Q6K => Self::Q6K,
588            IsqType::Q8K => Self::Q8K,
589            IsqType::Q8_0 => Self::Q8_0,
590            IsqType::Q8_1 => Self::Q8_1,
591            _ => candle_core::bail!("Expected valid GGML ISQ type."),
592        };
593        #[cfg(feature = "cuda")]
594        {
595            if !matches!(
596                tp,
597                GgmlDType::Q4_0
598                    | GgmlDType::Q4_1
599                    | GgmlDType::Q5_0
600                    | GgmlDType::Q5_1
601                    | GgmlDType::Q8_0
602                    | GgmlDType::Q2K
603                    | GgmlDType::Q3K
604                    | GgmlDType::Q4K
605                    | GgmlDType::Q5K
606                    | GgmlDType::Q6K
607            ) {
608                candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
609            }
610        }
611        Ok(tp)
612    }
613}
614
615impl TryFrom<GgmlDType> for IsqType {
616    type Error = candle_core::Error;
617
618    fn try_from(value: GgmlDType) -> Result<Self> {
619        match value {
620            GgmlDType::Q2K => Ok(Self::Q2K),
621            GgmlDType::Q3K => Ok(Self::Q3K),
622            GgmlDType::Q4K => Ok(Self::Q4K),
623            GgmlDType::Q5K => Ok(Self::Q5K),
624            GgmlDType::Q6K => Ok(Self::Q6K),
625            GgmlDType::Q4_0 => Ok(Self::Q4_0),
626            GgmlDType::Q4_1 => Ok(Self::Q4_1),
627            GgmlDType::Q5_0 => Ok(Self::Q5_0),
628            GgmlDType::Q5_1 => Ok(Self::Q5_1),
629            GgmlDType::Q8_0 => Ok(Self::Q8_0),
630            GgmlDType::Q8_1 => Ok(Self::Q8_1),
631            GgmlDType::Q8K => Ok(Self::Q8K),
632            GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
633                candle_core::bail!("Expected valid GGML ISQ type.")
634            }
635        }
636    }
637}
638
639#[derive(Debug, Clone, Copy)]
640pub enum QuantizedSerdeType {
641    Gguf = 0,
642    Unquant = 1,
643    Hqq = 2,
644    Fp8 = 3,
645    Afq = 4,
646}
647
648impl TryFrom<usize> for QuantizedSerdeType {
649    type Error = candle_core::Error;
650    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
651        match value {
652            0 => Ok(Self::Gguf),
653            1 => Ok(Self::Unquant),
654            2 => Ok(Self::Hqq),
655            3 => Ok(Self::Fp8),
656            4 => Ok(Self::Afq),
657            other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
658        }
659    }
660}
661
662pub trait QuantizedSerde {
663    fn name(&self) -> &'static str;
664    fn isq_serde_supported(&self) -> bool {
665        false
666    }
667    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
668        candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
669    }
670    fn deserialize(
671        _data: Cow<[u8]>,
672        _device: &Device,
673        _comm: &Arc<crate::Comm>,
674        _guard: QuantizeOntoGuard,
675    ) -> Result<Arc<dyn QuantMethod>>
676    where
677        Self: Sized,
678    {
679        candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
680    }
681    fn deserialize_ext_bias(
682        _data: Cow<[u8]>,
683        _device: &Device,
684        _guard: QuantizeOntoGuard,
685    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
686    where
687        Self: Sized,
688    {
689        candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
690    }
691    /// NOT meant for external calling
692    fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
693        candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
694    }
695}
696
697/// Used to gate access to quantizing onto the host device
698#[derive(Clone, Debug)]
699#[allow(unused)]
700pub struct QuantizeOntoGuard {
701    pub inner: Arc<Mutex<()>>,
702}
703
704/// Real (for Metal) and Fake (for CUDA)
705pub enum QuantizeOntoDropGuard<'a> {
706    Real(MutexGuard<'a, ()>),
707    Fake,
708}
709
710impl Default for QuantizeOntoGuard {
711    fn default() -> Self {
712        Self::new()
713    }
714}
715
716impl QuantizeOntoGuard {
717    pub fn new() -> Self {
718        QuantizeOntoGuard {
719            inner: Arc::new(Mutex::new(())),
720        }
721    }
722
723    /// Acquire the quantize drop guard to protect the critical section.
724    ///
725    /// On metal, this waits for outstanding work to finish to avoid "A command encoder is already encoding to this command buffer"
726    pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
727        #[cfg(feature = "cuda")]
728        {
729            let _ = device;
730            QuantizeOntoDropGuard::Fake
731        }
732
733        #[cfg(not(feature = "cuda"))]
734        {
735            #[cfg(feature = "metal")]
736            if let Device::Metal(dev) = device {
737                // This is necessary to avoid the errors of "A command encoder is already encoding to this command buffer"
738                dev.wait_until_completed()
739                    .expect("Failed to flush command buffer.");
740            }
741            #[cfg(not(feature = "metal"))]
742            let _ = device;
743
744            QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
745        }
746    }
747}
748
749pub enum DistributedKind {
750    ColumnParallel,
751    RowParallel,
752    Replicated,
753}
754
755/// Quantized method for a quantized matmul.
756pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
757    fn new(method: QuantMethodConfig) -> Result<Self>
758    where
759        Self: Sized;
760
761    fn dequantize_w(&self) -> Result<Tensor>;
762
763    /// Compute matmul of `self` and `a`. `self` should contain the weights.
764    /// Automatically cast to required quantization activation type and back
765    fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
766        let original_ty = a.dtype();
767        let a = if let Some(t) = self.quantized_act_type() {
768            a.to_dtype(t)?
769        } else {
770            a.clone()
771        };
772        self.forward(&a)?.to_dtype(original_ty)
773    }
774
775    /// Compute matmul of `self` and `a`. `self` should contain the weights.
776    fn forward(&self, a: &Tensor) -> Result<Tensor>;
777
778    /// Compute matmul of `self` and `a`. `self` should contain the weights.
779    /// Automatically cast to required quantization activation type and back.
780    ///
781    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
782    /// then the indices are (n_tokens, n_experts).
783    fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
784        let original_ty = a.dtype();
785        let a = if let Some(t) = self.quantized_act_type() {
786            a.to_dtype(t)?
787        } else {
788            a.clone()
789        };
790        self.gather_forward(&a, indices)?.to_dtype(original_ty)
791    }
792
793    /// Compute matmul of `self` and `a`. `self` should contain the weights.
794    ///
795    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
796    /// then the indices are (n_tokens, n_experts).
797    fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
798        candle_core::bail!(
799            "{} does not support `gather_forward`. Please raise an issue.",
800            self.name()
801        )
802    }
803
804    /// If a quantized method, return the activation dtype.
805    fn quantized_act_type(&self) -> Option<DType>;
806
807    /// Weight dtype and device
808    fn dtype_and_device(&self) -> (DType, Device);
809
810    /// Add a delta weight from LoRA to the weights. This should be prescaled with alpha.
811    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
812
813    /// If the quant is backed by a qmatmul.
814    fn apply_isq(
815        self: Arc<Self>,
816        dtype: Option<IsqType>,
817        device: Device,
818        n_quantized: &AtomicUsize,
819        imatrix_weight: Option<Vec<f32>>,
820        guard: QuantizeOntoGuard,
821    ) -> Result<Arc<dyn QuantMethod>>;
822
823    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
824        None
825    }
826
827    /// Begin tracking stats into an ImatrixLayerStats
828    fn begin_track_stats(&mut self) -> Result<()> {
829        candle_core::bail!("`{}` does not support tracking stats.", self.name())
830    }
831
832    /// End tracking stats into an ImatrixLayerStats. Returns the computed imatrix.
833    fn end_track_stats(&self) -> Result<Tensor> {
834        candle_core::bail!("`{}` does not support tracking stats.", self.name())
835    }
836
837    fn is_distributed(&self) -> Option<DistributedKind> {
838        None
839    }
840}
841
842impl Module for dyn QuantMethod {
843    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
844        Self::forward(self, xs)
845    }
846}
847
848pub fn linear_no_bias(
849    in_dim: usize,
850    out_dim: usize,
851    config: &Option<QuantizedConfig>,
852    vb: ShardedVarBuilder,
853) -> Result<Arc<dyn QuantMethod>> {
854    let base_vb = vb.clone();
855    let vb = if should_apply_immediate_isq(&vb) {
856        vb.set_device(Device::Cpu)
857    } else {
858        vb
859    };
860
861    let layer = if let Some(quant_conf) = &config {
862        match quant_conf {
863            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
864            QuantizedConfig::Fp8 { weight_block_size } => {
865                if weight_block_size.is_some() {
866                    blockwise_fp8_linear_b(
867                        in_dim,
868                        out_dim,
869                        quant_conf,
870                        false,
871                        Default::default(),
872                        vb,
873                    )?
874                } else {
875                    pertensor_fp8_linear_b(
876                        in_dim,
877                        out_dim,
878                        quant_conf,
879                        false,
880                        Default::default(),
881                        vb,
882                    )?
883                }
884            }
885            QuantizedConfig::Bitsandbytes { .. } => {
886                Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
887            }
888            QuantizedConfig::Afq { .. } => {
889                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
890            }
891            QuantizedConfig::MXFP4 {} => {
892                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
893            }
894        }
895    } else {
896        // Handle the case where the layer is dummy (no tensors)
897        if !vb.contains_tensor("weight") {
898            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
899            Arc::new(layer) as Arc<dyn QuantMethod>
900        } else {
901            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
902            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
903
904            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
905                Linear::new(weight, None),
906            ))?;
907            Arc::new(layer) as Arc<dyn QuantMethod>
908        }
909    };
910    apply_immediate_isq(layer, base_vb)
911}
912
913pub fn linear(
914    in_dim: usize,
915    out_dim: usize,
916    config: &Option<QuantizedConfig>,
917    vb: ShardedVarBuilder,
918) -> Result<Arc<dyn QuantMethod>> {
919    let base_vb = vb.clone();
920    let vb = if should_apply_immediate_isq(&vb) {
921        vb.set_device(Device::Cpu)
922    } else {
923        vb
924    };
925
926    let layer = if let Some(quant_conf) = &config {
927        match quant_conf {
928            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
929            QuantizedConfig::Fp8 { weight_block_size } => {
930                if weight_block_size.is_some() {
931                    blockwise_fp8_linear_b(
932                        in_dim,
933                        out_dim,
934                        quant_conf,
935                        true,
936                        Default::default(),
937                        vb,
938                    )?
939                } else {
940                    pertensor_fp8_linear_b(
941                        in_dim,
942                        out_dim,
943                        quant_conf,
944                        true,
945                        Default::default(),
946                        vb,
947                    )?
948                }
949            }
950            QuantizedConfig::Bitsandbytes { .. } => {
951                Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
952            }
953            QuantizedConfig::Afq { .. } => {
954                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
955            }
956            QuantizedConfig::MXFP4 {} => {
957                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
958            }
959        }
960    } else {
961        // Handle the case where the layer is dummy (no tensors)
962        if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
963            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
964            Arc::new(layer) as Arc<dyn QuantMethod>
965        } else {
966            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
967            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
968            let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
969
970            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
971                Linear::new(weight, Some(bias)),
972            ))?;
973            Arc::new(layer) as Arc<dyn QuantMethod>
974        }
975    };
976    apply_immediate_isq(layer, base_vb)
977}
978
979pub fn linear_b(
980    in_dim: usize,
981    out_dim: usize,
982    bias: bool,
983    config: &Option<QuantizedConfig>,
984    vb: ShardedVarBuilder,
985) -> Result<Arc<dyn QuantMethod>> {
986    if bias {
987        linear(in_dim, out_dim, config, vb)
988    } else {
989        linear_no_bias(in_dim, out_dim, config, vb)
990    }
991}