Skip to main content

hanzo_quant/distributed/
layers.rs

1use std::sync::Arc;
2
3use hanzo_ml::{Context, Device, IndexOp, Result, Tensor, D};
4use hanzo_nn::Linear;
5
6use crate::{
7    blockwise_fp8::{blockwise_fp8_linear_b, blockwise_fp8_moe},
8    distributed,
9    gptq::gptq_linear,
10    lora::merge_lora_weights,
11    make_dummy_or_error,
12    pertensor_fp8::pertensor_fp8_linear_b,
13    should_apply_immediate_isq,
14    utils::isq::apply_immediate_isq,
15    AfqLayer, BnbLinear, DistributedKind, F8Q8Linear, FP8Linear, GgufMatMul, HqqLayer, MXFP4Layer,
16    QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
17    QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
18};
19
20use super::{Comm, SumAllReduce};
21
22fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
23    Shard::Simple {
24        dim,
25        rank,
26        world_size,
27    }
28}
29
30/// This layer has a weight that is parallelized along the input dimension,
31/// returning the "full" output dimension.
32#[derive(Debug)]
33pub struct RowParallelLayer {
34    weight: Arc<dyn QuantMethod>,
35    bias: Option<Tensor>,
36    all_reduce: distributed::SumAllReduce,
37}
38
39impl RowParallelLayer {
40    #[allow(clippy::new_ret_no_self)]
41    pub fn new(
42        in_dim: usize,
43        out_dim: usize,
44        config: &Option<QuantizedConfig>,
45        bias: bool,
46        comm: &Arc<crate::Comm>,
47        vb: ShardedVarBuilder,
48    ) -> Result<Arc<dyn QuantMethod>> {
49        let rank = comm.rank();
50        let world_size = comm.world_size();
51        let shard = shard(1, rank, world_size);
52
53        let base_vb = vb.clone();
54        let vb = if should_apply_immediate_isq(&vb) {
55            vb.set_device(Device::Cpu)
56        } else {
57            vb
58        };
59
60        let weight = if let Some(quant_conf) = &config {
61            // GPTQ and BNB do not support tensor parallelism
62            if matches!(
63                quant_conf,
64                QuantizedConfig::GptqAwq { .. }
65                    | QuantizedConfig::Bitsandbytes { .. }
66                    | QuantizedConfig::Afq { .. }
67            ) && comm.world_size() != 1
68            {
69                hanzo_ml::bail!(
70                    "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
71                    comm.world_size()
72                );
73            }
74
75            match quant_conf {
76                QuantizedConfig::GptqAwq { .. } => {
77                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
78                }
79                QuantizedConfig::Fp8 { weight_block_size } => {
80                    // NOTE: no bias for fp8 as it might be parallelized
81                    if weight_block_size.is_some() {
82                        blockwise_fp8_linear_b(
83                            in_dim,
84                            out_dim,
85                            quant_conf,
86                            false,
87                            shard,
88                            vb.clone(),
89                        )?
90                    } else {
91                        pertensor_fp8_linear_b(
92                            in_dim,
93                            out_dim,
94                            quant_conf,
95                            false,
96                            shard,
97                            vb.clone(),
98                        )?
99                    }
100                }
101                QuantizedConfig::Bitsandbytes { .. } => {
102                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
103                }
104                QuantizedConfig::Afq { .. } => {
105                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
106                }
107                QuantizedConfig::MXFP4 {} => {
108                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
109                }
110            }
111        } else {
112            if !vb.contains_tensor("weight") {
113                make_dummy_or_error("row_parallel_linear", &vb, &["weight"])?
114            } else {
115                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
116                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
117
118                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
119                    Linear::new(weight, None),
120                ))?;
121                Arc::new(layer) as Arc<dyn QuantMethod>
122            }
123        };
124
125        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
126        let bias = if bias && vb.contains_tensor("bias") {
127            Some(vb.get((out_dim,), "bias")?)
128        } else {
129            None
130        };
131
132        let this_unquant = Arc::new(Self {
133            weight,
134            bias,
135            all_reduce: distributed::SumAllReduce::new(comm),
136        });
137        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
138        Ok(this)
139    }
140
141    #[allow(clippy::new_ret_no_self)]
142    pub fn new_matformer(
143        in_dim: usize,
144        out_dim: usize,
145        orig_intermediate_size: usize,
146        config: &Option<QuantizedConfig>,
147        bias: bool,
148        comm: &Arc<crate::Comm>,
149        vb: ShardedVarBuilder,
150    ) -> Result<Arc<dyn QuantMethod>> {
151        let rank = comm.rank();
152        let world_size = comm.world_size();
153        let shard = shard(1, rank, world_size);
154
155        let base_vb = vb.clone();
156        let vb = if should_apply_immediate_isq(&vb) {
157            vb.set_device(Device::Cpu)
158        } else {
159            vb
160        };
161
162        if config.is_some() {
163            hanzo_ml::bail!("Cannot load a matformer layer with a pre-quantized model.");
164        }
165
166        let weight = if !vb.contains_tensor("weight") {
167            make_dummy_or_error("row_parallel_matformer_linear", &vb, &["weight"])?
168        } else {
169            let weight = vb
170                .get_with_hints(
171                    (out_dim, orig_intermediate_size),
172                    "weight",
173                    Default::default(),
174                )?
175                .i((.., ..in_dim))?
176                .contiguous()?;
177
178            let weight = shard.apply_to(&weight)?;
179            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
180
181            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
182                Linear::new(weight, None),
183            ))?;
184            Arc::new(layer) as Arc<dyn QuantMethod>
185        };
186
187        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
188        let bias = if bias && vb.contains_tensor("bias") {
189            Some(vb.get((out_dim,), "bias")?)
190        } else {
191            None
192        };
193
194        let this_unquant = Arc::new(Self {
195            weight,
196            bias,
197            all_reduce: distributed::SumAllReduce::new(comm),
198        });
199        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
200        Ok(this)
201    }
202}
203
204impl QuantMethod for RowParallelLayer {
205    fn new(_method: QuantMethodConfig) -> Result<Self>
206    where
207        Self: Sized,
208    {
209        hanzo_ml::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
210    }
211
212    fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
213        let mut xs = self.weight.forward_raw(a)?;
214        xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
215        if let Some(bias) = &self.bias {
216            xs = xs.broadcast_add(bias)?;
217        }
218        Ok(xs)
219    }
220
221    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
222        let weight = self.weight.add_delta_w(delta)?;
223        Ok(Arc::new(Self {
224            weight,
225            bias: self.bias.clone(),
226            all_reduce: self.all_reduce.clone(),
227        }))
228    }
229
230    fn dequantize_w(&self) -> Result<Tensor> {
231        self.weight.dequantize_w()
232    }
233
234    fn dtype_and_device(&self) -> (hanzo_ml::DType, hanzo_ml::Device) {
235        self.weight.dtype_and_device()
236    }
237
238    fn begin_track_stats(&mut self) -> Result<()> {
239        Arc::get_mut(&mut self.weight)
240            .context("Failed to get &mut to weight")?
241            .begin_track_stats()
242    }
243
244    fn end_track_stats(&self) -> Result<Tensor> {
245        self.weight.end_track_stats()
246    }
247
248    fn quantized_act_type(&self) -> Option<hanzo_ml::DType> {
249        self.weight.quantized_act_type()
250    }
251
252    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
253        self.weight.unquant_weight_bias()
254    }
255
256    fn has_bias(&self) -> bool {
257        self.bias.is_some() || self.weight.has_bias()
258    }
259
260    #[cfg(feature = "cuda")]
261    fn get_qtensor(&self) -> Option<&hanzo_ml::quantized::QTensor> {
262        self.weight.get_qtensor()
263    }
264
265    fn apply_isq(
266        self: Arc<Self>,
267        dtype: Option<crate::IsqType>,
268        device: hanzo_ml::Device,
269        n_quantized: &std::sync::atomic::AtomicUsize,
270        imatrix_weight: Option<Vec<f32>>,
271        guard: QuantizeOntoGuard,
272    ) -> Result<Arc<dyn QuantMethod>> {
273        let weight =
274            self.weight
275                .clone()
276                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
277        let bias = match &self.bias {
278            Some(b) => {
279                let (dtype, device) = weight.dtype_and_device();
280                Some(b.to_device(&device)?.to_dtype(dtype)?)
281            }
282            None => None,
283        };
284        Ok(Arc::new(Self {
285            weight,
286            bias,
287            all_reduce: self.all_reduce.clone(),
288        }))
289    }
290
291    fn is_distributed(&self) -> Option<DistributedKind> {
292        Some(DistributedKind::RowParallel)
293    }
294}
295
296impl QuantizedSerde for RowParallelLayer {
297    fn isq_serde_supported(&self) -> bool {
298        self.weight.isq_serde_supported()
299    }
300    fn name(&self) -> &'static str {
301        self.weight.name()
302    }
303    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
304        self.weight.serialize_with_bias(self.bias.clone())
305    }
306    fn deserialize(
307        data: std::borrow::Cow<[u8]>,
308        device: &hanzo_ml::Device,
309        comm: &Arc<crate::Comm>,
310        guard: QuantizeOntoGuard,
311    ) -> Result<Arc<dyn QuantMethod>>
312    where
313        Self: Sized,
314    {
315        // NOTE(hanzoai): isq type is ALWAYS byte 4 (5th) of the tensor.
316        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
317        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
318            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
319            QuantizedSerdeType::Unquant => {
320                UnquantLinear::deserialize_ext_bias(data, device, guard)?
321            }
322            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
323            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
324            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
325            QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize_ext_bias(data, device, guard)?,
326            QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize_ext_bias(data, device, guard)?,
327        };
328        Ok(Arc::new(Self {
329            weight,
330            bias,
331            all_reduce: SumAllReduce::new(comm),
332        }))
333    }
334}
335
336#[derive(Debug)]
337/// This layer has a weight that is parallelized along the output dimension,
338/// taking the "full" input dimension.
339pub struct ColumnParallelLayer {
340    weight: Arc<dyn QuantMethod>,
341    bias: Option<Tensor>,
342}
343
344impl ColumnParallelLayer {
345    #[allow(clippy::new_ret_no_self)]
346    pub fn new_with_shard(
347        in_dim: usize,
348        out_dim: usize,
349        config: &Option<QuantizedConfig>,
350        bias: bool,
351        comm: &Arc<crate::Comm>,
352        shard: Shard,
353        vb: ShardedVarBuilder,
354    ) -> Result<Arc<dyn QuantMethod>> {
355        let base_vb = vb.clone();
356        let vb = if should_apply_immediate_isq(&vb) {
357            vb.set_device(Device::Cpu)
358        } else {
359            vb
360        };
361
362        let weight = if let Some(quant_conf) = &config {
363            // GPTQ and BNB do not support tensor parallelism
364            if matches!(
365                quant_conf,
366                QuantizedConfig::GptqAwq { .. }
367                    | QuantizedConfig::Bitsandbytes { .. }
368                    | QuantizedConfig::Afq { .. }
369            ) && comm.world_size() != 1
370            {
371                hanzo_ml::bail!(
372                    "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
373                    comm.world_size()
374                );
375            }
376
377            match quant_conf {
378                QuantizedConfig::GptqAwq { .. } => {
379                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
380                }
381                QuantizedConfig::Fp8 { weight_block_size } => {
382                    // NOTE: no bias for fp8 as it might be parallelized
383                    if weight_block_size.is_some() {
384                        blockwise_fp8_linear_b(
385                            in_dim,
386                            out_dim,
387                            quant_conf,
388                            false,
389                            shard,
390                            vb.clone(),
391                        )?
392                    } else {
393                        pertensor_fp8_linear_b(
394                            in_dim,
395                            out_dim,
396                            quant_conf,
397                            false,
398                            shard,
399                            vb.clone(),
400                        )?
401                    }
402                }
403                QuantizedConfig::Bitsandbytes { .. } => {
404                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
405                }
406                QuantizedConfig::Afq { .. } => {
407                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
408                }
409                QuantizedConfig::MXFP4 {} => {
410                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
411                }
412            }
413        } else {
414            if !vb.contains_tensor("weight") {
415                make_dummy_or_error("column_parallel_linear", &vb, &["weight"])?
416            } else {
417                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
418                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
419
420                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
421                    Linear::new(weight, None),
422                ))?;
423                Arc::new(layer) as Arc<dyn QuantMethod>
424            }
425        };
426
427        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
428        let bias = if bias && vb.contains_tensor("bias") {
429            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
430        } else {
431            None
432        };
433
434        let this_unquant = Arc::new(Self { weight, bias });
435        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
436        Ok(this)
437    }
438
439    #[allow(clippy::new_ret_no_self)]
440    pub fn new(
441        in_dim: usize,
442        out_dim: usize,
443        config: &Option<QuantizedConfig>,
444        bias: bool,
445        comm: &Arc<crate::Comm>,
446        vb: ShardedVarBuilder,
447    ) -> Result<Arc<dyn QuantMethod>> {
448        let rank = comm.rank();
449        let world_size = comm.world_size();
450        let shard = shard(0, rank, world_size);
451
452        Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
453    }
454
455    #[allow(clippy::new_ret_no_self)]
456    pub fn new_matformer(
457        in_dim: usize,
458        out_dim: usize,
459        orig_intermediate_size: usize,
460        config: &Option<QuantizedConfig>,
461        bias: bool,
462        comm: &Arc<crate::Comm>,
463        vb: ShardedVarBuilder,
464    ) -> Result<Arc<dyn QuantMethod>> {
465        let rank = comm.rank();
466        let world_size = comm.world_size();
467        let shard = shard(0, rank, world_size);
468
469        let base_vb = vb.clone();
470        let vb = if should_apply_immediate_isq(&vb) {
471            vb.set_device(Device::Cpu)
472        } else {
473            vb
474        };
475
476        if config.is_some() {
477            hanzo_ml::bail!("Cannot load a matformer layer with a pre-quantized model.");
478        }
479
480        let weight = if !vb.contains_tensor("weight") {
481            make_dummy_or_error("column_parallel_matformer_linear", &vb, &["weight"])?
482        } else {
483            let weight = vb
484                .get_with_hints(
485                    (orig_intermediate_size, in_dim),
486                    "weight",
487                    Default::default(),
488                )?
489                .i((..out_dim, ..))?
490                .contiguous()?;
491
492            let weight = shard.apply_to(&weight)?;
493            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
494
495            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
496                Linear::new(weight, None),
497            ))?;
498            Arc::new(layer) as Arc<dyn QuantMethod>
499        };
500
501        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
502        let bias = if bias && vb.contains_tensor("bias") {
503            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
504        } else {
505            None
506        };
507
508        let this_unquant = Arc::new(Self { weight, bias });
509        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
510        Ok(this)
511    }
512
513    pub fn new_merged(
514        in_dim: usize,
515        out_dim: usize,
516        chunks: usize,
517        config: &Option<QuantizedConfig>,
518        bias: bool,
519        comm: &Arc<crate::Comm>,
520        vb: ShardedVarBuilder,
521    ) -> Result<Vec<Arc<dyn QuantMethod>>> {
522        let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
523        for chunk_idx in 0..chunks {
524            let layer = ColumnParallelLayer::new_with_shard(
525                in_dim,
526                out_dim,
527                config,
528                bias,
529                comm,
530                shard(
531                    0,
532                    chunk_idx * comm.world_size() + comm.rank(),
533                    chunks * comm.world_size(),
534                ),
535                vb.clone(),
536            )?;
537            vec_layers.push(layer);
538        }
539        Ok(vec_layers)
540    }
541}
542
543impl QuantMethod for ColumnParallelLayer {
544    fn new(_method: QuantMethodConfig) -> Result<Self>
545    where
546        Self: Sized,
547    {
548        hanzo_ml::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
549    }
550
551    fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
552        let mut xs = self.weight.forward_raw(a)?;
553        if let Some(bias) = &self.bias {
554            xs = xs.broadcast_add(bias)?;
555        }
556        Ok(xs)
557    }
558
559    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
560        let weight = self.weight.add_delta_w(delta)?;
561        Ok(Arc::new(Self {
562            weight,
563            bias: self.bias.clone(),
564        }))
565    }
566
567    fn dequantize_w(&self) -> Result<Tensor> {
568        self.weight.dequantize_w()
569    }
570
571    fn dtype_and_device(&self) -> (hanzo_ml::DType, hanzo_ml::Device) {
572        self.weight.dtype_and_device()
573    }
574
575    fn begin_track_stats(&mut self) -> Result<()> {
576        Arc::get_mut(&mut self.weight)
577            .context("Failed to get &mut to weight")?
578            .begin_track_stats()
579    }
580
581    fn end_track_stats(&self) -> Result<Tensor> {
582        self.weight.end_track_stats()
583    }
584
585    fn quantized_act_type(&self) -> Option<hanzo_ml::DType> {
586        self.weight.quantized_act_type()
587    }
588
589    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
590        self.weight.unquant_weight_bias()
591    }
592
593    fn has_bias(&self) -> bool {
594        self.bias.is_some() || self.weight.has_bias()
595    }
596
597    #[cfg(feature = "cuda")]
598    fn get_qtensor(&self) -> Option<&hanzo_ml::quantized::QTensor> {
599        self.weight.get_qtensor()
600    }
601
602    fn apply_isq(
603        self: Arc<Self>,
604        dtype: Option<crate::IsqType>,
605        device: hanzo_ml::Device,
606        n_quantized: &std::sync::atomic::AtomicUsize,
607        imatrix_weight: Option<Vec<f32>>,
608        guard: QuantizeOntoGuard,
609    ) -> Result<Arc<dyn QuantMethod>> {
610        let weight =
611            self.weight
612                .clone()
613                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
614        let bias = match &self.bias {
615            Some(b) => {
616                let (dtype, device) = weight.dtype_and_device();
617                Some(b.to_device(&device)?.to_dtype(dtype)?)
618            }
619            None => None,
620        };
621        Ok(Arc::new(Self { weight, bias }))
622    }
623
624    fn is_distributed(&self) -> Option<DistributedKind> {
625        Some(DistributedKind::ColumnParallel)
626    }
627}
628
629impl QuantizedSerde for ColumnParallelLayer {
630    fn isq_serde_supported(&self) -> bool {
631        self.weight.isq_serde_supported()
632    }
633    fn name(&self) -> &'static str {
634        self.weight.name()
635    }
636    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
637        self.weight.serialize_with_bias(self.bias.clone())
638    }
639    fn deserialize(
640        data: std::borrow::Cow<[u8]>,
641        device: &hanzo_ml::Device,
642        _comm: &Arc<crate::Comm>,
643        guard: QuantizeOntoGuard,
644    ) -> Result<Arc<dyn QuantMethod>>
645    where
646        Self: Sized,
647    {
648        // NOTE(hanzoai): isq type is ALWAYS byte 4 (5th) of the tensor.
649        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
650        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
651            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
652            QuantizedSerdeType::Unquant => {
653                UnquantLinear::deserialize_ext_bias(data, device, guard)?
654            }
655            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
656            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
657            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
658            QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize_ext_bias(data, device, guard)?,
659            QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize_ext_bias(data, device, guard)?,
660        };
661        Ok(Arc::new(Self { weight, bias }))
662    }
663}
664
665#[derive(Debug)]
666/// This layer has no parallelization
667pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
668
669impl ReplicatedLayer {
670    pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
671        let dev = lin.weight().device().clone();
672        if let Some(crate::ImmediateIsqParams {
673            guard,
674            ty: Some(immediate_isq),
675            pool,
676            backpressure,
677            ..
678        }) = crate::get_immediate_isq()
679        {
680            // Global ISQ type is set — move to CPU for GGML quantization,
681            // then quantize onto the original device.
682            let lin = if !dev.is_cpu() {
683                Linear::new(lin.weight().to_device(&Device::Cpu)?, lin.bias().cloned())
684            } else {
685                lin
686            };
687            let layer: Arc<dyn QuantMethod> =
688                Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
689            if let Some(pool) = &pool {
690                backpressure.acquire();
691                let backpressure = backpressure.clone();
692                let dev = dev.clone();
693                let (tx, rx) = crate::pending_layer::pending_isq_channel();
694                pool.spawn(move || {
695                    let result = layer.clone().apply_isq(
696                        Some(immediate_isq),
697                        dev,
698                        &std::sync::atomic::AtomicUsize::new(0),
699                        None,
700                        guard,
701                    );
702                    let _ = tx.send(result);
703                    backpressure.release();
704                });
705                Ok(Arc::new(crate::PendingIsqLayer::new(rx)))
706            } else {
707                layer.clone().apply_isq(
708                    Some(immediate_isq),
709                    dev,
710                    &std::sync::atomic::AtomicUsize::new(0),
711                    None,
712                    guard,
713                )
714            }
715        } else {
716            // No global ISQ — keep as unquantized on original device.
717            Ok(Arc::new(UnquantLinear::new(
718                QuantMethodConfig::Unquantized(lin),
719            )?))
720        }
721    }
722
723    #[allow(clippy::new_ret_no_self)]
724    pub fn new(
725        in_dim: usize,
726        out_dim: usize,
727        config: &Option<QuantizedConfig>,
728        bias: bool,
729        vb: ShardedVarBuilder,
730    ) -> Result<Arc<dyn QuantMethod>> {
731        let base_vb = vb.clone();
732        let vb = if should_apply_immediate_isq(&vb) {
733            vb.set_device(Device::Cpu)
734        } else {
735            vb
736        };
737
738        let layer = if let Some(quant_conf) = &config {
739            match quant_conf {
740                QuantizedConfig::GptqAwq { .. } => {
741                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
742                }
743                QuantizedConfig::Fp8 { weight_block_size } => {
744                    if weight_block_size.is_some() {
745                        blockwise_fp8_linear_b(
746                            in_dim,
747                            out_dim,
748                            quant_conf,
749                            bias,
750                            Default::default(),
751                            vb.clone(),
752                        )?
753                    } else {
754                        pertensor_fp8_linear_b(
755                            in_dim,
756                            out_dim,
757                            quant_conf,
758                            bias,
759                            Default::default(),
760                            vb.clone(),
761                        )?
762                    }
763                }
764                QuantizedConfig::Bitsandbytes { .. } => {
765                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
766                }
767                QuantizedConfig::Afq { .. } => {
768                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
769                }
770                QuantizedConfig::MXFP4 {} => {
771                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
772                }
773            }
774        } else {
775            if !vb.contains_tensor("weight") {
776                make_dummy_or_error("replicated_linear", &vb, &["weight"])?
777            } else {
778                let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
779                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
780
781                let bias = if bias {
782                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
783                } else {
784                    None
785                };
786                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
787                    Linear::new(weight, bias),
788                ))?;
789                Arc::new(layer) as Arc<dyn QuantMethod>
790            }
791        };
792
793        let this_unquant = Arc::new(Self(layer));
794        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
795        Ok(this)
796    }
797
798    #[allow(clippy::new_ret_no_self)]
799    pub fn new_layers_matformer_indices(
800        in_dim: usize,
801        out_dim: usize,
802        kept_layers_indices: Option<&Tensor>,
803        orig_num_hidden_layers: usize,
804        config: &Option<QuantizedConfig>,
805        bias: bool,
806        vb: ShardedVarBuilder,
807    ) -> Result<Arc<dyn QuantMethod>> {
808        let base_vb = vb.clone();
809        let vb = if should_apply_immediate_isq(&vb) {
810            vb.set_device(Device::Cpu)
811        } else {
812            vb
813        };
814
815        let layer = if let Some(quant_conf) = &config {
816            if kept_layers_indices.is_some() {
817                hanzo_ml::bail!("Cannot load a matformer layer with a pre-quantized model.");
818            }
819
820            match quant_conf {
821                QuantizedConfig::GptqAwq { .. } => {
822                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
823                }
824                QuantizedConfig::Fp8 { weight_block_size } => {
825                    if weight_block_size.is_some() {
826                        blockwise_fp8_linear_b(
827                            in_dim,
828                            out_dim,
829                            quant_conf,
830                            bias,
831                            Default::default(),
832                            vb.clone(),
833                        )?
834                    } else {
835                        pertensor_fp8_linear_b(
836                            in_dim,
837                            out_dim,
838                            quant_conf,
839                            bias,
840                            Default::default(),
841                            vb.clone(),
842                        )?
843                    }
844                }
845                QuantizedConfig::Bitsandbytes { .. } => {
846                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
847                }
848                QuantizedConfig::Afq { .. } => {
849                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
850                }
851                QuantizedConfig::MXFP4 {} => {
852                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
853                }
854            }
855        } else {
856            if !vb.contains_tensor("weight") {
857                make_dummy_or_error("replicated_matformer_linear", &vb, &["weight"])?
858            } else {
859                let mut weight =
860                    vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
861
862                if let Some(kept_layers_indices) = &kept_layers_indices {
863                    let weight_reshaped = weight.reshape((
864                        orig_num_hidden_layers,
865                        weight.dim(0)? / orig_num_hidden_layers,
866                        weight.dim(1)?,
867                    ))?;
868
869                    weight = weight_reshaped
870                        .index_select(&kept_layers_indices.to_device(weight.device())?, 0)?
871                        .reshape(((), weight_reshaped.dim(D::Minus1)?))?
872                        .contiguous()?;
873                }
874
875                weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
876
877                let bias = if bias {
878                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
879                } else {
880                    None
881                };
882                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
883                    Linear::new(weight, bias),
884                ))?;
885                Arc::new(layer) as Arc<dyn QuantMethod>
886            }
887        };
888
889        let this_unquant = Arc::new(Self(layer));
890        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
891        Ok(this)
892    }
893}
894
895impl QuantMethod for ReplicatedLayer {
896    fn new(_method: QuantMethodConfig) -> Result<Self>
897    where
898        Self: Sized,
899    {
900        hanzo_ml::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
901    }
902
903    fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
904        self.0.forward_raw(a)
905    }
906
907    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
908        self.0.add_delta_w(delta)
909    }
910
911    fn dequantize_w(&self) -> Result<Tensor> {
912        self.0.dequantize_w()
913    }
914
915    fn dtype_and_device(&self) -> (hanzo_ml::DType, hanzo_ml::Device) {
916        self.0.dtype_and_device()
917    }
918
919    fn begin_track_stats(&mut self) -> Result<()> {
920        Arc::get_mut(&mut self.0)
921            .context("Failed to get &mut to weight")?
922            .begin_track_stats()
923    }
924
925    fn end_track_stats(&self) -> Result<Tensor> {
926        self.0.end_track_stats()
927    }
928
929    fn quantized_act_type(&self) -> Option<hanzo_ml::DType> {
930        self.0.quantized_act_type()
931    }
932
933    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
934        self.0.unquant_weight_bias()
935    }
936
937    fn has_bias(&self) -> bool {
938        self.0.has_bias()
939    }
940
941    #[cfg(feature = "cuda")]
942    fn get_qtensor(&self) -> Option<&hanzo_ml::quantized::QTensor> {
943        self.0.get_qtensor()
944    }
945
946    fn apply_isq(
947        self: Arc<Self>,
948        dtype: Option<crate::IsqType>,
949        device: hanzo_ml::Device,
950        n_quantized: &std::sync::atomic::AtomicUsize,
951        imatrix_weight: Option<Vec<f32>>,
952        guard: QuantizeOntoGuard,
953    ) -> Result<Arc<dyn QuantMethod>> {
954        self.0
955            .clone()
956            .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
957    }
958
959    fn is_distributed(&self) -> Option<DistributedKind> {
960        Some(DistributedKind::Replicated)
961    }
962}
963
964impl QuantizedSerde for ReplicatedLayer {
965    fn isq_serde_supported(&self) -> bool {
966        self.0.isq_serde_supported()
967    }
968    fn name(&self) -> &'static str {
969        self.0.name()
970    }
971    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
972        self.0.serialize()
973    }
974    fn deserialize(
975        data: std::borrow::Cow<[u8]>,
976        device: &hanzo_ml::Device,
977        comm: &Arc<crate::Comm>,
978        guard: QuantizeOntoGuard,
979    ) -> Result<Arc<dyn QuantMethod>>
980    where
981        Self: Sized,
982    {
983        // NOTE(hanzoai): isq type is ALWAYS byte 4 (5th) of the tensor.
984        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
985        let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
986            QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
987            QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
988            QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
989            QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
990            QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
991            QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize(data, device, comm, guard)?,
992            QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize(data, device, comm, guard)?,
993        };
994        Ok(Arc::new(Self(deserialized)))
995    }
996}
997
998#[derive(Debug)]
999pub struct PackedExperts {
1000    pub gate_proj: Vec<Arc<dyn QuantMethod>>,
1001    pub up_proj: Vec<Arc<dyn QuantMethod>>,
1002    pub down_proj: Vec<Arc<dyn QuantMethod>>,
1003}
1004
1005impl PackedExperts {
1006    /// Note: we only support AFQ and unquantized here because they are the only ones that support indexed.
1007    #[allow(clippy::too_many_arguments)]
1008    pub fn new(
1009        num_local_experts: usize,
1010        hidden_size: usize,
1011        intermediate_size: usize,
1012        config: &Option<QuantizedConfig>,
1013        bias: bool,
1014        comm: &Arc<crate::Comm>,
1015        vb: ShardedVarBuilder,
1016    ) -> Result<Self> {
1017        if bias {
1018            hanzo_ml::bail!("PackedExperts does not support bias.");
1019        }
1020
1021        let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
1022            // GPTQ and BNB do not support tensor parallelism
1023            if comm.world_size() != 1 {
1024                hanzo_ml::bail!(
1025                    "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
1026                    comm.world_size()
1027                );
1028            }
1029
1030            match quant_conf {
1031                QuantizedConfig::Afq { .. } => {
1032                    if !vb.contains_tensor("gate_up_proj")
1033                        || !vb.contains_tensor("gate_up_proj.weight")
1034                    {
1035                        hanzo_ml::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
1036                    }
1037
1038                    let base_vb = vb.clone();
1039
1040                    let vb_gate_proj = if should_apply_immediate_isq(&vb) {
1041                        vb.pp("gate_proj").set_device(Device::Cpu)
1042                    } else {
1043                        vb.pp("gate_proj")
1044                    };
1045                    let vb_up_proj = if should_apply_immediate_isq(&vb) {
1046                        vb.pp("up_proj").set_device(Device::Cpu)
1047                    } else {
1048                        vb.pp("up_proj")
1049                    };
1050                    let vb_down_proj = if should_apply_immediate_isq(&vb) {
1051                        vb.pp("down_proj").set_device(Device::Cpu)
1052                    } else {
1053                        vb.pp("down_proj")
1054                    };
1055                    let mut gate_proj = AfqLayer::afq_packed_linear_b(
1056                        num_local_experts,
1057                        hidden_size,
1058                        intermediate_size,
1059                        quant_conf,
1060                        bias,
1061                        vb_gate_proj,
1062                    )?;
1063                    let mut up_proj = AfqLayer::afq_packed_linear_b(
1064                        num_local_experts,
1065                        hidden_size,
1066                        intermediate_size,
1067                        quant_conf,
1068                        bias,
1069                        vb_up_proj,
1070                    )?;
1071                    let mut down_proj = AfqLayer::afq_packed_linear_b(
1072                        num_local_experts,
1073                        intermediate_size,
1074                        hidden_size,
1075                        quant_conf,
1076                        bias,
1077                        vb_down_proj,
1078                    )?;
1079
1080                    gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
1081                    up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
1082                    down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
1083
1084                    (vec![gate_proj], vec![up_proj], vec![down_proj])
1085                }
1086                QuantizedConfig::Fp8 { weight_block_size } => {
1087                    // FP8 quantization for PackedExperts
1088                    // Keep weights as FP8 using BlockwiseFP8Linear to leverage native FP8 GEMM
1089                    let Some(weight_block_size) = weight_block_size else {
1090                        hanzo_ml::bail!("Blockwise FP8 for PackedExperts requires weight_block_size to be set.")
1091                    };
1092                    if weight_block_size.len() != 2 {
1093                        hanzo_ml::bail!(
1094                            "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1095                        );
1096                    }
1097
1098                    // Check if we have stacked format (gate_up_proj) or per-expert format
1099                    // Note: vb already has the "experts" prefix from the caller (experts.rs)
1100                    let is_stacked_format = vb.contains_tensor("gate_up_proj");
1101
1102                    if is_stacked_format {
1103                        // Stacked format: load FP8 tensors and split
1104                        let has_fp8_scales = vb.contains_tensor("gate_up_proj.weight_scale_inv");
1105
1106                        if has_fp8_scales {
1107                            // Load gate_up_proj FP8 tensor and scale
1108                            let gate_up_fp8 = vb.get_with_hints_dtype(
1109                                (num_local_experts, hidden_size, intermediate_size * 2),
1110                                "gate_up_proj",
1111                                Default::default(),
1112                                hanzo_ml::DType::F8E4M3,
1113                            )?;
1114                            let gate_up_scale = vb.get_with_hints_dtype(
1115                                (
1116                                    num_local_experts,
1117                                    hidden_size.div_ceil(weight_block_size[0]),
1118                                    (intermediate_size * 2).div_ceil(weight_block_size[1]),
1119                                ),
1120                                "gate_up_proj.weight_scale_inv",
1121                                Default::default(),
1122                                hanzo_ml::DType::F32,
1123                            )?;
1124
1125                            // Load down_proj FP8 tensor and scale
1126                            let down_fp8 = vb.get_with_hints_dtype(
1127                                (num_local_experts, intermediate_size, hidden_size),
1128                                "down_proj",
1129                                Default::default(),
1130                                hanzo_ml::DType::F8E4M3,
1131                            )?;
1132                            let down_scale = vb.get_with_hints_dtype(
1133                                (
1134                                    num_local_experts,
1135                                    intermediate_size.div_ceil(weight_block_size[0]),
1136                                    hidden_size.div_ceil(weight_block_size[1]),
1137                                ),
1138                                "down_proj.weight_scale_inv",
1139                                Default::default(),
1140                                hanzo_ml::DType::F32,
1141                            )?;
1142
1143                            // Split and create individual BlockwiseFP8Linear for each expert
1144                            let mut gs = Vec::new();
1145                            let mut us = Vec::new();
1146                            let mut ds = Vec::new();
1147
1148                            for i in 0..num_local_experts {
1149                                // Extract this expert's weights
1150                                let gate_up_expert =
1151                                    gate_up_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1152                                let gate_up_scale_expert = gate_up_scale.i(i)?.contiguous()?;
1153                                let down_expert = down_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1154                                let down_scale_expert = down_scale.i(i)?.contiguous()?;
1155
1156                                // Split gate_up into gate and up
1157                                let gate_expert = gate_up_expert.narrow(0, 0, intermediate_size)?;
1158                                let up_expert = gate_up_expert.narrow(
1159                                    0,
1160                                    intermediate_size,
1161                                    intermediate_size,
1162                                )?;
1163
1164                                // Split scales
1165                                let gate_scale_expert = gate_up_scale_expert.narrow(
1166                                    1,
1167                                    0,
1168                                    intermediate_size.div_ceil(weight_block_size[1]),
1169                                )?;
1170                                let up_scale_expert = gate_up_scale_expert.narrow(
1171                                    1,
1172                                    intermediate_size.div_ceil(weight_block_size[1]),
1173                                    intermediate_size.div_ceil(weight_block_size[1]),
1174                                )?;
1175
1176                                // Create BlockwiseFP8Linear for each projection
1177                                use crate::blockwise_fp8::BlockwiseFP8Linear;
1178                                use crate::QuantMethodConfig;
1179
1180                                let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1181                                    BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1182                                        weight: gate_expert,
1183                                        weight_scale_inv: gate_scale_expert.transpose(0, 1)?,
1184                                        bias: None,
1185                                        dequant_dtype: vb.dtype(),
1186                                        weight_block_size: weight_block_size.clone(),
1187                                    })?,
1188                                );
1189                                let up_layer: Arc<dyn QuantMethod> = Arc::new(
1190                                    BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1191                                        weight: up_expert,
1192                                        weight_scale_inv: up_scale_expert.transpose(0, 1)?,
1193                                        bias: None,
1194                                        dequant_dtype: vb.dtype(),
1195                                        weight_block_size: weight_block_size.clone(),
1196                                    })?,
1197                                );
1198                                let down_layer: Arc<dyn QuantMethod> = Arc::new(
1199                                    BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1200                                        weight: down_expert,
1201                                        weight_scale_inv: down_scale_expert.transpose(0, 1)?,
1202                                        bias: None,
1203                                        dequant_dtype: vb.dtype(),
1204                                        weight_block_size: weight_block_size.clone(),
1205                                    })?,
1206                                );
1207
1208                                gs.push(gate_layer);
1209                                us.push(up_layer);
1210                                ds.push(down_layer);
1211                            }
1212
1213                            (gs, us, ds)
1214                        } else {
1215                            hanzo_ml::bail!(
1216                                "PackedExperts with FP8 requires weight_scale_inv tensors"
1217                            );
1218                        }
1219                    } else {
1220                        // Per-expert format: load each expert individually
1221                        let mut gs = Vec::new();
1222                        let mut us = Vec::new();
1223                        let mut ds = Vec::new();
1224
1225                        for i in 0..num_local_experts {
1226                            let expert_vb = vb.pp(i);
1227
1228                            // Load FP8 weights and scales for each projection
1229                            let gate_fp8 = expert_vb.get_with_hints_dtype(
1230                                (intermediate_size, hidden_size),
1231                                "gate_proj.weight",
1232                                Default::default(),
1233                                hanzo_ml::DType::F8E4M3,
1234                            )?;
1235                            let gate_scale = expert_vb.get_with_hints_dtype(
1236                                (
1237                                    intermediate_size.div_ceil(weight_block_size[0]),
1238                                    hidden_size.div_ceil(weight_block_size[1]),
1239                                ),
1240                                "gate_proj.weight_scale_inv",
1241                                Default::default(),
1242                                hanzo_ml::DType::F32,
1243                            )?;
1244
1245                            let up_fp8 = expert_vb.get_with_hints_dtype(
1246                                (intermediate_size, hidden_size),
1247                                "up_proj.weight",
1248                                Default::default(),
1249                                hanzo_ml::DType::F8E4M3,
1250                            )?;
1251                            let up_scale = expert_vb.get_with_hints_dtype(
1252                                (
1253                                    intermediate_size.div_ceil(weight_block_size[0]),
1254                                    hidden_size.div_ceil(weight_block_size[1]),
1255                                ),
1256                                "up_proj.weight_scale_inv",
1257                                Default::default(),
1258                                hanzo_ml::DType::F32,
1259                            )?;
1260
1261                            let down_fp8 = expert_vb.get_with_hints_dtype(
1262                                (hidden_size, intermediate_size),
1263                                "down_proj.weight",
1264                                Default::default(),
1265                                hanzo_ml::DType::F8E4M3,
1266                            )?;
1267                            let down_scale = expert_vb.get_with_hints_dtype(
1268                                (
1269                                    hidden_size.div_ceil(weight_block_size[0]),
1270                                    intermediate_size.div_ceil(weight_block_size[1]),
1271                                ),
1272                                "down_proj.weight_scale_inv",
1273                                Default::default(),
1274                                hanzo_ml::DType::F32,
1275                            )?;
1276
1277                            // Create BlockwiseFP8Linear for each projection
1278                            use crate::blockwise_fp8::BlockwiseFP8Linear;
1279                            use crate::QuantMethodConfig;
1280
1281                            let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1282                                BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1283                                    weight: gate_fp8,
1284                                    weight_scale_inv: gate_scale,
1285                                    bias: None,
1286                                    dequant_dtype: vb.dtype(),
1287                                    weight_block_size: weight_block_size.clone(),
1288                                })?,
1289                            );
1290                            let up_layer: Arc<dyn QuantMethod> = Arc::new(BlockwiseFP8Linear::new(
1291                                QuantMethodConfig::BlockwiseFP8 {
1292                                    weight: up_fp8,
1293                                    weight_scale_inv: up_scale,
1294                                    bias: None,
1295                                    dequant_dtype: vb.dtype(),
1296                                    weight_block_size: weight_block_size.clone(),
1297                                },
1298                            )?);
1299                            let down_layer: Arc<dyn QuantMethod> = Arc::new(
1300                                BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1301                                    weight: down_fp8,
1302                                    weight_scale_inv: down_scale,
1303                                    bias: None,
1304                                    dequant_dtype: vb.dtype(),
1305                                    weight_block_size: weight_block_size.clone(),
1306                                })?,
1307                            );
1308
1309                            gs.push(gate_layer);
1310                            us.push(up_layer);
1311                            ds.push(down_layer);
1312                        }
1313
1314                        (gs, us, ds)
1315                    }
1316                }
1317                QuantizedConfig::MXFP4 {} => {
1318                    // MXFP4 quantization for PackedExperts
1319                    // Keep weights as MXFP4 using MXFP4Layer to leverage native MXFP4 GEMM
1320                    // Note: MXFP4 models use stacked format, so we load directly as packed experts
1321                    let gate_proj = MXFP4Layer::packed_linear_b(
1322                        num_local_experts,
1323                        hidden_size,
1324                        intermediate_size,
1325                        quant_conf,
1326                        bias,
1327                        vb.pp("gate_proj"),
1328                    )?;
1329                    let up_proj = MXFP4Layer::packed_linear_b(
1330                        num_local_experts,
1331                        hidden_size,
1332                        intermediate_size,
1333                        quant_conf,
1334                        bias,
1335                        vb.pp("up_proj"),
1336                    )?;
1337                    let down_proj = MXFP4Layer::packed_linear_b(
1338                        num_local_experts,
1339                        intermediate_size,
1340                        hidden_size,
1341                        quant_conf,
1342                        bias,
1343                        vb.pp("down_proj"),
1344                    )?;
1345
1346                    (vec![gate_proj], vec![up_proj], vec![down_proj])
1347                }
1348                _ => hanzo_ml::bail!(
1349                    "PackedExperts with quantization config only allows AFQ, FP8, or MXFP4 quantization"
1350                ),
1351            }
1352        } else if !vb.contains_tensor("gate_up_proj") {
1353            // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
1354            let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
1355            let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
1356            let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
1357            for _ in 0..num_local_experts {
1358                gs.push(make_dummy_or_error(
1359                    "packed_experts_gate_proj",
1360                    &vb,
1361                    &["gate_up_proj"],
1362                )?);
1363                us.push(make_dummy_or_error(
1364                    "packed_experts_up_proj",
1365                    &vb,
1366                    &["gate_up_proj"],
1367                )?);
1368                ds.push(make_dummy_or_error(
1369                    "packed_experts_down_proj",
1370                    &vb,
1371                    &["gate_up_proj"],
1372                )?);
1373            }
1374            (gs, us, ds)
1375        } else {
1376            // Parallelized like:
1377            // Each gpu holds all experts.
1378            // Gate/Up proj is parallelized on dim 2 (column)
1379            // Down proj is parallelized on dim 1 (row)
1380            // All reduce at the end.
1381
1382            // Handle the case where the layer is dummy (no tensors)
1383            let gate_up_block_size = intermediate_size / comm.world_size();
1384            let gate_up_start = gate_up_block_size * comm.rank();
1385
1386            // Gate is right before Up in the gate_up
1387            let shard_gate = Shard::Offset {
1388                dim: 2,
1389                offset: gate_up_start,
1390                len: gate_up_block_size,
1391            };
1392            let shard_up = Shard::Offset {
1393                dim: 2,
1394                offset: intermediate_size + gate_up_start,
1395                len: gate_up_block_size,
1396            };
1397            let shard_down = Shard::Simple {
1398                dim: 1,
1399                rank: comm.rank(),
1400                world_size: comm.world_size(),
1401            };
1402
1403            let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
1404                vb.pp("gate_up_proj").set_device(Device::Cpu)
1405            } else {
1406                vb.pp("gate_up_proj")
1407            };
1408            let vb_down_proj = if should_apply_immediate_isq(&vb) {
1409                vb.pp("down_proj").set_device(Device::Cpu)
1410            } else {
1411                vb.pp("down_proj")
1412            };
1413
1414            let gate_proj = vb
1415                .get_with_hints(
1416                    (num_local_experts, hidden_size, intermediate_size * 2),
1417                    "gate_up_proj",
1418                    shard_gate,
1419                )?
1420                .t()?
1421                .contiguous()?;
1422            let up_proj = vb
1423                .get_with_hints(
1424                    (num_local_experts, hidden_size, intermediate_size * 2),
1425                    "gate_up_proj",
1426                    shard_up,
1427                )?
1428                .t()?
1429                .contiguous()?;
1430            let down_proj = vb
1431                .get_with_hints(
1432                    (num_local_experts, intermediate_size, hidden_size),
1433                    "down_proj",
1434                    shard_down,
1435                )?
1436                .t()?
1437                .contiguous()?;
1438
1439            let gc = gate_proj.chunk(num_local_experts, 0)?;
1440            let uc = up_proj.chunk(num_local_experts, 0)?;
1441            let dc = down_proj.chunk(num_local_experts, 0)?;
1442            drop((gate_proj, up_proj, down_proj));
1443
1444            let mut gs = Vec::new();
1445            let mut us = Vec::new();
1446            let mut ds = Vec::new();
1447            for ((mut gate_proj, mut up_proj), mut down_proj) in gc.into_iter().zip(uc).zip(dc) {
1448                gate_proj = gate_proj.squeeze(0)?;
1449                up_proj = up_proj.squeeze(0)?;
1450                down_proj = down_proj.squeeze(0)?;
1451                let gate_proj = merge_lora_weights(
1452                    &vb,
1453                    gate_proj,
1454                    hidden_size,
1455                    intermediate_size * 2,
1456                    shard_gate,
1457                )?;
1458                let up_proj =
1459                    merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
1460                let down_proj =
1461                    merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
1462
1463                let mut gate_proj: Arc<dyn QuantMethod> =
1464                    Arc::new(<UnquantLinear as QuantMethod>::new(
1465                        QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1466                    )?);
1467                gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
1468                let mut up_proj: Arc<dyn QuantMethod> =
1469                    Arc::new(<UnquantLinear as QuantMethod>::new(
1470                        QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1471                    )?);
1472                up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
1473                let mut down_proj: Arc<dyn QuantMethod> =
1474                    Arc::new(<UnquantLinear as QuantMethod>::new(
1475                        QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1476                    )?);
1477                down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
1478                gs.push(gate_proj);
1479                us.push(up_proj);
1480                ds.push(down_proj);
1481            }
1482            (gs, us, ds)
1483        };
1484
1485        Ok(Self {
1486            gate_proj,
1487            up_proj,
1488            down_proj,
1489        })
1490    }
1491}
1492
1493pub struct FusedExperts {
1494    pub fused_gate_proj: Arc<dyn QuantMethod>,
1495    pub fused_up_proj: Arc<dyn QuantMethod>,
1496    pub fused_down_proj: Arc<dyn QuantMethod>,
1497}
1498
1499impl FusedExperts {
1500    pub fn new(
1501        hidden_size: usize,
1502        moe_intermediate_size: usize,
1503        num_experts: usize,
1504        quantization_config: &Option<QuantizedConfig>,
1505        vb: ShardedVarBuilder,
1506    ) -> Result<Self> {
1507        // Detect if weights are in stacked format (e.g., Qwen3 VL MoE):
1508        // - experts.gate_up_proj: (num_experts, hidden_size, intermediate_size * 2)
1509        // - experts.down_proj: (num_experts, intermediate_size, hidden_size)
1510        // Or per-expert format (e.g., Qwen3 MoE):
1511        // - experts.{i}.gate_proj.weight, experts.{i}.up_proj.weight, experts.{i}.down_proj.weight
1512        let experts_vb = vb.pp("experts");
1513        let is_stacked_format = experts_vb.contains_tensor("gate_up_proj");
1514
1515        let (fused_gate_proj, fused_up_proj, fused_down_proj) = if matches!(
1516            &quantization_config,
1517            Some(QuantizedConfig::Afq { .. })
1518        ) {
1519            let quantization_config = quantization_config.as_ref().unwrap();
1520
1521            let fused_gate_proj = AfqLayer::afq_packed_linear_b(
1522                num_experts,
1523                hidden_size,
1524                moe_intermediate_size,
1525                quantization_config,
1526                false,
1527                vb.pp("switch_mlp.gate_proj"),
1528            )?;
1529            let fused_up_proj = AfqLayer::afq_packed_linear_b(
1530                num_experts,
1531                hidden_size,
1532                moe_intermediate_size,
1533                quantization_config,
1534                false,
1535                vb.pp("switch_mlp.up_proj"),
1536            )?;
1537            let fused_down_proj = AfqLayer::afq_packed_linear_b(
1538                num_experts,
1539                moe_intermediate_size,
1540                hidden_size,
1541                quantization_config,
1542                false,
1543                vb.pp("switch_mlp.down_proj"),
1544            )?;
1545
1546            (fused_gate_proj, fused_up_proj, fused_down_proj)
1547        } else if is_stacked_format
1548            && matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. }))
1549        {
1550            // Stacked format with FP8 quantization
1551            // Keep weights as FP8 using BlockwiseFP8 to leverage native FP8 GEMM in gather_forward
1552            let has_fp8_scales = experts_vb.contains_tensor("gate_up_proj.weight_scale_inv");
1553
1554            if has_fp8_scales {
1555                let weight_block_size = match quantization_config {
1556                    Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1557                    _ => unreachable!(),
1558                };
1559
1560                let Some(weight_block_size) = weight_block_size else {
1561                    hanzo_ml::bail!(
1562                        "Blockwise FP8 for stacked experts requires weight_block_size to be set."
1563                    )
1564                };
1565                if weight_block_size.len() != 2 {
1566                    hanzo_ml::bail!(
1567                        "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1568                    );
1569                }
1570
1571                // Load gate_up_proj FP8 tensor and scale
1572                // Shape: [num_experts, hidden_size, intermediate_size * 2]
1573                let gate_up_fp8 = experts_vb.get_with_hints_dtype(
1574                    (num_experts, hidden_size, moe_intermediate_size * 2),
1575                    "gate_up_proj",
1576                    Default::default(),
1577                    hanzo_ml::DType::F8E4M3,
1578                )?;
1579                let gate_up_scale = experts_vb.get_with_hints_dtype(
1580                    (
1581                        num_experts,
1582                        hidden_size.div_ceil(weight_block_size[0]),
1583                        (moe_intermediate_size * 2).div_ceil(weight_block_size[1]),
1584                    ),
1585                    "gate_up_proj.weight_scale_inv",
1586                    Default::default(),
1587                    hanzo_ml::DType::F32,
1588                )?;
1589
1590                // Load down_proj FP8 tensor and scale
1591                // Shape: [num_experts, intermediate_size, hidden_size]
1592                let down_fp8 = experts_vb.get_with_hints_dtype(
1593                    (num_experts, moe_intermediate_size, hidden_size),
1594                    "down_proj",
1595                    Default::default(),
1596                    hanzo_ml::DType::F8E4M3,
1597                )?;
1598                let down_scale = experts_vb.get_with_hints_dtype(
1599                    (
1600                        num_experts,
1601                        moe_intermediate_size.div_ceil(weight_block_size[0]),
1602                        hidden_size.div_ceil(weight_block_size[1]),
1603                    ),
1604                    "down_proj.weight_scale_inv",
1605                    Default::default(),
1606                    hanzo_ml::DType::F32,
1607                )?;
1608
1609                // Split gate_up into gate and up
1610                let gate_fp8 = gate_up_fp8.narrow(2, 0, moe_intermediate_size)?;
1611                let up_fp8 = gate_up_fp8.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1612
1613                // Split scales similarly
1614                let gate_scale = gate_up_scale.narrow(
1615                    2,
1616                    0,
1617                    moe_intermediate_size.div_ceil(weight_block_size[1]),
1618                )?;
1619                let up_scale = gate_up_scale.narrow(
1620                    2,
1621                    moe_intermediate_size.div_ceil(weight_block_size[1]),
1622                    moe_intermediate_size.div_ceil(weight_block_size[1]),
1623                )?;
1624
1625                // Transpose to match expected format: [num_experts, N, K]
1626                // gate/up: [num_experts, hidden_size, intermediate_size] -> [num_experts, intermediate_size, hidden_size]
1627                let gate_fp8 = gate_fp8.transpose(1, 2)?.contiguous()?;
1628                let up_fp8 = up_fp8.transpose(1, 2)?.contiguous()?;
1629                // down: [num_experts, intermediate_size, hidden_size] -> [num_experts, hidden_size, intermediate_size]
1630                let down_fp8 = down_fp8.transpose(1, 2)?.contiguous()?;
1631
1632                // Transpose scales to match weight layout
1633                let gate_scale = gate_scale.transpose(1, 2)?.contiguous()?;
1634                let up_scale = up_scale.transpose(1, 2)?.contiguous()?;
1635                let down_scale = down_scale.transpose(1, 2)?.contiguous()?;
1636
1637                // Create BlockwiseFP8Linear for each projection
1638                let fused_gate_proj =
1639                    blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1640                let fused_up_proj =
1641                    blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1642                let fused_down_proj =
1643                    blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1644
1645                (fused_gate_proj, fused_up_proj, fused_down_proj)
1646            } else {
1647                // FP8 config but no scale tensors - weights are actually unquantized
1648                tracing::warn!(
1649                        "FP8 quantization config specified but no scale tensors found for stacked MoE experts. \
1650                        Loading as unquantized."
1651                    );
1652                let gate_up_proj = experts_vb
1653                    .get(
1654                        (num_experts, hidden_size, moe_intermediate_size * 2),
1655                        "gate_up_proj",
1656                    )
1657                    .or_else(|_| {
1658                        experts_vb
1659                            .get(
1660                                (num_experts, moe_intermediate_size * 2, hidden_size),
1661                                "gate_up_proj",
1662                            )
1663                            .and_then(|t| t.transpose(1, 2)?.contiguous())
1664                    })?;
1665                let down_proj_packed = experts_vb
1666                    .get(
1667                        (num_experts, moe_intermediate_size, hidden_size),
1668                        "down_proj",
1669                    )
1670                    .or_else(|_| {
1671                        experts_vb
1672                            .get(
1673                                (num_experts, hidden_size, moe_intermediate_size),
1674                                "down_proj",
1675                            )
1676                            .and_then(|t| t.transpose(1, 2)?.contiguous())
1677                    })?;
1678
1679                // Split gate_up_proj into gate_proj and up_proj along the last dimension
1680                let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1681                let up_proj =
1682                    gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1683
1684                // Transpose dims 1 and 2 to match GGUF format
1685                let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1686                let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1687                let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1688
1689                // When immediate ISQ targets these weights, move to CPU for GGML quantization.
1690                let isq_gate_up = should_apply_immediate_isq(&experts_vb.pp("gate_up_proj"));
1691                let isq_down = should_apply_immediate_isq(&experts_vb.pp("down_proj"));
1692                let target_device = gate_proj.device().clone();
1693                let (gate_proj, up_proj, down_proj) =
1694                    if (isq_gate_up || isq_down) && !target_device.is_cpu() {
1695                        (
1696                            gate_proj.to_device(&Device::Cpu)?,
1697                            up_proj.to_device(&Device::Cpu)?,
1698                            down_proj.to_device(&Device::Cpu)?,
1699                        )
1700                    } else {
1701                        (gate_proj, up_proj, down_proj)
1702                    };
1703
1704                let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1705                    QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1706                )?);
1707                let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1708                    QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1709                )?);
1710                let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1711                    QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1712                )?);
1713                // Pass the original-device VB so apply_immediate_isq targets
1714                // the correct device and respects topology overrides.
1715                let vb_gate_up = experts_vb.pp("gate_up_proj");
1716                let vb_down = experts_vb.pp("down_proj");
1717                fused_gate_proj = apply_immediate_isq(fused_gate_proj, vb_gate_up.clone())?;
1718                fused_up_proj = apply_immediate_isq(fused_up_proj, vb_gate_up)?;
1719                fused_down_proj = apply_immediate_isq(fused_down_proj, vb_down)?;
1720
1721                (fused_gate_proj, fused_up_proj, fused_down_proj)
1722            }
1723        } else if is_stacked_format
1724            && matches!(&quantization_config, Some(QuantizedConfig::MXFP4 {}))
1725        {
1726            // Stacked format with MXFP4 quantization
1727            // For MXFP4, weights are stored as packed FP4 (2 values per byte)
1728            // with E8M0 scales
1729            let quantization_config = quantization_config.as_ref().unwrap();
1730
1731            // Load MXFP4 packed experts using MXFP4Layer::packed_linear_b
1732            // The tensors are expected at:
1733            //   gate_proj.blocks: [num_experts, intermediate_size, hidden_size/2]
1734            //   gate_proj.scales: [num_experts, intermediate_size, hidden_size/32]
1735            let fused_gate_proj = MXFP4Layer::packed_linear_b(
1736                num_experts,
1737                hidden_size,
1738                moe_intermediate_size,
1739                quantization_config,
1740                false,
1741                experts_vb.pp("gate_proj"),
1742            )?;
1743            let fused_up_proj = MXFP4Layer::packed_linear_b(
1744                num_experts,
1745                hidden_size,
1746                moe_intermediate_size,
1747                quantization_config,
1748                false,
1749                experts_vb.pp("up_proj"),
1750            )?;
1751            let fused_down_proj = MXFP4Layer::packed_linear_b(
1752                num_experts,
1753                moe_intermediate_size,
1754                hidden_size,
1755                quantization_config,
1756                false,
1757                experts_vb.pp("down_proj"),
1758            )?;
1759
1760            (fused_gate_proj, fused_up_proj, fused_down_proj)
1761        } else if is_stacked_format {
1762            // Stacked format from safetensors. Two conventions exist:
1763            // Convention A: [num_experts, hidden_size, intermediate_size * 2]
1764            // Convention B (nn.Linear): [num_experts, intermediate_size * 2, hidden_size]
1765            //
1766            // GGUF/indexed_moe_forward expects:
1767            // - gate/up: [num_experts, intermediate_size, hidden_size]
1768            // - down: [num_experts, hidden_size, intermediate_size]
1769            //
1770            // Try convention A first, fall back to convention B (transposing to A).
1771            let gate_up_proj = experts_vb
1772                .get(
1773                    (num_experts, hidden_size, moe_intermediate_size * 2),
1774                    "gate_up_proj",
1775                )
1776                .or_else(|_| {
1777                    experts_vb
1778                        .get(
1779                            (num_experts, moe_intermediate_size * 2, hidden_size),
1780                            "gate_up_proj",
1781                        )
1782                        .and_then(|t| t.transpose(1, 2)?.contiguous())
1783                })?;
1784            let down_proj_packed = experts_vb
1785                .get(
1786                    (num_experts, moe_intermediate_size, hidden_size),
1787                    "down_proj",
1788                )
1789                .or_else(|_| {
1790                    experts_vb
1791                        .get(
1792                            (num_experts, hidden_size, moe_intermediate_size),
1793                            "down_proj",
1794                        )
1795                        .and_then(|t| t.transpose(1, 2)?.contiguous())
1796                })?;
1797
1798            // Split gate_up_proj into gate_proj and up_proj along the last dimension
1799            // gate_proj: [num_experts, hidden_size, intermediate_size]
1800            // up_proj: [num_experts, hidden_size, intermediate_size]
1801            let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1802            let up_proj = gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1803
1804            // Transpose dims 1 and 2 to match GGUF format:
1805            // gate/up: [num_experts, hidden_size, intermediate_size] -> [num_experts, intermediate_size, hidden_size]
1806            let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1807            let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1808            // down_proj: [num_experts, intermediate_size, hidden_size] -> [num_experts, hidden_size, intermediate_size]
1809            let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1810
1811            // When immediate ISQ targets these weights, move to CPU for GGML quantization.
1812            let isq_gate_up = should_apply_immediate_isq(&experts_vb.pp("gate_up_proj"));
1813            let isq_down = should_apply_immediate_isq(&experts_vb.pp("down_proj"));
1814            let target_device = gate_proj.device().clone();
1815            let (gate_proj, up_proj, down_proj) =
1816                if (isq_gate_up || isq_down) && !target_device.is_cpu() {
1817                    (
1818                        gate_proj.to_device(&Device::Cpu)?,
1819                        up_proj.to_device(&Device::Cpu)?,
1820                        down_proj.to_device(&Device::Cpu)?,
1821                    )
1822                } else {
1823                    (gate_proj, up_proj, down_proj)
1824                };
1825
1826            let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1827                QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1828            )?);
1829            let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1830                QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1831            )?);
1832            let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1833                QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1834            )?);
1835            // Pass the original-device VB so apply_immediate_isq targets
1836            // the correct device and respects topology overrides.
1837            let vb_gate_up = experts_vb.pp("gate_up_proj");
1838            let vb_down = experts_vb.pp("down_proj");
1839            fused_gate_proj = apply_immediate_isq(fused_gate_proj, vb_gate_up.clone())?;
1840            fused_up_proj = apply_immediate_isq(fused_up_proj, vb_gate_up)?;
1841            fused_down_proj = apply_immediate_isq(fused_down_proj, vb_down)?;
1842
1843            (fused_gate_proj, fused_up_proj, fused_down_proj)
1844        } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
1845            // Per-expert format with FP8 quantization
1846            // Keep weights as FP8 using BlockwiseFP8 to leverage native FP8 GEMM in gather_forward
1847            let weight_block_size = match quantization_config {
1848                Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1849                _ => unreachable!(),
1850            };
1851
1852            let Some(weight_block_size) = weight_block_size else {
1853                hanzo_ml::bail!(
1854                    "Blockwise FP8 for per-expert format requires weight_block_size to be set."
1855                )
1856            };
1857            if weight_block_size.len() != 2 {
1858                hanzo_ml::bail!(
1859                    "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1860                );
1861            }
1862
1863            let mut gate_fp8_vec = Vec::new();
1864            let mut gate_scale_vec = Vec::new();
1865            let mut up_fp8_vec = Vec::new();
1866            let mut up_scale_vec = Vec::new();
1867            let mut down_fp8_vec = Vec::new();
1868            let mut down_scale_vec = Vec::new();
1869
1870            for i in 0..num_experts {
1871                let expert_vb = experts_vb.pp(i);
1872
1873                // Load FP8 weights and scales for each projection
1874                let gate_fp8 = expert_vb.get_with_hints_dtype(
1875                    (moe_intermediate_size, hidden_size),
1876                    "gate_proj.weight",
1877                    Default::default(),
1878                    hanzo_ml::DType::F8E4M3,
1879                )?;
1880                let gate_scale = expert_vb.get_with_hints_dtype(
1881                    (
1882                        moe_intermediate_size.div_ceil(weight_block_size[0]),
1883                        hidden_size.div_ceil(weight_block_size[1]),
1884                    ),
1885                    "gate_proj.weight_scale_inv",
1886                    Default::default(),
1887                    hanzo_ml::DType::F32,
1888                )?;
1889
1890                let up_fp8 = expert_vb.get_with_hints_dtype(
1891                    (moe_intermediate_size, hidden_size),
1892                    "up_proj.weight",
1893                    Default::default(),
1894                    hanzo_ml::DType::F8E4M3,
1895                )?;
1896                let up_scale = expert_vb.get_with_hints_dtype(
1897                    (
1898                        moe_intermediate_size.div_ceil(weight_block_size[0]),
1899                        hidden_size.div_ceil(weight_block_size[1]),
1900                    ),
1901                    "up_proj.weight_scale_inv",
1902                    Default::default(),
1903                    hanzo_ml::DType::F32,
1904                )?;
1905
1906                let down_fp8 = expert_vb.get_with_hints_dtype(
1907                    (hidden_size, moe_intermediate_size),
1908                    "down_proj.weight",
1909                    Default::default(),
1910                    hanzo_ml::DType::F8E4M3,
1911                )?;
1912                let down_scale = expert_vb.get_with_hints_dtype(
1913                    (
1914                        hidden_size.div_ceil(weight_block_size[0]),
1915                        moe_intermediate_size.div_ceil(weight_block_size[1]),
1916                    ),
1917                    "down_proj.weight_scale_inv",
1918                    Default::default(),
1919                    hanzo_ml::DType::F32,
1920                )?;
1921
1922                gate_fp8_vec.push(gate_fp8);
1923                gate_scale_vec.push(gate_scale);
1924                up_fp8_vec.push(up_fp8);
1925                up_scale_vec.push(up_scale);
1926                down_fp8_vec.push(down_fp8);
1927                down_scale_vec.push(down_scale);
1928            }
1929
1930            // Stack into [num_experts, N, K]
1931            let gate_fp8 = Tensor::stack(&gate_fp8_vec, 0)?;
1932            let gate_scale = Tensor::stack(&gate_scale_vec, 0)?;
1933            let up_fp8 = Tensor::stack(&up_fp8_vec, 0)?;
1934            let up_scale = Tensor::stack(&up_scale_vec, 0)?;
1935            let down_fp8 = Tensor::stack(&down_fp8_vec, 0)?;
1936            let down_scale = Tensor::stack(&down_scale_vec, 0)?;
1937
1938            // Create BlockwiseFP8Linear for each projection
1939            let fused_gate_proj =
1940                blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1941            let fused_up_proj =
1942                blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1943            let fused_down_proj =
1944                blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1945
1946            (fused_gate_proj, fused_up_proj, fused_down_proj)
1947        } else if !experts_vb.pp("0").contains_tensor("gate_proj.weight") {
1948            // Handle the case where the layer is dummy (no tensors) during UQFF loading.
1949            // Deserialize will handle it.
1950            let expert_vb = experts_vb.pp("0");
1951            let fused_gate_proj =
1952                make_dummy_or_error("fused_experts_gate_proj", &expert_vb, &["gate_proj.weight"])?;
1953            let fused_up_proj =
1954                make_dummy_or_error("fused_experts_up_proj", &expert_vb, &["gate_proj.weight"])?;
1955            let fused_down_proj =
1956                make_dummy_or_error("fused_experts_down_proj", &expert_vb, &["gate_proj.weight"])?;
1957            (fused_gate_proj, fused_up_proj, fused_down_proj)
1958        } else {
1959            // Per-expert format: load each expert individually and stack
1960            // When immediate ISQ is active, load on CPU for GGML quantization.
1961            let load_experts_vb =
1962                if crate::get_immediate_isq().is_some() && !experts_vb.device().is_cpu() {
1963                    experts_vb.clone().set_device(Device::Cpu)
1964                } else {
1965                    experts_vb.clone()
1966                };
1967            let mut gate_proj_vec = Vec::new();
1968            let mut up_proj_vec = Vec::new();
1969            let mut down_proj_vec = Vec::new();
1970            for i in 0..num_experts {
1971                let expert_vb = load_experts_vb.pp(i);
1972                let gate_proj =
1973                    expert_vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
1974                let up_proj =
1975                    expert_vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
1976                let down_proj =
1977                    expert_vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
1978
1979                gate_proj_vec.push(gate_proj);
1980                up_proj_vec.push(up_proj);
1981                down_proj_vec.push(down_proj);
1982            }
1983
1984            let mut gate_proj: Arc<dyn QuantMethod> =
1985                Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1986                    Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1987                ))?);
1988            let mut up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1989                QuantMethodConfig::Unquantized(Linear::new(Tensor::stack(&up_proj_vec, 0)?, None)),
1990            )?);
1991            let mut down_proj: Arc<dyn QuantMethod> =
1992                Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1993                    Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1994                ))?);
1995            // Use experts.0.{proj} prefix to match the actual weight paths for ISQ predicate matching
1996            let expert0_vb = experts_vb.pp("0");
1997            gate_proj = apply_immediate_isq(gate_proj, expert0_vb.pp("gate_proj"))?;
1998            up_proj = apply_immediate_isq(up_proj, expert0_vb.pp("up_proj"))?;
1999            down_proj = apply_immediate_isq(down_proj, expert0_vb.pp("down_proj"))?;
2000
2001            (gate_proj, up_proj, down_proj)
2002        };
2003
2004        Ok(Self {
2005            fused_gate_proj,
2006            fused_up_proj,
2007            fused_down_proj,
2008        })
2009    }
2010}
2011
2012/// Compute the appropriate KV shard. This handles KV head replication. Be sure to use `compute_n_kv_groups` in tandem.
2013pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
2014    if comm.world_size() == 1 {
2015        return Shard::default();
2016    }
2017
2018    // Tensor parallelism case
2019
2020    // We may need to replicate the kv heads
2021    let kv_replicate = if comm.world_size() > total_num_kv_heads {
2022        comm.world_size() / total_num_kv_heads
2023    } else {
2024        return Shard::Simple {
2025            dim: 0,
2026            rank: comm.rank(),
2027            world_size: comm.world_size(),
2028        };
2029    };
2030
2031    let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
2032    let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
2033    Shard::Offset {
2034        dim: 0,
2035        offset: kv_shard_id * head_dim,
2036        len: head_dim,
2037    }
2038}
2039
2040/// Compute the number of KV groups, taking into account KV head replication.
2041pub fn compute_n_kv_groups(
2042    total_num_kv_heads: usize,
2043    num_attention_heads: usize,
2044    comm: &Comm,
2045) -> usize {
2046    let kv_replicate = if comm.world_size() > total_num_kv_heads {
2047        comm.world_size() / total_num_kv_heads
2048    } else {
2049        1
2050    };
2051    (num_attention_heads / total_num_kv_heads)
2052        .checked_div(kv_replicate)
2053        .unwrap_or(num_attention_heads / total_num_kv_heads)
2054}