1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use hanzo_ml::{DType, Device, Result, Tensor};
9
10use crate::{
11 utils::{
12 deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
13 UQFF_VERSION,
14 },
15 Comm, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
16 QuantizedSerde, QuantizedSerdeType, ShardedVarBuilder,
17};
18
19pub(crate) mod ops;
20
21#[cfg(feature = "cuda")]
22pub(crate) mod ffi;
23
24#[repr(u8)]
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum AfqBits {
27 Two = 2,
28 Three = 3,
29 Four = 4,
30 Six = 6,
31 Eight = 8,
32 Mxfp4 = 40,
33}
34
35impl TryFrom<usize> for AfqBits {
36 type Error = hanzo_ml::Error;
37 fn try_from(value: usize) -> Result<Self> {
38 match value {
39 2 => Ok(Self::Two),
40 3 => Ok(Self::Three),
41 4 => Ok(Self::Four),
42 6 => Ok(Self::Six),
43 8 => Ok(Self::Eight),
44 40 => Ok(Self::Mxfp4),
45 x => hanzo_ml::bail!("Invalid AFQ bits {x}."),
46 }
47 }
48}
49
50impl TryFrom<u8> for AfqBits {
51 type Error = hanzo_ml::Error;
52 fn try_from(value: u8) -> Result<Self> {
53 Self::try_from(value as usize)
54 }
55}
56
57#[repr(u8)]
58#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
59pub enum AfqGroupSize {
60 Low = 32,
61 #[default]
62 Med = 64,
63 High = 128,
64}
65
66impl TryFrom<usize> for AfqGroupSize {
67 type Error = hanzo_ml::Error;
68 fn try_from(value: usize) -> Result<Self> {
69 match value {
70 32 => Ok(Self::Low),
71 64 => Ok(Self::Med),
72 128 => Ok(Self::High),
73 x => hanzo_ml::bail!("Invalid AFQ group size {x}."),
74 }
75 }
76}
77
78impl TryFrom<u8> for AfqGroupSize {
79 type Error = hanzo_ml::Error;
80 fn try_from(value: u8) -> Result<Self> {
81 Self::try_from(value as usize)
82 }
83}
84
85#[derive(Debug)]
86pub struct AfqLayer {
87 w_q: Tensor,
88 scales: Tensor,
89 biases: Tensor,
90 bias: Option<Tensor>,
91 bits: AfqBits,
92 group_size: AfqGroupSize,
93}
94
95#[derive(Clone)]
97pub struct AfqInner<'a> {
98 pub w_q: &'a Tensor,
99 pub scales: &'a Tensor,
100 pub biases: &'a Tensor,
101 pub bias: Option<&'a Tensor>,
102 pub bits: AfqBits,
103 pub group_size: AfqGroupSize,
104}
105
106impl QuantMethod for AfqLayer {
107 fn new(method: QuantMethodConfig) -> hanzo_ml::Result<Self>
108 where
109 Self: Sized,
110 {
111 match method {
112 QuantMethodConfig::Gguf { .. }
113 | QuantMethodConfig::GptqAwq { .. }
114 | QuantMethodConfig::Hqq { .. }
115 | QuantMethodConfig::Dummy
116 | QuantMethodConfig::FP8 { .. }
117 | QuantMethodConfig::Bnb { .. }
118 | QuantMethodConfig::BlockwiseFP8 { .. }
119 | QuantMethodConfig::PerTensorFP8 { .. }
120 | QuantMethodConfig::Unquantized(_)
121 | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
122 QuantMethodConfig::Afq {
123 weight,
124 bias,
125 bits,
126 group_size,
127 } => {
128 let (w_q, scales, biases) = ops::afq_quantize_op(&weight, group_size, bits)?;
129
130 Ok(Self {
131 w_q,
132 scales,
133 biases,
134 bias,
135 bits,
136 group_size,
137 })
138 }
139 }
140 }
141
142 fn dequantize_w(&self) -> Result<hanzo_ml::Tensor> {
143 ops::afq_dequantize_op(
144 &self.w_q,
145 &self.scales,
146 &self.biases,
147 self.group_size,
148 self.bits,
149 )
150 }
151
152 fn forward_raw(&self, x: &Tensor) -> Result<Tensor> {
153 ops::afq_mm_op(
154 x,
155 &self.w_q,
156 &self.scales,
157 &self.biases,
158 None,
159 None,
160 self.group_size,
161 self.bits,
162 true,
163 )
164 }
165
166 fn gather_forward_raw(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
167 ops::afq_mm_op(
168 x,
169 &self.w_q,
170 &self.scales,
171 &self.biases,
172 None,
173 Some(indices),
174 self.group_size,
175 self.bits,
176 true,
177 )
178 }
179
180 fn quantized_act_type(&self) -> Option<DType> {
181 None
182 }
183
184 fn afq_inner(&self) -> Option<crate::AfqInner<'_>> {
185 Some(crate::AfqInner {
186 w_q: &self.w_q,
187 scales: &self.scales,
188 biases: &self.biases,
189 bias: self.bias.as_ref(),
190 bits: self.bits,
191 group_size: self.group_size,
192 })
193 }
194
195 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
196 let dequant = self.dequantize_w()?;
197 Ok(Arc::new(Self::new(QuantMethodConfig::Afq {
198 weight: (dequant + delta)?,
199 bias: self.bias.clone(),
200 bits: self.bits,
201 group_size: self.group_size,
202 })?))
203 }
204
205 fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
206 (self.scales.dtype(), self.scales.device().clone())
207 }
208
209 fn apply_isq(
210 self: Arc<Self>,
211 dtype: Option<IsqType>,
212 device: Device,
213 _n_quantized: &AtomicUsize,
214 _imatrix_weight: Option<Vec<f32>>,
215 guard: QuantizeOntoGuard,
216 ) -> Result<Arc<dyn QuantMethod>> {
217 match dtype {
218 Some(IsqType::F8Q8) => {
219 let _acquired_quantize_guard = guard.acquire(&device);
220 let w = self.dequantize_w()?.to_device(&device)?;
221 let b = self
222 .bias
223 .as_ref()
224 .map(|b| b.to_device(&device))
225 .transpose()?;
226 Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
227 }
228 _ => todo!(),
229 }
230 }
231}
232
233impl AfqLayer {
234 pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
235 let mut buffer = Cursor::new(data.to_vec());
236
237 let version = buffer.read_u32::<LittleEndian>()?;
238 if let Err(e) = version_is_compatible(version) {
239 return Err(hanzo_ml::Error::wrap(e));
240 }
241
242 let isq_type = buffer.read_u8()? as usize;
243 if isq_type != QuantizedSerdeType::Afq as usize {
244 hanzo_ml::bail!(
245 "ISQ type ({isq_type}) doesn't match expected type {}",
246 QuantizedSerdeType::Afq as usize
247 );
248 }
249
250 let has_bias = buffer.read_u8()? != 0;
251
252 fake_deserialize_tensor(&mut buffer)?;
254 fake_deserialize_tensor(&mut buffer)?;
255 fake_deserialize_tensor(&mut buffer)?;
256
257 let bits: AfqBits = buffer.read_u8()?.try_into()?;
259 let _group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
260
261 if has_bias {
262 fake_deserialize_tensor(&mut buffer)?
263 }
264
265 match bits {
266 AfqBits::Two => Ok(IsqType::AFQ2),
267 AfqBits::Three => Ok(IsqType::AFQ3),
268 AfqBits::Four => Ok(IsqType::AFQ4),
269 AfqBits::Six => Ok(IsqType::AFQ6),
270 AfqBits::Eight => Ok(IsqType::AFQ8),
271 AfqBits::Mxfp4 => hanzo_ml::bail!("mxfp4 is not supported as an ISQ type"),
272 }
273 }
274
275 pub fn afq_linear_b(
276 in_dim: usize,
277 out_dim: usize,
278 config: &QuantizedConfig,
279 bias: bool,
280 vb: ShardedVarBuilder,
281 ) -> Result<Arc<dyn QuantMethod>> {
282 let QuantizedConfig::Afq { bits, group_size } = config else {
283 hanzo_ml::bail!("Unexpected quantization config.")
284 };
285
286 let w_q = vb.get_with_hints_dtype(
287 (out_dim, in_dim * bits / 32),
288 "weight",
289 Default::default(),
290 DType::U32,
291 )?;
292 let scales =
293 vb.get_with_hints((out_dim, in_dim / group_size), "scales", Default::default())?;
294 let biases =
295 vb.get_with_hints((out_dim, in_dim / group_size), "biases", Default::default())?;
296
297 let bias = if bias {
298 Some(vb.get((out_dim,), "bias")?)
299 } else {
300 None
301 };
302
303 Ok(Arc::new(Self {
304 w_q,
305 scales,
306 bias,
307 biases,
308 bits: AfqBits::try_from(*bits)?,
309 group_size: AfqGroupSize::try_from(*group_size)?,
310 }))
311 }
312
313 pub fn afq_packed_linear_b(
314 num_local_experts: usize,
315 in_dim: usize,
316 out_dim: usize,
317 config: &QuantizedConfig,
318 bias: bool,
319 vb: ShardedVarBuilder,
320 ) -> Result<Arc<dyn QuantMethod>> {
321 let QuantizedConfig::Afq { bits, group_size } = config else {
322 hanzo_ml::bail!("Unexpected quantization config.")
323 };
324
325 let w_q = vb.get_with_hints_dtype(
326 (num_local_experts, out_dim, in_dim * bits / 32),
327 "weight",
328 Default::default(),
329 DType::U32,
330 )?;
331 let scales = vb.get_with_hints(
332 (num_local_experts, out_dim, in_dim / group_size),
333 "scales",
334 Default::default(),
335 )?;
336 let biases = vb.get_with_hints(
337 (num_local_experts, out_dim, in_dim / group_size),
338 "biases",
339 Default::default(),
340 )?;
341
342 let bias = if bias {
343 Some(vb.get((num_local_experts, out_dim), "bias")?)
344 } else {
345 None
346 };
347
348 Ok(Arc::new(Self {
349 w_q,
350 scales,
351 bias,
352 biases,
353 bits: AfqBits::try_from(*bits)?,
354 group_size: AfqGroupSize::try_from(*group_size)?,
355 }))
356 }
357}
358
359impl QuantizedSerde for AfqLayer {
360 fn name(&self) -> &'static str {
361 "afq-layer"
362 }
363 fn isq_serde_supported(&self) -> bool {
364 true
365 }
366 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
367 self.serialize_with_bias(self.bias.clone())
368 }
369 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
370 let mut buffer = Vec::new();
371
372 buffer.extend(&UQFF_VERSION.to_le_bytes());
374
375 buffer.push(QuantizedSerdeType::Afq as u8);
377
378 buffer.push(bias.is_some() as u8);
380
381 serialize_tensor(&mut buffer, &self.w_q)?;
383 serialize_tensor(&mut buffer, &self.scales)?;
384 serialize_tensor(&mut buffer, &self.biases)?;
385
386 buffer.push(self.bits as u8);
388 buffer.push(self.group_size as u8);
389
390 if let Some(bias) = &bias {
391 serialize_tensor(&mut buffer, bias)?;
393 }
394
395 Ok(Cow::from(buffer))
396 }
397 fn deserialize(
398 data: Cow<[u8]>,
399 device: &Device,
400 _comm: &Arc<Comm>,
401 guard: QuantizeOntoGuard,
402 ) -> Result<Arc<dyn QuantMethod>>
403 where
404 Self: Sized,
405 {
406 let mut buffer = Cursor::new(data);
407
408 let version = buffer.read_u32::<LittleEndian>()?;
409 if let Err(e) = version_is_compatible(version) {
410 return Err(hanzo_ml::Error::wrap(e));
411 }
412
413 let isq_type = buffer.read_u8()? as usize;
414 if isq_type != QuantizedSerdeType::Afq as usize {
415 hanzo_ml::bail!(
416 "ISQ type ({isq_type}) doesn't match expected type {}",
417 QuantizedSerdeType::Afq as usize
418 );
419 }
420
421 let has_bias = buffer.read_u8()? != 0;
422
423 let _acquired_load_guard = guard.acquire(device);
424 let w_q = deserialize_tensor(&mut buffer, device)?;
426 let scales = deserialize_tensor(&mut buffer, device)?;
427 let biases = deserialize_tensor(&mut buffer, device)?;
428
429 let bits: AfqBits = buffer.read_u8()?.try_into()?;
431 let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
432
433 let b = if has_bias {
434 Some(deserialize_tensor(&mut buffer, device)?)
435 } else {
436 None
437 };
438
439 Ok(Arc::new(Self {
440 w_q,
441 scales,
442 bias: b,
443 biases,
444 bits,
445 group_size,
446 }))
447 }
448 fn deserialize_ext_bias(
449 data: Cow<[u8]>,
450 device: &Device,
451 guard: QuantizeOntoGuard,
452 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
453 where
454 Self: Sized,
455 {
456 let mut buffer = Cursor::new(data);
457
458 let version = buffer.read_u32::<LittleEndian>()?;
459 if let Err(e) = version_is_compatible(version) {
460 return Err(hanzo_ml::Error::wrap(e));
461 }
462
463 let isq_type = buffer.read_u8()? as usize;
464 if isq_type != QuantizedSerdeType::Afq as usize {
465 hanzo_ml::bail!(
466 "ISQ type ({isq_type}) doesn't match expected type {}",
467 QuantizedSerdeType::Afq as usize
468 );
469 }
470
471 let has_bias = buffer.read_u8()? != 0;
472
473 let _acquired_load_guard = guard.acquire(device);
474 let w_q = deserialize_tensor(&mut buffer, device)?;
476 let scales = deserialize_tensor(&mut buffer, device)?;
477 let biases = deserialize_tensor(&mut buffer, device)?;
478
479 let bits: AfqBits = buffer.read_u8()?.try_into()?;
481 let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
482
483 let b = if has_bias {
484 Some(deserialize_tensor(&mut buffer, device)?)
485 } else {
486 None
487 };
488
489 Ok((
490 Arc::new(Self {
491 w_q,
492 scales,
493 bias: None,
494 biases,
495 bits,
496 group_size,
497 }),
498 b,
499 ))
500 }
501}