mistralrs_quant/pertensor_fp8/
mod.rs1use std::{
2 borrow::Cow,
3 sync::{atomic::AtomicUsize, Arc},
4};
5
6use candle_core::{quantized::GgmlDType, DType, Device, Result, Tensor};
7use candle_nn::Linear;
8
9mod ops;
10
11use crate::{
12 generate_isq, generate_isq_imatrix,
13 hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
14 utils::{serialize_tensor, UQFF_VERSION},
15 AfqBits, AfqGroupSize, AfqLayer, DummyLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits,
16 HqqConfig, HqqLayer, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard,
17 QuantizedConfig, QuantizedSerde, QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
18};
19
20#[derive(Debug)]
28pub struct PerTensorFP8Linear {
29 weight: Tensor,
30 #[allow(dead_code)]
31 weight_scale_inv: Tensor,
32 #[allow(dead_code)]
33 activation_scale: Option<Tensor>,
34 bias: Option<Tensor>,
35 #[allow(dead_code)]
36 dequant_dtype: DType,
37}
38
39impl QuantMethod for PerTensorFP8Linear {
40 fn new(method: QuantMethodConfig) -> candle_core::Result<Self>
41 where
42 Self: Sized,
43 {
44 match method {
45 QuantMethodConfig::PerTensorFP8 {
46 weight,
47 weight_scale_inv,
48 activation_scale,
49 bias,
50 dequant_dtype,
51 } => {
52 let dequant_weight =
54 ops::fp8_pertensor_dequantize(&weight, &weight_scale_inv, dequant_dtype)?;
55 Ok(Self {
56 weight: dequant_weight,
57 weight_scale_inv,
58 activation_scale,
59 bias,
60 dequant_dtype,
61 })
62 }
63 _ => unreachable!(),
64 }
65 }
66
67 fn dequantize_w(&self) -> Result<Tensor> {
68 Ok(self.weight.clone())
70 }
71
72 fn forward(&self, x: &Tensor) -> Result<Tensor> {
73 let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
75 self.weight.clone(),
76 self.bias.clone(),
77 )))?;
78 unquant.forward(x)
79 }
80
81 fn quantized_act_type(&self) -> Option<DType> {
82 None
83 }
84
85 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
86 candle_core::bail!("PerTensorFP8Linear does not support add_delta_w")
87 }
88
89 fn dtype_and_device(&self) -> (DType, Device) {
90 (DType::F8E4M3, self.weight.device().clone())
91 }
92
93 fn apply_isq(
94 self: Arc<Self>,
95 dtype: Option<IsqType>,
96 device: Device,
97 n_quantized: &AtomicUsize,
98 imatrix_weight: Option<Vec<f32>>,
99 guard: QuantizeOntoGuard,
100 ) -> Result<Arc<dyn QuantMethod>> {
101 let weight = self.dequantize_w()?;
102 match dtype {
103 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
104 let _acquired_quantize_guard = guard.acquire(&device);
105 if imatrix_weight.is_some() {
106 candle_core::bail!("HQQ does not support imatrix.");
107 }
108
109 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
110 let bits = match dtype.unwrap() {
111 IsqType::HQQ8 => HqqBits::Eight,
112 IsqType::HQQ4 => HqqBits::Four,
113 _ => unreachable!(),
114 };
115 let cfg = HqqConfig {
116 bits,
117 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
118 axis: HqqAxis::Zero,
119 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
120 round_zeros: false,
121 channel_wise: true,
122 };
123 let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
124 if let Some(bias) = &self.bias {
125 let bias = bias
126 .to_device(&device)?
127 .to_dtype(res.dtype_and_device().0)?;
128 Ok(Arc::new(res.with_bias(bias)))
129 } else {
130 Ok(Arc::new(res))
131 }
132 }
133 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
134 let _acquired_quantize_guard = guard.acquire(&device);
135 if imatrix_weight.is_some() {
136 candle_core::bail!("AFQ does not support imatrix.");
137 }
138
139 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
140 let bits = match dtype.unwrap() {
141 IsqType::AFQ8 => AfqBits::Eight,
142 IsqType::AFQ6 => AfqBits::Six,
143 IsqType::AFQ4 => AfqBits::Four,
144 IsqType::AFQ3 => AfqBits::Three,
145 IsqType::AFQ2 => AfqBits::Two,
146 _ => unreachable!(),
147 };
148
149 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
150 weight: weight.to_device(&device)?,
151 bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
152 bits,
153 group_size: AfqGroupSize::default(),
154 })?))
155 }
156 Some(
157 IsqType::Q2K
158 | IsqType::Q3K
159 | IsqType::Q4K
160 | IsqType::Q4_0
161 | IsqType::Q4_1
162 | IsqType::Q5K
163 | IsqType::Q5_0
164 | IsqType::Q5_1
165 | IsqType::Q6K
166 | IsqType::Q8K
167 | IsqType::Q8_0
168 | IsqType::Q8_1,
169 ) => {
170 let dtype: GgmlDType = dtype.unwrap().try_into()?;
171 let res = if let Some(imatrix_weight) = imatrix_weight {
172 generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
173 } else {
174 generate_isq!(weight, device, dtype, n_quantized, guard)
175 };
176 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
177 q_weight: res,
178 b: self
179 .bias
180 .as_ref()
181 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
182 })?))
183 }
184 Some(IsqType::F8E4M3) => {
185 let _acquired_quantize_guard = guard.acquire(&device);
186 if imatrix_weight.is_some() {
187 candle_core::bail!("F8E4M3 does not support imatrix.");
188 }
189
190 let w = weight.to_device(&device)?;
191 let b = if let Some(b) = &self.bias {
192 Some(b.to_device(&device)?)
193 } else {
194 None
195 };
196 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
197 lin: Linear::new(w, b),
198 dtype: DType::F8E4M3,
199 })?))
200 }
201 Some(IsqType::F8Q8) => {
202 let _acquired_quantize_guard = guard.acquire(&device);
203 if imatrix_weight.is_some() {
204 candle_core::bail!("F8Q8 does not support imatrix.");
205 }
206
207 let w = weight.to_device(&device)?;
208 let b = if let Some(b) = &self.bias {
209 Some(b.to_device(&device)?)
210 } else {
211 None
212 };
213 Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
214 }
215 Some(IsqType::MXFP4) => {
216 let _acquired_quantize_guard = guard.acquire(&device);
217 if imatrix_weight.is_some() {
218 candle_core::bail!("MXFP4 does not support imatrix.");
219 }
220
221 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
222 let w = weight.to_device(&device)?;
223 let b = self
224 .bias
225 .as_ref()
226 .map(|b| b.to_device(&device))
227 .transpose()?;
228 crate::MXFP4Layer::quantize(&w, b, &device)
229 }
230 None => {
231 let _acquired_quantize_guard = guard.acquire(&device);
232
233 let w = weight.to_device(&device)?;
234 let b = if let Some(b) = &self.bias {
235 Some(b.to_device(&device)?)
236 } else {
237 None
238 };
239 Ok(Arc::new(UnquantLinear::new(
240 QuantMethodConfig::Unquantized(Linear::new(w, b)),
241 )?))
242 }
243 }
244 }
245}
246
247impl QuantizedSerde for PerTensorFP8Linear {
262 fn isq_serde_supported(&self) -> bool {
263 true
264 }
265 fn name(&self) -> &'static str {
266 "pertensor-fp8-linear"
267 }
268 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
269 self.serialize_with_bias(self.bias.clone())
270 }
271 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
272 let mut buffer = Vec::new();
274
275 buffer.extend(&UQFF_VERSION.to_le_bytes());
277
278 buffer.push(QuantizedSerdeType::Unquant as u8);
280
281 buffer.push(bias.is_some() as u8);
283
284 serialize_tensor(&mut buffer, &self.weight)?;
286
287 if let Some(bias) = &bias {
288 serialize_tensor(&mut buffer, bias)?;
290 }
291
292 Ok(Cow::from(buffer))
293 }
294}
295
296pub fn pertensor_fp8_linear_b(
302 in_dim: usize,
303 out_dim: usize,
304 _config: &QuantizedConfig,
305 bias: bool,
306 _hints: Shard,
307 vb: ShardedVarBuilder,
308) -> Result<Arc<dyn QuantMethod>> {
309 if vb.contains_tensor("weight") && !vb.contains_tensor("weight_scale_inv") {
311 return crate::linear_b(in_dim, out_dim, bias, &None, vb);
312 }
313
314 if !vb.contains_tensor("weight") {
316 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
317 return Ok(Arc::new(layer) as Arc<dyn QuantMethod>);
318 }
319
320 let weight = vb.get_with_hints_dtype(
322 (out_dim, in_dim),
323 "weight",
324 Default::default(),
325 DType::F8E4M3,
326 )?;
327
328 let weight_scale_inv =
330 vb.get_with_hints_dtype((), "weight_scale_inv", Default::default(), DType::F32)?;
331
332 let activation_scale = if vb.contains_tensor("activation_scale") {
334 Some(vb.get_with_hints_dtype((), "activation_scale", Default::default(), DType::F32)?)
335 } else {
336 None
337 };
338
339 let bias = if bias && vb.contains_tensor("bias") {
340 Some(vb.get((out_dim,), "bias")?)
341 } else {
342 None
343 };
344
345 let dequant_dtype = bias.as_ref().map(|b| b.dtype()).unwrap_or(DType::BF16);
349
350 Ok(Arc::new(PerTensorFP8Linear::new(
352 QuantMethodConfig::PerTensorFP8 {
353 weight,
354 weight_scale_inv,
355 activation_scale,
356 bias,
357 dequant_dtype,
358 },
359 )?))
360}