1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use candle_core::{quantized::GgmlDType, DType, Device, DeviceLocation, Result, Shape, Tensor, D};
9use candle_nn::Linear;
10
11use crate::{
12 cublaslt::{maybe_init_cublas_lt_wrapper, CUBLASLT_CONTROLLER},
13 generate_isq, generate_isq_imatrix,
14 hqq::{HqqAxis, HqqBits, HqqConfig, HqqLayer, ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
15 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
16 AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, ImatrixLayerStats, IsqType,
17 QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedSerde, QuantizedSerdeType,
18};
19
20#[derive(Debug)]
21pub struct UnquantLinear {
22 w: Tensor,
23 b: Option<Tensor>,
24 stats: Option<ImatrixLayerStats>,
25}
26
27impl QuantMethod for UnquantLinear {
28 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
29 where
30 Self: Sized,
31 {
32 match method {
33 QuantMethodConfig::Gguf { .. }
34 | QuantMethodConfig::GptqAwq { .. }
35 | QuantMethodConfig::Hqq { .. }
36 | QuantMethodConfig::Dummy
37 | QuantMethodConfig::FP8 { .. }
38 | QuantMethodConfig::Bnb { .. }
39 | QuantMethodConfig::BlockwiseFP8 { .. }
40 | QuantMethodConfig::PerTensorFP8 { .. }
41 | QuantMethodConfig::Afq { .. }
42 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
43 QuantMethodConfig::Unquantized(l) => Ok(Self {
44 w: l.weight().clone(),
45 b: l.bias().cloned(),
46 stats: None,
47 }),
48 }
49 }
50
51 fn dequantize_w(&self) -> Result<Tensor> {
52 Ok(self.w.clone())
53 }
54
55 fn forward(&self, a: &Tensor) -> Result<Tensor> {
56 maybe_init_cublas_lt_wrapper(a.device().clone());
58
59 #[cfg(feature = "cuda")]
61 if crate::gemv::should_use_gemv(a, &self.w) {
62 return crate::gemv::gemv(a, &self.w, self.b.as_ref());
63 }
64
65 let w = match *a.dims() {
66 [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?,
67 [bsize, _, _] => self.w.broadcast_left(bsize)?,
68 _ => self.w.clone(),
69 };
70
71 if let Some(stats) = &self.stats {
72 stats.process(a)?;
73 }
74
75 if let Some(b) = self.b.as_ref() {
76 let mut tgt_shape = a.dims().to_vec();
77 tgt_shape[a.dims().len() - 1] = w.dim(D::Minus2)?;
78 let b = b.broadcast_as(Shape::from_dims(&tgt_shape))?;
79
80 match a.device().location() {
81 DeviceLocation::Cuda { .. } => {
82 if let (Device::Cuda(_), Some(cublaslt)) =
84 (a.device(), CUBLASLT_CONTROLLER.get_for_device(a.device()))
85 {
86 cublaslt
87 .batch_matmul(
88 a,
89 &w,
90 Some(&b.t()?.contiguous()?),
91 None,
92 Some(1.0),
93 None,
94 None,
95 )?
96 .t()
97 } else {
98 let matmul_result = a.matmul(&w.t()?)?;
99 matmul_result.broadcast_add(&b)
100 }
101 }
102 DeviceLocation::Metal { .. } => {
103 let matmul_result = a.matmul(&w.t()?)?;
104 matmul_result.broadcast_add(&b)
105 }
106 DeviceLocation::Cpu => {
107 #[cfg(feature = "accelerate")]
108 {
109 let original_dtype = a.dtype();
110 let a_f32 = a.to_dtype(DType::F32)?;
111 let w_f32 = w.t()?.to_dtype(DType::F32)?;
112 let b_f32 = b.to_dtype(DType::F32)?;
113 let matmul_result = a_f32.matmul(&w_f32)?;
114 matmul_result
115 .broadcast_add(&b_f32)?
116 .to_dtype(original_dtype)
117 }
118 #[cfg(not(feature = "accelerate"))]
119 {
120 let matmul_result = a.matmul(&w.t()?)?;
121 matmul_result.broadcast_add(&b)
122 }
123 }
124 }
125 } else {
126 match a.device().location() {
127 DeviceLocation::Cuda { .. } => {
128 if let (Device::Cuda(_), Some(cublaslt)) =
129 (a.device(), CUBLASLT_CONTROLLER.get_for_device(a.device()))
130 {
131 if a.rank() >= 3 && w.rank() >= 3 {
133 cublaslt
134 .batch_matmul(a, &w, None, None, None, None, None)?
135 .t()
136 } else {
137 a.matmul(&w.t()?)
138 }
139 } else {
140 a.matmul(&w.t()?)
141 }
142 }
143 DeviceLocation::Metal { .. } => a.matmul(&w.t()?),
144 DeviceLocation::Cpu => {
145 #[cfg(feature = "accelerate")]
146 {
147 let original_dtype = a.dtype();
148 a.to_dtype(DType::F32)?
149 .matmul(&w.t()?.to_dtype(DType::F32)?)?
150 .to_dtype(original_dtype)
151 }
152 #[cfg(not(feature = "accelerate"))]
153 {
154 a.matmul(&w.t()?)
155 }
156 }
157 }
158 }
159 }
160
161 fn gather_forward(&self, a: &Tensor, indices: &Tensor) -> Result<Tensor> {
162 let w = &self.w;
171 let (_num_experts, out_features, _in_features) = w.dims3()?;
172
173 match a.dims() {
174 &[b_size, seq_len, 1, 1, hidden_dim] => {
176 let (_b, _s, num_experts_per_tok) = indices.dims3()?;
177 let flat_indices = indices.reshape((b_size * seq_len * num_experts_per_tok,))?;
179
180 let selected_w = w.index_select(&flat_indices, 0)?;
182
183 let a_flat = a.reshape((b_size * seq_len, hidden_dim))?;
185
186 let a_expanded = a_flat
189 .unsqueeze(1)?
190 .broadcast_as((b_size * seq_len, num_experts_per_tok, hidden_dim))?
191 .reshape((b_size * seq_len * num_experts_per_tok, hidden_dim))?;
192
193 let result = a_expanded
195 .unsqueeze(1)?
196 .matmul(&selected_w.transpose(1, 2)?)?
197 .squeeze(1)?;
198
199 result.reshape((b_size, seq_len, num_experts_per_tok, out_features))
201 }
202 &[num_tokens, 1, hidden_dim] => {
204 let (_, num_experts_per_tok) = indices.dims2()?;
205
206 let flat_indices = indices.reshape((num_tokens * num_experts_per_tok,))?;
208
209 let selected_w = w.index_select(&flat_indices, 0)?;
211
212 let a_expanded = a
214 .broadcast_as((num_tokens, num_experts_per_tok, hidden_dim))?
215 .reshape((num_tokens * num_experts_per_tok, hidden_dim))?;
216
217 let result = a_expanded
219 .unsqueeze(1)?
220 .matmul(&selected_w.transpose(1, 2)?)?
221 .squeeze(1)?;
222
223 result.reshape((num_tokens, num_experts_per_tok, out_features))
225 }
226 dims => {
227 candle_core::bail!(
228 "UnquantLinear::gather_forward: unsupported input shape {:?}",
229 dims
230 );
231 }
232 }
233 }
234
235 fn quantized_act_type(&self) -> Option<DType> {
236 None
237 }
238
239 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
240 Ok(Arc::new(Self {
241 w: (&self.w + delta)?,
242 b: self.b.clone(),
243 stats: self.stats.clone(),
244 }))
245 }
246
247 fn dtype_and_device(&self) -> (DType, candle_core::Device) {
248 (self.w.dtype(), self.w.device().clone())
249 }
250
251 fn apply_isq(
252 self: Arc<Self>,
253 dtype: Option<IsqType>,
254 device: Device,
255 n_quantized: &AtomicUsize,
256 imatrix_weight: Option<Vec<f32>>,
257 guard: QuantizeOntoGuard,
258 ) -> Result<Arc<dyn QuantMethod>> {
259 match dtype {
260 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
262 let _acquired_quantize_guard = guard.acquire(&device);
263 if imatrix_weight.is_some() {
264 candle_core::bail!("HQQ does not support imatrix.");
266 }
267
268 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
269 let bits = match dtype.unwrap() {
270 IsqType::HQQ8 => HqqBits::Eight,
271 IsqType::HQQ4 => HqqBits::Four,
272 _ => unreachable!(),
276 };
277 let cfg = HqqConfig {
278 bits,
279 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
280 axis: HqqAxis::Zero,
281 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
282 round_zeros: false,
283 channel_wise: true,
284 };
285 let res = HqqLayer::quantize(&self.w.to_device(&device)?, &device, cfg)?;
286 if let Some(bias) = &self.b {
287 let bias = bias
288 .to_device(&device)?
289 .to_dtype(res.dtype_and_device().0)?;
290 Ok(Arc::new(res.with_bias(bias)))
291 } else {
292 Ok(Arc::new(res))
293 }
294 }
295 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
296 let _acquired_quantize_guard = guard.acquire(&device);
297 if imatrix_weight.is_some() {
298 candle_core::bail!("AFQ does not support imatrix.");
300 }
301
302 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
303 let bits = match dtype.unwrap() {
304 IsqType::AFQ8 => AfqBits::Eight,
305 IsqType::AFQ6 => AfqBits::Six,
306 IsqType::AFQ4 => AfqBits::Four,
307 IsqType::AFQ3 => AfqBits::Three,
308 IsqType::AFQ2 => AfqBits::Two,
309 _ => unreachable!(),
310 };
311
312 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
313 weight: self.w.to_device(&device)?,
314 bias: self.b.as_ref().map(|b| b.to_device(&device).unwrap()),
315 bits,
316 group_size: AfqGroupSize::default(),
317 })?))
318 }
319 Some(
320 IsqType::Q2K
321 | IsqType::Q3K
322 | IsqType::Q4K
323 | IsqType::Q4_0
324 | IsqType::Q4_1
325 | IsqType::Q5K
326 | IsqType::Q5_0
327 | IsqType::Q5_1
328 | IsqType::Q6K
329 | IsqType::Q8K
330 | IsqType::Q8_0
331 | IsqType::Q8_1,
332 ) => {
333 let dtype: GgmlDType = dtype.unwrap().try_into()?;
334 let res = if let Some(imatrix_weight) = imatrix_weight {
335 generate_isq_imatrix!(self.w, imatrix_weight, device, dtype, n_quantized, guard)
336 } else {
337 generate_isq!(self.w, device, dtype, n_quantized, guard)
338 };
339 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
340 q_weight: res,
341 b: self
342 .b
343 .as_ref()
344 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
345 })?))
346 }
347 Some(IsqType::F8E4M3) => {
348 let _acquired_quantize_guard = guard.acquire(&device);
349 if imatrix_weight.is_some() {
350 candle_core::bail!("F8E4M3 does not support imatrix.");
352 }
353
354 let w = self.w.to_device(&device)?;
355 let b = if let Some(b) = &self.b {
356 Some(b.to_device(&device)?)
357 } else {
358 None
359 };
360 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
361 lin: Linear::new(w, b),
362 dtype: DType::F8E4M3,
363 })?))
364 }
365 Some(IsqType::MXFP4) => {
366 let _acquired_quantize_guard = guard.acquire(&device);
367 if imatrix_weight.is_some() {
368 candle_core::bail!("MXFP4 does not support imatrix.");
369 }
370
371 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
372 let w = self.w.to_device(&device)?;
373 let b = self.b.as_ref().map(|b| b.to_device(&device)).transpose()?;
374 crate::MXFP4Layer::quantize(&w, b, &device)
375 }
376 Some(IsqType::F8Q8) => {
377 let _acquired_quantize_guard = guard.acquire(&device);
378 if imatrix_weight.is_some() {
379 candle_core::bail!("F8Q8 does not support imatrix.");
380 }
381
382 let w = self.w.to_device(&device)?;
383 let b = if let Some(b) = &self.b {
384 Some(b.to_device(&device)?)
385 } else {
386 None
387 };
388 Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
389 }
390 None => {
391 let _acquired_quantize_guard = guard.acquire(&device);
392 let w = self.w.to_device(&device)?;
395 let b = if let Some(b) = &self.b {
396 Some(b.to_device(&device)?)
397 } else {
398 None
399 };
400 Ok(Arc::new(UnquantLinear::new(
401 QuantMethodConfig::Unquantized(Linear::new(w, b)),
402 )?))
403 }
404 }
405 }
406
407 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
408 Some((self.w.clone(), self.b.clone()))
409 }
410
411 fn begin_track_stats(&mut self) -> Result<()> {
412 self.stats = Some(ImatrixLayerStats::new(&self.w, self.w.device())?);
413 Ok(())
414 }
415
416 fn end_track_stats(&self) -> Result<Tensor> {
417 if let Some(stats) = &self.stats {
418 let imatrix = stats.compute_imatrix()?;
419 stats.clear()?;
420 Ok(imatrix)
421 } else {
422 candle_core::bail!("`{}` does not support tracking stats.", self.name())
423 }
424 }
425}
426
427impl QuantizedSerde for UnquantLinear {
442 fn isq_serde_supported(&self) -> bool {
443 true
444 }
445 fn name(&self) -> &'static str {
446 "unquant-linear"
447 }
448 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
449 self.serialize_with_bias(self.b.clone())
450 }
451 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
452 let mut buffer = Vec::new();
453
454 buffer.extend(&UQFF_VERSION.to_le_bytes());
457
458 buffer.push(QuantizedSerdeType::Unquant as u8);
460
461 buffer.push(bias.is_some() as u8);
463
464 serialize_tensor(&mut buffer, &self.w)?;
466
467 if let Some(bias) = &bias {
468 serialize_tensor(&mut buffer, bias)?;
470 }
471
472 Ok(Cow::from(buffer))
473 }
474
475 fn deserialize(
476 data: Cow<[u8]>,
477 device: &Device,
478 _comm: &Arc<crate::Comm>,
479 guard: QuantizeOntoGuard,
480 ) -> Result<Arc<dyn QuantMethod>>
481 where
482 Self: Sized,
483 {
484 let mut buffer = Cursor::new(data);
485
486 let version = buffer.read_u32::<LittleEndian>()?;
487 if let Err(e) = version_is_compatible(version) {
488 return Err(candle_core::Error::wrap(e));
489 }
490
491 let isq_type = buffer.read_u8()? as usize;
492 if isq_type != QuantizedSerdeType::Unquant as usize {
493 candle_core::bail!(
494 "ISQ type ({isq_type}) doesn't match expected type {}",
495 QuantizedSerdeType::Unquant as usize
496 );
497 }
498
499 let has_bias = buffer.read_u8()? != 0;
500
501 let _acquired_load_guard = guard.acquire(device);
502 let w = deserialize_tensor(&mut buffer, device)?;
503
504 let b = if has_bias {
505 Some(deserialize_tensor(&mut buffer, device)?)
506 } else {
507 None
508 };
509
510 Ok(Arc::new(Self { w, b, stats: None }))
511 }
512 fn deserialize_ext_bias(
513 data: Cow<[u8]>,
514 device: &Device,
515 guard: QuantizeOntoGuard,
516 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
517 where
518 Self: Sized,
519 {
520 let mut buffer = Cursor::new(data);
521
522 let version = buffer.read_u32::<LittleEndian>()?;
523 if let Err(e) = version_is_compatible(version) {
524 return Err(candle_core::Error::wrap(e));
525 }
526
527 let isq_type = buffer.read_u8()? as usize;
528 if isq_type != QuantizedSerdeType::Unquant as usize {
529 candle_core::bail!(
530 "ISQ type ({isq_type}) doesn't match expected type {}",
531 QuantizedSerdeType::Unquant as usize
532 );
533 }
534
535 let has_bias = buffer.read_u8()? != 0;
536
537 let _acquired_load_guard = guard.acquire(device);
538 let w = deserialize_tensor(&mut buffer, device)?;
539
540 let b = if has_bias {
541 Some(deserialize_tensor(&mut buffer, device)?)
542 } else {
543 None
544 };
545
546 Ok((
547 Arc::new(Self {
548 w,
549 b: None,
550 stats: None,
551 }),
552 b,
553 ))
554 }
555}