Skip to main content

mistralrs_quant/distributed/
layers.rs

1use std::sync::Arc;
2
3use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
4use candle_nn::Linear;
5
6use crate::{
7    blockwise_fp8::{blockwise_fp8_linear_b, blockwise_fp8_moe},
8    distributed,
9    gptq::gptq_linear,
10    lora::merge_lora_weights,
11    pertensor_fp8::pertensor_fp8_linear_b,
12    should_apply_immediate_isq,
13    utils::isq::apply_immediate_isq,
14    AfqLayer, BnbLinear, DistributedKind, DummyLayer, F8Q8Linear, FP8Linear, GgufMatMul, HqqLayer,
15    MXFP4Layer, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
16    QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
17};
18
19use super::{Comm, SumAllReduce};
20
21fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
22    Shard::Simple {
23        dim,
24        rank,
25        world_size,
26    }
27}
28
29/// This layer has a weight that is parallelized along the input dimension,
30/// returning the "full" output dimension.
31#[derive(Debug)]
32pub struct RowParallelLayer {
33    weight: Arc<dyn QuantMethod>,
34    bias: Option<Tensor>,
35    all_reduce: distributed::SumAllReduce,
36}
37
38impl RowParallelLayer {
39    #[allow(clippy::new_ret_no_self)]
40    pub fn new(
41        in_dim: usize,
42        out_dim: usize,
43        config: &Option<QuantizedConfig>,
44        bias: bool,
45        comm: &Arc<crate::Comm>,
46        vb: ShardedVarBuilder,
47    ) -> Result<Arc<dyn QuantMethod>> {
48        let rank = comm.rank();
49        let world_size = comm.world_size();
50        let shard = shard(1, rank, world_size);
51
52        let base_vb = vb.clone();
53        let vb = if should_apply_immediate_isq(&vb) {
54            vb.set_device(Device::Cpu)
55        } else {
56            vb
57        };
58
59        let weight = if let Some(quant_conf) = &config {
60            // GPTQ and BNB do not support tensor parallelism
61            if matches!(
62                quant_conf,
63                QuantizedConfig::GptqAwq { .. }
64                    | QuantizedConfig::Bitsandbytes { .. }
65                    | QuantizedConfig::Afq { .. }
66            ) && comm.world_size() != 1
67            {
68                candle_core::bail!(
69                    "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
70                    comm.world_size()
71                );
72            }
73
74            match quant_conf {
75                QuantizedConfig::GptqAwq { .. } => {
76                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
77                }
78                QuantizedConfig::Fp8 { weight_block_size } => {
79                    // NOTE: no bias for fp8 as it might be parallelized
80                    if weight_block_size.is_some() {
81                        blockwise_fp8_linear_b(
82                            in_dim,
83                            out_dim,
84                            quant_conf,
85                            false,
86                            shard,
87                            vb.clone(),
88                        )?
89                    } else {
90                        pertensor_fp8_linear_b(
91                            in_dim,
92                            out_dim,
93                            quant_conf,
94                            false,
95                            shard,
96                            vb.clone(),
97                        )?
98                    }
99                }
100                QuantizedConfig::Bitsandbytes { .. } => {
101                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
102                }
103                QuantizedConfig::Afq { .. } => {
104                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
105                }
106                QuantizedConfig::MXFP4 {} => {
107                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
108                }
109            }
110        } else {
111            // Handle the case where the layer is dummy (no tensors)
112            if !vb.contains_tensor("weight") {
113                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
114                Arc::new(layer) as Arc<dyn QuantMethod>
115            } else {
116                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
117                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
118
119                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
120                    Linear::new(weight, None),
121                ))?;
122                Arc::new(layer) as Arc<dyn QuantMethod>
123            }
124        };
125
126        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
127        let bias = if bias && vb.contains_tensor("bias") {
128            Some(vb.get((out_dim,), "bias")?)
129        } else {
130            None
131        };
132
133        let this_unquant = Arc::new(Self {
134            weight,
135            bias,
136            all_reduce: distributed::SumAllReduce::new(comm),
137        });
138        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
139        Ok(this)
140    }
141
142    #[allow(clippy::new_ret_no_self)]
143    pub fn new_matformer(
144        in_dim: usize,
145        out_dim: usize,
146        orig_intermediate_size: usize,
147        config: &Option<QuantizedConfig>,
148        bias: bool,
149        comm: &Arc<crate::Comm>,
150        vb: ShardedVarBuilder,
151    ) -> Result<Arc<dyn QuantMethod>> {
152        let rank = comm.rank();
153        let world_size = comm.world_size();
154        let shard = shard(1, rank, world_size);
155
156        let base_vb = vb.clone();
157        let vb = if should_apply_immediate_isq(&vb) {
158            vb.set_device(Device::Cpu)
159        } else {
160            vb
161        };
162
163        if config.is_some() {
164            candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
165        }
166
167        // Handle the case where the layer is dummy (no tensors)
168        let weight = if !vb.contains_tensor("weight") {
169            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
170            Arc::new(layer) as Arc<dyn QuantMethod>
171        } else {
172            let weight = vb
173                .get_with_hints(
174                    (out_dim, orig_intermediate_size),
175                    "weight",
176                    Default::default(),
177                )?
178                .i((.., ..in_dim))?
179                .contiguous()?;
180
181            let weight = shard.apply_to(&weight)?;
182            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
183
184            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
185                Linear::new(weight, None),
186            ))?;
187            Arc::new(layer) as Arc<dyn QuantMethod>
188        };
189
190        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
191        let bias = if bias && vb.contains_tensor("bias") {
192            Some(vb.get((out_dim,), "bias")?)
193        } else {
194            None
195        };
196
197        let this_unquant = Arc::new(Self {
198            weight,
199            bias,
200            all_reduce: distributed::SumAllReduce::new(comm),
201        });
202        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
203        Ok(this)
204    }
205}
206
207impl QuantMethod for RowParallelLayer {
208    fn new(_method: QuantMethodConfig) -> Result<Self>
209    where
210        Self: Sized,
211    {
212        candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
213    }
214
215    fn forward(&self, a: &Tensor) -> Result<Tensor> {
216        let mut xs = self.weight.forward(a)?;
217        xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
218        if let Some(bias) = &self.bias {
219            xs = xs.broadcast_add(bias)?;
220        }
221        Ok(xs)
222    }
223
224    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
225        let weight = self.weight.add_delta_w(delta)?;
226        Ok(Arc::new(Self {
227            weight,
228            bias: self.bias.clone(),
229            all_reduce: self.all_reduce.clone(),
230        }))
231    }
232
233    fn dequantize_w(&self) -> Result<Tensor> {
234        self.weight.dequantize_w()
235    }
236
237    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
238        self.weight.dtype_and_device()
239    }
240
241    fn begin_track_stats(&mut self) -> Result<()> {
242        Arc::get_mut(&mut self.weight)
243            .context("Failed to get &mut to weight")?
244            .begin_track_stats()
245    }
246
247    fn end_track_stats(&self) -> Result<Tensor> {
248        self.weight.end_track_stats()
249    }
250
251    fn quantized_act_type(&self) -> Option<candle_core::DType> {
252        self.weight.quantized_act_type()
253    }
254
255    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
256        self.weight.unquant_weight_bias()
257    }
258
259    fn apply_isq(
260        self: Arc<Self>,
261        dtype: Option<crate::IsqType>,
262        device: candle_core::Device,
263        n_quantized: &std::sync::atomic::AtomicUsize,
264        imatrix_weight: Option<Vec<f32>>,
265        guard: QuantizeOntoGuard,
266    ) -> Result<Arc<dyn QuantMethod>> {
267        let weight =
268            self.weight
269                .clone()
270                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
271        let bias = match &self.bias {
272            Some(b) => {
273                let (dtype, device) = weight.dtype_and_device();
274                Some(b.to_device(&device)?.to_dtype(dtype)?)
275            }
276            None => None,
277        };
278        Ok(Arc::new(Self {
279            weight,
280            bias,
281            all_reduce: self.all_reduce.clone(),
282        }))
283    }
284
285    fn is_distributed(&self) -> Option<DistributedKind> {
286        Some(DistributedKind::RowParallel)
287    }
288}
289
290impl QuantizedSerde for RowParallelLayer {
291    fn isq_serde_supported(&self) -> bool {
292        self.weight.isq_serde_supported()
293    }
294    fn name(&self) -> &'static str {
295        self.weight.name()
296    }
297    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
298        self.weight.serialize_with_bias(self.bias.clone())
299    }
300    fn deserialize(
301        data: std::borrow::Cow<[u8]>,
302        device: &candle_core::Device,
303        comm: &Arc<crate::Comm>,
304        guard: QuantizeOntoGuard,
305    ) -> Result<Arc<dyn QuantMethod>>
306    where
307        Self: Sized,
308    {
309        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
310        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
311        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
312            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
313            QuantizedSerdeType::Unquant => {
314                UnquantLinear::deserialize_ext_bias(data, device, guard)?
315            }
316            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
317            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
318            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
319            QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize_ext_bias(data, device, guard)?,
320            QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize_ext_bias(data, device, guard)?,
321        };
322        Ok(Arc::new(Self {
323            weight,
324            bias,
325            all_reduce: SumAllReduce::new(comm),
326        }))
327    }
328}
329
330#[derive(Debug)]
331/// This layer has a weight that is parallelized along the output dimension,
332/// taking the "full" input dimension.
333pub struct ColumnParallelLayer {
334    weight: Arc<dyn QuantMethod>,
335    bias: Option<Tensor>,
336}
337
338impl ColumnParallelLayer {
339    #[allow(clippy::new_ret_no_self)]
340    pub fn new_with_shard(
341        in_dim: usize,
342        out_dim: usize,
343        config: &Option<QuantizedConfig>,
344        bias: bool,
345        comm: &Arc<crate::Comm>,
346        shard: Shard,
347        vb: ShardedVarBuilder,
348    ) -> Result<Arc<dyn QuantMethod>> {
349        let base_vb = vb.clone();
350        let vb = if should_apply_immediate_isq(&vb) {
351            vb.set_device(Device::Cpu)
352        } else {
353            vb
354        };
355
356        let weight = if let Some(quant_conf) = &config {
357            // GPTQ and BNB do not support tensor parallelism
358            if matches!(
359                quant_conf,
360                QuantizedConfig::GptqAwq { .. }
361                    | QuantizedConfig::Bitsandbytes { .. }
362                    | QuantizedConfig::Afq { .. }
363            ) && comm.world_size() != 1
364            {
365                candle_core::bail!(
366                    "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
367                    comm.world_size()
368                );
369            }
370
371            match quant_conf {
372                QuantizedConfig::GptqAwq { .. } => {
373                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
374                }
375                QuantizedConfig::Fp8 { weight_block_size } => {
376                    // NOTE: no bias for fp8 as it might be parallelized
377                    if weight_block_size.is_some() {
378                        blockwise_fp8_linear_b(
379                            in_dim,
380                            out_dim,
381                            quant_conf,
382                            false,
383                            shard,
384                            vb.clone(),
385                        )?
386                    } else {
387                        pertensor_fp8_linear_b(
388                            in_dim,
389                            out_dim,
390                            quant_conf,
391                            false,
392                            shard,
393                            vb.clone(),
394                        )?
395                    }
396                }
397                QuantizedConfig::Bitsandbytes { .. } => {
398                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
399                }
400                QuantizedConfig::Afq { .. } => {
401                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
402                }
403                QuantizedConfig::MXFP4 {} => {
404                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
405                }
406            }
407        } else {
408            // Handle the case where the layer is dummy (no tensors)
409            if !vb.contains_tensor("weight") {
410                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
411                Arc::new(layer) as Arc<dyn QuantMethod>
412            } else {
413                let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
414                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
415
416                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
417                    Linear::new(weight, None),
418                ))?;
419                Arc::new(layer) as Arc<dyn QuantMethod>
420            }
421        };
422
423        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
424        let bias = if bias && vb.contains_tensor("bias") {
425            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
426        } else {
427            None
428        };
429
430        let this_unquant = Arc::new(Self { weight, bias });
431        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
432        Ok(this)
433    }
434
435    #[allow(clippy::new_ret_no_self)]
436    pub fn new(
437        in_dim: usize,
438        out_dim: usize,
439        config: &Option<QuantizedConfig>,
440        bias: bool,
441        comm: &Arc<crate::Comm>,
442        vb: ShardedVarBuilder,
443    ) -> Result<Arc<dyn QuantMethod>> {
444        let rank = comm.rank();
445        let world_size = comm.world_size();
446        let shard = shard(0, rank, world_size);
447
448        Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
449    }
450
451    #[allow(clippy::new_ret_no_self)]
452    pub fn new_matformer(
453        in_dim: usize,
454        out_dim: usize,
455        orig_intermediate_size: usize,
456        config: &Option<QuantizedConfig>,
457        bias: bool,
458        comm: &Arc<crate::Comm>,
459        vb: ShardedVarBuilder,
460    ) -> Result<Arc<dyn QuantMethod>> {
461        let rank = comm.rank();
462        let world_size = comm.world_size();
463        let shard = shard(0, rank, world_size);
464
465        let base_vb = vb.clone();
466        let vb = if should_apply_immediate_isq(&vb) {
467            vb.set_device(Device::Cpu)
468        } else {
469            vb
470        };
471
472        if config.is_some() {
473            candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
474        }
475
476        // Handle the case where the layer is dummy (no tensors)
477        let weight = if !vb.contains_tensor("weight") {
478            let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
479            Arc::new(layer) as Arc<dyn QuantMethod>
480        } else {
481            let weight = vb
482                .get_with_hints(
483                    (orig_intermediate_size, in_dim),
484                    "weight",
485                    Default::default(),
486                )?
487                .i((..out_dim, ..))?
488                .contiguous()?;
489
490            let weight = shard.apply_to(&weight)?;
491            let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
492
493            let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
494                Linear::new(weight, None),
495            ))?;
496            Arc::new(layer) as Arc<dyn QuantMethod>
497        };
498
499        // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
500        let bias = if bias && vb.contains_tensor("bias") {
501            Some(vb.get_with_hints((out_dim,), "bias", shard)?)
502        } else {
503            None
504        };
505
506        let this_unquant = Arc::new(Self { weight, bias });
507        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
508        Ok(this)
509    }
510
511    pub fn new_merged(
512        in_dim: usize,
513        out_dim: usize,
514        chunks: usize,
515        config: &Option<QuantizedConfig>,
516        bias: bool,
517        comm: &Arc<crate::Comm>,
518        vb: ShardedVarBuilder,
519    ) -> Result<Vec<Arc<dyn QuantMethod>>> {
520        let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
521        for chunk_idx in 0..chunks {
522            let layer = ColumnParallelLayer::new_with_shard(
523                in_dim,
524                out_dim,
525                config,
526                bias,
527                comm,
528                shard(
529                    0,
530                    chunk_idx * comm.world_size() + comm.rank(),
531                    chunks * comm.world_size(),
532                ),
533                vb.clone(),
534            )?;
535            vec_layers.push(layer);
536        }
537        Ok(vec_layers)
538    }
539}
540
541impl QuantMethod for ColumnParallelLayer {
542    fn new(_method: QuantMethodConfig) -> Result<Self>
543    where
544        Self: Sized,
545    {
546        candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
547    }
548
549    fn forward(&self, a: &Tensor) -> Result<Tensor> {
550        let mut xs = self.weight.forward(a)?;
551        if let Some(bias) = &self.bias {
552            xs = xs.broadcast_add(bias)?;
553        }
554        Ok(xs)
555    }
556
557    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
558        let weight = self.weight.add_delta_w(delta)?;
559        Ok(Arc::new(Self {
560            weight,
561            bias: self.bias.clone(),
562        }))
563    }
564
565    fn dequantize_w(&self) -> Result<Tensor> {
566        self.weight.dequantize_w()
567    }
568
569    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
570        self.weight.dtype_and_device()
571    }
572
573    fn begin_track_stats(&mut self) -> Result<()> {
574        Arc::get_mut(&mut self.weight)
575            .context("Failed to get &mut to weight")?
576            .begin_track_stats()
577    }
578
579    fn end_track_stats(&self) -> Result<Tensor> {
580        self.weight.end_track_stats()
581    }
582
583    fn quantized_act_type(&self) -> Option<candle_core::DType> {
584        self.weight.quantized_act_type()
585    }
586
587    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
588        self.weight.unquant_weight_bias()
589    }
590
591    fn apply_isq(
592        self: Arc<Self>,
593        dtype: Option<crate::IsqType>,
594        device: candle_core::Device,
595        n_quantized: &std::sync::atomic::AtomicUsize,
596        imatrix_weight: Option<Vec<f32>>,
597        guard: QuantizeOntoGuard,
598    ) -> Result<Arc<dyn QuantMethod>> {
599        let weight =
600            self.weight
601                .clone()
602                .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
603        let bias = match &self.bias {
604            Some(b) => {
605                let (dtype, device) = weight.dtype_and_device();
606                Some(b.to_device(&device)?.to_dtype(dtype)?)
607            }
608            None => None,
609        };
610        Ok(Arc::new(Self { weight, bias }))
611    }
612
613    fn is_distributed(&self) -> Option<DistributedKind> {
614        Some(DistributedKind::ColumnParallel)
615    }
616}
617
618impl QuantizedSerde for ColumnParallelLayer {
619    fn isq_serde_supported(&self) -> bool {
620        self.weight.isq_serde_supported()
621    }
622    fn name(&self) -> &'static str {
623        self.weight.name()
624    }
625    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
626        self.weight.serialize_with_bias(self.bias.clone())
627    }
628    fn deserialize(
629        data: std::borrow::Cow<[u8]>,
630        device: &candle_core::Device,
631        _comm: &Arc<crate::Comm>,
632        guard: QuantizeOntoGuard,
633    ) -> Result<Arc<dyn QuantMethod>>
634    where
635        Self: Sized,
636    {
637        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
638        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
639        let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
640            QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
641            QuantizedSerdeType::Unquant => {
642                UnquantLinear::deserialize_ext_bias(data, device, guard)?
643            }
644            QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
645            QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
646            QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
647            QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize_ext_bias(data, device, guard)?,
648            QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize_ext_bias(data, device, guard)?,
649        };
650        Ok(Arc::new(Self { weight, bias }))
651    }
652}
653
654#[derive(Debug)]
655/// This layer has no parallelization
656pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
657
658impl ReplicatedLayer {
659    pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
660        let dev = lin.weight().device().clone();
661        if let Some(crate::ImmediateIsqParams {
662            guard,
663            ty: Some(immediate_isq),
664            pool,
665            backpressure,
666            ..
667        }) = crate::get_immediate_isq()
668        {
669            // Global ISQ type is set — move to CPU for GGML quantization,
670            // then quantize onto the original device.
671            let lin = if !dev.is_cpu() {
672                Linear::new(lin.weight().to_device(&Device::Cpu)?, lin.bias().cloned())
673            } else {
674                lin
675            };
676            let layer: Arc<dyn QuantMethod> =
677                Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
678            if let Some(pool) = &pool {
679                backpressure.acquire();
680                let backpressure = backpressure.clone();
681                let dev = dev.clone();
682                let (tx, rx) = crate::pending_layer::pending_isq_channel();
683                pool.spawn(move || {
684                    let result = layer.clone().apply_isq(
685                        Some(immediate_isq),
686                        dev,
687                        &std::sync::atomic::AtomicUsize::new(0),
688                        None,
689                        guard,
690                    );
691                    let _ = tx.send(result);
692                    backpressure.release();
693                });
694                Ok(Arc::new(crate::PendingIsqLayer::new(rx)))
695            } else {
696                layer.clone().apply_isq(
697                    Some(immediate_isq),
698                    dev,
699                    &std::sync::atomic::AtomicUsize::new(0),
700                    None,
701                    guard,
702                )
703            }
704        } else {
705            // No global ISQ — keep as unquantized on original device.
706            Ok(Arc::new(UnquantLinear::new(
707                QuantMethodConfig::Unquantized(lin),
708            )?))
709        }
710    }
711
712    #[allow(clippy::new_ret_no_self)]
713    pub fn new(
714        in_dim: usize,
715        out_dim: usize,
716        config: &Option<QuantizedConfig>,
717        bias: bool,
718        vb: ShardedVarBuilder,
719    ) -> Result<Arc<dyn QuantMethod>> {
720        let base_vb = vb.clone();
721        let vb = if should_apply_immediate_isq(&vb) {
722            vb.set_device(Device::Cpu)
723        } else {
724            vb
725        };
726
727        let layer = if let Some(quant_conf) = &config {
728            match quant_conf {
729                QuantizedConfig::GptqAwq { .. } => {
730                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
731                }
732                QuantizedConfig::Fp8 { weight_block_size } => {
733                    if weight_block_size.is_some() {
734                        blockwise_fp8_linear_b(
735                            in_dim,
736                            out_dim,
737                            quant_conf,
738                            bias,
739                            Default::default(),
740                            vb.clone(),
741                        )?
742                    } else {
743                        pertensor_fp8_linear_b(
744                            in_dim,
745                            out_dim,
746                            quant_conf,
747                            bias,
748                            Default::default(),
749                            vb.clone(),
750                        )?
751                    }
752                }
753                QuantizedConfig::Bitsandbytes { .. } => {
754                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
755                }
756                QuantizedConfig::Afq { .. } => {
757                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
758                }
759                QuantizedConfig::MXFP4 {} => {
760                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
761                }
762            }
763        } else {
764            // Handle the case where the layer is dummy (no tensors)
765            if !vb.contains_tensor("weight") {
766                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
767                Arc::new(layer) as Arc<dyn QuantMethod>
768            } else {
769                let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
770                let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
771
772                let bias = if bias {
773                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
774                } else {
775                    None
776                };
777                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
778                    Linear::new(weight, bias),
779                ))?;
780                Arc::new(layer) as Arc<dyn QuantMethod>
781            }
782        };
783
784        let this_unquant = Arc::new(Self(layer));
785        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
786        Ok(this)
787    }
788
789    #[allow(clippy::new_ret_no_self)]
790    pub fn new_layers_matformer_indices(
791        in_dim: usize,
792        out_dim: usize,
793        kept_layers_indices: Option<&Tensor>,
794        orig_num_hidden_layers: usize,
795        config: &Option<QuantizedConfig>,
796        bias: bool,
797        vb: ShardedVarBuilder,
798    ) -> Result<Arc<dyn QuantMethod>> {
799        let base_vb = vb.clone();
800        let vb = if should_apply_immediate_isq(&vb) {
801            vb.set_device(Device::Cpu)
802        } else {
803            vb
804        };
805
806        let layer = if let Some(quant_conf) = &config {
807            if kept_layers_indices.is_some() {
808                candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
809            }
810
811            match quant_conf {
812                QuantizedConfig::GptqAwq { .. } => {
813                    gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
814                }
815                QuantizedConfig::Fp8 { weight_block_size } => {
816                    if weight_block_size.is_some() {
817                        blockwise_fp8_linear_b(
818                            in_dim,
819                            out_dim,
820                            quant_conf,
821                            bias,
822                            Default::default(),
823                            vb.clone(),
824                        )?
825                    } else {
826                        pertensor_fp8_linear_b(
827                            in_dim,
828                            out_dim,
829                            quant_conf,
830                            bias,
831                            Default::default(),
832                            vb.clone(),
833                        )?
834                    }
835                }
836                QuantizedConfig::Bitsandbytes { .. } => {
837                    Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
838                }
839                QuantizedConfig::Afq { .. } => {
840                    AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
841                }
842                QuantizedConfig::MXFP4 {} => {
843                    MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
844                }
845            }
846        } else {
847            // Handle the case where the layer is dummy (no tensors)
848            if !vb.contains_tensor("weight") {
849                let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
850                Arc::new(layer) as Arc<dyn QuantMethod>
851            } else {
852                let mut weight =
853                    vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
854
855                if let Some(kept_layers_indices) = &kept_layers_indices {
856                    let weight_reshaped = weight.reshape((
857                        orig_num_hidden_layers,
858                        weight.dim(0)? / orig_num_hidden_layers,
859                        weight.dim(1)?,
860                    ))?;
861
862                    weight = weight_reshaped
863                        .index_select(&kept_layers_indices.to_device(weight.device())?, 0)?
864                        .reshape(((), weight_reshaped.dim(D::Minus1)?))?
865                        .contiguous()?;
866                }
867
868                weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
869
870                let bias = if bias {
871                    Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
872                } else {
873                    None
874                };
875                let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
876                    Linear::new(weight, bias),
877                ))?;
878                Arc::new(layer) as Arc<dyn QuantMethod>
879            }
880        };
881
882        let this_unquant = Arc::new(Self(layer));
883        let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
884        Ok(this)
885    }
886}
887
888impl QuantMethod for ReplicatedLayer {
889    fn new(_method: QuantMethodConfig) -> Result<Self>
890    where
891        Self: Sized,
892    {
893        candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
894    }
895
896    fn forward(&self, a: &Tensor) -> Result<Tensor> {
897        self.0.forward(a)
898    }
899
900    fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
901        self.0.add_delta_w(delta)
902    }
903
904    fn dequantize_w(&self) -> Result<Tensor> {
905        self.0.dequantize_w()
906    }
907
908    fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
909        self.0.dtype_and_device()
910    }
911
912    fn begin_track_stats(&mut self) -> Result<()> {
913        Arc::get_mut(&mut self.0)
914            .context("Failed to get &mut to weight")?
915            .begin_track_stats()
916    }
917
918    fn end_track_stats(&self) -> Result<Tensor> {
919        self.0.end_track_stats()
920    }
921
922    fn quantized_act_type(&self) -> Option<candle_core::DType> {
923        self.0.quantized_act_type()
924    }
925
926    fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
927        self.0.unquant_weight_bias()
928    }
929
930    fn apply_isq(
931        self: Arc<Self>,
932        dtype: Option<crate::IsqType>,
933        device: candle_core::Device,
934        n_quantized: &std::sync::atomic::AtomicUsize,
935        imatrix_weight: Option<Vec<f32>>,
936        guard: QuantizeOntoGuard,
937    ) -> Result<Arc<dyn QuantMethod>> {
938        self.0
939            .clone()
940            .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
941    }
942
943    fn is_distributed(&self) -> Option<DistributedKind> {
944        Some(DistributedKind::Replicated)
945    }
946}
947
948impl QuantizedSerde for ReplicatedLayer {
949    fn isq_serde_supported(&self) -> bool {
950        self.0.isq_serde_supported()
951    }
952    fn name(&self) -> &'static str {
953        self.0.name()
954    }
955    fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
956        self.0.serialize()
957    }
958    fn deserialize(
959        data: std::borrow::Cow<[u8]>,
960        device: &candle_core::Device,
961        comm: &Arc<crate::Comm>,
962        guard: QuantizeOntoGuard,
963    ) -> Result<Arc<dyn QuantMethod>>
964    where
965        Self: Sized,
966    {
967        // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
968        let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
969        let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
970            QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
971            QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
972            QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
973            QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
974            QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
975            QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize(data, device, comm, guard)?,
976            QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize(data, device, comm, guard)?,
977        };
978        Ok(Arc::new(Self(deserialized)))
979    }
980}
981
982#[derive(Debug)]
983pub struct PackedExperts {
984    pub gate_proj: Vec<Arc<dyn QuantMethod>>,
985    pub up_proj: Vec<Arc<dyn QuantMethod>>,
986    pub down_proj: Vec<Arc<dyn QuantMethod>>,
987}
988
989impl PackedExperts {
990    /// Note: we only support AFQ and unquantized here because they are the only ones that support indexed.
991    #[allow(clippy::too_many_arguments)]
992    pub fn new(
993        num_local_experts: usize,
994        hidden_size: usize,
995        intermediate_size: usize,
996        config: &Option<QuantizedConfig>,
997        bias: bool,
998        comm: &Arc<crate::Comm>,
999        vb: ShardedVarBuilder,
1000    ) -> Result<Self> {
1001        if bias {
1002            candle_core::bail!("PackedExperts does not support bias.");
1003        }
1004
1005        let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
1006            // GPTQ and BNB do not support tensor parallelism
1007            if comm.world_size() != 1 {
1008                candle_core::bail!(
1009                    "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
1010                    comm.world_size()
1011                );
1012            }
1013
1014            match quant_conf {
1015                QuantizedConfig::Afq { .. } => {
1016                    if !vb.contains_tensor("gate_up_proj")
1017                        || !vb.contains_tensor("gate_up_proj.weight")
1018                    {
1019                        candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
1020                    }
1021
1022                    let base_vb = vb.clone();
1023
1024                    let vb_gate_proj = if should_apply_immediate_isq(&vb) {
1025                        vb.pp("gate_proj").set_device(Device::Cpu)
1026                    } else {
1027                        vb.pp("gate_proj")
1028                    };
1029                    let vb_up_proj = if should_apply_immediate_isq(&vb) {
1030                        vb.pp("up_proj").set_device(Device::Cpu)
1031                    } else {
1032                        vb.pp("up_proj")
1033                    };
1034                    let vb_down_proj = if should_apply_immediate_isq(&vb) {
1035                        vb.pp("down_proj").set_device(Device::Cpu)
1036                    } else {
1037                        vb.pp("down_proj")
1038                    };
1039                    let mut gate_proj = AfqLayer::afq_packed_linear_b(
1040                        num_local_experts,
1041                        hidden_size,
1042                        intermediate_size,
1043                        quant_conf,
1044                        bias,
1045                        vb_gate_proj,
1046                    )?;
1047                    let mut up_proj = AfqLayer::afq_packed_linear_b(
1048                        num_local_experts,
1049                        hidden_size,
1050                        intermediate_size,
1051                        quant_conf,
1052                        bias,
1053                        vb_up_proj,
1054                    )?;
1055                    let mut down_proj = AfqLayer::afq_packed_linear_b(
1056                        num_local_experts,
1057                        intermediate_size,
1058                        hidden_size,
1059                        quant_conf,
1060                        bias,
1061                        vb_down_proj,
1062                    )?;
1063
1064                    gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
1065                    up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
1066                    down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
1067
1068                    (vec![gate_proj], vec![up_proj], vec![down_proj])
1069                }
1070                QuantizedConfig::Fp8 { weight_block_size } => {
1071                    // FP8 quantization for PackedExperts
1072                    // Keep weights as FP8 using BlockwiseFP8Linear to leverage native FP8 GEMM
1073                    let Some(weight_block_size) = weight_block_size else {
1074                        candle_core::bail!("Blockwise FP8 for PackedExperts requires weight_block_size to be set.")
1075                    };
1076                    if weight_block_size.len() != 2 {
1077                        candle_core::bail!(
1078                            "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1079                        );
1080                    }
1081
1082                    // Check if we have stacked format (gate_up_proj) or per-expert format
1083                    // Note: vb already has the "experts" prefix from the caller (experts.rs)
1084                    let is_stacked_format = vb.contains_tensor("gate_up_proj");
1085
1086                    if is_stacked_format {
1087                        // Stacked format: load FP8 tensors and split
1088                        let has_fp8_scales = vb.contains_tensor("gate_up_proj.weight_scale_inv");
1089
1090                        if has_fp8_scales {
1091                            // Load gate_up_proj FP8 tensor and scale
1092                            let gate_up_fp8 = vb.get_with_hints_dtype(
1093                                (num_local_experts, hidden_size, intermediate_size * 2),
1094                                "gate_up_proj",
1095                                Default::default(),
1096                                candle_core::DType::F8E4M3,
1097                            )?;
1098                            let gate_up_scale = vb.get_with_hints_dtype(
1099                                (
1100                                    num_local_experts,
1101                                    hidden_size.div_ceil(weight_block_size[0]),
1102                                    (intermediate_size * 2).div_ceil(weight_block_size[1]),
1103                                ),
1104                                "gate_up_proj.weight_scale_inv",
1105                                Default::default(),
1106                                candle_core::DType::F32,
1107                            )?;
1108
1109                            // Load down_proj FP8 tensor and scale
1110                            let down_fp8 = vb.get_with_hints_dtype(
1111                                (num_local_experts, intermediate_size, hidden_size),
1112                                "down_proj",
1113                                Default::default(),
1114                                candle_core::DType::F8E4M3,
1115                            )?;
1116                            let down_scale = vb.get_with_hints_dtype(
1117                                (
1118                                    num_local_experts,
1119                                    intermediate_size.div_ceil(weight_block_size[0]),
1120                                    hidden_size.div_ceil(weight_block_size[1]),
1121                                ),
1122                                "down_proj.weight_scale_inv",
1123                                Default::default(),
1124                                candle_core::DType::F32,
1125                            )?;
1126
1127                            // Split and create individual BlockwiseFP8Linear for each expert
1128                            let mut gs = Vec::new();
1129                            let mut us = Vec::new();
1130                            let mut ds = Vec::new();
1131
1132                            for i in 0..num_local_experts {
1133                                // Extract this expert's weights
1134                                let gate_up_expert =
1135                                    gate_up_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1136                                let gate_up_scale_expert = gate_up_scale.i(i)?.contiguous()?;
1137                                let down_expert = down_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1138                                let down_scale_expert = down_scale.i(i)?.contiguous()?;
1139
1140                                // Split gate_up into gate and up
1141                                let gate_expert = gate_up_expert.narrow(0, 0, intermediate_size)?;
1142                                let up_expert = gate_up_expert.narrow(
1143                                    0,
1144                                    intermediate_size,
1145                                    intermediate_size,
1146                                )?;
1147
1148                                // Split scales
1149                                let gate_scale_expert = gate_up_scale_expert.narrow(
1150                                    1,
1151                                    0,
1152                                    intermediate_size.div_ceil(weight_block_size[1]),
1153                                )?;
1154                                let up_scale_expert = gate_up_scale_expert.narrow(
1155                                    1,
1156                                    intermediate_size.div_ceil(weight_block_size[1]),
1157                                    intermediate_size.div_ceil(weight_block_size[1]),
1158                                )?;
1159
1160                                // Create BlockwiseFP8Linear for each projection
1161                                use crate::blockwise_fp8::BlockwiseFP8Linear;
1162                                use crate::QuantMethodConfig;
1163
1164                                let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1165                                    BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1166                                        weight: gate_expert,
1167                                        weight_scale_inv: gate_scale_expert.transpose(0, 1)?,
1168                                        bias: None,
1169                                        dequant_dtype: vb.dtype(),
1170                                        weight_block_size: weight_block_size.clone(),
1171                                    })?,
1172                                );
1173                                let up_layer: Arc<dyn QuantMethod> = Arc::new(
1174                                    BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1175                                        weight: up_expert,
1176                                        weight_scale_inv: up_scale_expert.transpose(0, 1)?,
1177                                        bias: None,
1178                                        dequant_dtype: vb.dtype(),
1179                                        weight_block_size: weight_block_size.clone(),
1180                                    })?,
1181                                );
1182                                let down_layer: Arc<dyn QuantMethod> = Arc::new(
1183                                    BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1184                                        weight: down_expert,
1185                                        weight_scale_inv: down_scale_expert.transpose(0, 1)?,
1186                                        bias: None,
1187                                        dequant_dtype: vb.dtype(),
1188                                        weight_block_size: weight_block_size.clone(),
1189                                    })?,
1190                                );
1191
1192                                gs.push(gate_layer);
1193                                us.push(up_layer);
1194                                ds.push(down_layer);
1195                            }
1196
1197                            (gs, us, ds)
1198                        } else {
1199                            candle_core::bail!(
1200                                "PackedExperts with FP8 requires weight_scale_inv tensors"
1201                            );
1202                        }
1203                    } else {
1204                        // Per-expert format: load each expert individually
1205                        let mut gs = Vec::new();
1206                        let mut us = Vec::new();
1207                        let mut ds = Vec::new();
1208
1209                        for i in 0..num_local_experts {
1210                            let expert_vb = vb.pp(i);
1211
1212                            // Load FP8 weights and scales for each projection
1213                            let gate_fp8 = expert_vb.get_with_hints_dtype(
1214                                (intermediate_size, hidden_size),
1215                                "gate_proj.weight",
1216                                Default::default(),
1217                                candle_core::DType::F8E4M3,
1218                            )?;
1219                            let gate_scale = expert_vb.get_with_hints_dtype(
1220                                (
1221                                    intermediate_size.div_ceil(weight_block_size[0]),
1222                                    hidden_size.div_ceil(weight_block_size[1]),
1223                                ),
1224                                "gate_proj.weight_scale_inv",
1225                                Default::default(),
1226                                candle_core::DType::F32,
1227                            )?;
1228
1229                            let up_fp8 = expert_vb.get_with_hints_dtype(
1230                                (intermediate_size, hidden_size),
1231                                "up_proj.weight",
1232                                Default::default(),
1233                                candle_core::DType::F8E4M3,
1234                            )?;
1235                            let up_scale = expert_vb.get_with_hints_dtype(
1236                                (
1237                                    intermediate_size.div_ceil(weight_block_size[0]),
1238                                    hidden_size.div_ceil(weight_block_size[1]),
1239                                ),
1240                                "up_proj.weight_scale_inv",
1241                                Default::default(),
1242                                candle_core::DType::F32,
1243                            )?;
1244
1245                            let down_fp8 = expert_vb.get_with_hints_dtype(
1246                                (hidden_size, intermediate_size),
1247                                "down_proj.weight",
1248                                Default::default(),
1249                                candle_core::DType::F8E4M3,
1250                            )?;
1251                            let down_scale = expert_vb.get_with_hints_dtype(
1252                                (
1253                                    hidden_size.div_ceil(weight_block_size[0]),
1254                                    intermediate_size.div_ceil(weight_block_size[1]),
1255                                ),
1256                                "down_proj.weight_scale_inv",
1257                                Default::default(),
1258                                candle_core::DType::F32,
1259                            )?;
1260
1261                            // Create BlockwiseFP8Linear for each projection
1262                            use crate::blockwise_fp8::BlockwiseFP8Linear;
1263                            use crate::QuantMethodConfig;
1264
1265                            let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1266                                BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1267                                    weight: gate_fp8,
1268                                    weight_scale_inv: gate_scale,
1269                                    bias: None,
1270                                    dequant_dtype: vb.dtype(),
1271                                    weight_block_size: weight_block_size.clone(),
1272                                })?,
1273                            );
1274                            let up_layer: Arc<dyn QuantMethod> = Arc::new(BlockwiseFP8Linear::new(
1275                                QuantMethodConfig::BlockwiseFP8 {
1276                                    weight: up_fp8,
1277                                    weight_scale_inv: up_scale,
1278                                    bias: None,
1279                                    dequant_dtype: vb.dtype(),
1280                                    weight_block_size: weight_block_size.clone(),
1281                                },
1282                            )?);
1283                            let down_layer: Arc<dyn QuantMethod> = Arc::new(
1284                                BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1285                                    weight: down_fp8,
1286                                    weight_scale_inv: down_scale,
1287                                    bias: None,
1288                                    dequant_dtype: vb.dtype(),
1289                                    weight_block_size: weight_block_size.clone(),
1290                                })?,
1291                            );
1292
1293                            gs.push(gate_layer);
1294                            us.push(up_layer);
1295                            ds.push(down_layer);
1296                        }
1297
1298                        (gs, us, ds)
1299                    }
1300                }
1301                QuantizedConfig::MXFP4 {} => {
1302                    // MXFP4 quantization for PackedExperts
1303                    // Keep weights as MXFP4 using MXFP4Layer to leverage native MXFP4 GEMM
1304                    // Note: MXFP4 models use stacked format, so we load directly as packed experts
1305                    let gate_proj = MXFP4Layer::packed_linear_b(
1306                        num_local_experts,
1307                        hidden_size,
1308                        intermediate_size,
1309                        quant_conf,
1310                        bias,
1311                        vb.pp("gate_proj"),
1312                    )?;
1313                    let up_proj = MXFP4Layer::packed_linear_b(
1314                        num_local_experts,
1315                        hidden_size,
1316                        intermediate_size,
1317                        quant_conf,
1318                        bias,
1319                        vb.pp("up_proj"),
1320                    )?;
1321                    let down_proj = MXFP4Layer::packed_linear_b(
1322                        num_local_experts,
1323                        intermediate_size,
1324                        hidden_size,
1325                        quant_conf,
1326                        bias,
1327                        vb.pp("down_proj"),
1328                    )?;
1329
1330                    (vec![gate_proj], vec![up_proj], vec![down_proj])
1331                }
1332                _ => candle_core::bail!(
1333                    "PackedExperts with quantization config only allows AFQ, FP8, or MXFP4 quantization"
1334                ),
1335            }
1336        } else if !vb.contains_tensor("gate_up_proj") {
1337            // Handle the case where the layer is dummy (no tensors) during UQFF loading. Deserialize will handle it.
1338            let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
1339            let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
1340            let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
1341            for _ in 0..num_local_experts {
1342                gs.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1343                us.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1344                ds.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1345            }
1346            (gs, us, ds)
1347        } else {
1348            // Parallelized like:
1349            // Each gpu holds all experts.
1350            // Gate/Up proj is parallelized on dim 2 (column)
1351            // Down proj is parallelized on dim 1 (row)
1352            // All reduce at the end.
1353
1354            // Handle the case where the layer is dummy (no tensors)
1355            let gate_up_block_size = intermediate_size / comm.world_size();
1356            let gate_up_start = gate_up_block_size * comm.rank();
1357
1358            // Gate is right before Up in the gate_up
1359            let shard_gate = Shard::Offset {
1360                dim: 2,
1361                offset: gate_up_start,
1362                len: gate_up_block_size,
1363            };
1364            let shard_up = Shard::Offset {
1365                dim: 2,
1366                offset: intermediate_size + gate_up_start,
1367                len: gate_up_block_size,
1368            };
1369            let shard_down = Shard::Simple {
1370                dim: 1,
1371                rank: comm.rank(),
1372                world_size: comm.world_size(),
1373            };
1374
1375            let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
1376                vb.pp("gate_up_proj").set_device(Device::Cpu)
1377            } else {
1378                vb.pp("gate_up_proj")
1379            };
1380            let vb_down_proj = if should_apply_immediate_isq(&vb) {
1381                vb.pp("down_proj").set_device(Device::Cpu)
1382            } else {
1383                vb.pp("down_proj")
1384            };
1385
1386            let gate_proj = vb
1387                .get_with_hints(
1388                    (num_local_experts, hidden_size, intermediate_size * 2),
1389                    "gate_up_proj",
1390                    shard_gate,
1391                )?
1392                .t()?
1393                .contiguous()?;
1394            let up_proj = vb
1395                .get_with_hints(
1396                    (num_local_experts, hidden_size, intermediate_size * 2),
1397                    "gate_up_proj",
1398                    shard_up,
1399                )?
1400                .t()?
1401                .contiguous()?;
1402            let down_proj = vb
1403                .get_with_hints(
1404                    (num_local_experts, intermediate_size, hidden_size),
1405                    "down_proj",
1406                    shard_down,
1407                )?
1408                .t()?
1409                .contiguous()?;
1410
1411            let gc = gate_proj.chunk(num_local_experts, 0)?;
1412            let uc = up_proj.chunk(num_local_experts, 0)?;
1413            let dc = down_proj.chunk(num_local_experts, 0)?;
1414            drop((gate_proj, up_proj, down_proj));
1415
1416            let mut gs = Vec::new();
1417            let mut us = Vec::new();
1418            let mut ds = Vec::new();
1419            for ((mut gate_proj, mut up_proj), mut down_proj) in
1420                gc.into_iter().zip(uc.into_iter()).zip(dc.into_iter())
1421            {
1422                gate_proj = gate_proj.squeeze(0)?;
1423                up_proj = up_proj.squeeze(0)?;
1424                down_proj = down_proj.squeeze(0)?;
1425                let gate_proj = merge_lora_weights(
1426                    &vb,
1427                    gate_proj,
1428                    hidden_size,
1429                    intermediate_size * 2,
1430                    shard_gate,
1431                )?;
1432                let up_proj =
1433                    merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
1434                let down_proj =
1435                    merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
1436
1437                let mut gate_proj: Arc<dyn QuantMethod> =
1438                    Arc::new(<UnquantLinear as QuantMethod>::new(
1439                        QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1440                    )?);
1441                gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
1442                let mut up_proj: Arc<dyn QuantMethod> =
1443                    Arc::new(<UnquantLinear as QuantMethod>::new(
1444                        QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1445                    )?);
1446                up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
1447                let mut down_proj: Arc<dyn QuantMethod> =
1448                    Arc::new(<UnquantLinear as QuantMethod>::new(
1449                        QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1450                    )?);
1451                down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
1452                gs.push(gate_proj);
1453                us.push(up_proj);
1454                ds.push(down_proj);
1455            }
1456            (gs, us, ds)
1457        };
1458
1459        Ok(Self {
1460            gate_proj,
1461            up_proj,
1462            down_proj,
1463        })
1464    }
1465}
1466
1467pub struct FusedExperts {
1468    pub fused_gate_proj: Arc<dyn QuantMethod>,
1469    pub fused_up_proj: Arc<dyn QuantMethod>,
1470    pub fused_down_proj: Arc<dyn QuantMethod>,
1471}
1472
1473impl FusedExperts {
1474    pub fn new(
1475        hidden_size: usize,
1476        moe_intermediate_size: usize,
1477        num_experts: usize,
1478        quantization_config: &Option<QuantizedConfig>,
1479        vb: ShardedVarBuilder,
1480    ) -> Result<Self> {
1481        // Detect if weights are in stacked format (e.g., Qwen3 VL MoE):
1482        // - experts.gate_up_proj: (num_experts, hidden_size, intermediate_size * 2)
1483        // - experts.down_proj: (num_experts, intermediate_size, hidden_size)
1484        // Or per-expert format (e.g., Qwen3 MoE):
1485        // - experts.{i}.gate_proj.weight, experts.{i}.up_proj.weight, experts.{i}.down_proj.weight
1486        let experts_vb = vb.pp("experts");
1487        let is_stacked_format = experts_vb.contains_tensor("gate_up_proj");
1488
1489        let (fused_gate_proj, fused_up_proj, fused_down_proj) = if matches!(
1490            &quantization_config,
1491            Some(QuantizedConfig::Afq { .. })
1492        ) {
1493            let quantization_config = quantization_config.as_ref().unwrap();
1494
1495            let fused_gate_proj = AfqLayer::afq_packed_linear_b(
1496                num_experts,
1497                hidden_size,
1498                moe_intermediate_size,
1499                quantization_config,
1500                false,
1501                vb.pp("switch_mlp.gate_proj"),
1502            )?;
1503            let fused_up_proj = AfqLayer::afq_packed_linear_b(
1504                num_experts,
1505                hidden_size,
1506                moe_intermediate_size,
1507                quantization_config,
1508                false,
1509                vb.pp("switch_mlp.up_proj"),
1510            )?;
1511            let fused_down_proj = AfqLayer::afq_packed_linear_b(
1512                num_experts,
1513                moe_intermediate_size,
1514                hidden_size,
1515                quantization_config,
1516                false,
1517                vb.pp("switch_mlp.down_proj"),
1518            )?;
1519
1520            (fused_gate_proj, fused_up_proj, fused_down_proj)
1521        } else if is_stacked_format
1522            && matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. }))
1523        {
1524            // Stacked format with FP8 quantization
1525            // Keep weights as FP8 using BlockwiseFP8 to leverage native FP8 GEMM in gather_forward
1526            let has_fp8_scales = experts_vb.contains_tensor("gate_up_proj.weight_scale_inv");
1527
1528            if has_fp8_scales {
1529                let weight_block_size = match quantization_config {
1530                    Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1531                    _ => unreachable!(),
1532                };
1533
1534                let Some(weight_block_size) = weight_block_size else {
1535                    candle_core::bail!(
1536                        "Blockwise FP8 for stacked experts requires weight_block_size to be set."
1537                    )
1538                };
1539                if weight_block_size.len() != 2 {
1540                    candle_core::bail!(
1541                        "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1542                    );
1543                }
1544
1545                // Load gate_up_proj FP8 tensor and scale
1546                // Shape: [num_experts, hidden_size, intermediate_size * 2]
1547                let gate_up_fp8 = experts_vb.get_with_hints_dtype(
1548                    (num_experts, hidden_size, moe_intermediate_size * 2),
1549                    "gate_up_proj",
1550                    Default::default(),
1551                    candle_core::DType::F8E4M3,
1552                )?;
1553                let gate_up_scale = experts_vb.get_with_hints_dtype(
1554                    (
1555                        num_experts,
1556                        hidden_size.div_ceil(weight_block_size[0]),
1557                        (moe_intermediate_size * 2).div_ceil(weight_block_size[1]),
1558                    ),
1559                    "gate_up_proj.weight_scale_inv",
1560                    Default::default(),
1561                    candle_core::DType::F32,
1562                )?;
1563
1564                // Load down_proj FP8 tensor and scale
1565                // Shape: [num_experts, intermediate_size, hidden_size]
1566                let down_fp8 = experts_vb.get_with_hints_dtype(
1567                    (num_experts, moe_intermediate_size, hidden_size),
1568                    "down_proj",
1569                    Default::default(),
1570                    candle_core::DType::F8E4M3,
1571                )?;
1572                let down_scale = experts_vb.get_with_hints_dtype(
1573                    (
1574                        num_experts,
1575                        moe_intermediate_size.div_ceil(weight_block_size[0]),
1576                        hidden_size.div_ceil(weight_block_size[1]),
1577                    ),
1578                    "down_proj.weight_scale_inv",
1579                    Default::default(),
1580                    candle_core::DType::F32,
1581                )?;
1582
1583                // Split gate_up into gate and up
1584                let gate_fp8 = gate_up_fp8.narrow(2, 0, moe_intermediate_size)?;
1585                let up_fp8 = gate_up_fp8.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1586
1587                // Split scales similarly
1588                let gate_scale = gate_up_scale.narrow(
1589                    2,
1590                    0,
1591                    moe_intermediate_size.div_ceil(weight_block_size[1]),
1592                )?;
1593                let up_scale = gate_up_scale.narrow(
1594                    2,
1595                    moe_intermediate_size.div_ceil(weight_block_size[1]),
1596                    moe_intermediate_size.div_ceil(weight_block_size[1]),
1597                )?;
1598
1599                // Transpose to match expected format: [num_experts, N, K]
1600                // gate/up: [num_experts, hidden_size, intermediate_size] -> [num_experts, intermediate_size, hidden_size]
1601                let gate_fp8 = gate_fp8.transpose(1, 2)?.contiguous()?;
1602                let up_fp8 = up_fp8.transpose(1, 2)?.contiguous()?;
1603                // down: [num_experts, intermediate_size, hidden_size] -> [num_experts, hidden_size, intermediate_size]
1604                let down_fp8 = down_fp8.transpose(1, 2)?.contiguous()?;
1605
1606                // Transpose scales to match weight layout
1607                let gate_scale = gate_scale.transpose(1, 2)?.contiguous()?;
1608                let up_scale = up_scale.transpose(1, 2)?.contiguous()?;
1609                let down_scale = down_scale.transpose(1, 2)?.contiguous()?;
1610
1611                // Create BlockwiseFP8Linear for each projection
1612                let fused_gate_proj =
1613                    blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1614                let fused_up_proj =
1615                    blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1616                let fused_down_proj =
1617                    blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1618
1619                (fused_gate_proj, fused_up_proj, fused_down_proj)
1620            } else {
1621                // FP8 config but no scale tensors - weights are actually unquantized
1622                tracing::warn!(
1623                        "FP8 quantization config specified but no scale tensors found for stacked MoE experts. \
1624                        Loading as unquantized."
1625                    );
1626                let gate_up_proj = experts_vb
1627                    .get(
1628                        (num_experts, hidden_size, moe_intermediate_size * 2),
1629                        "gate_up_proj",
1630                    )
1631                    .or_else(|_| {
1632                        experts_vb
1633                            .get(
1634                                (num_experts, moe_intermediate_size * 2, hidden_size),
1635                                "gate_up_proj",
1636                            )
1637                            .and_then(|t| t.transpose(1, 2)?.contiguous())
1638                    })?;
1639                let down_proj_packed = experts_vb
1640                    .get(
1641                        (num_experts, moe_intermediate_size, hidden_size),
1642                        "down_proj",
1643                    )
1644                    .or_else(|_| {
1645                        experts_vb
1646                            .get(
1647                                (num_experts, hidden_size, moe_intermediate_size),
1648                                "down_proj",
1649                            )
1650                            .and_then(|t| t.transpose(1, 2)?.contiguous())
1651                    })?;
1652
1653                // Split gate_up_proj into gate_proj and up_proj along the last dimension
1654                let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1655                let up_proj =
1656                    gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1657
1658                // Transpose dims 1 and 2 to match GGUF format
1659                let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1660                let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1661                let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1662
1663                // When immediate ISQ targets these weights, move to CPU for GGML quantization.
1664                let isq_gate_up = should_apply_immediate_isq(&experts_vb.pp("gate_up_proj"));
1665                let isq_down = should_apply_immediate_isq(&experts_vb.pp("down_proj"));
1666                let target_device = gate_proj.device().clone();
1667                let (gate_proj, up_proj, down_proj) =
1668                    if (isq_gate_up || isq_down) && !target_device.is_cpu() {
1669                        (
1670                            gate_proj.to_device(&Device::Cpu)?,
1671                            up_proj.to_device(&Device::Cpu)?,
1672                            down_proj.to_device(&Device::Cpu)?,
1673                        )
1674                    } else {
1675                        (gate_proj, up_proj, down_proj)
1676                    };
1677
1678                let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1679                    QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1680                )?);
1681                let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1682                    QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1683                )?);
1684                let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1685                    QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1686                )?);
1687                // Pass the original-device VB so apply_immediate_isq targets
1688                // the correct device and respects topology overrides.
1689                let vb_gate_up = experts_vb.pp("gate_up_proj");
1690                let vb_down = experts_vb.pp("down_proj");
1691                fused_gate_proj = apply_immediate_isq(fused_gate_proj, vb_gate_up.clone())?;
1692                fused_up_proj = apply_immediate_isq(fused_up_proj, vb_gate_up)?;
1693                fused_down_proj = apply_immediate_isq(fused_down_proj, vb_down)?;
1694
1695                (fused_gate_proj, fused_up_proj, fused_down_proj)
1696            }
1697        } else if is_stacked_format
1698            && matches!(&quantization_config, Some(QuantizedConfig::MXFP4 {}))
1699        {
1700            // Stacked format with MXFP4 quantization
1701            // For MXFP4, weights are stored as packed FP4 (2 values per byte)
1702            // with E8M0 scales
1703            let quantization_config = quantization_config.as_ref().unwrap();
1704
1705            // Load MXFP4 packed experts using MXFP4Layer::packed_linear_b
1706            // The tensors are expected at:
1707            //   gate_proj.blocks: [num_experts, intermediate_size, hidden_size/2]
1708            //   gate_proj.scales: [num_experts, intermediate_size, hidden_size/32]
1709            let fused_gate_proj = MXFP4Layer::packed_linear_b(
1710                num_experts,
1711                hidden_size,
1712                moe_intermediate_size,
1713                quantization_config,
1714                false,
1715                experts_vb.pp("gate_proj"),
1716            )?;
1717            let fused_up_proj = MXFP4Layer::packed_linear_b(
1718                num_experts,
1719                hidden_size,
1720                moe_intermediate_size,
1721                quantization_config,
1722                false,
1723                experts_vb.pp("up_proj"),
1724            )?;
1725            let fused_down_proj = MXFP4Layer::packed_linear_b(
1726                num_experts,
1727                moe_intermediate_size,
1728                hidden_size,
1729                quantization_config,
1730                false,
1731                experts_vb.pp("down_proj"),
1732            )?;
1733
1734            (fused_gate_proj, fused_up_proj, fused_down_proj)
1735        } else if is_stacked_format {
1736            // Stacked format from safetensors. Two conventions exist:
1737            // Convention A: [num_experts, hidden_size, intermediate_size * 2]
1738            // Convention B (nn.Linear): [num_experts, intermediate_size * 2, hidden_size]
1739            //
1740            // GGUF/indexed_moe_forward expects:
1741            // - gate/up: [num_experts, intermediate_size, hidden_size]
1742            // - down: [num_experts, hidden_size, intermediate_size]
1743            //
1744            // Try convention A first, fall back to convention B (transposing to A).
1745            let gate_up_proj = experts_vb
1746                .get(
1747                    (num_experts, hidden_size, moe_intermediate_size * 2),
1748                    "gate_up_proj",
1749                )
1750                .or_else(|_| {
1751                    experts_vb
1752                        .get(
1753                            (num_experts, moe_intermediate_size * 2, hidden_size),
1754                            "gate_up_proj",
1755                        )
1756                        .and_then(|t| t.transpose(1, 2)?.contiguous())
1757                })?;
1758            let down_proj_packed = experts_vb
1759                .get(
1760                    (num_experts, moe_intermediate_size, hidden_size),
1761                    "down_proj",
1762                )
1763                .or_else(|_| {
1764                    experts_vb
1765                        .get(
1766                            (num_experts, hidden_size, moe_intermediate_size),
1767                            "down_proj",
1768                        )
1769                        .and_then(|t| t.transpose(1, 2)?.contiguous())
1770                })?;
1771
1772            // Split gate_up_proj into gate_proj and up_proj along the last dimension
1773            // gate_proj: [num_experts, hidden_size, intermediate_size]
1774            // up_proj: [num_experts, hidden_size, intermediate_size]
1775            let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1776            let up_proj = gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1777
1778            // Transpose dims 1 and 2 to match GGUF format:
1779            // gate/up: [num_experts, hidden_size, intermediate_size] -> [num_experts, intermediate_size, hidden_size]
1780            let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1781            let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1782            // down_proj: [num_experts, intermediate_size, hidden_size] -> [num_experts, hidden_size, intermediate_size]
1783            let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1784
1785            // When immediate ISQ targets these weights, move to CPU for GGML quantization.
1786            let isq_gate_up = should_apply_immediate_isq(&experts_vb.pp("gate_up_proj"));
1787            let isq_down = should_apply_immediate_isq(&experts_vb.pp("down_proj"));
1788            let target_device = gate_proj.device().clone();
1789            let (gate_proj, up_proj, down_proj) =
1790                if (isq_gate_up || isq_down) && !target_device.is_cpu() {
1791                    (
1792                        gate_proj.to_device(&Device::Cpu)?,
1793                        up_proj.to_device(&Device::Cpu)?,
1794                        down_proj.to_device(&Device::Cpu)?,
1795                    )
1796                } else {
1797                    (gate_proj, up_proj, down_proj)
1798                };
1799
1800            let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1801                QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1802            )?);
1803            let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1804                QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1805            )?);
1806            let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1807                QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1808            )?);
1809            // Pass the original-device VB so apply_immediate_isq targets
1810            // the correct device and respects topology overrides.
1811            let vb_gate_up = experts_vb.pp("gate_up_proj");
1812            let vb_down = experts_vb.pp("down_proj");
1813            fused_gate_proj = apply_immediate_isq(fused_gate_proj, vb_gate_up.clone())?;
1814            fused_up_proj = apply_immediate_isq(fused_up_proj, vb_gate_up)?;
1815            fused_down_proj = apply_immediate_isq(fused_down_proj, vb_down)?;
1816
1817            (fused_gate_proj, fused_up_proj, fused_down_proj)
1818        } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
1819            // Per-expert format with FP8 quantization
1820            // Keep weights as FP8 using BlockwiseFP8 to leverage native FP8 GEMM in gather_forward
1821            let weight_block_size = match quantization_config {
1822                Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1823                _ => unreachable!(),
1824            };
1825
1826            let Some(weight_block_size) = weight_block_size else {
1827                candle_core::bail!(
1828                    "Blockwise FP8 for per-expert format requires weight_block_size to be set."
1829                )
1830            };
1831            if weight_block_size.len() != 2 {
1832                candle_core::bail!(
1833                    "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1834                );
1835            }
1836
1837            let mut gate_fp8_vec = Vec::new();
1838            let mut gate_scale_vec = Vec::new();
1839            let mut up_fp8_vec = Vec::new();
1840            let mut up_scale_vec = Vec::new();
1841            let mut down_fp8_vec = Vec::new();
1842            let mut down_scale_vec = Vec::new();
1843
1844            for i in 0..num_experts {
1845                let expert_vb = experts_vb.pp(i);
1846
1847                // Load FP8 weights and scales for each projection
1848                let gate_fp8 = expert_vb.get_with_hints_dtype(
1849                    (moe_intermediate_size, hidden_size),
1850                    "gate_proj.weight",
1851                    Default::default(),
1852                    candle_core::DType::F8E4M3,
1853                )?;
1854                let gate_scale = expert_vb.get_with_hints_dtype(
1855                    (
1856                        moe_intermediate_size.div_ceil(weight_block_size[0]),
1857                        hidden_size.div_ceil(weight_block_size[1]),
1858                    ),
1859                    "gate_proj.weight_scale_inv",
1860                    Default::default(),
1861                    candle_core::DType::F32,
1862                )?;
1863
1864                let up_fp8 = expert_vb.get_with_hints_dtype(
1865                    (moe_intermediate_size, hidden_size),
1866                    "up_proj.weight",
1867                    Default::default(),
1868                    candle_core::DType::F8E4M3,
1869                )?;
1870                let up_scale = expert_vb.get_with_hints_dtype(
1871                    (
1872                        moe_intermediate_size.div_ceil(weight_block_size[0]),
1873                        hidden_size.div_ceil(weight_block_size[1]),
1874                    ),
1875                    "up_proj.weight_scale_inv",
1876                    Default::default(),
1877                    candle_core::DType::F32,
1878                )?;
1879
1880                let down_fp8 = expert_vb.get_with_hints_dtype(
1881                    (hidden_size, moe_intermediate_size),
1882                    "down_proj.weight",
1883                    Default::default(),
1884                    candle_core::DType::F8E4M3,
1885                )?;
1886                let down_scale = expert_vb.get_with_hints_dtype(
1887                    (
1888                        hidden_size.div_ceil(weight_block_size[0]),
1889                        moe_intermediate_size.div_ceil(weight_block_size[1]),
1890                    ),
1891                    "down_proj.weight_scale_inv",
1892                    Default::default(),
1893                    candle_core::DType::F32,
1894                )?;
1895
1896                gate_fp8_vec.push(gate_fp8);
1897                gate_scale_vec.push(gate_scale);
1898                up_fp8_vec.push(up_fp8);
1899                up_scale_vec.push(up_scale);
1900                down_fp8_vec.push(down_fp8);
1901                down_scale_vec.push(down_scale);
1902            }
1903
1904            // Stack into [num_experts, N, K]
1905            let gate_fp8 = Tensor::stack(&gate_fp8_vec, 0)?;
1906            let gate_scale = Tensor::stack(&gate_scale_vec, 0)?;
1907            let up_fp8 = Tensor::stack(&up_fp8_vec, 0)?;
1908            let up_scale = Tensor::stack(&up_scale_vec, 0)?;
1909            let down_fp8 = Tensor::stack(&down_fp8_vec, 0)?;
1910            let down_scale = Tensor::stack(&down_scale_vec, 0)?;
1911
1912            // Create BlockwiseFP8Linear for each projection
1913            let fused_gate_proj =
1914                blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1915            let fused_up_proj =
1916                blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1917            let fused_down_proj =
1918                blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1919
1920            (fused_gate_proj, fused_up_proj, fused_down_proj)
1921        } else if !experts_vb.pp("0").contains_tensor("gate_proj.weight") {
1922            // Handle the case where the layer is dummy (no tensors) during UQFF loading.
1923            // Deserialize will handle it.
1924            let fused_gate_proj: Arc<dyn QuantMethod> =
1925                Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?);
1926            let fused_up_proj: Arc<dyn QuantMethod> =
1927                Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?);
1928            let fused_down_proj: Arc<dyn QuantMethod> =
1929                Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?);
1930            (fused_gate_proj, fused_up_proj, fused_down_proj)
1931        } else {
1932            // Per-expert format: load each expert individually and stack
1933            // When immediate ISQ is active, load on CPU for GGML quantization.
1934            let load_experts_vb =
1935                if crate::get_immediate_isq().is_some() && !experts_vb.device().is_cpu() {
1936                    experts_vb.clone().set_device(Device::Cpu)
1937                } else {
1938                    experts_vb.clone()
1939                };
1940            let mut gate_proj_vec = Vec::new();
1941            let mut up_proj_vec = Vec::new();
1942            let mut down_proj_vec = Vec::new();
1943            for i in 0..num_experts {
1944                let expert_vb = load_experts_vb.pp(i);
1945                let gate_proj =
1946                    expert_vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
1947                let up_proj =
1948                    expert_vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
1949                let down_proj =
1950                    expert_vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
1951
1952                gate_proj_vec.push(gate_proj);
1953                up_proj_vec.push(up_proj);
1954                down_proj_vec.push(down_proj);
1955            }
1956
1957            let mut gate_proj: Arc<dyn QuantMethod> =
1958                Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1959                    Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1960                ))?);
1961            let mut up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1962                QuantMethodConfig::Unquantized(Linear::new(Tensor::stack(&up_proj_vec, 0)?, None)),
1963            )?);
1964            let mut down_proj: Arc<dyn QuantMethod> =
1965                Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1966                    Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1967                ))?);
1968            // Use experts.0.{proj} prefix to match the actual weight paths for ISQ predicate matching
1969            let expert0_vb = experts_vb.pp("0");
1970            gate_proj = apply_immediate_isq(gate_proj, expert0_vb.pp("gate_proj"))?;
1971            up_proj = apply_immediate_isq(up_proj, expert0_vb.pp("up_proj"))?;
1972            down_proj = apply_immediate_isq(down_proj, expert0_vb.pp("down_proj"))?;
1973
1974            (gate_proj, up_proj, down_proj)
1975        };
1976
1977        Ok(Self {
1978            fused_gate_proj,
1979            fused_up_proj,
1980            fused_down_proj,
1981        })
1982    }
1983}
1984
1985/// Compute the appropriate KV shard. This handles KV head replication. Be sure to use `compute_n_kv_groups` in tandem.
1986pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
1987    if comm.world_size() == 1 {
1988        return Shard::default();
1989    }
1990
1991    // Tensor parallelism case
1992
1993    // We may need to replicate the kv heads
1994    let kv_replicate = if comm.world_size() > total_num_kv_heads {
1995        comm.world_size() / total_num_kv_heads
1996    } else {
1997        return Shard::Simple {
1998            dim: 0,
1999            rank: comm.rank(),
2000            world_size: comm.world_size(),
2001        };
2002    };
2003
2004    let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
2005    let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
2006    Shard::Offset {
2007        dim: 0,
2008        offset: kv_shard_id * head_dim,
2009        len: head_dim,
2010    }
2011}
2012
2013/// Compute the number of KV groups, taking into account KV head replication.
2014pub fn compute_n_kv_groups(
2015    total_num_kv_heads: usize,
2016    num_attention_heads: usize,
2017    comm: &Comm,
2018) -> usize {
2019    let kv_replicate = if comm.world_size() > total_num_kv_heads {
2020        comm.world_size() / total_num_kv_heads
2021    } else {
2022        1
2023    };
2024    if kv_replicate != 0 {
2025        (num_attention_heads / total_num_kv_heads) / kv_replicate
2026    } else {
2027        num_attention_heads / total_num_kv_heads
2028    }
2029}