Skip to main content

mistralrs_quant/
lib.rs

1use std::{
2    borrow::Cow,
3    fmt::Debug,
4    num::NonZeroUsize,
5    sync::{atomic::AtomicUsize, Arc, Condvar, Mutex, MutexGuard},
6};
7
8use blockwise_fp8::blockwise_fp8_linear_b;
9use candle_core::{
10    quantized::{GgmlDType, QMatMul, QTensor},
11    DType, Device, Result, Tensor,
12};
13use pertensor_fp8::pertensor_fp8_linear_b;
14
15#[cfg(feature = "metal")]
16mod metal_kernels;
17
18mod afq;
19mod bitsandbytes;
20mod blockwise_fp8;
21pub mod cublaslt;
22pub mod distributed;
23mod dummy;
24pub mod f8q8;
25mod fp8;
26pub mod gemv;
27mod gguf;
28mod gptq;
29mod hqq;
30mod imatrix;
31mod lora;
32mod mxfp4;
33mod pending_layer;
34mod pertensor_fp8;
35pub mod rotary;
36pub mod safetensors;
37mod scalar_fp8;
38mod unquantized;
39mod utils;
40mod vector_fp8;
41
42use gptq::gptq_linear;
43use lora::merge_lora_weights;
44use regex::Regex;
45pub use safetensors::{Shard, ShardedSafeTensors, ShardedVarBuilder};
46
47pub use afq::{AfqBits, AfqGroupSize, AfqLayer};
48pub use bitsandbytes::{BnbLinear, BnbQuantParams, BnbQuantType};
49pub use blockwise_fp8::{
50    blockwise_fp8_moe, fp8_blockwise_dequantize, fp8_blockwise_quantize, BlockwiseFP8Linear,
51};
52pub use distributed::{
53    layers::{
54        compute_kv_shard, compute_n_kv_groups, ColumnParallelLayer, FusedExperts, PackedExperts,
55        ReplicatedLayer, RowParallelLayer,
56    },
57    socket::{Client, Server},
58    BarrierLike, Comm, Id, RingConfig, SumAllReduce,
59};
60pub use dummy::DummyLayer;
61pub use f8q8::F8Q8Linear;
62pub use fp8::FP8Linear;
63#[cfg(feature = "cuda")]
64pub use gemv::gemv;
65pub use gemv::{should_use_gemv, GEMV_CONTROLLER};
66pub use gguf::GgufMatMul;
67pub use gptq::GptqLayer;
68pub use hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer};
69pub use imatrix::{CollectedImatrixData, ImatrixLayerStats};
70pub use lora::{
71    clear_applied_loras, get_applied_loras, linear_no_bias_static_lora, push_applied_lora,
72    LoraAdapter, LoraConfig, StaticLoraConfig, MULTI_LORA_DELIMITER,
73};
74pub use mxfp4::MXFP4Layer;
75pub use pending_layer::PendingIsqLayer;
76pub use pertensor_fp8::PerTensorFP8Linear;
77pub use unquantized::UnquantLinear;
78pub use utils::flash_attn_sinks_metal;
79pub use utils::flash_attn_sinks_varlen_metal;
80#[cfg(feature = "cuda")]
81pub use utils::gptoss_swiglu_fused;
82#[cfg(feature = "cuda")]
83pub use utils::gptoss_swiglu_interleaved;
84pub use utils::isq::apply_immediate_isq;
85pub use utils::softmax_with_sinks;
86pub use utils::{fused_glu, GluActivationType};
87pub use utils::{log, BitWiseOp, CumSumOp, LeftshiftOp, NonZeroOp, SortOp, UQFF_QUANT_TYPE_OFFSET};
88pub use vector_fp8::{fp8_vector_dequantize, fp8_vector_quantize};
89
90use candle_nn::{Conv1d, Conv2d, Linear, Module};
91use serde::{Deserialize, Deserializer, Serialize};
92
93/// Limits outstanding async ISQ jobs to prevent unbounded memory growth.
94///
95/// Without backpressure, MoE models (e.g. Gemma4 with 128 experts × 30 layers)
96/// queue BF16 tensor data in the rayon pool faster than the pool can quantize,
97/// causing OOM on memory-constrained systems like macOS Metal with unified memory.
98pub struct IsqBackpressure {
99    count: Mutex<usize>,
100    cvar: Condvar,
101    max: usize,
102}
103
104impl IsqBackpressure {
105    pub fn new(max: usize) -> Self {
106        Self {
107            count: Mutex::new(0),
108            cvar: Condvar::new(),
109            max,
110        }
111    }
112
113    /// Block until a slot is available, then increment the outstanding count.
114    pub fn acquire(&self) {
115        let mut count = self.count.lock().expect("ISQ backpressure lock poisoned");
116        while *count >= self.max {
117            count = self
118                .cvar
119                .wait(count)
120                .expect("ISQ backpressure lock poisoned");
121        }
122        *count += 1;
123    }
124
125    /// Decrement the outstanding count and wake a blocked loader thread.
126    pub fn release(&self) {
127        let mut count = self.count.lock().expect("ISQ backpressure lock poisoned");
128        *count = count.saturating_sub(1);
129        self.cvar.notify_one();
130    }
131}
132
133impl Debug for IsqBackpressure {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        let count = self.count.lock().map(|c| *c).unwrap_or(0);
136        f.debug_struct("IsqBackpressure")
137            .field("outstanding", &count)
138            .field("max", &self.max)
139            .finish()
140    }
141}
142
143#[derive(Clone, Debug)]
144pub struct ImmediateIsqParams {
145    pub guard: QuantizeOntoGuard,
146    pub ty: Option<IsqType>,
147    pub predicates: Vec<Regex>,
148    pub overrides: Vec<ImmediateIsqOverride>,
149    /// Thread pool for parallel immediate ISQ on discrete GPUs.
150    /// When `Some`, `apply_immediate_isq` will spawn quantization tasks
151    /// on this pool and return `PendingIsqLayer` wrappers.
152    pub pool: Option<Arc<rayon::ThreadPool>>,
153    /// Backpressure to limit outstanding async ISQ jobs.
154    pub backpressure: Arc<IsqBackpressure>,
155}
156
157#[derive(Clone, Debug)]
158pub struct ImmediateIsqOverride {
159    pub predicate: Regex,
160    pub ty: Option<IsqType>,
161    pub device: Option<Device>,
162}
163
164#[derive(Clone, Debug)]
165pub struct ImmediateIsqMatch {
166    pub ty: IsqType,
167    pub device: Option<Device>,
168}
169
170thread_local! {
171    static ENGINE_IMMEDIATE_ISQ: std::cell::RefCell<Option<ImmediateIsqParams>> = const { std::cell::RefCell::new(None) } ;
172}
173
174pub fn set_immediate_isq(isq: Option<IsqType>, predicates: Vec<Regex>) {
175    let (pool, _) = create_isq_thread_pool(isq);
176    set_immediate_isq_with_pool(isq, predicates, Vec::new(), pool);
177}
178
179pub fn set_immediate_isq_with_pool(
180    isq: Option<IsqType>,
181    predicates: Vec<Regex>,
182    overrides: Vec<ImmediateIsqOverride>,
183    pool: rayon::ThreadPool,
184) {
185    // Allow pool threads + 1 outstanding jobs: enough for pipeline overlap
186    // (load next tensor while pool quantizes current) without unbounded growth.
187    let max_outstanding = pool.current_num_threads() + 1;
188    ENGINE_IMMEDIATE_ISQ.with(|cell| {
189        *cell.borrow_mut() = Some(ImmediateIsqParams {
190            guard: QuantizeOntoGuard::new(),
191            ty: isq,
192            predicates,
193            overrides,
194            backpressure: Arc::new(IsqBackpressure::new(max_outstanding)),
195            pool: Some(Arc::new(pool)),
196        });
197    });
198}
199
200/// Create a rayon thread pool for parallel immediate ISQ.
201/// Returns `(pool, num_threads)` so callers can log the thread count.
202///
203/// Thread count is based on the quantization type:
204/// - GGML types (Q2K-Q8K) and F8E4M3: `rayon::current_num_threads()` (CPU quantization)
205/// - HQQ/AFQ: 1 thread (GPU quantization, serialized by `QuantizeOntoGuard`)
206pub fn create_isq_thread_pool(ty: Option<IsqType>) -> (rayon::ThreadPool, usize) {
207    let num_threads = if std::env::var("MISTRALRS_ISQ_SINGLETHREAD").is_ok() {
208        1
209    } else if let Some(ty) = ty {
210        ty.get_max_isq_cpu_threads()
211            .map(usize::from)
212            .unwrap_or_else(rayon::current_num_threads)
213    } else {
214        rayon::current_num_threads()
215    };
216
217    let pool = rayon::ThreadPoolBuilder::new()
218        .num_threads(num_threads)
219        .build()
220        .expect("Failed to create ISQ thread pool");
221    (pool, num_threads)
222}
223
224pub fn get_immediate_isq() -> Option<ImmediateIsqParams> {
225    ENGINE_IMMEDIATE_ISQ.with(|cell| cell.borrow().clone())
226}
227
228pub fn clear_immediate_isq() {
229    ENGINE_IMMEDIATE_ISQ.with(|cell| {
230        *cell.borrow_mut() = None;
231    });
232}
233
234pub fn should_apply_immediate_isq(vb: &ShardedVarBuilder) -> bool {
235    immediate_isq_match(vb).is_some()
236}
237
238pub fn immediate_isq_match(vb: &ShardedVarBuilder) -> Option<ImmediateIsqMatch> {
239    let immediate_isq = get_immediate_isq()?;
240    // Add a .weight to match the ISQ regexes!
241    let prefix = format!("{}.weight", vb.prefix());
242    resolve_immediate_isq(&immediate_isq, &prefix)
243}
244
245fn resolve_immediate_isq(params: &ImmediateIsqParams, prefix: &str) -> Option<ImmediateIsqMatch> {
246    if let Some(override_hit) = params
247        .overrides
248        .iter()
249        .find(|override_pred| override_pred.predicate.is_match(prefix))
250    {
251        if let Some(ty) = override_hit.ty.or(params.ty) {
252            return Some(ImmediateIsqMatch {
253                ty,
254                device: override_hit.device.clone(),
255            });
256        }
257        return None;
258    }
259
260    if let Some(ty) = params.ty {
261        if params
262            .predicates
263            .iter()
264            .any(|predicate| predicate.is_match(prefix))
265        {
266            return Some(ImmediateIsqMatch { ty, device: None });
267        }
268    }
269
270    None
271}
272
273#[derive(Debug, Clone, Serialize)]
274#[serde(tag = "quant_method", rename_all = "lowercase")]
275pub enum QuantizedConfig {
276    GptqAwq {
277        bits: usize,
278        group_size: usize,
279        checkpoint_format: Option<String>,
280        is_awq: bool,
281    },
282    Fp8 {
283        weight_block_size: Option<Vec<usize>>,
284    },
285    Bitsandbytes {
286        bnb_4bit_quant_type: Option<String>,
287    },
288    Afq {
289        bits: usize,
290        group_size: usize,
291    },
292    MXFP4 {},
293}
294
295// Common fields for all variants
296#[derive(Deserialize)]
297struct RawConfig {
298    quant_method: Option<String>,
299    bits: Option<usize>,
300    group_size: Option<usize>,
301    checkpoint_format: Option<String>,
302    weight_block_size: Option<Vec<usize>>,
303    bnb_4bit_quant_type: Option<String>,
304}
305
306// Custom deserializer implementation
307impl<'de> Deserialize<'de> for QuantizedConfig {
308    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
309    where
310        D: Deserializer<'de>,
311    {
312        let raw = RawConfig::deserialize(deserializer)?;
313
314        match &raw.quant_method {
315            Some(m) if m == "gptq" || m == "awq" => {
316                let bits = raw
317                    .bits
318                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
319                let group_size = raw
320                    .group_size
321                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
322                Ok(QuantizedConfig::GptqAwq {
323                    bits,
324                    group_size,
325                    checkpoint_format: raw.checkpoint_format,
326                    is_awq: m == "awq",
327                })
328            }
329            Some(m) if m == "fp8" => {
330                // weight_block_size is optional - None means per-tensor quantization
331                Ok(QuantizedConfig::Fp8 {
332                    weight_block_size: raw.weight_block_size,
333                })
334            }
335            Some(m) if m == "bitsandbytes" => Ok(QuantizedConfig::Bitsandbytes {
336                bnb_4bit_quant_type: raw.bnb_4bit_quant_type,
337            }),
338            Some(m) if m == "afq" => {
339                let bits = raw
340                    .bits
341                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
342                let group_size = raw
343                    .group_size
344                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
345                Ok(QuantizedConfig::Afq { bits, group_size })
346            }
347            Some(m) if m == "mxfp4" => {
348                Ok(QuantizedConfig::MXFP4 {  })
349            }
350            None => {
351                let bits = raw
352                    .bits
353                    .ok_or_else(|| serde::de::Error::missing_field("bits"))?;
354                let group_size = raw
355                    .group_size
356                    .ok_or_else(|| serde::de::Error::missing_field("group_size"))?;
357                Ok(QuantizedConfig::Afq { bits, group_size })
358            }
359            Some(unknown_method) => {
360                Err(serde::de::Error::custom(format!(
361                    "Unknown quantization method: {unknown_method}. Expected one of: gptq, fp8, bitsandbytes, afq, or not specified"
362                )))
363            },
364        }
365    }
366}
367
368impl QuantizedConfig {
369    pub fn name(&self) -> &'static str {
370        match self {
371            Self::GptqAwq { .. } => "gptq",
372            Self::Fp8 { .. } => "fp8",
373            Self::Bitsandbytes { .. } => "bitsandbytes",
374            Self::Afq { .. } => "afq",
375            Self::MXFP4 { .. } => "mxfp4",
376        }
377    }
378
379    pub fn get_bits_name(&self, _vb: &ShardedVarBuilder) -> String {
380        match self {
381            Self::GptqAwq { bits, .. } => format!("{bits} bits"),
382            Self::Fp8 { .. } => "8 bits".to_string(),
383            Self::Bitsandbytes {
384                bnb_4bit_quant_type: Some(_),
385            } => "4 bits".to_string(),
386            Self::Bitsandbytes {
387                bnb_4bit_quant_type: None,
388            } => "8 bits".to_string(),
389            Self::Afq { bits, .. } => format!("{bits} bits"),
390            Self::MXFP4 {} => format!("{} bits", mxfp4::N_BITS),
391        }
392    }
393
394    pub fn pack_factor(&self, dtype: DType) -> usize {
395        match self {
396            Self::GptqAwq { bits, .. } | Self::Afq { bits, .. } => match bits {
397                2 => IsqType::Q2K.pack_factor(dtype),
398                3 => IsqType::Q3K.pack_factor(dtype),
399                4 => IsqType::Q4K.pack_factor(dtype),
400                5 => IsqType::Q5K.pack_factor(dtype),
401                6 => IsqType::Q6K.pack_factor(dtype),
402                8 => IsqType::Q8_0.pack_factor(dtype),
403                40 => 4, // mxfp4: 2 FP4 values per byte = factor of 4
404                other => panic!("Unexpected bits in `pack_factor` {other}"),
405            },
406            Self::Fp8 { .. } => IsqType::Q8_0.pack_factor(dtype),
407            Self::Bitsandbytes {
408                bnb_4bit_quant_type: Some(_),
409            }
410            | Self::Bitsandbytes {
411                bnb_4bit_quant_type: None,
412            } => IsqType::Q4K.pack_factor(dtype),
413            Self::MXFP4 {} => IsqType::Q4_0.pack_factor(dtype),
414        }
415    }
416}
417
418#[derive(Debug, Clone)]
419pub enum QuantMethodConfig {
420    GptqAwq {
421        bits: i32,
422        use_exllama: bool,
423        q_weight: Tensor,
424        qzeros: Option<Tensor>,
425        scales: Tensor,
426        g_idx: Option<Tensor>,
427        bias: Option<Tensor>,
428        workspace: Option<Tensor>,
429        is_marlin: bool,
430        is_awq: bool,
431    },
432    Gguf {
433        q_weight: Arc<QTensor>,
434        b: Option<Tensor>,
435    },
436    Unquantized(Linear),
437    Hqq {
438        tensor: Tensor,
439        bits: HqqBits,
440        group_size: NonZeroUsize,
441        axis: HqqAxis,
442        optimization_steps: Option<usize>,
443        round_zeros: Option<bool>,
444        channel_wise: Option<bool>,
445        bias: Option<Tensor>,
446    },
447    Dummy,
448    FP8 {
449        lin: Linear,
450        dtype: DType,
451    },
452    Bnb {
453        weight: Tensor,
454        bias: Option<Tensor>,
455        params: BnbQuantParams,
456        quant_ty: BnbQuantType,
457    },
458    BlockwiseFP8 {
459        weight: Tensor,
460        weight_scale_inv: Tensor,
461        bias: Option<Tensor>,
462        dequant_dtype: DType,
463        weight_block_size: Vec<usize>,
464    },
465    PerTensorFP8 {
466        weight: Tensor,
467        weight_scale_inv: Tensor,
468        activation_scale: Option<Tensor>,
469        bias: Option<Tensor>,
470        dequant_dtype: DType,
471    },
472    Afq {
473        weight: Tensor,
474        bias: Option<Tensor>,
475        bits: AfqBits,
476        group_size: AfqGroupSize,
477    },
478    MXFP4 {
479        blocks: Tensor,
480        scales: Tensor,
481        bias: Option<Tensor>,
482    },
483}
484
485/// Device/configurable intelligent matrix multiplication
486/// - Handles limitation of `accelerate` which requires f32
487pub struct MatMul;
488
489impl MatMul {
490    /// Compute matrix-matrix product.
491    pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
492        #[cfg(feature = "accelerate")]
493        {
494            let original_dtype = a.dtype();
495            a.to_dtype(DType::F32)?
496                .matmul(&b.to_dtype(DType::F32)?)?
497                .to_dtype(original_dtype)
498        }
499        #[cfg(not(feature = "accelerate"))]
500        {
501            if a.device().is_cpu() {
502                let original_dtype = a.dtype();
503                a.to_dtype(DType::F16)?
504                    .matmul(&b.to_dtype(DType::F16)?)?
505                    .to_dtype(original_dtype)
506            } else {
507                a.matmul(b)
508            }
509        }
510    }
511
512    /// Compute matrix-matrix product.
513    /// The result will be divided by the `scale` parameter in an affine division.
514    pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
515        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
516        self.matmul(a, b)? / scale
517    }
518
519    /// Compute matrix-matrix product.
520    /// The result will be divided by the `scale` parameter in an affine multiplication.
521    pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
522        // TODO(EricLBuehler): Optimize this by using the gemm parameter?
523        self.matmul(a, b)? * scale
524    }
525
526    /// Compute quantized matrix-matrix product.
527    pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
528        matmul.forward(x)
529    }
530
531    /// Compute quantized matrix-matrix product.
532    pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
533        matmul.forward(x)
534    }
535}
536
537/// Device/configurable intelligent convolution
538/// - Handles limitation of cpu which requires f32
539pub struct Convolution;
540
541impl Convolution {
542    pub fn forward_1d(&self, layer: &Conv1d, x: &Tensor) -> Result<Tensor> {
543        if x.device().is_cpu() {
544            let original_dtype = x.dtype();
545            Conv1d::new(
546                layer.weight().to_dtype(DType::F32)?,
547                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
548                *layer.config(),
549            )
550            .forward(&x.to_dtype(DType::F32)?)?
551            .to_dtype(original_dtype)
552        } else {
553            layer.forward(x)
554        }
555    }
556
557    pub fn forward_2d(&self, layer: &Conv2d, x: &Tensor) -> Result<Tensor> {
558        if x.device().is_cpu() {
559            let original_dtype = x.dtype();
560            Conv2d::new(
561                layer.weight().to_dtype(DType::F32)?,
562                layer.bias().map(|b| b.to_dtype(DType::F32)).transpose()?,
563                *layer.config(),
564            )
565            .forward(&x.to_dtype(DType::F32)?)?
566            .to_dtype(original_dtype)
567        } else {
568            layer.forward(x)
569        }
570    }
571}
572
573/// In-situ quantization type specifying the format to apply to model weights.
574#[derive(Clone, Copy, Debug, PartialEq, Hash, Eq, Serialize, Deserialize)]
575pub enum IsqType {
576    Q4_0,
577    Q4_1,
578    Q5_0,
579    Q5_1,
580    Q8_0,
581    Q8_1,
582    Q2K,
583    Q3K,
584    Q4K,
585    Q5K,
586    Q6K,
587    Q8K,
588    HQQ8,
589    HQQ4,
590    // HQQ3,
591    // HQQ2,
592    // HQQ1,
593    F8E4M3,
594    AFQ8,
595    AFQ6,
596    AFQ4,
597    AFQ3,
598    AFQ2,
599    F8Q8,
600    MXFP4,
601}
602
603/// Target bit width for automatic ISQ quantization.
604///
605/// On Metal, these select AFQ variants; on CUDA/CPU, they select Q*K variants.
606#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
607pub enum IsqBits {
608    /// 2-bit quantization (AFQ2 on Metal, Q2K otherwise).
609    Two,
610    /// 3-bit quantization (AFQ3 on Metal, Q3K otherwise).
611    Three,
612    /// 4-bit quantization (AFQ4 on Metal, Q4K otherwise).
613    Four,
614    /// 5-bit quantization (Q5K on all platforms).
615    Five,
616    /// 6-bit quantization (AFQ6 on Metal, Q6K otherwise).
617    Six,
618    /// 8-bit quantization (AFQ8 on Metal, Q8_0 otherwise).
619    Eight,
620}
621
622impl IsqBits {
623    /// Resolve to the platform-appropriate `IsqType` for the given device.
624    pub fn resolve(self, device: &Device) -> IsqType {
625        match (self, device.is_metal()) {
626            (Self::Two, true) => IsqType::AFQ2,
627            (Self::Two, false) => IsqType::Q2K,
628            (Self::Three, true) => IsqType::AFQ3,
629            (Self::Three, false) => IsqType::Q3K,
630            (Self::Four, true) => IsqType::AFQ4,
631            (Self::Four, false) => IsqType::Q4K,
632            (Self::Five, _) => IsqType::Q5K,
633            (Self::Six, true) => IsqType::AFQ6,
634            (Self::Six, false) => IsqType::Q6K,
635            (Self::Eight, true) => IsqType::AFQ8,
636            (Self::Eight, false) => IsqType::Q8_0,
637        }
638    }
639
640    /// Return all platform variants, with the current platform's preferred variant first.
641    /// On Metal, AFQ variants come first; on other platforms, GGUF/Q variants come first.
642    pub fn expand(self) -> Vec<IsqType> {
643        #[cfg(feature = "metal")]
644        match self {
645            Self::Two => vec![IsqType::AFQ2, IsqType::Q2K],
646            Self::Three => vec![IsqType::AFQ3, IsqType::Q3K],
647            Self::Four => vec![IsqType::AFQ4, IsqType::Q4K],
648            Self::Five => vec![IsqType::Q5K],
649            Self::Six => vec![IsqType::AFQ6, IsqType::Q6K],
650            Self::Eight => vec![IsqType::AFQ8, IsqType::Q8_0],
651        }
652        #[cfg(not(feature = "metal"))]
653        match self {
654            Self::Two => vec![IsqType::Q2K, IsqType::AFQ2],
655            Self::Three => vec![IsqType::Q3K, IsqType::AFQ3],
656            Self::Four => vec![IsqType::Q4K, IsqType::AFQ4],
657            Self::Five => vec![IsqType::Q5K],
658            Self::Six => vec![IsqType::Q6K, IsqType::AFQ6],
659            Self::Eight => vec![IsqType::Q8_0, IsqType::AFQ8],
660        }
661    }
662}
663
664impl TryFrom<&str> for IsqBits {
665    type Error = ();
666    fn try_from(s: &str) -> std::result::Result<Self, ()> {
667        match s {
668            "2" => Ok(Self::Two),
669            "3" => Ok(Self::Three),
670            "4" => Ok(Self::Four),
671            "5" => Ok(Self::Five),
672            "6" => Ok(Self::Six),
673            "8" => Ok(Self::Eight),
674            _ => Err(()),
675        }
676    }
677}
678
679impl std::fmt::Display for IsqType {
680    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
681        match self {
682            Self::Q4_0 => write!(f, "q4_0"),
683            Self::Q4_1 => write!(f, "q4_1"),
684            Self::Q5_0 => write!(f, "q5_0"),
685            Self::Q5_1 => write!(f, "q5_1"),
686            Self::Q8_0 => write!(f, "q8_0"),
687            Self::Q8_1 => write!(f, "q8_1"),
688            Self::Q2K => write!(f, "q2k"),
689            Self::Q3K => write!(f, "q3k"),
690            Self::Q4K => write!(f, "q4k"),
691            Self::Q5K => write!(f, "q5k"),
692            Self::Q6K => write!(f, "q6k"),
693            Self::Q8K => write!(f, "q8k"),
694            Self::HQQ8 => write!(f, "hqq8"),
695            Self::HQQ4 => write!(f, "hqq4"),
696            Self::F8E4M3 => write!(f, "fp8"),
697            Self::AFQ8 => write!(f, "afq8"),
698            Self::AFQ6 => write!(f, "afq6"),
699            Self::AFQ4 => write!(f, "afq4"),
700            Self::AFQ3 => write!(f, "afq3"),
701            Self::AFQ2 => write!(f, "afq2"),
702            Self::F8Q8 => write!(f, "f8q8"),
703            Self::MXFP4 => write!(f, "mxfp4"),
704        }
705    }
706}
707
708impl IsqType {
709    /// Factor by which the weight size is reduced over the given dtype.
710    /// original size / pack factor = quantized size
711    pub fn pack_factor(&self, dtype: DType) -> usize {
712        match self {
713            Self::Q4_0 | Self::AFQ4 => (dtype.size_in_bytes() * GgmlDType::Q4_0.block_size())
714                .div_ceil(GgmlDType::Q4_0.type_size()),
715            Self::Q4_1 => (dtype.size_in_bytes() * GgmlDType::Q4_1.block_size())
716                .div_ceil(GgmlDType::Q4_1.type_size()),
717            Self::Q5_0 => (dtype.size_in_bytes() * GgmlDType::Q5_0.block_size())
718                .div_ceil(GgmlDType::Q5_0.type_size()),
719            Self::Q5_1 => (dtype.size_in_bytes() * GgmlDType::Q5_1.block_size())
720                .div_ceil(GgmlDType::Q5_1.type_size()),
721            Self::Q8_0 | Self::AFQ8 => (dtype.size_in_bytes() * GgmlDType::Q8_0.block_size())
722                .div_ceil(GgmlDType::Q8_0.type_size()),
723            Self::Q8_1 => (dtype.size_in_bytes() * GgmlDType::Q8_1.block_size())
724                .div_ceil(GgmlDType::Q8_1.type_size()),
725            Self::Q2K | Self::AFQ2 => (dtype.size_in_bytes() * GgmlDType::Q2K.block_size())
726                .div_ceil(GgmlDType::Q2K.type_size()),
727            Self::Q3K | Self::AFQ3 => (dtype.size_in_bytes() * GgmlDType::Q3K.block_size())
728                .div_ceil(GgmlDType::Q3K.type_size()),
729            Self::Q4K => (dtype.size_in_bytes() * GgmlDType::Q4K.block_size())
730                .div_ceil(GgmlDType::Q4K.type_size()),
731            Self::Q5K => (dtype.size_in_bytes() * GgmlDType::Q5K.block_size())
732                .div_ceil(GgmlDType::Q5K.type_size()),
733            Self::Q6K | Self::AFQ6 => (dtype.size_in_bytes() * GgmlDType::Q6K.block_size())
734                .div_ceil(GgmlDType::Q6K.type_size()),
735            Self::Q8K => (dtype.size_in_bytes() * GgmlDType::Q8K.block_size())
736                .div_ceil(GgmlDType::Q8K.type_size()),
737            // F8Q8: 33 bytes per 32 values -> similar to Q8_0
738            Self::F8Q8 => (dtype.size_in_bytes() * 32).div_ceil(33),
739            // Estimates
740            Self::HQQ4 => 4,
741            Self::HQQ8 => 2,
742            Self::F8E4M3 => 2,
743            // MXFP4: 4 bits per value + 1 byte scale per 32 values
744            // For BF16 (2 bytes): (2*32)/(16+1) ≈ 3.76 → 3
745            Self::MXFP4 => 3,
746        }
747    }
748
749    pub fn get_max_isq_cpu_threads(&self) -> Option<NonZeroUsize> {
750        match self {
751            /*IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
752            IsqType::HQQ4
753            | IsqType::HQQ8
754            | IsqType::AFQ2
755            | IsqType::AFQ3
756            | IsqType::AFQ4
757            | IsqType::AFQ6
758            | IsqType::AFQ8
759            | IsqType::MXFP4 => {
760                // Use 1 because our HQQ quantizes on the GPU
761                Some(1.try_into().unwrap())
762            }
763            IsqType::F8E4M3 | IsqType::F8Q8 => None,
764            IsqType::Q2K
765            | IsqType::Q3K
766            | IsqType::Q4K
767            | IsqType::Q4_0
768            | IsqType::Q4_1
769            | IsqType::Q5K
770            | IsqType::Q5_0
771            | IsqType::Q5_1
772            | IsqType::Q6K
773            | IsqType::Q8K
774            | IsqType::Q8_0
775            | IsqType::Q8_1 => None,
776        }
777    }
778}
779
780impl TryFrom<IsqType> for GgmlDType {
781    type Error = candle_core::Error;
782
783    fn try_from(value: IsqType) -> Result<Self> {
784        let tp = match value {
785            IsqType::Q2K => Self::Q2K,
786            IsqType::Q3K => Self::Q3K,
787            IsqType::Q4K => Self::Q4K,
788            IsqType::Q4_0 => Self::Q4_0,
789            IsqType::Q4_1 => Self::Q4_1,
790            IsqType::Q5K => Self::Q5K,
791            IsqType::Q5_0 => Self::Q5_0,
792            IsqType::Q5_1 => Self::Q5_1,
793            IsqType::Q6K => Self::Q6K,
794            IsqType::Q8K => Self::Q8K,
795            IsqType::Q8_0 => Self::Q8_0,
796            IsqType::Q8_1 => Self::Q8_1,
797            _ => candle_core::bail!("Expected valid GGML ISQ type."),
798        };
799        #[cfg(feature = "cuda")]
800        {
801            if !matches!(
802                tp,
803                GgmlDType::Q4_0
804                    | GgmlDType::Q4_1
805                    | GgmlDType::Q5_0
806                    | GgmlDType::Q5_1
807                    | GgmlDType::Q8_0
808                    | GgmlDType::Q2K
809                    | GgmlDType::Q3K
810                    | GgmlDType::Q4K
811                    | GgmlDType::Q5K
812                    | GgmlDType::Q6K
813            ) {
814                candle_core::bail!("GGML ISQ type on CUDA must be one of `Q4_0`, `Q4_1`, `Q5_0`, `Q5_1`, `Q8_0`, `Q2K`, `Q3K`, `Q4K`, `Q5K`, `Q6K`, `HQQ8`, `HQQ4`")
815            }
816        }
817        Ok(tp)
818    }
819}
820
821impl TryFrom<GgmlDType> for IsqType {
822    type Error = candle_core::Error;
823
824    fn try_from(value: GgmlDType) -> Result<Self> {
825        match value {
826            GgmlDType::Q2K => Ok(Self::Q2K),
827            GgmlDType::Q3K => Ok(Self::Q3K),
828            GgmlDType::Q4K => Ok(Self::Q4K),
829            GgmlDType::Q5K => Ok(Self::Q5K),
830            GgmlDType::Q6K => Ok(Self::Q6K),
831            GgmlDType::Q4_0 => Ok(Self::Q4_0),
832            GgmlDType::Q4_1 => Ok(Self::Q4_1),
833            GgmlDType::Q5_0 => Ok(Self::Q5_0),
834            GgmlDType::Q5_1 => Ok(Self::Q5_1),
835            GgmlDType::Q8_0 => Ok(Self::Q8_0),
836            GgmlDType::Q8_1 => Ok(Self::Q8_1),
837            GgmlDType::Q8K => Ok(Self::Q8K),
838            GgmlDType::BF16 | GgmlDType::F32 | GgmlDType::F16 => {
839                candle_core::bail!("Expected valid GGML ISQ type.")
840            }
841        }
842    }
843}
844
845#[derive(Debug, Clone, Copy)]
846pub enum QuantizedSerdeType {
847    Gguf = 0,
848    Unquant = 1,
849    Hqq = 2,
850    Fp8 = 3,
851    Afq = 4,
852    F8Q8 = 5,
853    Mxfp4 = 6,
854}
855
856impl TryFrom<usize> for QuantizedSerdeType {
857    type Error = candle_core::Error;
858    fn try_from(value: usize) -> std::result::Result<Self, Self::Error> {
859        match value {
860            0 => Ok(Self::Gguf),
861            1 => Ok(Self::Unquant),
862            2 => Ok(Self::Hqq),
863            3 => Ok(Self::Fp8),
864            4 => Ok(Self::Afq),
865            5 => Ok(Self::F8Q8),
866            6 => Ok(Self::Mxfp4),
867            other => candle_core::bail!("QuantizedSerdeType {other} is invalid."),
868        }
869    }
870}
871
872pub trait QuantizedSerde {
873    fn name(&self) -> &'static str;
874    fn isq_serde_supported(&self) -> bool {
875        false
876    }
877    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
878        candle_core::bail!("`QuantizedSerde::serialize` is not supported.")
879    }
880    fn deserialize(
881        _data: Cow<[u8]>,
882        _device: &Device,
883        _comm: &Arc<crate::Comm>,
884        _guard: QuantizeOntoGuard,
885    ) -> Result<Arc<dyn QuantMethod>>
886    where
887        Self: Sized,
888    {
889        candle_core::bail!("`QuantizedSerde::deserialize` is not supported.")
890    }
891    fn deserialize_ext_bias(
892        _data: Cow<[u8]>,
893        _device: &Device,
894        _guard: QuantizeOntoGuard,
895    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
896    where
897        Self: Sized,
898    {
899        candle_core::bail!("`QuantizedSerde::deserialize_ext_bias` is not supported.")
900    }
901    /// NOT meant for external calling
902    fn serialize_with_bias(&self, _bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
903        candle_core::bail!("`QuantizedSerde::serialize_with_bias` is not supported.")
904    }
905}
906
907/// Used to gate access to quantizing onto the host device
908#[derive(Clone, Debug)]
909#[allow(unused)]
910pub struct QuantizeOntoGuard {
911    pub inner: Arc<Mutex<()>>,
912}
913
914/// Real (for Metal) and Fake (for CUDA)
915pub enum QuantizeOntoDropGuard<'a> {
916    Real(MutexGuard<'a, ()>),
917    Fake,
918}
919
920impl Default for QuantizeOntoGuard {
921    fn default() -> Self {
922        Self::new()
923    }
924}
925
926impl QuantizeOntoGuard {
927    pub fn new() -> Self {
928        QuantizeOntoGuard {
929            inner: Arc::new(Mutex::new(())),
930        }
931    }
932
933    /// Acquire the quantize drop guard to protect the critical section.
934    ///
935    /// On metal, this waits for outstanding work to finish to avoid "A command encoder is already encoding to this command buffer"
936    pub fn acquire(&self, device: &Device) -> QuantizeOntoDropGuard<'_> {
937        #[cfg(feature = "cuda")]
938        {
939            let _ = device;
940            QuantizeOntoDropGuard::Fake
941        }
942
943        #[cfg(not(feature = "cuda"))]
944        {
945            #[cfg(feature = "metal")]
946            if let Device::Metal(dev) = device {
947                // This is necessary to avoid the errors of "A command encoder is already encoding to this command buffer"
948                dev.wait_until_completed()
949                    .expect("Failed to flush command buffer.");
950            }
951            #[cfg(not(feature = "metal"))]
952            let _ = device;
953
954            QuantizeOntoDropGuard::Real(self.inner.lock().expect("QuantizeOntoGuard was poisoned!"))
955        }
956    }
957}
958
959pub enum DistributedKind {
960    ColumnParallel,
961    RowParallel,
962    Replicated,
963}
964
965/// Quantized method for a quantized matmul.
966pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
967    fn new(method: QuantMethodConfig) -> Result<Self>
968    where
969        Self: Sized;
970
971    fn dequantize_w(&self) -> Result<Tensor>;
972
973    /// Compute matmul of `self` and `a`. `self` should contain the weights.
974    /// Automatically cast to required quantization activation type and back
975    fn forward_autocast(&self, a: &Tensor) -> Result<Tensor> {
976        let original_ty = a.dtype();
977        let a = if let Some(t) = self.quantized_act_type() {
978            a.to_dtype(t)?
979        } else {
980            a.clone()
981        };
982        self.forward(&a)?.to_dtype(original_ty)
983    }
984
985    /// Compute matmul of `self` and `a`. `self` should contain the weights.
986    fn forward(&self, a: &Tensor) -> Result<Tensor>;
987
988    /// Compute matmul of `self` and `a`. `self` should contain the weights.
989    /// Automatically cast to required quantization activation type and back.
990    ///
991    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
992    /// then the indices are (n_tokens, n_experts).
993    fn gather_forward_autocast(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
994        let original_ty = a.dtype();
995        let a = if let Some(t) = self.quantized_act_type() {
996            a.to_dtype(t)?
997        } else {
998            a.clone()
999        };
1000        self.gather_forward(&a, indices)?.to_dtype(original_ty)
1001    }
1002
1003    /// Compute matmul of `self` and `a`. `self` should contain the weights.
1004    ///
1005    /// If `a` is (n_tokens, n_experts, cols), `self` weights are (n_experts, rows, cols),
1006    /// then the indices are (n_tokens, n_experts).
1007    fn gather_forward(&self, _a: &Tensor, _indices: &Tensor) -> Result<Tensor> {
1008        candle_core::bail!(
1009            "{} does not support `gather_forward`. Please raise an issue.",
1010            self.name()
1011        )
1012    }
1013
1014    /// If a quantized method, return the activation dtype.
1015    fn quantized_act_type(&self) -> Option<DType>;
1016
1017    /// Weight dtype and device
1018    fn dtype_and_device(&self) -> (DType, Device);
1019
1020    /// Add a delta weight from LoRA to the weights. This should be prescaled with alpha.
1021    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>>;
1022
1023    /// If the quant is backed by a qmatmul.
1024    fn apply_isq(
1025        self: Arc<Self>,
1026        dtype: Option<IsqType>,
1027        device: Device,
1028        n_quantized: &AtomicUsize,
1029        imatrix_weight: Option<Vec<f32>>,
1030        guard: QuantizeOntoGuard,
1031    ) -> Result<Arc<dyn QuantMethod>>;
1032
1033    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
1034        None
1035    }
1036
1037    /// Begin tracking stats into an ImatrixLayerStats
1038    fn begin_track_stats(&mut self) -> Result<()> {
1039        candle_core::bail!("`{}` does not support tracking stats.", self.name())
1040    }
1041
1042    /// End tracking stats into an ImatrixLayerStats. Returns the computed imatrix.
1043    fn end_track_stats(&self) -> Result<Tensor> {
1044        candle_core::bail!("`{}` does not support tracking stats.", self.name())
1045    }
1046
1047    fn is_distributed(&self) -> Option<DistributedKind> {
1048        None
1049    }
1050}
1051
1052impl Module for dyn QuantMethod {
1053    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
1054        Self::forward(self, xs)
1055    }
1056}
1057
1058pub fn linear_no_bias(
1059    in_dim: usize,
1060    out_dim: usize,
1061    config: &Option<QuantizedConfig>,
1062    vb: ShardedVarBuilder,
1063) -> Result<Arc<dyn QuantMethod>> {
1064    let base_vb = vb.clone();
1065    let vb = if should_apply_immediate_isq(&vb) {
1066        vb.set_device(Device::Cpu)
1067    } else {
1068        vb
1069    };
1070
1071    let layer = if let Some(quant_conf) = &config {
1072        match quant_conf {
1073            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
1074            QuantizedConfig::Fp8 { weight_block_size } => {
1075                if weight_block_size.is_some() {
1076                    blockwise_fp8_linear_b(
1077                        in_dim,
1078                        out_dim,
1079                        quant_conf,
1080                        false,
1081                        Default::default(),
1082                        vb,
1083                    )?
1084                } else {
1085                    pertensor_fp8_linear_b(
1086                        in_dim,
1087                        out_dim,
1088                        quant_conf,
1089                        false,
1090                        Default::default(),
1091                        vb,
1092                    )?
1093                }
1094            }
1095            QuantizedConfig::Bitsandbytes { .. } => {
1096                Arc::new(BnbLinear::linear_b(in_dim, out_dim, false, vb)?) as Arc<_>
1097            }
1098            QuantizedConfig::Afq { .. } => {
1099                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, false, vb)?
1100            }
1101            QuantizedConfig::MXFP4 {} => {
1102                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, false, vb)?
1103            }
1104        }
1105    } else {
1106        // Handle the case where the layer is dummy (no tensors)
1107        if !vb.contains_tensor("weight") {
1108            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
1109            Arc::new(layer) as Arc<dyn QuantMethod>
1110        } else {
1111            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
1112            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
1113
1114            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
1115                Linear::new(weight, None),
1116            ))?;
1117            Arc::new(layer) as Arc<dyn QuantMethod>
1118        }
1119    };
1120    apply_immediate_isq(layer, base_vb)
1121}
1122
1123pub fn linear(
1124    in_dim: usize,
1125    out_dim: usize,
1126    config: &Option<QuantizedConfig>,
1127    vb: ShardedVarBuilder,
1128) -> Result<Arc<dyn QuantMethod>> {
1129    let base_vb = vb.clone();
1130    let vb = if should_apply_immediate_isq(&vb) {
1131        vb.set_device(Device::Cpu)
1132    } else {
1133        vb
1134    };
1135
1136    let layer = if let Some(quant_conf) = &config {
1137        match quant_conf {
1138            QuantizedConfig::GptqAwq { .. } => gptq_linear(in_dim, out_dim, quant_conf, vb)?,
1139            QuantizedConfig::Fp8 { weight_block_size } => {
1140                if weight_block_size.is_some() {
1141                    blockwise_fp8_linear_b(
1142                        in_dim,
1143                        out_dim,
1144                        quant_conf,
1145                        true,
1146                        Default::default(),
1147                        vb,
1148                    )?
1149                } else {
1150                    pertensor_fp8_linear_b(
1151                        in_dim,
1152                        out_dim,
1153                        quant_conf,
1154                        true,
1155                        Default::default(),
1156                        vb,
1157                    )?
1158                }
1159            }
1160            QuantizedConfig::Bitsandbytes { .. } => {
1161                Arc::new(BnbLinear::linear_b(in_dim, out_dim, true, vb)?) as Arc<_>
1162            }
1163            QuantizedConfig::Afq { .. } => {
1164                AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, true, vb)?
1165            }
1166            QuantizedConfig::MXFP4 {} => {
1167                MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, true, vb)?
1168            }
1169        }
1170    } else {
1171        // Handle the case where the layer is dummy (no tensors)
1172        if !(vb.contains_tensor("weight") && vb.contains_tensor("bias")) {
1173            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
1174            Arc::new(layer) as Arc<dyn QuantMethod>
1175        } else {
1176            let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
1177            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
1178            let bias = vb.get_with_hints((out_dim,), "bias", Default::default())?;
1179
1180            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
1181                Linear::new(weight, Some(bias)),
1182            ))?;
1183            Arc::new(layer) as Arc<dyn QuantMethod>
1184        }
1185    };
1186    apply_immediate_isq(layer, base_vb)
1187}
1188
1189pub fn linear_b(
1190    in_dim: usize,
1191    out_dim: usize,
1192    bias: bool,
1193    config: &Option<QuantizedConfig>,
1194    vb: ShardedVarBuilder,
1195) -> Result<Arc<dyn QuantMethod>> {
1196    if bias {
1197        linear(in_dim, out_dim, config, vb)
1198    } else {
1199        linear_no_bias(in_dim, out_dim, config, vb)
1200    }
1201}