hanzo_quant/pertensor_fp8/
mod.rs1use std::{
2 borrow::Cow,
3 sync::{atomic::AtomicUsize, Arc},
4};
5
6use hanzo_ml::{quantized::GgmlDType, DType, Device, Result, Tensor};
7use hanzo_nn::Linear;
8
9mod ops;
10
11use crate::{
12 generate_isq, generate_isq_imatrix, has_missing_required_tensors,
13 hqq::{ISQ_HQQ_DEFAULT_OPT_STEPS, ISQ_HQQ_GROUP_SIZE},
14 make_dummy_or_error,
15 utils::{serialize_tensor, UQFF_VERSION},
16 AfqBits, AfqGroupSize, AfqLayer, FP8Linear, GgufMatMul, HqqAxis, HqqBits, HqqConfig, HqqLayer,
17 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
18 QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
19};
20
21#[derive(Debug)]
29pub struct PerTensorFP8Linear {
30 weight: Tensor,
31 #[allow(dead_code)]
32 weight_scale_inv: Tensor,
33 #[allow(dead_code)]
34 activation_scale: Option<Tensor>,
35 bias: Option<Tensor>,
36 #[allow(dead_code)]
37 dequant_dtype: DType,
38}
39
40impl QuantMethod for PerTensorFP8Linear {
41 fn new(method: QuantMethodConfig) -> hanzo_ml::Result<Self>
42 where
43 Self: Sized,
44 {
45 match method {
46 QuantMethodConfig::PerTensorFP8 {
47 weight,
48 weight_scale_inv,
49 activation_scale,
50 bias,
51 dequant_dtype,
52 } => {
53 let dequant_weight =
55 ops::fp8_pertensor_dequantize(&weight, &weight_scale_inv, dequant_dtype)?;
56 Ok(Self {
57 weight: dequant_weight,
58 weight_scale_inv,
59 activation_scale,
60 bias,
61 dequant_dtype,
62 })
63 }
64 _ => unreachable!(),
65 }
66 }
67
68 fn dequantize_w(&self) -> Result<Tensor> {
69 Ok(self.weight.clone())
71 }
72
73 fn forward_raw(&self, x: &Tensor) -> Result<Tensor> {
74 let unquant = UnquantLinear::new(QuantMethodConfig::Unquantized(Linear::new(
76 self.weight.clone(),
77 self.bias.clone(),
78 )))?;
79 unquant.forward(x)
80 }
81
82 fn quantized_act_type(&self) -> Option<DType> {
83 None
84 }
85
86 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
87 hanzo_ml::bail!("PerTensorFP8Linear does not support add_delta_w")
88 }
89
90 fn dtype_and_device(&self) -> (DType, Device) {
91 (DType::F8E4M3, self.weight.device().clone())
92 }
93
94 fn apply_isq(
95 self: Arc<Self>,
96 dtype: Option<IsqType>,
97 device: Device,
98 n_quantized: &AtomicUsize,
99 imatrix_weight: Option<Vec<f32>>,
100 guard: QuantizeOntoGuard,
101 ) -> Result<Arc<dyn QuantMethod>> {
102 let weight = self.dequantize_w()?;
103 match dtype {
104 Some(IsqType::HQQ4 | IsqType::HQQ8) => {
105 let _acquired_quantize_guard = guard.acquire(&device);
106 if imatrix_weight.is_some() {
107 hanzo_ml::bail!("HQQ does not support imatrix.");
108 }
109
110 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
111 let bits = match dtype.unwrap() {
112 IsqType::HQQ8 => HqqBits::Eight,
113 IsqType::HQQ4 => HqqBits::Four,
114 _ => unreachable!(),
115 };
116 let cfg = HqqConfig {
117 bits,
118 group_size: ISQ_HQQ_GROUP_SIZE.try_into()?,
119 axis: HqqAxis::Zero,
120 optimization_steps: ISQ_HQQ_DEFAULT_OPT_STEPS,
121 round_zeros: false,
122 channel_wise: true,
123 };
124 let res = HqqLayer::quantize(&weight.to_device(&device)?, &device, cfg)?;
125 if let Some(bias) = &self.bias {
126 let bias = bias
127 .to_device(&device)?
128 .to_dtype(res.dtype_and_device().0)?;
129 Ok(Arc::new(res.with_bias(bias)))
130 } else {
131 Ok(Arc::new(res))
132 }
133 }
134 Some(IsqType::AFQ2 | IsqType::AFQ3 | IsqType::AFQ4 | IsqType::AFQ6 | IsqType::AFQ8) => {
135 let _acquired_quantize_guard = guard.acquire(&device);
136 if imatrix_weight.is_some() {
137 hanzo_ml::bail!("AFQ does not support imatrix.");
138 }
139
140 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
141 let bits = match dtype.unwrap() {
142 IsqType::AFQ8 => AfqBits::Eight,
143 IsqType::AFQ6 => AfqBits::Six,
144 IsqType::AFQ4 => AfqBits::Four,
145 IsqType::AFQ3 => AfqBits::Three,
146 IsqType::AFQ2 => AfqBits::Two,
147 _ => unreachable!(),
148 };
149
150 Ok(Arc::new(AfqLayer::new(QuantMethodConfig::Afq {
151 weight: weight.to_device(&device)?,
152 bias: self.bias.as_ref().map(|b| b.to_device(&device).unwrap()),
153 bits,
154 group_size: AfqGroupSize::default(),
155 })?))
156 }
157 Some(
158 IsqType::Q2K
159 | IsqType::Q3K
160 | IsqType::Q4K
161 | IsqType::Q4_0
162 | IsqType::Q4_1
163 | IsqType::Q5K
164 | IsqType::Q5_0
165 | IsqType::Q5_1
166 | IsqType::Q6K
167 | IsqType::Q8K
168 | IsqType::Q8_0
169 | IsqType::Q8_1,
170 ) => {
171 let dtype: GgmlDType = dtype.unwrap().try_into()?;
172 let res = if let Some(imatrix_weight) = imatrix_weight {
173 generate_isq_imatrix!(weight, imatrix_weight, device, dtype, n_quantized, guard)
174 } else {
175 generate_isq!(weight, device, dtype, n_quantized, guard)
176 };
177 Ok(Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
178 q_weight: res,
179 b: self
180 .bias
181 .as_ref()
182 .map(|b| b.to_dtype(DType::F32).unwrap().to_device(&device).unwrap()),
183 })?))
184 }
185 Some(IsqType::F8E4M3) => {
186 let _acquired_quantize_guard = guard.acquire(&device);
187 if imatrix_weight.is_some() {
188 hanzo_ml::bail!("F8E4M3 does not support imatrix.");
189 }
190
191 let w = weight.to_device(&device)?;
192 let b = if let Some(b) = &self.bias {
193 Some(b.to_device(&device)?)
194 } else {
195 None
196 };
197 Ok(Arc::new(FP8Linear::new(QuantMethodConfig::FP8 {
198 lin: Linear::new(w, b),
199 dtype: DType::F8E4M3,
200 })?))
201 }
202 Some(IsqType::F8Q8) => {
203 let _acquired_quantize_guard = guard.acquire(&device);
204 if imatrix_weight.is_some() {
205 hanzo_ml::bail!("F8Q8 does not support imatrix.");
206 }
207
208 let w = weight.to_device(&device)?;
209 let b = if let Some(b) = &self.bias {
210 Some(b.to_device(&device)?)
211 } else {
212 None
213 };
214 Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
215 }
216 Some(IsqType::MXFP4) => {
217 let _acquired_quantize_guard = guard.acquire(&device);
218 if imatrix_weight.is_some() {
219 hanzo_ml::bail!("MXFP4 does not support imatrix.");
220 }
221
222 n_quantized.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
223 let w = weight.to_device(&device)?;
224 let b = self
225 .bias
226 .as_ref()
227 .map(|b| b.to_device(&device))
228 .transpose()?;
229 crate::MXFP4Layer::quantize(&w, b, &device)
230 }
231 None => {
232 let _acquired_quantize_guard = guard.acquire(&device);
233
234 let w = weight.to_device(&device)?;
235 let b = if let Some(b) = &self.bias {
236 Some(b.to_device(&device)?)
237 } else {
238 None
239 };
240 Ok(Arc::new(UnquantLinear::new(
241 QuantMethodConfig::Unquantized(Linear::new(w, b)),
242 )?))
243 }
244 }
245 }
246}
247
248impl QuantizedSerde for PerTensorFP8Linear {
263 fn isq_serde_supported(&self) -> bool {
264 true
265 }
266 fn name(&self) -> &'static str {
267 "pertensor-fp8-linear"
268 }
269 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
270 self.serialize_with_bias(self.bias.clone())
271 }
272 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
273 let mut buffer = Vec::new();
275
276 buffer.extend(&UQFF_VERSION.to_le_bytes());
278
279 buffer.push(QuantizedSerdeType::Unquant as u8);
281
282 buffer.push(bias.is_some() as u8);
284
285 serialize_tensor(&mut buffer, &self.weight)?;
287
288 if let Some(bias) = &bias {
289 serialize_tensor(&mut buffer, bias)?;
291 }
292
293 Ok(Cow::from(buffer))
294 }
295}
296
297pub fn pertensor_fp8_linear_b(
303 in_dim: usize,
304 out_dim: usize,
305 _config: &QuantizedConfig,
306 bias: bool,
307 _hints: Shard,
308 vb: ShardedVarBuilder,
309) -> Result<Arc<dyn QuantMethod>> {
310 if vb.contains_tensor("weight") && !vb.contains_tensor("weight_scale_inv") {
312 return crate::linear_b(in_dim, out_dim, bias, &None, vb);
313 }
314
315 if has_missing_required_tensors(&vb, &["weight", "weight_scale_inv"]) {
316 return make_dummy_or_error("pertensor_fp8_linear", &vb, &["weight", "weight_scale_inv"]);
317 }
318
319 let weight = vb.get_with_hints_dtype(
321 (out_dim, in_dim),
322 "weight",
323 Default::default(),
324 DType::F8E4M3,
325 )?;
326
327 let weight_scale_inv =
329 vb.get_with_hints_dtype((), "weight_scale_inv", Default::default(), DType::F32)?;
330
331 let activation_scale = if vb.contains_tensor("activation_scale") {
333 Some(vb.get_with_hints_dtype((), "activation_scale", Default::default(), DType::F32)?)
334 } else {
335 None
336 };
337
338 let bias = if bias && vb.contains_tensor("bias") {
339 Some(vb.get((out_dim,), "bias")?)
340 } else {
341 None
342 };
343
344 let dequant_dtype = bias.as_ref().map(|b| b.dtype()).unwrap_or(DType::BF16);
348
349 Ok(Arc::new(PerTensorFP8Linear::new(
351 QuantMethodConfig::PerTensorFP8 {
352 weight,
353 weight_scale_inv,
354 activation_scale,
355 bias,
356 dequant_dtype,
357 },
358 )?))
359}