Skip to main content

hanzo_quant/afq/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use hanzo_ml::{DType, Device, Result, Tensor};
9
10use crate::{
11    utils::{
12        deserialize_tensor, fake_deserialize_tensor, serialize_tensor, version_is_compatible,
13        UQFF_VERSION,
14    },
15    Comm, IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig,
16    QuantizedSerde, QuantizedSerdeType, ShardedVarBuilder,
17};
18
19pub(crate) mod ops;
20
21#[cfg(feature = "cuda")]
22pub(crate) mod ffi;
23
24#[repr(u8)]
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum AfqBits {
27    Two = 2,
28    Three = 3,
29    Four = 4,
30    Six = 6,
31    Eight = 8,
32    Mxfp4 = 40,
33}
34
35impl TryFrom<usize> for AfqBits {
36    type Error = hanzo_ml::Error;
37    fn try_from(value: usize) -> Result<Self> {
38        match value {
39            2 => Ok(Self::Two),
40            3 => Ok(Self::Three),
41            4 => Ok(Self::Four),
42            6 => Ok(Self::Six),
43            8 => Ok(Self::Eight),
44            40 => Ok(Self::Mxfp4),
45            x => hanzo_ml::bail!("Invalid AFQ bits {x}."),
46        }
47    }
48}
49
50impl TryFrom<u8> for AfqBits {
51    type Error = hanzo_ml::Error;
52    fn try_from(value: u8) -> Result<Self> {
53        Self::try_from(value as usize)
54    }
55}
56
57#[repr(u8)]
58#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
59pub enum AfqGroupSize {
60    Low = 32,
61    #[default]
62    Med = 64,
63    High = 128,
64}
65
66impl TryFrom<usize> for AfqGroupSize {
67    type Error = hanzo_ml::Error;
68    fn try_from(value: usize) -> Result<Self> {
69        match value {
70            32 => Ok(Self::Low),
71            64 => Ok(Self::Med),
72            128 => Ok(Self::High),
73            x => hanzo_ml::bail!("Invalid AFQ group size {x}."),
74        }
75    }
76}
77
78impl TryFrom<u8> for AfqGroupSize {
79    type Error = hanzo_ml::Error;
80    fn try_from(value: u8) -> Result<Self> {
81        Self::try_from(value as usize)
82    }
83}
84
85#[derive(Debug)]
86pub struct AfqLayer {
87    w_q: Tensor,
88    scales: Tensor,
89    biases: Tensor,
90    bias: Option<Tensor>,
91    bits: AfqBits,
92    group_size: AfqGroupSize,
93}
94
95/// View of an AfqLayer's storage tensors, used by fused QKV/gate-up paths.
96#[derive(Clone)]
97pub struct AfqInner<'a> {
98    pub w_q: &'a Tensor,
99    pub scales: &'a Tensor,
100    pub biases: &'a Tensor,
101    pub bias: Option<&'a Tensor>,
102    pub bits: AfqBits,
103    pub group_size: AfqGroupSize,
104}
105
106impl QuantMethod for AfqLayer {
107    fn new(method: QuantMethodConfig) -> hanzo_ml::Result<Self>
108    where
109        Self: Sized,
110    {
111        match method {
112            QuantMethodConfig::Gguf { .. }
113            | QuantMethodConfig::GptqAwq { .. }
114            | QuantMethodConfig::Hqq { .. }
115            | QuantMethodConfig::Dummy
116            | QuantMethodConfig::FP8 { .. }
117            | QuantMethodConfig::Bnb { .. }
118            | QuantMethodConfig::BlockwiseFP8 { .. }
119            | QuantMethodConfig::PerTensorFP8 { .. }
120            | QuantMethodConfig::Unquantized(_)
121            | QuantMethodConfig::MXFP4 { .. } => unreachable!(),
122            QuantMethodConfig::Afq {
123                weight,
124                bias,
125                bits,
126                group_size,
127            } => {
128                let (w_q, scales, biases) = ops::afq_quantize_op(&weight, group_size, bits)?;
129
130                Ok(Self {
131                    w_q,
132                    scales,
133                    biases,
134                    bias,
135                    bits,
136                    group_size,
137                })
138            }
139        }
140    }
141
142    fn dequantize_w(&self) -> Result<hanzo_ml::Tensor> {
143        ops::afq_dequantize_op(
144            &self.w_q,
145            &self.scales,
146            &self.biases,
147            self.group_size,
148            self.bits,
149        )
150    }
151
152    fn forward_raw(&self, x: &Tensor) -> Result<Tensor> {
153        ops::afq_mm_op(
154            x,
155            &self.w_q,
156            &self.scales,
157            &self.biases,
158            None,
159            None,
160            self.group_size,
161            self.bits,
162            true,
163        )
164    }
165
166    fn gather_forward_raw(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
167        ops::afq_mm_op(
168            x,
169            &self.w_q,
170            &self.scales,
171            &self.biases,
172            None,
173            Some(indices),
174            self.group_size,
175            self.bits,
176            true,
177        )
178    }
179
180    fn quantized_act_type(&self) -> Option<DType> {
181        None
182    }
183
184    fn afq_inner(&self) -> Option<crate::AfqInner<'_>> {
185        Some(crate::AfqInner {
186            w_q: &self.w_q,
187            scales: &self.scales,
188            biases: &self.biases,
189            bias: self.bias.as_ref(),
190            bits: self.bits,
191            group_size: self.group_size,
192        })
193    }
194
195    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
196        let dequant = self.dequantize_w()?;
197        Ok(Arc::new(Self::new(QuantMethodConfig::Afq {
198            weight: (dequant + delta)?,
199            bias: self.bias.clone(),
200            bits: self.bits,
201            group_size: self.group_size,
202        })?))
203    }
204
205    fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
206        (self.scales.dtype(), self.scales.device().clone())
207    }
208
209    fn apply_isq(
210        self: Arc<Self>,
211        dtype: Option<IsqType>,
212        device: Device,
213        _n_quantized: &AtomicUsize,
214        _imatrix_weight: Option<Vec<f32>>,
215        guard: QuantizeOntoGuard,
216    ) -> Result<Arc<dyn QuantMethod>> {
217        match dtype {
218            Some(IsqType::F8Q8) => {
219                let _acquired_quantize_guard = guard.acquire(&device);
220                let w = self.dequantize_w()?.to_device(&device)?;
221                let b = self
222                    .bias
223                    .as_ref()
224                    .map(|b| b.to_device(&device))
225                    .transpose()?;
226                Ok(Arc::new(crate::F8Q8Linear::from_weight(&w, b)?))
227            }
228            _ => todo!(),
229        }
230    }
231}
232
233impl AfqLayer {
234    pub fn get_isq_type_from_uqff(data: Cow<[u8]>) -> Result<IsqType> {
235        let mut buffer = Cursor::new(data.to_vec());
236
237        let version = buffer.read_u32::<LittleEndian>()?;
238        if let Err(e) = version_is_compatible(version) {
239            return Err(hanzo_ml::Error::wrap(e));
240        }
241
242        let isq_type = buffer.read_u8()? as usize;
243        if isq_type != QuantizedSerdeType::Afq as usize {
244            hanzo_ml::bail!(
245                "ISQ type ({isq_type}) doesn't match expected type {}",
246                QuantizedSerdeType::Afq as usize
247            );
248        }
249
250        let has_bias = buffer.read_u8()? != 0;
251
252        // Weight, scales, biases
253        fake_deserialize_tensor(&mut buffer)?;
254        fake_deserialize_tensor(&mut buffer)?;
255        fake_deserialize_tensor(&mut buffer)?;
256
257        // Bits and group size
258        let bits: AfqBits = buffer.read_u8()?.try_into()?;
259        let _group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
260
261        if has_bias {
262            fake_deserialize_tensor(&mut buffer)?
263        }
264
265        match bits {
266            AfqBits::Two => Ok(IsqType::AFQ2),
267            AfqBits::Three => Ok(IsqType::AFQ3),
268            AfqBits::Four => Ok(IsqType::AFQ4),
269            AfqBits::Six => Ok(IsqType::AFQ6),
270            AfqBits::Eight => Ok(IsqType::AFQ8),
271            AfqBits::Mxfp4 => hanzo_ml::bail!("mxfp4 is not supported as an ISQ type"),
272        }
273    }
274
275    pub fn afq_linear_b(
276        in_dim: usize,
277        out_dim: usize,
278        config: &QuantizedConfig,
279        bias: bool,
280        vb: ShardedVarBuilder,
281    ) -> Result<Arc<dyn QuantMethod>> {
282        let QuantizedConfig::Afq { bits, group_size } = config else {
283            hanzo_ml::bail!("Unexpected quantization config.")
284        };
285
286        let w_q = vb.get_with_hints_dtype(
287            (out_dim, in_dim * bits / 32),
288            "weight",
289            Default::default(),
290            DType::U32,
291        )?;
292        let scales =
293            vb.get_with_hints((out_dim, in_dim / group_size), "scales", Default::default())?;
294        let biases =
295            vb.get_with_hints((out_dim, in_dim / group_size), "biases", Default::default())?;
296
297        let bias = if bias {
298            Some(vb.get((out_dim,), "bias")?)
299        } else {
300            None
301        };
302
303        Ok(Arc::new(Self {
304            w_q,
305            scales,
306            bias,
307            biases,
308            bits: AfqBits::try_from(*bits)?,
309            group_size: AfqGroupSize::try_from(*group_size)?,
310        }))
311    }
312
313    pub fn afq_packed_linear_b(
314        num_local_experts: usize,
315        in_dim: usize,
316        out_dim: usize,
317        config: &QuantizedConfig,
318        bias: bool,
319        vb: ShardedVarBuilder,
320    ) -> Result<Arc<dyn QuantMethod>> {
321        let QuantizedConfig::Afq { bits, group_size } = config else {
322            hanzo_ml::bail!("Unexpected quantization config.")
323        };
324
325        let w_q = vb.get_with_hints_dtype(
326            (num_local_experts, out_dim, in_dim * bits / 32),
327            "weight",
328            Default::default(),
329            DType::U32,
330        )?;
331        let scales = vb.get_with_hints(
332            (num_local_experts, out_dim, in_dim / group_size),
333            "scales",
334            Default::default(),
335        )?;
336        let biases = vb.get_with_hints(
337            (num_local_experts, out_dim, in_dim / group_size),
338            "biases",
339            Default::default(),
340        )?;
341
342        let bias = if bias {
343            Some(vb.get((num_local_experts, out_dim), "bias")?)
344        } else {
345            None
346        };
347
348        Ok(Arc::new(Self {
349            w_q,
350            scales,
351            bias,
352            biases,
353            bits: AfqBits::try_from(*bits)?,
354            group_size: AfqGroupSize::try_from(*group_size)?,
355        }))
356    }
357}
358
359impl QuantizedSerde for AfqLayer {
360    fn name(&self) -> &'static str {
361        "afq-layer"
362    }
363    fn isq_serde_supported(&self) -> bool {
364        true
365    }
366    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
367        self.serialize_with_bias(self.bias.clone())
368    }
369    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
370        let mut buffer = Vec::new();
371
372        // Version is always first!
373        buffer.extend(&UQFF_VERSION.to_le_bytes());
374
375        // ISQ type for afq is 4
376        buffer.push(QuantizedSerdeType::Afq as u8);
377
378        // Has bias
379        buffer.push(bias.is_some() as u8);
380
381        // Weight, scales, biases
382        serialize_tensor(&mut buffer, &self.w_q)?;
383        serialize_tensor(&mut buffer, &self.scales)?;
384        serialize_tensor(&mut buffer, &self.biases)?;
385
386        // Bits and group size
387        buffer.push(self.bits as u8);
388        buffer.push(self.group_size as u8);
389
390        if let Some(bias) = &bias {
391            // Bias
392            serialize_tensor(&mut buffer, bias)?;
393        }
394
395        Ok(Cow::from(buffer))
396    }
397    fn deserialize(
398        data: Cow<[u8]>,
399        device: &Device,
400        _comm: &Arc<Comm>,
401        guard: QuantizeOntoGuard,
402    ) -> Result<Arc<dyn QuantMethod>>
403    where
404        Self: Sized,
405    {
406        let mut buffer = Cursor::new(data);
407
408        let version = buffer.read_u32::<LittleEndian>()?;
409        if let Err(e) = version_is_compatible(version) {
410            return Err(hanzo_ml::Error::wrap(e));
411        }
412
413        let isq_type = buffer.read_u8()? as usize;
414        if isq_type != QuantizedSerdeType::Afq as usize {
415            hanzo_ml::bail!(
416                "ISQ type ({isq_type}) doesn't match expected type {}",
417                QuantizedSerdeType::Afq as usize
418            );
419        }
420
421        let has_bias = buffer.read_u8()? != 0;
422
423        let _acquired_load_guard = guard.acquire(device);
424        // Weight, scales, biases
425        let w_q = deserialize_tensor(&mut buffer, device)?;
426        let scales = deserialize_tensor(&mut buffer, device)?;
427        let biases = deserialize_tensor(&mut buffer, device)?;
428
429        // Bits and group size
430        let bits: AfqBits = buffer.read_u8()?.try_into()?;
431        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
432
433        let b = if has_bias {
434            Some(deserialize_tensor(&mut buffer, device)?)
435        } else {
436            None
437        };
438
439        Ok(Arc::new(Self {
440            w_q,
441            scales,
442            bias: b,
443            biases,
444            bits,
445            group_size,
446        }))
447    }
448    fn deserialize_ext_bias(
449        data: Cow<[u8]>,
450        device: &Device,
451        guard: QuantizeOntoGuard,
452    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
453    where
454        Self: Sized,
455    {
456        let mut buffer = Cursor::new(data);
457
458        let version = buffer.read_u32::<LittleEndian>()?;
459        if let Err(e) = version_is_compatible(version) {
460            return Err(hanzo_ml::Error::wrap(e));
461        }
462
463        let isq_type = buffer.read_u8()? as usize;
464        if isq_type != QuantizedSerdeType::Afq as usize {
465            hanzo_ml::bail!(
466                "ISQ type ({isq_type}) doesn't match expected type {}",
467                QuantizedSerdeType::Afq as usize
468            );
469        }
470
471        let has_bias = buffer.read_u8()? != 0;
472
473        let _acquired_load_guard = guard.acquire(device);
474        // Weight, scales, biases
475        let w_q = deserialize_tensor(&mut buffer, device)?;
476        let scales = deserialize_tensor(&mut buffer, device)?;
477        let biases = deserialize_tensor(&mut buffer, device)?;
478
479        // Bits and group size
480        let bits: AfqBits = buffer.read_u8()?.try_into()?;
481        let group_size: AfqGroupSize = buffer.read_u8()?.try_into()?;
482
483        let b = if has_bias {
484            Some(deserialize_tensor(&mut buffer, device)?)
485        } else {
486            None
487        };
488
489        Ok((
490            Arc::new(Self {
491                w_q,
492                scales,
493                bias: None,
494                biases,
495                bits,
496                group_size,
497            }),
498            b,
499        ))
500    }
501}