Skip to main content

mistralrs_quant/unquantized/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12    cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13    generate_isq, generate_isq_imatrix,
14    hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16    AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType,
17    QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22    w: Tensor,
23    b: Option<Tensor>,
24    stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28    fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29    where
30        Self: Sized,
31    {
32        match method {
33            QuantMethodConfig::Gguf { .. }
34            | QuantMethodConfig::GptqAwq { .. }
35            | QuantMethodConfig::Hqq { .. }
36            | QuantMethodConfig::Dummy
37            | QuantMethodConfig::FP8 { .. }
38            | QuantMethodConfig::Bnb { .. }
39            | QuantMethodConfig::BlockwiseFP8 { .. }
40            | QuantMethodConfig::PerTensorFP8 { .. }
41            | QuantMethodConfig::Afq { .. }
42            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
43            QuantMethodConfig::Unquantized(l) => Ok(Self {
44                w: l.weight().clone(),
45                b: l.bias().cloned(),
46                stats: None,
47            }),
48        }
49    }
50
51    fn dequantize_w(&self) -> Result<Tensor> {
52        Ok(self.w.clone())
53    }
54
55    fn forward(&self, a: &Tensor) -> Result<Tensor> {
56        // Batch matrix multiplication
57        maybe_init_cublas_lt_wrapper(a.device().clone());
58
59        // Try custom GEMV for single-token decode (batch_size=1)
60        #[cfg(feature = "cuda")]
61        if crate::gemv::should_use_gemv(a, &self.w) {
62            return crate::gemv::gemv(a, &self.w, self.b.as_ref());
63        }
64
65        let w = match *a.dims() {
66            [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
67            [bsize, _, _] => self.w.broadcast_left(bsize)?,
68            _ => self.w.clone(),
69        };
70
71        if let Some(stats) = &self.stats {
72            stats.process(a)?;
73        }
74
75        if let Some(b) = self.b.as_ref() {
76            let mut tgt_shape = a.dims().to_vec();
77            tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
78            let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
79
80            match a.device().location() {
81                DeviceLocation::Cuda { .. } => {
82                    // Try to use cublaslt, otherwise fallback to gemm
83                    if let (Device::Cuda(_), Some(cublaslt)) =
84                        (a.device(), CUBLASLT_CONTROLLER.get_for_device(a.device()))
85                    {
86                        cublaslt
87                            .batch_matmul(
88                                a,
89                                &w,
90                                Some(&b.t()?.contiguous()?),
91                                None,
92                                Some(1.0),
93                                None,
94                                None,
95                            )?
96                            .t()
97                    } else {
98                        let matmul_result = a.matmul(&w.t()?)?;
99                        matmul_result.broadcast_add(&b)
100                    }
101                }
102                DeviceLocation::Metal { .. } => {
103                    let matmul_result = a.matmul(&w.t()?)?;
104                    matmul_result.broadcast_add(&b)
105                }
106                DeviceLocation::Cpu => {
107                    #[cfg(feature = "accelerate")]
108                    {
109                        let original_dtype = a.dtype();
110                        let a_f32 = a.to_dtype(DType::F32)?;
111                        let w_f32 = w.t()?.to_dtype(DType::F32)?;
112                        let b_f32 = b.to_dtype(DType::F32)?;
113                        let matmul_result = a_f32.matmul(&w_f32)?;
114                        matmul_result
115                            .broadcast_add(&b_f32)?
116                            .to_dtype(original_dtype)
117                    }
118                    #[cfg(not(feature = "accelerate"))]
119                    {
120                        let matmul_result = a.matmul(&w.t()?)?;
121                        matmul_result.broadcast_add(&b)
122                    }
123                }
124            }
125        } else {
126            match a.device().location() {
127                DeviceLocation::Cuda { .. } => {
128                    if let (Device::Cuda(_), Some(cublaslt)) =
129                        (a.device(), CUBLASLT_CONTROLLER.get_for_device(a.device()))
130                    {
131                        // cuBLAS batch_matmul requires 3D tensors, fall back to regular matmul for 2D.
132                        if a.rank() >= 3 && w.rank() >= 3 {
133                            cublaslt
134                                .batch_matmul(a, &w, None, None, None, None, None)?
135                                .t()
136                        } else {
137                            a.matmul(&w.t()?)
138                        }
139                    } else {
140                        a.matmul(&w.t()?)
141                    }
142                }
143                DeviceLocation::Metal { .. } => a.matmul(&w.t()?),
144                DeviceLocation::Cpu => {
145                    #[cfg(feature = "accelerate")]
146                    {
147                        let original_dtype = a.dtype();
148                        a.to_dtype(DType::F32)?
149                            .matmul(&w.t()?.to_dtype(DType::F32)?)?
150                            .to_dtype(original_dtype)
151                    }
152                    #[cfg(not(feature = "accelerate"))]
153                    {
154                        a.matmul(&w.t()?)
155                    }
156                }
157            }
158        }
159    }
160
161    fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
162        // Weights are [num_experts, out_features, in_features]
163        // For Metal path:
164        //   - a: (b_size, seq_len, 1, 1, hidden_dim) - 5D
165        //   - indices: (b_size, seq_len, num_experts_per_tok) - 3D
166        // For CUDA path:
167        //   - a: (num_tokens, 1, hidden_dim) - 3D
168        //   - indices: (num_tokens, num_experts_per_tok) - 2D
169
170        let w = &self.w;
171        let (_num_experts, out_features, _in_features) = w.dims3()?;
172
173        match a.dims() {
174            // Metal path: 5D input (b_size, seq_len, 1, 1, hidden_dim)
175            &[b_size, seq_len, 1, 1, hidden_dim] => {
176                let (_b, _s, num_experts_per_tok) = indices.dims3()?;
177                // Flatten indices to select experts
178                let flat_indices = indices.reshape((b_size * seq_len * num_experts_per_tok,))?;
179
180                // Select expert weights: [b*s*k, out_features, in_features]
181                let selected_w = w.index_select(&flat_indices, 0)?;
182
183                // Reshape input: [b*s, hidden_dim]
184                let a_flat = a.reshape((b_size * seq_len, hidden_dim))?;
185
186                // For each token, we need to compute with each selected expert
187                // Broadcast a to match: [b*s, 1, hidden_dim] -> [b*s, k, hidden_dim]
188                let a_expanded = a_flat
189                    .unsqueeze(1)?
190                    .broadcast_as((b_size * seq_len, num_experts_per_tok, hidden_dim))?
191                    .reshape((b_size * seq_len * num_experts_per_tok, hidden_dim))?;
192
193                // Matmul: [b*s*k, hidden_dim] @ [b*s*k, hidden_dim, out_features] -> [b*s*k, out_features]
194                let result = a_expanded
195                    .unsqueeze(1)?
196                    .matmul(&selected_w.transpose(1, 2)?)?
197                    .squeeze(1)?;
198
199                // Reshape back to [b, s, k, out_features]
200                result.reshape((b_size, seq_len, num_experts_per_tok, out_features))
201            }
202            // CUDA path: 3D input (num_tokens, 1, hidden_dim)
203            &[num_tokens, 1, hidden_dim] => {
204                let (_, num_experts_per_tok) = indices.dims2()?;
205
206                // Flatten indices
207                let flat_indices = indices.reshape((num_tokens * num_experts_per_tok,))?;
208
209                // Select expert weights: [n*k, out_features, in_features]
210                let selected_w = w.index_select(&flat_indices, 0)?;
211
212                // Broadcast input: [n, 1, hidden] -> [n, k, hidden] -> [n*k, hidden]
213                let a_expanded = a
214                    .broadcast_as((num_tokens, num_experts_per_tok, hidden_dim))?
215                    .reshape((num_tokens * num_experts_per_tok, hidden_dim))?;
216
217                // Matmul: [n*k, hidden] @ [n*k, hidden, out] -> [n*k, out]
218                let result = a_expanded
219                    .unsqueeze(1)?
220                    .matmul(&selected_w.transpose(1, 2)?)?
221                    .squeeze(1)?;
222
223                // Reshape to [n, k, out]
224                result.reshape((num_tokens, num_experts_per_tok, out_features))
225            }
226            dims => {
227                candle_core::bail!(
228                    "UnquantLinear::gather_forward: unsupported input shape {:?}",
229                    dims
230                );
231            }
232        }
233    }
234
235    fn quantized_act_type(&self) -> Option<DType> {
236        None
237    }
238
239    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
240        Ok(Arc::new(Self {
241            w: (&self.w + delta)?,
242            b: self.b.clone(),
243            stats: self.stats.clone(),
244        }))
245    }
246
247    fn dtype_and_device(&self) -> (DType, candle_core::Device) {
248        (self.w.dtype(), self.w.device().clone())
249    }
250
251    fn apply_isq(
252        self: Arc<Self>,
253        dtype: Option<IsqType>,
254        device: Device,
255        n_quantized: &AtomicUsize,
256        imatrix_weight: Option<Vec<f32>>,
257        guard: QuantizeOntoGuard,
258    ) -> Result<Arc<dyn QuantMethod>> {
259        match dtype {
260            /*Some(IsqType::HQQ1 | IsqType::HQQ2 | IsqType::HQQ3 | */
261            Some(IsqType::HQQ4 | IsqType::HQQ8) => {
262                let _acquired_quantize_guard = guard.acquire(&device);
263                if imatrix_weight.is_some() {
264                    // TODO just warn?
265                    candle_core::bail!("HQQ does not support imatrix.");
266                }
267
268                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
269                let bits = match dtype.unwrap() {
270                    IsqType::HQQ8 => HqqBits::Eight,
271                    IsqType::HQQ4 => HqqBits::Four,
272                    // IsqType::HQQ3 => HqqBits::Three,
273                    // IsqType::HQQ2 => HqqBits::Two,
274                    // IsqType::HQQ1 => HqqBits::One,
275                    _ => unreachable!(),
276                };
277                let cfg = HqqConfig {
278                    bits,
279                    group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
280                    axis: HqqAxis::Zero,
281                    optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
282                    round_zeros: false,
283                    channel_wise: true,
284                };
285                let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
286                if let Some(bias) = &self.b {
287                    let bias = bias
288                        .to_device(&device)?
289                        .to_dtype(res.dtype_and_device().0)?;
290                    Ok(Arc::new(res.with_bias(bias)))
291                } else {
292                    Ok(Arc::new(res))
293                }
294            }
295            Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
296                let _acquired_quantize_guard = guard.acquire(&device);
297                if imatrix_weight.is_some() {
298                    // TODO just warn?
299                    candle_core::bail!("AFQ does not support imatrix.");
300                }
301
302                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
303                let bits = match dtype.unwrap() {
304                    IsqType::AFQ8 => AfqBits::Eight,
305                    IsqType::AFQ6 => AfqBits::Six,
306                    IsqType::AFQ4 => AfqBits::Four,
307                    IsqType::AFQ3 => AfqBits::Three,
308                    IsqType::AFQ2 => AfqBits::Two,
309                    _ => unreachable!(),
310                };
311
312                Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
313                    weight: self.w.to_device(&device)?,
314                    bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
315                    bits,
316                    group_size: AfqGroupSize::default(),
317                })?))
318            }
319            Some(
320                IsqType::Q2K
321                | IsqType::Q3K
322                | IsqType::Q4K
323                | IsqType::Q4_0
324                | IsqType::Q4_1
325                | IsqType::Q5K
326                | IsqType::Q5_0
327                | IsqType::Q5_1
328                | IsqType::Q6K
329                | IsqType::Q8K
330                | IsqType::Q8_0
331                | IsqType::Q8_1,
332            ) => {
333                let dtype: GgmlDType = dtype.unwrap().try_into()?;
334                let res = if let Some(imatrix_weight) = imatrix_weight {
335                    generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
336                } else {
337                    generate_isq!(self.w, device, dtype, n_quantized, guard)
338                };
339                Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
340                    q_weight: res,
341                    b: self
342                        .b
343                        .as_ref()
344                        .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
345                })?))
346            }
347            Some(IsqType::F8E4M3) => {
348                let _acquired_quantize_guard = guard.acquire(&device);
349                if imatrix_weight.is_some() {
350                    // TODO just warn?
351                    candle_core::bail!("F8E4M3 does not support imatrix.");
352                }
353
354                let w = self.w.to_device(&device)?;
355                let b = if let Some(b) = &self.b {
356                    Some(b.to_device(&device)?)
357                } else {
358                    None
359                };
360                Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
361                    lin: Linear::new(w, b),
362                    dtype: DType::F8E4M3,
363                })?))
364            }
365            Some(IsqType::MXFP4) => {
366                let _acquired_quantize_guard = guard.acquire(&device);
367                if imatrix_weight.is_some() {
368                    candle_core::bail!("MXFP4 does not support imatrix.");
369                }
370
371                n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
372                let w = self.w.to_device(&device)?;
373                let b = self.b.as_ref().map(|b| b.to_device(&device)).transpose()?;
374                crate::MXFP4Layer::quantize(&w, b, &device)
375            }
376            Some(IsqType::F8Q8) => {
377                let _acquired_quantize_guard = guard.acquire(&device);
378                if imatrix_weight.is_some() {
379                    candle_core::bail!("F8Q8 does not support imatrix.");
380                }
381
382                let w = self.w.to_device(&device)?;
383                let b = if let Some(b) = &self.b {
384                    Some(b.to_device(&device)?)
385                } else {
386                    None
387                };
388                Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
389            }
390            None => {
391                let _acquired_quantize_guard = guard.acquire(&device);
392                // Ignore imatrix altogether
393
394                let w = self.w.to_device(&device)?;
395                let b = if let Some(b) = &self.b {
396                    Some(b.to_device(&device)?)
397                } else {
398                    None
399                };
400                Ok(Arc::new(UnquantLinear::new(
401                    QuantMethodConfig::Unquantized(Linear::new(w, b)),
402                )?))
403            }
404        }
405    }
406
407    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
408        Some((self.w.clone(), self.b.clone()))
409    }
410
411    fn begin_track_stats(&mut self) -> Result<()> {
412        self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
413        Ok(())
414    }
415
416    fn end_track_stats(&self) -> Result<Tensor> {
417        if let Some(stats) = &self.stats {
418            let imatrix = stats.compute_imatrix()?;
419            stats.clear()?;
420            Ok(imatrix)
421        } else {
422            candle_core::bail!("`{}` does not support tracking stats.", self.name())
423        }
424    }
425}
426
427// Serialization structure:
428//
429// -----------------------
430// UQFF version, u32, little endian
431// -----------------------
432// ISQ type (1 for unquantized), u8, little endian
433// -----------------------
434// Whether bias data is included, u8 boolean
435// -----------------------
436// Weight tensor data generated by `serialize_tensor`. Refer to its docs for layout.
437// -----------------------
438// [OPTIONAL] Bias tensor data generated by `serialize_tensor`. Refer to its docs for layout.
439// -----------------------
440
441impl QuantizedSerde for UnquantLinear {
442    fn isq_serde_supported(&self) -> bool {
443        true
444    }
445    fn name(&self) -> &'static str {
446        "unquant-linear"
447    }
448    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
449        self.serialize_with_bias(self.b.clone())
450    }
451    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
452        let mut buffer = Vec::new();
453
454        // Version is always first!
455
456        buffer.extend(&UQFF_VERSION.to_le_bytes());
457
458        // ISQ type for unquant is 1
459        buffer.push(QuantizedSerdeType::Unquant as u8);
460
461        // Has bias
462        buffer.push(bias.is_some() as u8);
463
464        // Weight
465        serialize_tensor(&mut buffer, &self.w)?;
466
467        if let Some(bias) = &bias {
468            // Bias
469            serialize_tensor(&mut buffer, bias)?;
470        }
471
472        Ok(Cow::from(buffer))
473    }
474
475    fn deserialize(
476        data: Cow<[u8]>,
477        device: &Device,
478        _comm: &Arc<crate::Comm>,
479        guard: QuantizeOntoGuard,
480    ) -> Result<Arc<dyn QuantMethod>>
481    where
482        Self: Sized,
483    {
484        let mut buffer = Cursor::new(data);
485
486        let version = buffer.read_u32::<LittleEndian>()?;
487        if let Err(e) = version_is_compatible(version) {
488            return Err(candle_core::Error::wrap(e));
489        }
490
491        let isq_type = buffer.read_u8()? as usize;
492        if isq_type != QuantizedSerdeType::Unquant as usize {
493            candle_core::bail!(
494                "ISQ type ({isq_type}) doesn't match expected type {}",
495                QuantizedSerdeType::Unquant as usize
496            );
497        }
498
499        let has_bias = buffer.read_u8()? != 0;
500
501        let _acquired_load_guard = guard.acquire(device);
502        let w = deserialize_tensor(&mut buffer, device)?;
503
504        let b = if has_bias {
505            Some(deserialize_tensor(&mut buffer, device)?)
506        } else {
507            None
508        };
509
510        Ok(Arc::new(Self { w, b, stats: None }))
511    }
512    fn deserialize_ext_bias(
513        data: Cow<[u8]>,
514        device: &Device,
515        guard: QuantizeOntoGuard,
516    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
517    where
518        Self: Sized,
519    {
520        let mut buffer = Cursor::new(data);
521
522        let version = buffer.read_u32::<LittleEndian>()?;
523        if let Err(e) = version_is_compatible(version) {
524            return Err(candle_core::Error::wrap(e));
525        }
526
527        let isq_type = buffer.read_u8()? as usize;
528        if isq_type != QuantizedSerdeType::Unquant as usize {
529            candle_core::bail!(
530                "ISQ type ({isq_type}) doesn't match expected type {}",
531                QuantizedSerdeType::Unquant as usize
532            );
533        }
534
535        let has_bias = buffer.read_u8()? != 0;
536
537        let _acquired_load_guard = guard.acquire(device);
538        let w = deserialize_tensor(&mut buffer, device)?;
539
540        let b = if has_bias {
541            Some(deserialize_tensor(&mut buffer, device)?)
542        } else {
543            None
544        };
545
546        Ok((
547            Arc::new(Self {
548                w,
549                b: None,
550                stats: None,
551            }),
552            b,
553        ))
554    }
555}