1use std::{
2 borrow::Cow,
3 io::Cursor,
4 sync::{atomic::AtomicUsize, Arc},
5};
6
7use byteorder::{LittleEndian, ReadBytesExt};
8use hanzo_ml::{DType, Device, Result, Tensor};
9
10use crate::{
11 utils::{deserialize_tensor, serialize_tensor, version_is_compatible, UQFF_VERSION},
12 IsqType, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
13 QuantizedSerdeType, ShardedVarBuilder,
14};
15
16#[cfg(feature = "cuda")]
17pub(crate) mod ffi;
18#[cfg(feature = "metal")]
19pub(crate) mod metal_ops;
20#[cfg(feature = "cuda")]
21pub(crate) mod ops;
22
23pub const MXFP4_BLOCK_SIZE: usize = 32;
25
26pub(crate) const N_BITS: usize = 4;
27
28#[derive(Debug)]
29pub struct MXFP4Layer {
30 #[allow(dead_code)]
33 blocks: Tensor,
34 scales: Tensor,
37 #[allow(dead_code)]
39 bias: Option<Tensor>,
40}
41
42impl QuantMethod for MXFP4Layer {
43 fn new(method: QuantMethodConfig) -> hanzo_ml::Result<Self>
44 where
45 Self: Sized,
46 {
47 match method {
48 QuantMethodConfig::Gguf { .. }
49 | QuantMethodConfig::GptqAwq { .. }
50 | QuantMethodConfig::Hqq { .. }
51 | QuantMethodConfig::Dummy
52 | QuantMethodConfig::FP8 { .. }
53 | QuantMethodConfig::Bnb { .. }
54 | QuantMethodConfig::BlockwiseFP8 { .. }
55 | QuantMethodConfig::PerTensorFP8 { .. }
56 | QuantMethodConfig::Unquantized(_)
57 | QuantMethodConfig::Afq { .. } => unreachable!(),
58 QuantMethodConfig::MXFP4 {
59 blocks,
60 scales,
61 bias,
62 } => Ok(Self {
63 blocks,
64 scales,
65 bias,
66 }),
67 }
68 }
69
70 fn dequantize_w(&self) -> Result<hanzo_ml::Tensor> {
71 #[cfg(feature = "metal")]
72 if self.blocks.device().is_metal() {
73 use crate::afq::ops;
74 use crate::{AfqBits, AfqGroupSize};
75 return ops::afq_dequantize_op(
76 &self.blocks,
77 &self.scales,
78 &self.scales.clone(),
79 AfqGroupSize::Low,
80 AfqBits::Mxfp4,
81 );
82 }
83 self.dequantize_weights()
85 }
86
87 #[allow(unused_variables)]
88 fn forward_raw(&self, x: &Tensor) -> Result<Tensor> {
89 #[cfg(feature = "cuda")]
90 if matches!(x.device(), Device::Cuda(_)) && ffi::HAVE_MXFP4_GEMM_KERNELS {
91 let orig_dims = x.dims().to_vec();
92 let x_2d = if orig_dims.len() > 2 {
93 let features = orig_dims[orig_dims.len() - 1];
94 let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
95 x.reshape((batch_size, features))?
96 } else {
97 x.clone()
98 };
99
100 let result = ops::mxfp4_matmul(&x_2d, &self.blocks, &self.scales, self.bias.as_ref())?;
101
102 if orig_dims.len() > 2 {
103 let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
104 new_dims.push(result.dim(1)?);
105 return result.reshape(new_dims);
106 }
107 return Ok(result);
108 }
109
110 #[cfg(feature = "metal")]
111 {
112 if x.device().is_metal() {
113 let orig_dims = x.dims().to_vec();
114 let x_2d = if orig_dims.len() > 2 {
115 let features = orig_dims[orig_dims.len() - 1];
116 let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
117 x.reshape((batch_size, features))?
118 } else {
119 x.clone()
120 };
121
122 let result =
123 metal_ops::mxfp4_matmul(&x_2d, &self.blocks, &self.scales, self.bias.as_ref())?;
124
125 if orig_dims.len() > 2 {
126 let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
127 new_dims.push(result.dim(1)?);
128 return result.reshape(new_dims);
129 }
130 return Ok(result);
131 }
132 }
133
134 self.forward_dequantize(x)
135 }
136
137 #[allow(unused_variables)]
138 fn gather_forward_raw(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
139 #[cfg(feature = "cuda")]
140 if matches!(x.device(), Device::Cuda(_)) && ffi::HAVE_MXFP4_GEMM_KERNELS {
141 return ops::mxfp4_indexed_moe_gemm(
142 x,
143 &self.blocks,
144 &self.scales,
145 self.bias.as_ref(),
146 indices,
147 );
148 }
149
150 #[cfg(feature = "metal")]
151 {
152 if x.device().is_metal() {
153 return metal_ops::mxfp4_indexed_moe_gemm(
154 x,
155 &self.blocks,
156 &self.scales,
157 self.bias.as_ref(),
158 indices,
159 );
160 }
161 }
162
163 self.gather_forward_dequantize(x, indices)
164 }
165
166 fn quantized_act_type(&self) -> Option<DType> {
167 None
168 }
169
170 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
171 hanzo_ml::bail!("MXFP4Layer does not support add_delta_w")
172 }
173
174 fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
175 (DType::BF16, self.scales.device().clone())
176 }
177
178 fn apply_isq(
179 self: Arc<Self>,
180 _dtype: Option<IsqType>,
181 _device: Device,
182 _n_quantized: &AtomicUsize,
183 _imatrix_weight: Option<Vec<f32>>,
184 _guard: QuantizeOntoGuard,
185 ) -> Result<Arc<dyn QuantMethod>> {
186 hanzo_ml::bail!("MXFP4Layer does not support ISQ")
187 }
188}
189
190impl MXFP4Layer {
191 fn device_supported(_device: &Device) -> bool {
193 #[cfg(feature = "cuda")]
194 if matches!(_device, Device::Cuda(_)) {
195 return ffi::HAVE_MXFP4_GEMM_KERNELS;
196 }
197 #[cfg(feature = "metal")]
198 if _device.is_metal() {
199 return true;
200 }
201 false
202 }
203
204 pub fn quantize(
207 weight: &Tensor,
208 bias: Option<Tensor>,
209 device: &Device,
210 ) -> Result<Arc<dyn QuantMethod>> {
211 let weight_f32 = weight.to_dtype(DType::F32)?.to_device(&Device::Cpu)?;
212 let dims = weight_f32.dims2()?;
213 let (n, k) = (dims.0, dims.1);
214
215 if k % MXFP4_BLOCK_SIZE != 0 {
216 hanzo_ml::bail!(
217 "MXFP4 quantization requires K ({k}) divisible by block size ({MXFP4_BLOCK_SIZE})"
218 );
219 }
220
221 let weight_data: Vec<f32> = weight_f32.flatten_all()?.to_vec1()?;
222 let num_blocks_per_row = k / MXFP4_BLOCK_SIZE;
223 let k_half = k / 2;
224
225 use rayon::prelude::*;
227 let row_results: Vec<(Vec<u8>, Vec<u8>)> = (0..n)
228 .into_par_iter()
229 .map(|row| {
230 let row_offset = row * k;
231 let mut row_packed = vec![0u8; k_half];
232 let mut row_scales = vec![0u8; num_blocks_per_row];
233
234 for (blk, row_scale) in row_scales.iter_mut().enumerate() {
235 let blk_start = row_offset + blk * MXFP4_BLOCK_SIZE;
236 let block = &weight_data[blk_start..blk_start + MXFP4_BLOCK_SIZE];
237
238 let max_abs = block.iter().fold(0.0f32, |m, &v| m.max(v.abs()));
239
240 let scale = if max_abs == 0.0 {
241 127u8
242 } else {
243 let raw = (max_abs / 6.0).log2().floor() as i32 + 127;
244 raw.clamp(0, 254) as u8
245 };
246 *row_scale = scale;
247
248 let scale_factor = 2.0f32.powi(scale as i32 - 127);
249 let inv_scale = if scale_factor == 0.0 {
250 0.0
251 } else {
252 1.0 / scale_factor
253 };
254
255 for (elem, &val) in block.iter().enumerate() {
256 let nibble = Self::quantize_to_fp4(val * inv_scale);
257 let k_idx = blk * MXFP4_BLOCK_SIZE + elem;
258 let byte_idx = k_idx / 2;
259 if k_idx.is_multiple_of(2) {
260 row_packed[byte_idx] |= nibble;
261 } else {
262 row_packed[byte_idx] |= nibble << 4;
263 }
264 }
265 }
266 (row_packed, row_scales)
267 })
268 .collect();
269
270 let mut packed = Vec::with_capacity(n * k_half);
271 let mut scales = Vec::with_capacity(n * num_blocks_per_row);
272 for (row_packed, row_scales) in row_results {
273 packed.extend_from_slice(&row_packed);
274 scales.extend_from_slice(&row_scales);
275 }
276
277 let blocks = Tensor::from_vec(packed, (n, k / 2), &Device::Cpu)?
278 .to_dtype(DType::U8)?
279 .to_device(device)?;
280 let scales = Tensor::from_vec(scales, (n, num_blocks_per_row), &Device::Cpu)?
281 .to_dtype(DType::U8)?
282 .to_device(device)?;
283 let bias = bias.map(|b| b.to_device(device)).transpose()?;
284
285 Ok(Arc::new(Self {
286 blocks,
287 scales,
288 bias,
289 }))
290 }
291
292 fn quantize_to_fp4(val: f32) -> u8 {
294 let sign = val < 0.0;
297 let abs_val = val.abs();
298
299 let nibble = if abs_val < 0.25 {
301 0 } else if abs_val < 0.75 {
303 1 } else if abs_val < 1.25 {
305 2 } else if abs_val < 1.75 {
307 3 } else if abs_val < 2.5 {
309 4 } else if abs_val < 3.5 {
311 5 } else if abs_val < 5.0 {
313 6 } else {
315 7 };
317
318 if sign {
319 nibble | 0x08
320 } else {
321 nibble
322 }
323 }
324
325 pub fn linear_b(
326 in_dim: usize,
327 out_dim: usize,
328 config: &QuantizedConfig,
329 bias: bool,
330 vb: ShardedVarBuilder,
331 ) -> Result<Arc<dyn QuantMethod>> {
332 if !Self::device_supported(vb.device()) {
333 hanzo_ml::bail!("MXFP4Layer requires CUDA or Metal device.");
334 }
335
336 let QuantizedConfig::MXFP4 {} = config else {
337 hanzo_ml::bail!("Unexpected quantization config.")
338 };
339
340 let blocks = vb.get_with_hints_dtype(
341 (out_dim, in_dim / 2),
342 "blocks",
343 Default::default(),
344 DType::U8,
345 )?;
346 let scales = vb.get_with_hints_dtype(
347 (out_dim, in_dim / MXFP4_BLOCK_SIZE),
348 "scales",
349 Default::default(),
350 DType::U8,
351 )?;
352
353 let bias = if bias {
354 Some(vb.get((out_dim,), "bias")?)
355 } else {
356 None
357 };
358
359 Ok(Arc::new(Self {
360 blocks,
361 scales,
362 bias,
363 }))
364 }
365
366 pub fn packed_linear_b(
367 num_local_experts: usize,
368 in_dim: usize,
369 out_dim: usize,
370 config: &QuantizedConfig,
371 bias: bool,
372 vb: ShardedVarBuilder,
373 ) -> Result<Arc<dyn QuantMethod>> {
374 if !Self::device_supported(vb.device()) {
375 hanzo_ml::bail!("MXFP4Layer requires CUDA or Metal device.");
376 }
377
378 let QuantizedConfig::MXFP4 {} = config else {
379 hanzo_ml::bail!("Unexpected quantization config.")
380 };
381
382 let blocks = vb.get_with_hints_dtype(
383 (num_local_experts, out_dim, in_dim / 2),
384 "blocks",
385 Default::default(),
386 DType::U8,
387 )?;
388 let scales = vb.get_with_hints_dtype(
389 (num_local_experts, out_dim, in_dim / MXFP4_BLOCK_SIZE),
390 "scales",
391 Default::default(),
392 DType::U8,
393 )?;
394
395 let bias = if bias {
396 Some(vb.get((num_local_experts, out_dim), "bias")?)
397 } else {
398 None
399 };
400
401 Ok(Arc::new(Self {
402 blocks,
403 scales,
404 bias,
405 }))
406 }
407
408 pub fn packed_gptoss_linear(
417 num_local_experts: usize,
418 in_dim: usize,
419 out_dim: usize,
420 bias: bool,
421 name: &str,
422 vb: ShardedVarBuilder,
423 ) -> Result<Arc<dyn QuantMethod>> {
424 if !Self::device_supported(vb.device()) {
425 hanzo_ml::bail!("MXFP4Layer requires CUDA or Metal device.");
426 }
427
428 let num_blocks = in_dim / MXFP4_BLOCK_SIZE;
429
430 let blocks_4d = vb.get_with_hints_dtype(
431 (num_local_experts, out_dim, num_blocks, 16),
432 &format!("{name}_blocks"),
433 Default::default(),
434 DType::U8,
435 )?;
436
437 let blocks = blocks_4d.reshape((num_local_experts, out_dim, num_blocks * 16))?;
438
439 let scales = vb.get_with_hints_dtype(
440 (num_local_experts, out_dim, num_blocks),
441 &format!("{name}_scales"),
442 Default::default(),
443 DType::U8,
444 )?;
445
446 let bias = if bias {
447 Some(vb.get((num_local_experts, out_dim), &format!("{name}_bias"))?)
448 } else {
449 None
450 };
451
452 Ok(Arc::new(Self {
453 blocks,
454 scales,
455 bias,
456 }))
457 }
458
459 const DEQUANT_LUT: [[f32; 16]; 256] = {
464 let mut lut = [[0.0f32; 16]; 256];
465 let fp4: [f32; 16] = [
466 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
467 ];
468 let mut s = 0u32;
469 while s < 256 {
470 let scale_factor = f32::from_bits(s << 23);
471 let mut n = 0;
472 while n < 16 {
473 lut[s as usize][n] = fp4[n] * scale_factor;
474 n += 1;
475 }
476 s += 1;
477 }
478 lut
479 };
480
481 fn dequantize_weights(&self) -> Result<Tensor> {
486 let blocks_dims = self.blocks.dims();
487
488 let (num_experts, n, k_half) = if blocks_dims.len() == 3 {
489 (blocks_dims[0], blocks_dims[1], blocks_dims[2])
490 } else {
491 (1, blocks_dims[0], blocks_dims[1])
492 };
493 let k = k_half * 2;
494 let num_blocks_per_row = k / MXFP4_BLOCK_SIZE;
495
496 let blocks_cpu = self.blocks.to_device(&Device::Cpu)?;
497 let scales_cpu = self.scales.to_device(&Device::Cpu)?;
498
499 let blocks_data: Vec<u8> = blocks_cpu.flatten_all()?.to_vec1()?;
500 let scales_data: Vec<u8> = scales_cpu.flatten_all()?.to_vec1()?;
501
502 let mut weights = vec![0f32; num_experts * n * k];
503 let half_block = MXFP4_BLOCK_SIZE / 2; for expert in 0..num_experts {
506 for row in 0..n {
507 let blocks_row = expert * n * k_half + row * k_half;
508 let scales_row = expert * n * num_blocks_per_row + row * num_blocks_per_row;
509 let weights_row = expert * n * k + row * k;
510
511 for blk in 0..num_blocks_per_row {
512 let scale = scales_data[scales_row + blk] as usize;
513 let dequant = &Self::DEQUANT_LUT[scale];
514 let blk_bytes = &blocks_data[blocks_row + blk * half_block..];
515 let w_out = &mut weights[weights_row + blk * MXFP4_BLOCK_SIZE..];
516
517 for byte_i in 0..half_block {
518 let packed = blk_bytes[byte_i];
519 w_out[byte_i * 2] = dequant[(packed & 0x0F) as usize];
520 w_out[byte_i * 2 + 1] = dequant[((packed >> 4) & 0x0F) as usize];
521 }
522 }
523 }
524 }
525
526 let shape = if blocks_dims.len() == 3 {
527 vec![num_experts, n, k]
528 } else {
529 vec![n, k]
530 };
531
532 Tensor::from_vec(weights, shape.as_slice(), &Device::Cpu)?
533 .to_device(self.blocks.device())?
534 .to_dtype(DType::BF16)
535 }
536
537 fn forward_dequantize(&self, x: &Tensor) -> Result<Tensor> {
541 let orig_dims = x.dims().to_vec();
542
543 let x_2d = if orig_dims.len() > 2 {
544 let features = orig_dims[orig_dims.len() - 1];
545 let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product();
546 x.reshape((batch_size, features))?
547 } else {
548 x.clone()
549 };
550
551 let x_f32 = x_2d.to_dtype(DType::F32)?.to_device(&Device::Cpu)?;
552 let (m, k) = x_f32.dims2()?;
553
554 let blocks_dims = self.blocks.dims();
555 let n = if blocks_dims.len() == 3 {
556 blocks_dims[1]
557 } else {
558 blocks_dims[0]
559 };
560 let num_blocks_per_row = k / MXFP4_BLOCK_SIZE;
561 let half_block = MXFP4_BLOCK_SIZE / 2;
562
563 let blocks_cpu = self.blocks.to_device(&Device::Cpu)?;
564 let scales_cpu = self.scales.to_device(&Device::Cpu)?;
565 let blocks_data: Vec<u8> = blocks_cpu.flatten_all()?.to_vec1()?;
566 let scales_data: Vec<u8> = scales_cpu.flatten_all()?.to_vec1()?;
567 let x_data: Vec<f32> = x_f32.flatten_all()?.to_vec1()?;
568
569 let mut output = vec![0f32; m * n];
571 let k_half = k / 2;
572
573 for blk in 0..num_blocks_per_row {
574 let col_start = blk * MXFP4_BLOCK_SIZE;
575
576 for row in 0..n {
577 let scale = scales_data[row * num_blocks_per_row + blk] as usize;
578 let dequant = &Self::DEQUANT_LUT[scale];
579 let blk_bytes = &blocks_data[row * k_half + blk * half_block..];
580
581 let mut w_block = [0f32; MXFP4_BLOCK_SIZE];
583 for byte_i in 0..half_block {
584 let packed = blk_bytes[byte_i];
585 w_block[byte_i * 2] = dequant[(packed & 0x0F) as usize];
586 w_block[byte_i * 2 + 1] = dequant[((packed >> 4) & 0x0F) as usize];
587 }
588
589 for token in 0..m {
591 let x_row = &x_data[token * k + col_start..];
592 let mut acc = 0f32;
593 for i in 0..MXFP4_BLOCK_SIZE {
594 acc += x_row[i] * w_block[i];
595 }
596 output[token * n + row] += acc;
597 }
598 }
599 }
600
601 let mut result = Tensor::from_vec(output, (m, n), &Device::Cpu)?
602 .to_device(x.device())?
603 .to_dtype(x.dtype())?;
604
605 if let Some(bias) = &self.bias {
606 result = result.broadcast_add(bias)?;
607 }
608
609 if orig_dims.len() > 2 {
610 let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec();
611 new_dims.push(result.dim(1)?);
612 result = result.reshape(new_dims)?;
613 }
614
615 Ok(result)
616 }
617
618 fn gather_forward_dequantize(&self, x: &Tensor, indices: &Tensor) -> Result<Tensor> {
621 let x_dims = x.dims();
622 let indices_dims = indices.dims();
623
624 let (num_tokens, topk, k, x_has_topk) = if x_dims.len() == 2 {
625 (x_dims[0], indices_dims[1], x_dims[1], false)
626 } else {
627 (x_dims[0], x_dims[1], x_dims[2], true)
628 };
629
630 let blocks_dims = self.blocks.dims();
631 let n = blocks_dims[1];
632 let k_half = k / 2;
633 let num_blocks_per_row = k / MXFP4_BLOCK_SIZE;
634 let half_block = MXFP4_BLOCK_SIZE / 2;
635
636 let blocks_cpu = self.blocks.to_device(&Device::Cpu)?;
637 let scales_cpu = self.scales.to_device(&Device::Cpu)?;
638 let blocks_data: Vec<u8> = blocks_cpu.flatten_all()?.to_vec1()?;
639 let scales_data: Vec<u8> = scales_cpu.flatten_all()?.to_vec1()?;
640
641 let x_f32 = x.to_dtype(DType::F32)?.to_device(&Device::Cpu)?;
642 let x_data: Vec<f32> = x_f32.flatten_all()?.to_vec1()?;
643
644 let indices_cpu = indices.to_device(&Device::Cpu)?.to_dtype(DType::U32)?;
645 let indices_data: Vec<u32> = indices_cpu.flatten_all()?.to_vec1()?;
646
647 let bias_data: Option<Vec<f32>> = self
648 .bias
649 .as_ref()
650 .map(|b| {
651 b.to_dtype(DType::F32)?
652 .to_device(&Device::Cpu)?
653 .flatten_all()?
654 .to_vec1()
655 })
656 .transpose()?;
657
658 let mut output = vec![0f32; num_tokens * topk * n];
660
661 for token_idx in 0..num_tokens {
662 for slot_idx in 0..topk {
663 let expert_idx = indices_data[token_idx * topk + slot_idx] as usize;
664 let out_row = token_idx * topk + slot_idx;
665
666 let x_offset = if x_has_topk {
668 (token_idx * topk + slot_idx) * k
669 } else {
670 token_idx * k
671 };
672
673 let expert_blocks_base = expert_idx * n * k_half;
675 let expert_scales_base = expert_idx * n * num_blocks_per_row;
676
677 for blk in 0..num_blocks_per_row {
678 let col_start = blk * MXFP4_BLOCK_SIZE;
679
680 let x_blk =
682 &x_data[x_offset + col_start..x_offset + col_start + MXFP4_BLOCK_SIZE];
683
684 for row in 0..n {
685 let scale = scales_data[expert_scales_base + row * num_blocks_per_row + blk]
686 as usize;
687 let dequant = &Self::DEQUANT_LUT[scale];
688 let blk_bytes =
689 &blocks_data[expert_blocks_base + row * k_half + blk * half_block..];
690
691 let mut dot = 0f32;
692 for byte_i in 0..half_block {
693 let packed = blk_bytes[byte_i];
694 let w0 = dequant[(packed & 0x0F) as usize];
695 let w1 = dequant[((packed >> 4) & 0x0F) as usize];
696 dot += x_blk[byte_i * 2] * w0 + x_blk[byte_i * 2 + 1] * w1;
697 }
698 output[out_row * n + row] += dot;
699 }
700 }
701
702 if let Some(ref bias) = bias_data {
704 let bias_offset = expert_idx * n;
705 for row in 0..n {
706 output[out_row * n + row] += bias[bias_offset + row];
707 }
708 }
709 }
710 }
711
712 let result = Tensor::from_vec(output, (num_tokens * topk, n), &Device::Cpu)?
713 .to_device(x.device())?
714 .to_dtype(x.dtype())?;
715 result.reshape((num_tokens, topk, n))
716 }
717}
718
719impl QuantizedSerde for MXFP4Layer {
733 fn name(&self) -> &'static str {
734 "mxfp4-layer"
735 }
736 fn isq_serde_supported(&self) -> bool {
737 true
738 }
739 fn serialize(&self) -> Result<Cow<'_, [u8]>> {
740 self.serialize_with_bias(self.bias.clone())
741 }
742 fn serialize_with_bias(&self, bias: Option<Tensor>) -> Result<Cow<'_, [u8]>> {
743 let mut buffer = Vec::new();
744
745 buffer.extend(&UQFF_VERSION.to_le_bytes());
746 buffer.push(QuantizedSerdeType::Mxfp4 as u8);
747 buffer.push(bias.is_some() as u8);
748
749 serialize_tensor(&mut buffer, &self.blocks)?;
750 serialize_tensor(&mut buffer, &self.scales)?;
751
752 if let Some(bias) = &bias {
753 serialize_tensor(&mut buffer, bias)?;
754 }
755
756 Ok(Cow::from(buffer))
757 }
758
759 fn deserialize(
760 data: Cow<[u8]>,
761 device: &Device,
762 _comm: &Arc<crate::Comm>,
763 guard: QuantizeOntoGuard,
764 ) -> Result<Arc<dyn QuantMethod>>
765 where
766 Self: Sized,
767 {
768 let (layer, _bias) = Self::deserialize_ext_bias(data, device, guard)?;
769 Ok(layer)
770 }
771
772 fn deserialize_ext_bias(
773 data: Cow<[u8]>,
774 device: &Device,
775 guard: QuantizeOntoGuard,
776 ) -> Result<(Arc<dyn QuantMethod>, Option<Tensor>)>
777 where
778 Self: Sized,
779 {
780 let mut buffer = Cursor::new(data.to_vec());
781
782 let version = buffer.read_u32::<LittleEndian>()?;
783 if let Err(e) = version_is_compatible(version) {
784 return Err(hanzo_ml::Error::wrap(e));
785 }
786
787 let isq_type = buffer.read_u8()? as usize;
788 if isq_type != QuantizedSerdeType::Mxfp4 as usize {
789 hanzo_ml::bail!(
790 "ISQ type ({isq_type}) doesn't match expected type {}",
791 QuantizedSerdeType::Mxfp4 as usize
792 );
793 }
794
795 let has_bias = buffer.read_u8()? != 0;
796
797 let _acquired_load_guard = guard.acquire(device);
798 let blocks = deserialize_tensor(&mut buffer, device)?;
799 let scales = deserialize_tensor(&mut buffer, device)?;
800
801 let bias = if has_bias {
802 Some(deserialize_tensor(&mut buffer, device)?)
803 } else {
804 None
805 };
806
807 let ext_bias = bias.clone();
808
809 Ok((
810 Arc::new(Self {
811 blocks,
812 scales,
813 bias,
814 }),
815 ext_bias,
816 ))
817 }
818}