Skip to main content

hanzo_quant/mxfp4/
mod.rs

1use std::{
2    borrow::Cow,
3    io::Cursor,
4    sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use hanzo_ml::{DType, Device, Result, Tensor};
9
10use crate::{
11    utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
12    IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
13    QuantizedSerdeType, ShardedVarBuilder,
14};
15
16#[cfg(feature = "cuda")]
17pub(crate) mod ffi;
18#[cfg(feature = "metal")]
19pub(crate) mod metal_ops;
20#[cfg(feature = "cuda")]
21pub(crate) mod ops;
22
23/// MXFP4 block size (32 elements per scale)
24pub const MXFP4_BLOCK_SIZE: usize = 32;
25
26pub(crate) const N_BITS: usize = 4;
27
28#[derive(Debug)]
29pub struct MXFP4Layer {
30    /// Packed FP4 weights: [N, K/2] or [num_experts, N, K/2]
31    /// Each byte contains 2 FP4 values (low nibble = k, high nibble = k+1)
32    #[allow(dead_code)]
33    blocks: Tensor,
34    /// E8M0 scales: [N, K/32] or [num_experts, N, K/32]
35    /// Each byte is an 8-bit exponent with bias 127
36    scales: Tensor,
37    /// Optional bias: [N] or [num_experts, N]
38    #[allow(dead_code)]
39    bias: Option<Tensor>,
40}
41
42impl QuantMethod for MXFP4Layer {
43    fn new(method: QuantMethodConfig) -> hanzo_ml::Result<Self>
44    where
45        Self: Sized,
46    {
47        match method {
48            QuantMethodConfig::Gguf { .. }
49            | QuantMethodConfig::GptqAwq { .. }
50            | QuantMethodConfig::Hqq { .. }
51            | QuantMethodConfig::Dummy
52            | QuantMethodConfig::FP8 { .. }
53            | QuantMethodConfig::Bnb { .. }
54            | QuantMethodConfig::BlockwiseFP8 { .. }
55            | QuantMethodConfig::PerTensorFP8 { .. }
56            | QuantMethodConfig::Unquantized(_)
57            | QuantMethodConfig::Afq { .. } => unreachable!(),
58            QuantMethodConfig::MXFP4 {
59                blocks,
60                scales,
61                bias,
62            } => Ok(Self {
63                blocks,
64                scales,
65                bias,
66            }),
67        }
68    }
69
70    fn dequantize_w(&self) -> Result<hanzo_ml::Tensor> {
71        #[cfg(feature = "metal")]
72        if self.blocks.device().is_metal() {
73            use crate::afq::ops;
74            use crate::{AfqBits, AfqGroupSize};
75            return ops::afq_dequantize_op(
76                &self.blocks,
77                &self.scales,
78                &self.scales.clone(),
79                AfqGroupSize::Low,
80                AfqBits::Mxfp4,
81            );
82        }
83        // CPU fallback
84        self.dequantize_weights()
85    }
86
87    #[allow(unused_variables)]
88    fn forward_raw(&self, x: &Tensor) -> Result<Tensor> {
89        #[cfg(feature = "cuda")]
90        if matches!(x.device(), Device::Cuda(_)) && ffi::HAVE_MXFP4_GEMM_KERNELS {
91            let orig_dims = x.dims().to_vec();
92            let x_2d = if orig_dims.len() > 2 {
93                let features = orig_dims[orig_dims.len() - 1];
94                let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
95                x.reshape((batch_size, features))?
96            } else {
97                x.clone()
98            };
99
100            let result = ops::mxfp4_matmul(&x_2d, &self.blocks, &self.scales, self.bias.as_ref())?;
101
102            if orig_dims.len() > 2 {
103                let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
104                new_dims.push(result.dim(1)?);
105                return result.reshape(new_dims);
106            }
107            return Ok(result);
108        }
109
110        #[cfg(feature = "metal")]
111        {
112            if x.device().is_metal() {
113                let orig_dims = x.dims().to_vec();
114                let x_2d = if orig_dims.len() > 2 {
115                    let features = orig_dims[orig_dims.len() - 1];
116                    let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
117                    x.reshape((batch_size, features))?
118                } else {
119                    x.clone()
120                };
121
122                let result =
123                    metal_ops::mxfp4_matmul(&x_2d, &self.blocks, &self.scales, self.bias.as_ref())?;
124
125                if orig_dims.len() > 2 {
126                    let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
127                    new_dims.push(result.dim(1)?);
128                    return result.reshape(new_dims);
129                }
130                return Ok(result);
131            }
132        }
133
134        self.forward_dequantize(x)
135    }
136
137    #[allow(unused_variables)]
138    fn gather_forward_raw(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
139        #[cfg(feature = "cuda")]
140        if matches!(x.device(), Device::Cuda(_)) && ffi::HAVE_MXFP4_GEMM_KERNELS {
141            return ops::mxfp4_indexed_moe_gemm(
142                x,
143                &self.blocks,
144                &self.scales,
145                self.bias.as_ref(),
146                indices,
147            );
148        }
149
150        #[cfg(feature = "metal")]
151        {
152            if x.device().is_metal() {
153                return metal_ops::mxfp4_indexed_moe_gemm(
154                    x,
155                    &self.blocks,
156                    &self.scales,
157                    self.bias.as_ref(),
158                    indices,
159                );
160            }
161        }
162
163        self.gather_forward_dequantize(x, indices)
164    }
165
166    fn quantized_act_type(&self) -> Option<DType> {
167        None
168    }
169
170    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
171        hanzo_ml::bail!("MXFP4Layer does not support add_delta_w")
172    }
173
174    fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
175        (DType::BF16, self.scales.device().clone())
176    }
177
178    fn apply_isq(
179        self: Arc<Self>,
180        _dtype: Option<IsqType>,
181        _device: Device,
182        _n_quantized: &AtomicUsize,
183        _imatrix_weight: Option<Vec<f32>>,
184        _guard: QuantizeOntoGuard,
185    ) -> Result<Arc<dyn QuantMethod>> {
186        hanzo_ml::bail!("MXFP4Layer does not support ISQ")
187    }
188}
189
190impl MXFP4Layer {
191    /// Check if the device supports MXFP4 operations
192    fn device_supported(_device: &Device) -> bool {
193        #[cfg(feature = "cuda")]
194        if matches!(_device, Device::Cuda(_)) {
195            return ffi::HAVE_MXFP4_GEMM_KERNELS;
196        }
197        #[cfg(feature = "metal")]
198        if _device.is_metal() {
199            return true;
200        }
201        false
202    }
203
204    /// Quantize an unquantized weight tensor to MXFP4 format.
205    /// weight shape: `[N, K]`, bias shape: `[N]` (optional)
206    pub fn quantize(
207        weight: &Tensor,
208        bias: Option<Tensor>,
209        device: &Device,
210    ) -> Result<Arc<dyn QuantMethod>> {
211        let weight_f32 = weight.to_dtype(DType::F32)?.to_device(&Device::Cpu)?;
212        let dims = weight_f32.dims2()?;
213        let (n, k) = (dims.0, dims.1);
214
215        if k % MXFP4_BLOCK_SIZE != 0 {
216            hanzo_ml::bail!(
217                "MXFP4 quantization requires K ({k}) divisible by block size ({MXFP4_BLOCK_SIZE})"
218            );
219        }
220
221        let weight_data: Vec<f32> = weight_f32.flatten_all()?.to_vec1()?;
222        let num_blocks_per_row = k / MXFP4_BLOCK_SIZE;
223        let k_half = k / 2;
224
225        // Parallelize quantization across rows with rayon
226        use rayon::prelude::*;
227        let row_results: Vec<(Vec<u8>, Vec<u8>)> = (0..n)
228            .into_par_iter()
229            .map(|row| {
230                let row_offset = row * k;
231                let mut row_packed = vec![0u8; k_half];
232                let mut row_scales = vec![0u8; num_blocks_per_row];
233
234                for (blk, row_scale) in row_scales.iter_mut().enumerate() {
235                    let blk_start = row_offset + blk * MXFP4_BLOCK_SIZE;
236                    let block = &weight_data[blk_start..blk_start + MXFP4_BLOCK_SIZE];
237
238                    let max_abs = block.iter().fold(0.0f32, |m, &v| m.max(v.abs()));
239
240                    let scale = if max_abs == 0.0 {
241                        127u8
242                    } else {
243                        let raw = (max_abs / 6.0).log2().floor() as i32 + 127;
244                        raw.clamp(0, 254) as u8
245                    };
246                    *row_scale = scale;
247
248                    let scale_factor = 2.0f32.powi(scale as i32 - 127);
249                    let inv_scale = if scale_factor == 0.0 {
250                        0.0
251                    } else {
252                        1.0 / scale_factor
253                    };
254
255                    for (elem, &val) in block.iter().enumerate() {
256                        let nibble = Self::quantize_to_fp4(val * inv_scale);
257                        let k_idx = blk * MXFP4_BLOCK_SIZE + elem;
258                        let byte_idx = k_idx / 2;
259                        if k_idx.is_multiple_of(2) {
260                            row_packed[byte_idx] |= nibble;
261                        } else {
262                            row_packed[byte_idx] |= nibble << 4;
263                        }
264                    }
265                }
266                (row_packed, row_scales)
267            })
268            .collect();
269
270        let mut packed = Vec::with_capacity(n * k_half);
271        let mut scales = Vec::with_capacity(n * num_blocks_per_row);
272        for (row_packed, row_scales) in row_results {
273            packed.extend_from_slice(&row_packed);
274            scales.extend_from_slice(&row_scales);
275        }
276
277        let blocks = Tensor::from_vec(packed, (n, k / 2), &Device::Cpu)?
278            .to_dtype(DType::U8)?
279            .to_device(device)?;
280        let scales = Tensor::from_vec(scales, (n, num_blocks_per_row), &Device::Cpu)?
281            .to_dtype(DType::U8)?
282            .to_device(device)?;
283        let bias = bias.map(|b| b.to_device(device)).transpose()?;
284
285        Ok(Arc::new(Self {
286            blocks,
287            scales,
288            bias,
289        }))
290    }
291
292    /// Quantize a single scaled value to the nearest FP4 E2M1 nibble (0..15).
293    fn quantize_to_fp4(val: f32) -> u8 {
294        // FP4 E2M1 positive values: 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0
295        // Negative values are the same with sign bit set (indices 8..15)
296        let sign = val < 0.0;
297        let abs_val = val.abs();
298
299        // Decision boundaries (midpoints between consecutive FP4 values)
300        let nibble = if abs_val < 0.25 {
301            0 // 0.0
302        } else if abs_val < 0.75 {
303            1 // 0.5
304        } else if abs_val < 1.25 {
305            2 // 1.0
306        } else if abs_val < 1.75 {
307            3 // 1.5
308        } else if abs_val < 2.5 {
309            4 // 2.0
310        } else if abs_val < 3.5 {
311            5 // 3.0
312        } else if abs_val < 5.0 {
313            6 // 4.0
314        } else {
315            7 // 6.0
316        };
317
318        if sign {
319            nibble | 0x08
320        } else {
321            nibble
322        }
323    }
324
325    pub fn linear_b(
326        in_dim: usize,
327        out_dim: usize,
328        config: &QuantizedConfig,
329        bias: bool,
330        vb: ShardedVarBuilder,
331    ) -> Result<Arc<dyn QuantMethod>> {
332        if !Self::device_supported(vb.device()) {
333            hanzo_ml::bail!("MXFP4Layer requires CUDA or Metal device.");
334        }
335
336        let QuantizedConfig::MXFP4 {} = config else {
337            hanzo_ml::bail!("Unexpected quantization config.")
338        };
339
340        let blocks = vb.get_with_hints_dtype(
341            (out_dim, in_dim / 2),
342            "blocks",
343            Default::default(),
344            DType::U8,
345        )?;
346        let scales = vb.get_with_hints_dtype(
347            (out_dim, in_dim / MXFP4_BLOCK_SIZE),
348            "scales",
349            Default::default(),
350            DType::U8,
351        )?;
352
353        let bias = if bias {
354            Some(vb.get((out_dim,), "bias")?)
355        } else {
356            None
357        };
358
359        Ok(Arc::new(Self {
360            blocks,
361            scales,
362            bias,
363        }))
364    }
365
366    pub fn packed_linear_b(
367        num_local_experts: usize,
368        in_dim: usize,
369        out_dim: usize,
370        config: &QuantizedConfig,
371        bias: bool,
372        vb: ShardedVarBuilder,
373    ) -> Result<Arc<dyn QuantMethod>> {
374        if !Self::device_supported(vb.device()) {
375            hanzo_ml::bail!("MXFP4Layer requires CUDA or Metal device.");
376        }
377
378        let QuantizedConfig::MXFP4 {} = config else {
379            hanzo_ml::bail!("Unexpected quantization config.")
380        };
381
382        let blocks = vb.get_with_hints_dtype(
383            (num_local_experts, out_dim, in_dim / 2),
384            "blocks",
385            Default::default(),
386            DType::U8,
387        )?;
388        let scales = vb.get_with_hints_dtype(
389            (num_local_experts, out_dim, in_dim / MXFP4_BLOCK_SIZE),
390            "scales",
391            Default::default(),
392            DType::U8,
393        )?;
394
395        let bias = if bias {
396            Some(vb.get((num_local_experts, out_dim), "bias")?)
397        } else {
398            None
399        };
400
401        Ok(Arc::new(Self {
402            blocks,
403            scales,
404            bias,
405        }))
406    }
407
408    /// Load GPT-OSS style MXFP4 experts (combined gate_up_proj format).
409    ///
410    /// GPT-OSS stores tensors as:
411    /// - `{name}_blocks`: [num_experts, out_dim, num_blocks, 16] where 16 bytes = 32 FP4 values
412    /// - `{name}_scales`: [num_experts, out_dim, num_blocks]
413    /// - `{name}_bias`: [num_experts, out_dim]
414    ///
415    /// This function loads and reshapes the 4D blocks tensor to 3D [num_experts, out_dim, in_dim/2].
416    pub fn packed_gptoss_linear(
417        num_local_experts: usize,
418        in_dim: usize,
419        out_dim: usize,
420        bias: bool,
421        name: &str,
422        vb: ShardedVarBuilder,
423    ) -> Result<Arc<dyn QuantMethod>> {
424        if !Self::device_supported(vb.device()) {
425            hanzo_ml::bail!("MXFP4Layer requires CUDA or Metal device.");
426        }
427
428        let num_blocks = in_dim / MXFP4_BLOCK_SIZE;
429
430        let blocks_4d = vb.get_with_hints_dtype(
431            (num_local_experts, out_dim, num_blocks, 16),
432            &format!("{name}_blocks"),
433            Default::default(),
434            DType::U8,
435        )?;
436
437        let blocks = blocks_4d.reshape((num_local_experts, out_dim, num_blocks * 16))?;
438
439        let scales = vb.get_with_hints_dtype(
440            (num_local_experts, out_dim, num_blocks),
441            &format!("{name}_scales"),
442            Default::default(),
443            DType::U8,
444        )?;
445
446        let bias = if bias {
447            Some(vb.get((num_local_experts, out_dim), &format!("{name}_bias"))?)
448        } else {
449            None
450        };
451
452        Ok(Arc::new(Self {
453            blocks,
454            scales,
455            bias,
456        }))
457    }
458
459    /// Combined FP4 × E8M0 dequant table: `DEQUANT_LUT[scale][nibble]`.
460    /// For each of the 256 possible E8M0 scale values, stores the 16 possible
461    /// dequantized values (FP4_LUT[nibble] * 2^(scale - 127)).
462    /// This turns dequantization into a single table lookup per element.
463    const DEQUANT_LUT: [[f32; 16]; 256] = {
464        let mut lut = [[0.0f32; 16]; 256];
465        let fp4: [f32; 16] = [
466            0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
467        ];
468        let mut s = 0u32;
469        while s < 256 {
470            let scale_factor = f32::from_bits(s << 23);
471            let mut n = 0;
472            while n < 16 {
473                lut[s as usize][n] = fp4[n] * scale_factor;
474                n += 1;
475            }
476            s += 1;
477        }
478        lut
479    };
480
481    /// Dequantize MXFP4 weights to f32
482    /// blocks: [num_experts, N, K/2] packed bytes
483    /// scales: [num_experts, N, K/32] E8M0 scales
484    /// Returns: [num_experts, N, K] f32 weights
485    fn dequantize_weights(&self) -> Result<Tensor> {
486        let blocks_dims = self.blocks.dims();
487
488        let (num_experts, n, k_half) = if blocks_dims.len() == 3 {
489            (blocks_dims[0], blocks_dims[1], blocks_dims[2])
490        } else {
491            (1, blocks_dims[0], blocks_dims[1])
492        };
493        let k = k_half * 2;
494        let num_blocks_per_row = k / MXFP4_BLOCK_SIZE;
495
496        let blocks_cpu = self.blocks.to_device(&Device::Cpu)?;
497        let scales_cpu = self.scales.to_device(&Device::Cpu)?;
498
499        let blocks_data: Vec<u8> = blocks_cpu.flatten_all()?.to_vec1()?;
500        let scales_data: Vec<u8> = scales_cpu.flatten_all()?.to_vec1()?;
501
502        let mut weights = vec![0f32; num_experts * n * k];
503        let half_block = MXFP4_BLOCK_SIZE / 2; // 16 packed bytes per block
504
505        for expert in 0..num_experts {
506            for row in 0..n {
507                let blocks_row = expert * n * k_half + row * k_half;
508                let scales_row = expert * n * num_blocks_per_row + row * num_blocks_per_row;
509                let weights_row = expert * n * k + row * k;
510
511                for blk in 0..num_blocks_per_row {
512                    let scale = scales_data[scales_row + blk] as usize;
513                    let dequant = &Self::DEQUANT_LUT[scale];
514                    let blk_bytes = &blocks_data[blocks_row + blk * half_block..];
515                    let w_out = &mut weights[weights_row + blk * MXFP4_BLOCK_SIZE..];
516
517                    for byte_i in 0..half_block {
518                        let packed = blk_bytes[byte_i];
519                        w_out[byte_i * 2] = dequant[(packed & 0x0F) as usize];
520                        w_out[byte_i * 2 + 1] = dequant[((packed >> 4) & 0x0F) as usize];
521                    }
522                }
523            }
524        }
525
526        let shape = if blocks_dims.len() == 3 {
527            vec![num_experts, n, k]
528        } else {
529            vec![n, k]
530        };
531
532        Tensor::from_vec(weights, shape.as_slice(), &Device::Cpu)?
533            .to_device(self.blocks.device())?
534            .to_dtype(DType::BF16)
535    }
536
537    /// CPU forward pass: blocked dequant + matmul to avoid full weight allocation.
538    /// Processes MXFP4_BLOCK_SIZE (32) input columns at a time, dequantizing only
539    /// the needed weight slice before accumulating partial results.
540    fn forward_dequantize(&self, x: &Tensor) -> Result<Tensor> {
541        let orig_dims = x.dims().to_vec();
542
543        let x_2d = if orig_dims.len() > 2 {
544            let features = orig_dims[orig_dims.len() - 1];
545            let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
546            x.reshape((batch_size, features))?
547        } else {
548            x.clone()
549        };
550
551        let x_f32 = x_2d.to_dtype(DType::F32)?.to_device(&Device::Cpu)?;
552        let (m, k) = x_f32.dims2()?;
553
554        let blocks_dims = self.blocks.dims();
555        let n = if blocks_dims.len() == 3 {
556            blocks_dims[1]
557        } else {
558            blocks_dims[0]
559        };
560        let num_blocks_per_row = k / MXFP4_BLOCK_SIZE;
561        let half_block = MXFP4_BLOCK_SIZE / 2;
562
563        let blocks_cpu = self.blocks.to_device(&Device::Cpu)?;
564        let scales_cpu = self.scales.to_device(&Device::Cpu)?;
565        let blocks_data: Vec<u8> = blocks_cpu.flatten_all()?.to_vec1()?;
566        let scales_data: Vec<u8> = scales_cpu.flatten_all()?.to_vec1()?;
567        let x_data: Vec<f32> = x_f32.flatten_all()?.to_vec1()?;
568
569        // output: [m, n], accumulate x @ W^T in blocks of 32 columns
570        let mut output = vec![0f32; m * n];
571        let k_half = k / 2;
572
573        for blk in 0..num_blocks_per_row {
574            let col_start = blk * MXFP4_BLOCK_SIZE;
575
576            for row in 0..n {
577                let scale = scales_data[row * num_blocks_per_row + blk] as usize;
578                let dequant = &Self::DEQUANT_LUT[scale];
579                let blk_bytes = &blocks_data[row * k_half + blk * half_block..];
580
581                // Dequantize this block of 32 weights for this output row
582                let mut w_block = [0f32; MXFP4_BLOCK_SIZE];
583                for byte_i in 0..half_block {
584                    let packed = blk_bytes[byte_i];
585                    w_block[byte_i * 2] = dequant[(packed & 0x0F) as usize];
586                    w_block[byte_i * 2 + 1] = dequant[((packed >> 4) & 0x0F) as usize];
587                }
588
589                // Accumulate dot product for all tokens against this weight block
590                for token in 0..m {
591                    let x_row = &x_data[token * k + col_start..];
592                    let mut acc = 0f32;
593                    for i in 0..MXFP4_BLOCK_SIZE {
594                        acc += x_row[i] * w_block[i];
595                    }
596                    output[token * n + row] += acc;
597                }
598            }
599        }
600
601        let mut result = Tensor::from_vec(output, (m, n), &Device::Cpu)?
602            .to_device(x.device())?
603            .to_dtype(x.dtype())?;
604
605        if let Some(bias) = &self.bias {
606            result = result.broadcast_add(bias)?;
607        }
608
609        if orig_dims.len() > 2 {
610            let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
611            new_dims.push(result.dim(1)?);
612            result = result.reshape(new_dims)?;
613        }
614
615        Ok(result)
616    }
617
618    /// CPU MoE forward: blocked dequant per (token, expert) pair.
619    /// Avoids dequantizing all experts, only touches the needed weight blocks.
620    fn gather_forward_dequantize(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
621        let x_dims = x.dims();
622        let indices_dims = indices.dims();
623
624        let (num_tokens, topk, k, x_has_topk) = if x_dims.len() == 2 {
625            (x_dims[0], indices_dims[1], x_dims[1], false)
626        } else {
627            (x_dims[0], x_dims[1], x_dims[2], true)
628        };
629
630        let blocks_dims = self.blocks.dims();
631        let n = blocks_dims[1];
632        let k_half = k / 2;
633        let num_blocks_per_row = k / MXFP4_BLOCK_SIZE;
634        let half_block = MXFP4_BLOCK_SIZE / 2;
635
636        let blocks_cpu = self.blocks.to_device(&Device::Cpu)?;
637        let scales_cpu = self.scales.to_device(&Device::Cpu)?;
638        let blocks_data: Vec<u8> = blocks_cpu.flatten_all()?.to_vec1()?;
639        let scales_data: Vec<u8> = scales_cpu.flatten_all()?.to_vec1()?;
640
641        let x_f32 = x.to_dtype(DType::F32)?.to_device(&Device::Cpu)?;
642        let x_data: Vec<f32> = x_f32.flatten_all()?.to_vec1()?;
643
644        let indices_cpu = indices.to_device(&Device::Cpu)?.to_dtype(DType::U32)?;
645        let indices_data: Vec<u32> = indices_cpu.flatten_all()?.to_vec1()?;
646
647        let bias_data: Option<Vec<f32>> = self
648            .bias
649            .as_ref()
650            .map(|b| {
651                b.to_dtype(DType::F32)?
652                    .to_device(&Device::Cpu)?
653                    .flatten_all()?
654                    .to_vec1()
655            })
656            .transpose()?;
657
658        // output: [num_tokens * topk, n]
659        let mut output = vec![0f32; num_tokens * topk * n];
660
661        for token_idx in 0..num_tokens {
662            for slot_idx in 0..topk {
663                let expert_idx = indices_data[token_idx * topk + slot_idx] as usize;
664                let out_row = token_idx * topk + slot_idx;
665
666                // Get input row
667                let x_offset = if x_has_topk {
668                    (token_idx * topk + slot_idx) * k
669                } else {
670                    token_idx * k
671                };
672
673                // Blocked dequant + matmul for this (token, expert) pair
674                let expert_blocks_base = expert_idx * n * k_half;
675                let expert_scales_base = expert_idx * n * num_blocks_per_row;
676
677                for blk in 0..num_blocks_per_row {
678                    let col_start = blk * MXFP4_BLOCK_SIZE;
679
680                    // Load input block
681                    let x_blk =
682                        &x_data[x_offset + col_start..x_offset + col_start + MXFP4_BLOCK_SIZE];
683
684                    for row in 0..n {
685                        let scale = scales_data[expert_scales_base + row * num_blocks_per_row + blk]
686                            as usize;
687                        let dequant = &Self::DEQUANT_LUT[scale];
688                        let blk_bytes =
689                            &blocks_data[expert_blocks_base + row * k_half + blk * half_block..];
690
691                        let mut dot = 0f32;
692                        for byte_i in 0..half_block {
693                            let packed = blk_bytes[byte_i];
694                            let w0 = dequant[(packed & 0x0F) as usize];
695                            let w1 = dequant[((packed >> 4) & 0x0F) as usize];
696                            dot += x_blk[byte_i * 2] * w0 + x_blk[byte_i * 2 + 1] * w1;
697                        }
698                        output[out_row * n + row] += dot;
699                    }
700                }
701
702                // Add bias
703                if let Some(ref bias) = bias_data {
704                    let bias_offset = expert_idx * n;
705                    for row in 0..n {
706                        output[out_row * n + row] += bias[bias_offset + row];
707                    }
708                }
709            }
710        }
711
712        let result = Tensor::from_vec(output, (num_tokens * topk, n), &Device::Cpu)?
713            .to_device(x.device())?
714            .to_dtype(x.dtype())?;
715        result.reshape((num_tokens, topk, n))
716    }
717}
718
719// UQFF binary layout for MXFP4Layer:
720// -----------------------
721// [u32 LE] UQFF version
722// [u8]     QuantizedSerdeType::Mxfp4 (6)
723// [u8]     has_bias (0 or 1)
724// -----------------------
725// Blocks tensor data via serialize_tensor
726// -----------------------
727// Scales tensor data via serialize_tensor
728// -----------------------
729// [OPTIONAL] Bias tensor data via serialize_tensor
730// -----------------------
731
732impl QuantizedSerde for MXFP4Layer {
733    fn name(&self) -> &'static str {
734        "mxfp4-layer"
735    }
736    fn isq_serde_supported(&self) -> bool {
737        true
738    }
739    fn serialize(&self) -> Result<Cow<'_, [u8]>> {
740        self.serialize_with_bias(self.bias.clone())
741    }
742    fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
743        let mut buffer = Vec::new();
744
745        buffer.extend(&UQFF_VERSION.to_le_bytes());
746        buffer.push(QuantizedSerdeType::Mxfp4 as u8);
747        buffer.push(bias.is_some() as u8);
748
749        serialize_tensor(&mut buffer, &self.blocks)?;
750        serialize_tensor(&mut buffer, &self.scales)?;
751
752        if let Some(bias) = &bias {
753            serialize_tensor(&mut buffer, bias)?;
754        }
755
756        Ok(Cow::from(buffer))
757    }
758
759    fn deserialize(
760        data: Cow<[u8]>,
761        device: &Device,
762        _comm: &Arc<crate::Comm>,
763        guard: QuantizeOntoGuard,
764    ) -> Result<Arc<dyn QuantMethod>>
765    where
766        Self: Sized,
767    {
768        let (layer, _bias) = Self::deserialize_ext_bias(data, device, guard)?;
769        Ok(layer)
770    }
771
772    fn deserialize_ext_bias(
773        data: Cow<[u8]>,
774        device: &Device,
775        guard: QuantizeOntoGuard,
776    ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
777    where
778        Self: Sized,
779    {
780        let mut buffer = Cursor::new(data.to_vec());
781
782        let version = buffer.read_u32::<LittleEndian>()?;
783        if let Err(e) = version_is_compatible(version) {
784            return Err(hanzo_ml::Error::wrap(e));
785        }
786
787        let isq_type = buffer.read_u8()? as usize;
788        if isq_type != QuantizedSerdeType::Mxfp4 as usize {
789            hanzo_ml::bail!(
790                "ISQ type ({isq_type}) doesn't match expected type {}",
791                QuantizedSerdeType::Mxfp4 as usize
792            );
793        }
794
795        let has_bias = buffer.read_u8()? != 0;
796
797        let _acquired_load_guard = guard.acquire(device);
798        let blocks = deserialize_tensor(&mut buffer, device)?;
799        let scales = deserialize_tensor(&mut buffer, device)?;
800
801        let bias = if has_bias {
802            Some(deserialize_tensor(&mut buffer, device)?)
803        } else {
804            None
805        };
806
807        let ext_bias = bias.clone();
808
809        Ok((
810            Arc::new(Self {
811                blocks,
812                scales,
813                bias,
814            }),
815            ext_bias,
816        ))
817    }
818}