mistralrs-core 0.8.1

Fast, flexible LLM inference.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
//! Unified MoE experts layer supporting multiple backends and weight formats.
//!
//! This module provides `MoEExperts`, a flexible experts layer that:
//! - Does NOT carry the gate (router) - gate is external
//! - Supports both per-expert and stacked weight formats
//! - Handles backend selection (fused/fast/slow)
//! - Manages tensor parallelism with all-reduce

use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::Linear;
use mistralrs_quant::{
    apply_immediate_isq, should_apply_immediate_isq, DummyLayer, FusedExperts, MatMul,
    PackedExperts, QuantMethod, QuantMethodConfig, QuantizedConfig, ShardedVarBuilder,
    SumAllReduce, UnquantLinear,
};
use std::sync::Arc;

use crate::cuda::moe;
use crate::layers::Activation;
use crate::moe::shard;

/// Configuration for MoEExperts
pub struct MoEExpertsConfig {
    pub num_experts: usize,
    pub num_experts_per_tok: usize,
    pub hidden_size: usize,
    pub moe_intermediate_size: usize,
}

/// Backend selection for MoE experts
pub enum MoEExpertsBackend {
    /// Use fused CUDA kernels with raw tensors (fastest for CUDA unquantized)
    Fused,
    /// Use gather-based implementation (good for Metal, ISQ)
    Fast,
    /// Use loop-based implementation (fallback for quantized)
    Slow,
}

impl MoEExpertsBackend {
    /// Determine the best backend based on device and quantization settings
    pub fn select(
        device: &Device,
        loading_isq: bool,
        quantization_config: &Option<QuantizedConfig>,
    ) -> Self {
        let has_immediate_isq = mistralrs_quant::get_immediate_isq().is_some();
        let use_fast = device.is_metal()
            || (device.is_cuda()
                && (loading_isq || quantization_config.is_some() || has_immediate_isq));

        if use_fast {
            Self::Fast
        } else if quantization_config.is_none()
            && !loading_isq
            && !has_immediate_isq
            && device.is_cuda()
        {
            Self::Fused
        } else {
            Self::Slow
        }
    }
}

/// Internal representation of fused expert weights for CUDA kernels
struct FusedExpertsWeights {
    /// gate_up weights: [E, N, K] for standard, [E, K, N] for stacked
    gate_up_w: Tensor,
    /// down weights: [E, N, K] for standard, [E, K, N] for stacked
    down_w: Tensor,
    /// Size of intermediate dimension (after sharding)
    w_size_n: usize,
    /// Whether weights are in stacked format [E, K, N]
    stacked_format: bool,
}

/// Internal representation for gather-based experts (Metal/ISQ)
struct FastExpertsWeights {
    fused_gate_proj: Arc<dyn QuantMethod>,
    fused_up_proj: Arc<dyn QuantMethod>,
    fused_down_proj: Arc<dyn QuantMethod>,
}

/// Internal representation for loop-based experts (quantized fallback)
struct SlowExpertsWeights {
    experts: PackedExperts,
}

/// MoE experts layer without gate
///
/// This struct encapsulates the expert weights and forward logic,
/// but does NOT include the routing gate. The caller is responsible
/// for computing routing weights and topk indices.
pub struct MoEExperts {
    backend: MoEExpertsBackendImpl,
    act: Activation,
    num_experts_per_tok: usize,
    all_reduce: SumAllReduce,
    world_size: usize,
}

enum MoEExpertsBackendImpl {
    Fused(FusedExpertsWeights),
    Fast(FastExpertsWeights),
    Slow(SlowExpertsWeights),
}

impl MoEExperts {
    /// Create MoEExperts with automatic backend selection
    ///
    /// Automatically detects weight format (stacked vs per-expert) and
    /// selects the appropriate backend based on device and quantization.
    pub fn new(
        cfg: &MoEExpertsConfig,
        vb: ShardedVarBuilder,
        layer_device: Device,
        comm: &Arc<mistralrs_quant::Comm>,
        loading_isq: bool,
        quantization_config: &Option<QuantizedConfig>,
        act: Activation,
    ) -> Result<Self> {
        let backend = MoEExpertsBackend::select(&layer_device, loading_isq, quantization_config);
        Self::new_with_backend(
            cfg,
            vb,
            layer_device,
            comm,
            backend,
            quantization_config,
            act,
        )
    }

    /// Create MoEExperts with explicit backend selection
    pub fn new_with_backend(
        cfg: &MoEExpertsConfig,
        vb: ShardedVarBuilder,
        layer_device: Device,
        comm: &Arc<mistralrs_quant::Comm>,
        backend: MoEExpertsBackend,
        quantization_config: &Option<QuantizedConfig>,
        act: Activation,
    ) -> Result<Self> {
        let experts_vb = vb.pp("experts").set_device(layer_device.clone());

        // Detect format: stacked has "gate_up_proj", per-expert has "0.gate_proj"
        let is_stacked = experts_vb.contains_tensor("gate_up_proj");

        let backend_impl = match backend {
            MoEExpertsBackend::Fused => {
                if is_stacked {
                    MoEExpertsBackendImpl::Fused(Self::load_fused_stacked(cfg, experts_vb, comm)?)
                } else {
                    MoEExpertsBackendImpl::Fused(Self::load_fused_standard(cfg, experts_vb, comm)?)
                }
            }
            MoEExpertsBackend::Fast => {
                if is_stacked {
                    MoEExpertsBackendImpl::Fast(Self::load_fast_stacked(
                        cfg,
                        vb,
                        quantization_config,
                    )?)
                } else {
                    MoEExpertsBackendImpl::Fast(Self::load_fast_standard(
                        cfg,
                        vb,
                        quantization_config,
                    )?)
                }
            }
            MoEExpertsBackend::Slow => MoEExpertsBackendImpl::Slow(Self::load_slow(
                cfg,
                experts_vb,
                comm,
                quantization_config,
            )?),
        };

        Ok(Self {
            backend: backend_impl,
            act,
            num_experts_per_tok: cfg.num_experts_per_tok,
            all_reduce: SumAllReduce::new(comm),
            world_size: comm.world_size(),
        })
    }

    /// Create MoEExperts from a VarBuilder already at the experts level.
    ///
    /// Unlike `new` which does `vb.pp("experts")` internally, this takes the VB
    /// already pointing at the experts-level path. Use this when the model's weight
    /// structure doesn't have an "experts" sublevel (e.g., Gemma 4 uses `moe.*` directly).
    ///
    /// Supports two weight formats:
    /// - Combined stacked: `gate_up_proj` [E, hidden, 2*inter]
    /// - Per-expert: `{i}/gate_proj/weight` [inter, hidden]
    pub fn new_direct(
        cfg: &MoEExpertsConfig,
        experts_vb: ShardedVarBuilder,
        comm: &Arc<mistralrs_quant::Comm>,
        loading_isq: bool,
        quantization_config: &Option<QuantizedConfig>,
        act: Activation,
    ) -> Result<Self> {
        let layer_device = experts_vb.device().clone();
        let backend = MoEExpertsBackend::select(&layer_device, loading_isq, quantization_config);

        let is_stacked_combined = experts_vb.contains_tensor("gate_up_proj");

        let backend_impl = match backend {
            MoEExpertsBackend::Fused => {
                if is_stacked_combined {
                    MoEExpertsBackendImpl::Fused(Self::load_fused_stacked(cfg, experts_vb, comm)?)
                } else {
                    MoEExpertsBackendImpl::Fused(Self::load_fused_standard(cfg, experts_vb, comm)?)
                }
            }
            MoEExpertsBackend::Fast => {
                if is_stacked_combined && quantization_config.is_none() {
                    MoEExpertsBackendImpl::Fast(Self::load_fast_combined_stacked(cfg, experts_vb)?)
                } else if is_stacked_combined {
                    MoEExpertsBackendImpl::Slow(Self::load_slow_from_combined_stacked(
                        cfg, experts_vb,
                    )?)
                } else {
                    MoEExpertsBackendImpl::Fast(Self::load_fast_direct_standard(cfg, experts_vb)?)
                }
            }
            MoEExpertsBackend::Slow => {
                if is_stacked_combined {
                    MoEExpertsBackendImpl::Slow(Self::load_slow_from_combined_stacked(
                        cfg, experts_vb,
                    )?)
                } else {
                    MoEExpertsBackendImpl::Slow(Self::load_slow(
                        cfg,
                        experts_vb,
                        comm,
                        quantization_config,
                    )?)
                }
            }
        };

        Ok(Self {
            backend: backend_impl,
            act,
            num_experts_per_tok: cfg.num_experts_per_tok,
            all_reduce: SumAllReduce::new(comm),
            world_size: comm.world_size(),
        })
    }

    /// Load fused weights in standard per-expert format
    fn load_fused_standard(
        cfg: &MoEExpertsConfig,
        experts_vb: ShardedVarBuilder,
        comm: &Arc<mistralrs_quant::Comm>,
    ) -> Result<FusedExpertsWeights> {
        let num_experts = cfg.num_experts;
        let mut gate_up_experts = Vec::with_capacity(num_experts);
        let mut down_experts = Vec::with_capacity(num_experts);

        for i in 0..num_experts {
            let expert_vb = experts_vb.pp(i.to_string());
            // n x k format
            let gate_expert = expert_vb.pp("gate_proj").get_with_hints(
                (cfg.moe_intermediate_size, cfg.hidden_size),
                "weight",
                shard(0, comm.rank(), comm.world_size()),
            )?;
            let up_expert = expert_vb.pp("up_proj").get_with_hints(
                (cfg.moe_intermediate_size, cfg.hidden_size),
                "weight",
                shard(0, comm.rank(), comm.world_size()),
            )?;
            let down_expert = expert_vb.pp("down_proj").get_with_hints(
                (cfg.hidden_size, cfg.moe_intermediate_size),
                "weight",
                shard(1, comm.rank(), comm.world_size()),
            )?;
            // Pack gate_proj and up_proj
            let gate_up_expert = Tensor::cat(&[&gate_expert, &up_expert], 0)?;

            gate_up_experts.push(gate_up_expert);
            down_experts.push(down_expert);
        }

        let gate_up_w = Tensor::stack(&gate_up_experts, 0)?;
        let down_w = Tensor::stack(&down_experts, 0)?;
        let w_size_n = gate_up_w.dim(1)? / 2;

        Ok(FusedExpertsWeights {
            gate_up_w,
            down_w,
            w_size_n,
            stacked_format: false,
        })
    }

    /// Load fused weights in stacked format
    fn load_fused_stacked(
        cfg: &MoEExpertsConfig,
        experts_vb: ShardedVarBuilder,
        comm: &Arc<mistralrs_quant::Comm>,
    ) -> Result<FusedExpertsWeights> {
        let num_experts = cfg.num_experts;

        // Stacked format has two conventions:
        // Convention A: [num_experts, hidden, inter*2] (CUDA kernel format)
        // Convention B (nn.Linear): [num_experts, inter*2, hidden]
        // Try A first, fall back to B with transpose.
        let gate_up_w = experts_vb
            .get_with_hints(
                (num_experts, cfg.hidden_size, cfg.moe_intermediate_size * 2),
                "gate_up_proj",
                shard(2, comm.rank(), comm.world_size()),
            )
            .or_else(|_| {
                experts_vb
                    .get_with_hints(
                        (num_experts, cfg.moe_intermediate_size * 2, cfg.hidden_size),
                        "gate_up_proj",
                        shard(1, comm.rank(), comm.world_size()),
                    )
                    .and_then(|t| t.transpose(1, 2)?.contiguous())
            })?;

        let down_w = experts_vb
            .get_with_hints(
                (num_experts, cfg.moe_intermediate_size, cfg.hidden_size),
                "down_proj",
                shard(1, comm.rank(), comm.world_size()),
            )
            .or_else(|_| {
                experts_vb
                    .get_with_hints(
                        (num_experts, cfg.hidden_size, cfg.moe_intermediate_size),
                        "down_proj",
                        shard(2, comm.rank(), comm.world_size()),
                    )
                    .and_then(|t| t.transpose(1, 2)?.contiguous())
            })?;

        let w_size_n = gate_up_w.dim(2)? / 2;

        Ok(FusedExpertsWeights {
            gate_up_w,
            down_w,
            w_size_n,
            stacked_format: true,
        })
    }

    /// Load fast (gather-based) weights from a combined stacked `gate_up_proj`.
    ///
    /// Supports:
    /// - `gate_up_proj`: [E, hidden, 2*inter] or [E, 2*inter, hidden]
    /// - `down_proj`: [E, inter, hidden] or [E, hidden, inter]
    fn load_fast_combined_stacked(
        cfg: &MoEExpertsConfig,
        experts_vb: ShardedVarBuilder,
    ) -> Result<FastExpertsWeights> {
        let num_experts = cfg.num_experts;

        let isq_gate_up = should_apply_immediate_isq(&experts_vb.pp("gate_up_proj"));
        let isq_down = should_apply_immediate_isq(&experts_vb.pp("down_proj"));

        // When immediate ISQ is active, load directly on CPU to avoid creating
        // large GPU buffers that will be immediately copied to CPU for quantization.
        // On unified memory systems (Metal), this prevents doubling memory usage.
        let load_vb = if (isq_gate_up || isq_down) && !experts_vb.device().is_cpu() {
            experts_vb.clone().set_device(Device::Cpu)
        } else {
            experts_vb.clone()
        };

        let gate_up_proj = load_vb
            .get(
                (num_experts, cfg.hidden_size, cfg.moe_intermediate_size * 2),
                "gate_up_proj",
            )
            .or_else(|_| {
                load_vb
                    .get(
                        (num_experts, cfg.moe_intermediate_size * 2, cfg.hidden_size),
                        "gate_up_proj",
                    )
                    .and_then(|t| t.transpose(1, 2)?.contiguous())
            })?;
        let down_proj_packed = load_vb
            .get(
                (num_experts, cfg.moe_intermediate_size, cfg.hidden_size),
                "down_proj",
            )
            .or_else(|_| {
                load_vb
                    .get(
                        (num_experts, cfg.hidden_size, cfg.moe_intermediate_size),
                        "down_proj",
                    )
                    .and_then(|t| t.transpose(1, 2)?.contiguous())
            })?;

        let gate_proj = gate_up_proj
            .narrow(2, 0, cfg.moe_intermediate_size)?
            .transpose(1, 2)?
            .contiguous()?;
        let up_proj = gate_up_proj
            .narrow(2, cfg.moe_intermediate_size, cfg.moe_intermediate_size)?
            .transpose(1, 2)?
            .contiguous()?;
        // Drop gate_up_proj early to free memory before creating more tensors
        drop(gate_up_proj);
        let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
        drop(down_proj_packed);

        let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
            QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
        )?);
        let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
            QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
        )?);
        let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
            QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
        )?);

        // Pass the original-device VB (not CPU) so apply_immediate_isq targets
        // the correct device for the quantized weights.
        let vb_gate_up = experts_vb.pp("gate_up_proj");
        let vb_down = experts_vb.pp("down_proj");
        fused_gate_proj = apply_immediate_isq(fused_gate_proj, vb_gate_up.clone())?;
        fused_up_proj = apply_immediate_isq(fused_up_proj, vb_gate_up)?;
        fused_down_proj = apply_immediate_isq(fused_down_proj, vb_down)?;

        Ok(FastExpertsWeights {
            fused_gate_proj,
            fused_up_proj,
            fused_down_proj,
        })
    }

    /// Load fast (gather-based) weights in per-expert format from a VB already
    /// at the experts level (no `.pp("experts")` applied).
    ///
    /// Handles both real per-expert weights and UQFF dummy layers.
    fn load_fast_direct_standard(
        cfg: &MoEExpertsConfig,
        experts_vb: ShardedVarBuilder,
    ) -> Result<FastExpertsWeights> {
        let num_experts = cfg.num_experts;

        // UQFF loading: experts have no real tensors yet, create dummy layers
        // that will be replaced during deserialization.
        if !experts_vb.pp("0").contains_tensor("gate_proj.weight") {
            let fused_gate_proj: Arc<dyn QuantMethod> =
                Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?);
            let fused_up_proj: Arc<dyn QuantMethod> =
                Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?);
            let fused_down_proj: Arc<dyn QuantMethod> =
                Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?);
            return Ok(FastExpertsWeights {
                fused_gate_proj,
                fused_up_proj,
                fused_down_proj,
            });
        }

        // Real per-expert weights: load, stack, and optionally ISQ
        let load_experts_vb =
            if mistralrs_quant::get_immediate_isq().is_some() && !experts_vb.device().is_cpu() {
                experts_vb.clone().set_device(Device::Cpu)
            } else {
                experts_vb.clone()
            };

        let mut gate_proj_vec = Vec::with_capacity(num_experts);
        let mut up_proj_vec = Vec::with_capacity(num_experts);
        let mut down_proj_vec = Vec::with_capacity(num_experts);

        for i in 0..num_experts {
            let expert_vb = load_experts_vb.pp(i.to_string());
            gate_proj_vec.push(expert_vb.get(
                (cfg.moe_intermediate_size, cfg.hidden_size),
                "gate_proj.weight",
            )?);
            up_proj_vec.push(expert_vb.get(
                (cfg.moe_intermediate_size, cfg.hidden_size),
                "up_proj.weight",
            )?);
            down_proj_vec.push(expert_vb.get(
                (cfg.hidden_size, cfg.moe_intermediate_size),
                "down_proj.weight",
            )?);
        }

        let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
            QuantMethodConfig::Unquantized(Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None)),
        )?);
        let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
            QuantMethodConfig::Unquantized(Linear::new(Tensor::stack(&up_proj_vec, 0)?, None)),
        )?);
        let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
            QuantMethodConfig::Unquantized(Linear::new(Tensor::stack(&down_proj_vec, 0)?, None)),
        )?);

        let expert0_vb = experts_vb.pp("0");
        fused_gate_proj = apply_immediate_isq(fused_gate_proj, expert0_vb.pp("gate_proj"))?;
        fused_up_proj = apply_immediate_isq(fused_up_proj, expert0_vb.pp("up_proj"))?;
        fused_down_proj = apply_immediate_isq(fused_down_proj, expert0_vb.pp("down_proj"))?;

        Ok(FastExpertsWeights {
            fused_gate_proj,
            fused_up_proj,
            fused_down_proj,
        })
    }

    /// Load slow (loop-based) weights from a combined stacked `gate_up_proj`.
    ///
    /// Supports both direct stacked conventions used by Gemma4 checkpoints:
    /// - `gate_up_proj`: [E, hidden, 2*inter] or [E, 2*inter, hidden]
    /// - `down_proj`: [E, inter, hidden] or [E, hidden, inter]
    fn load_slow_from_combined_stacked(
        cfg: &MoEExpertsConfig,
        experts_vb: ShardedVarBuilder,
    ) -> Result<SlowExpertsWeights> {
        let num_experts = cfg.num_experts;

        let isq_gate_up = should_apply_immediate_isq(&experts_vb.pp("gate_up_proj"));
        let isq_down = should_apply_immediate_isq(&experts_vb.pp("down_proj"));

        // When immediate ISQ is active, load directly on CPU to avoid creating
        // large GPU buffers that will be immediately copied to CPU for quantization.
        let load_vb = if (isq_gate_up || isq_down) && !experts_vb.device().is_cpu() {
            experts_vb.clone().set_device(Device::Cpu)
        } else {
            experts_vb.clone()
        };

        let gate_up_proj = load_vb
            .get(
                (num_experts, cfg.hidden_size, cfg.moe_intermediate_size * 2),
                "gate_up_proj",
            )
            .or_else(|_| {
                load_vb
                    .get(
                        (num_experts, cfg.moe_intermediate_size * 2, cfg.hidden_size),
                        "gate_up_proj",
                    )
                    .and_then(|t| t.transpose(1, 2)?.contiguous())
            })?;
        let down_proj_packed = load_vb
            .get(
                (num_experts, cfg.moe_intermediate_size, cfg.hidden_size),
                "down_proj",
            )
            .or_else(|_| {
                load_vb
                    .get(
                        (num_experts, cfg.hidden_size, cfg.moe_intermediate_size),
                        "down_proj",
                    )
                    .and_then(|t| t.transpose(1, 2)?.contiguous())
            })?;

        // Pass the original-device VB (not CPU) so apply_immediate_isq targets
        // the correct device for the quantized weights.
        let vb_gate_up = experts_vb.pp("gate_up_proj");
        let vb_down = experts_vb.pp("down_proj");

        let mut gate_proj = Vec::with_capacity(num_experts);
        let mut up_proj = Vec::with_capacity(num_experts);
        let mut down_proj = Vec::with_capacity(num_experts);

        for i in 0..num_experts {
            let gate_up_expert = gate_up_proj.i(i)?;
            let gate = gate_up_expert
                .narrow(1, 0, cfg.moe_intermediate_size)?
                .transpose(0, 1)?
                .contiguous()?;
            let up = gate_up_expert
                .narrow(1, cfg.moe_intermediate_size, cfg.moe_intermediate_size)?
                .transpose(0, 1)?
                .contiguous()?;
            let down = down_proj_packed.i(i)?.transpose(0, 1)?.contiguous()?;

            let mut gate_layer: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
                QuantMethodConfig::Unquantized(candle_nn::Linear::new(gate, None)),
            )?);
            let mut up_layer: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
                QuantMethodConfig::Unquantized(candle_nn::Linear::new(up, None)),
            )?);
            let mut down_layer: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
                QuantMethodConfig::Unquantized(candle_nn::Linear::new(down, None)),
            )?);

            gate_layer = apply_immediate_isq(gate_layer, vb_gate_up.clone())?;
            up_layer = apply_immediate_isq(up_layer, vb_gate_up.clone())?;
            down_layer = apply_immediate_isq(down_layer, vb_down.clone())?;

            gate_proj.push(gate_layer);
            up_proj.push(up_layer);
            down_proj.push(down_layer);
        }

        Ok(SlowExpertsWeights {
            experts: PackedExperts {
                gate_proj,
                up_proj,
                down_proj,
            },
        })
    }

    /// Load fast (gather-based) weights in standard per-expert format
    fn load_fast_standard(
        cfg: &MoEExpertsConfig,
        vb: ShardedVarBuilder,
        quantization_config: &Option<QuantizedConfig>,
    ) -> Result<FastExpertsWeights> {
        let FusedExperts {
            fused_gate_proj,
            fused_up_proj,
            fused_down_proj,
        } = FusedExperts::new(
            cfg.hidden_size,
            cfg.moe_intermediate_size,
            cfg.num_experts,
            quantization_config,
            vb,
        )?;

        Ok(FastExpertsWeights {
            fused_gate_proj,
            fused_up_proj,
            fused_down_proj,
        })
    }

    /// Load fast (gather-based) weights in stacked format
    fn load_fast_stacked(
        cfg: &MoEExpertsConfig,
        vb: ShardedVarBuilder,
        quantization_config: &Option<QuantizedConfig>,
    ) -> Result<FastExpertsWeights> {
        // FusedExperts auto-detects stacked format
        let FusedExperts {
            fused_gate_proj,
            fused_up_proj,
            fused_down_proj,
        } = FusedExperts::new(
            cfg.hidden_size,
            cfg.moe_intermediate_size,
            cfg.num_experts,
            quantization_config,
            vb,
        )?;

        Ok(FastExpertsWeights {
            fused_gate_proj,
            fused_up_proj,
            fused_down_proj,
        })
    }

    /// Load slow (loop-based) weights using PackedExperts
    fn load_slow(
        cfg: &MoEExpertsConfig,
        experts_vb: ShardedVarBuilder,
        comm: &Arc<mistralrs_quant::Comm>,
        quantization_config: &Option<QuantizedConfig>,
    ) -> Result<SlowExpertsWeights> {
        let experts = PackedExperts::new(
            cfg.num_experts,
            cfg.hidden_size,
            cfg.moe_intermediate_size,
            quantization_config,
            false,
            comm,
            experts_vb,
        )?;

        Ok(SlowExpertsWeights { experts })
    }

    /// Forward pass through experts
    ///
    /// # Arguments
    /// * `xs` - Input tensor of shape [batch, seq_len, hidden_dim]
    /// * `topk_weights` - Top-k routing weights of shape [num_tokens, num_experts_per_tok]
    /// * `topk_ids` - Top-k expert indices of shape [num_tokens, num_experts_per_tok]
    ///
    /// # Returns
    /// Output tensor of shape [batch, seq_len, hidden_dim]
    pub fn forward(&self, xs: &Tensor, topk_weights: Tensor, topk_ids: &Tensor) -> Result<Tensor> {
        let (b_size, seq_len, hidden_dim) = xs.dims3()?;
        // Prefill = processing multiple tokens; Decode = single token generation
        let is_prefill = seq_len > 1;

        let mut ys = match &self.backend {
            MoEExpertsBackendImpl::Fused(weights) => {
                self.forward_fused(xs, &topk_weights, topk_ids, weights, is_prefill)?
            }
            MoEExpertsBackendImpl::Fast(weights) => {
                self.forward_fast(xs, &topk_weights, topk_ids, weights)?
            }
            MoEExpertsBackendImpl::Slow(weights) => {
                self.forward_slow(xs, &topk_weights, topk_ids, weights)?
            }
        };

        // Apply all-reduce for tensor parallelism
        if self.world_size > 1 {
            ys = self.all_reduce.sum_all_reduce(&ys)?;
        }

        ys.reshape((b_size, seq_len, hidden_dim))
    }

    /// Fused CUDA kernel forward pass
    fn forward_fused(
        &self,
        xs: &Tensor,
        topk_weights: &Tensor,
        topk_ids: &Tensor,
        weights: &FusedExpertsWeights,
        is_prefill: bool,
    ) -> Result<Tensor> {
        let (_b_size, _seq_len, hidden_dim) = xs.dims3()?;
        let xs = xs.reshape(((), hidden_dim))?;
        let (num_tokens, _) = xs.dims2()?;

        // Sort tokens by expert for efficient processing
        let (expert_ids, sorted_token_ids) = if is_prefill {
            #[cfg(feature = "cuda")]
            {
                use crate::ops::ArgSortOp;
                topk_ids.flatten_all()?.sort(true)?
            }
            #[cfg(not(feature = "cuda"))]
            topk_ids.flatten_all()?.sort_last_dim(true)?
        } else {
            topk_ids.flatten_all()?.sort_last_dim(true)?
        };

        // First GEMM: gate_up projection
        let gate_up = if weights.stacked_format {
            moe::moe_gemm_transposed(
                &xs,
                &weights.gate_up_w,
                &None,
                &sorted_token_ids,
                &expert_ids,
                self.num_experts_per_tok,
                is_prefill,
            )?
        } else {
            moe::moe_gemm(
                &xs,
                &weights.gate_up_w,
                &None,
                &sorted_token_ids,
                &expert_ids,
                self.num_experts_per_tok,
                is_prefill,
            )?
        };

        // Split and apply activation
        let gate = gate_up
            .narrow(D::Minus1, 0, weights.w_size_n)?
            .contiguous()?;
        let up = gate_up
            .narrow(D::Minus1, weights.w_size_n, weights.w_size_n)?
            .contiguous()?;

        let down_inputs = (up * gate.apply(&self.act)?)?.reshape(((), weights.w_size_n))?;

        // Second GEMM: down projection with weight aggregation
        let ys = if weights.stacked_format {
            moe::moe_gemm_transposed(
                &down_inputs,
                &weights.down_w,
                &Some(topk_weights.clone()),
                &sorted_token_ids,
                &expert_ids,
                self.num_experts_per_tok,
                is_prefill,
            )?
        } else {
            moe::moe_gemm(
                &down_inputs,
                &weights.down_w,
                &Some(topk_weights.clone()),
                &sorted_token_ids,
                &expert_ids,
                self.num_experts_per_tok,
                is_prefill,
            )?
        };

        ys.reshape((num_tokens, (), hidden_dim))?.sum(D::Minus2)
    }

    /// Gather-based forward pass (Metal/ISQ)
    fn forward_fast(
        &self,
        xs: &Tensor,
        topk_weights: &Tensor,
        topk_ids: &Tensor,
        weights: &FastExpertsWeights,
    ) -> Result<Tensor> {
        let original_dtype = xs.dtype();
        let (b_size, seq_len, hidden_dim) = xs.dims3()?;
        let num_tokens = b_size * seq_len;

        let xs_flat = xs.reshape((num_tokens, hidden_dim))?;

        let ys = if xs.device().is_cuda() {
            // CUDA path: use indexed_moe_forward compatible shapes
            let xs = xs_flat.reshape((num_tokens, 1, hidden_dim))?;
            let gate = weights
                .fused_gate_proj
                .gather_forward_autocast(&xs, topk_ids)?;
            let up = weights
                .fused_up_proj
                .gather_forward_autocast(&xs, topk_ids)?;
            weights
                .fused_down_proj
                .gather_forward_autocast(&(up * gate.apply(&self.act)?)?, topk_ids)?
        } else {
            // Metal path: use broadcast gather shapes
            let xs = xs.reshape((b_size, seq_len, 1, 1, hidden_dim))?;
            let indices = topk_ids.reshape((b_size, seq_len, self.num_experts_per_tok))?;
            let gate = weights
                .fused_gate_proj
                .gather_forward_autocast(&xs, &indices)?;
            let up = weights
                .fused_up_proj
                .gather_forward_autocast(&xs, &indices)?;
            let xs = weights
                .fused_down_proj
                .gather_forward_autocast(&(up * gate.apply(&self.act)?)?, &indices)?;
            xs.squeeze(D::Minus2)?
                .reshape((num_tokens, self.num_experts_per_tok, hidden_dim))?
        };

        ys.to_dtype(DType::F32)?
            .broadcast_mul(&topk_weights.unsqueeze(D::Minus1)?)?
            .sum(D::Minus2)?
            .to_dtype(original_dtype)
    }

    /// Loop-based forward pass (quantized fallback)
    fn forward_slow(
        &self,
        xs: &Tensor,
        topk_weights: &Tensor,
        topk_ids: &Tensor,
        weights: &SlowExpertsWeights,
    ) -> Result<Tensor> {
        let (b_size, seq_len, hidden_dim) = xs.dims3()?;
        let xs = xs.reshape(((), hidden_dim))?;

        let routing_weights = topk_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
        let experts_per_tok = topk_ids.to_vec2::<u32>()?;
        let num_experts = weights.experts.gate_proj.len();

        let mut top_x = vec![vec![]; num_experts];
        let mut selected_experts = vec![vec![]; num_experts];

        for (row_idx, (rw, expert_idxs)) in routing_weights
            .iter()
            .zip(experts_per_tok.iter())
            .enumerate()
        {
            for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
                let expert_idx = expert_idx as usize;
                #[allow(clippy::cast_possible_truncation)]
                top_x[expert_idx].push(row_idx as u32);
                selected_experts[expert_idx].push(rw)
            }
        }

        let mut ys = xs.zeros_like()?;
        for expert_idx in 0..num_experts {
            let top_x_expert = &top_x[expert_idx];
            if top_x_expert.is_empty() {
                continue;
            }
            let top_x_tensor = Tensor::new(top_x_expert.as_slice(), xs.device())?;
            let selected_experts_tensor =
                Tensor::new(selected_experts[expert_idx].as_slice(), xs.device())?
                    .reshape(((), 1))?
                    .to_dtype(xs.dtype())?;
            let current_state = xs
                .index_select(&top_x_tensor, 0)?
                .reshape(((), hidden_dim))?;

            // Forward through expert MLP
            let original_dtype = current_state.dtype();
            let mut expert_input = current_state.clone();
            if let Some(t) = weights.experts.gate_proj[expert_idx].quantized_act_type() {
                expert_input = expert_input.to_dtype(t)?;
            }
            let gate_out = MatMul
                .qmethod_matmul(&expert_input, &*weights.experts.gate_proj[expert_idx])?
                .apply(&self.act)?;
            let up_out =
                MatMul.qmethod_matmul(&expert_input, &*weights.experts.up_proj[expert_idx])?;
            let mut current_hidden_states = MatMul.qmethod_matmul(
                &(gate_out * up_out)?,
                &*weights.experts.down_proj[expert_idx],
            )?;
            if weights.experts.gate_proj[expert_idx]
                .quantized_act_type()
                .is_some()
            {
                current_hidden_states = current_hidden_states.to_dtype(original_dtype)?;
            }

            let current_hidden_states =
                current_hidden_states.broadcast_mul(&selected_experts_tensor)?;
            ys = ys.index_add(&top_x_tensor, &current_hidden_states, 0)?;
        }

        ys.reshape((b_size * seq_len, hidden_dim))
    }

    /// Get mutable references to quantizable layers for ISQ
    /// Returns mutable references to all ISQ-quantizable layers.
    /// The count must match `num_isq_layers`.
    pub fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
        match &mut self.backend {
            MoEExpertsBackendImpl::Fused(_) => vec![],
            MoEExpertsBackendImpl::Fast(weights) => {
                vec![
                    &mut weights.fused_gate_proj,
                    &mut weights.fused_up_proj,
                    &mut weights.fused_down_proj,
                ]
            }
            MoEExpertsBackendImpl::Slow(weights) => {
                let mut layers = Vec::new();
                for (gate, (up, down)) in weights.experts.gate_proj.iter_mut().zip(
                    weights
                        .experts
                        .up_proj
                        .iter_mut()
                        .zip(weights.experts.down_proj.iter_mut()),
                ) {
                    layers.push(gate);
                    layers.push(up);
                    layers.push(down);
                }
                layers
            }
        }
    }

    /// Returns the number of ISQ-quantizable layers.
    /// Must match the length of `get_isq_layers`.
    pub fn num_isq_layers(&self) -> usize {
        match &self.backend {
            MoEExpertsBackendImpl::Fused(_) => 0,
            MoEExpertsBackendImpl::Fast(_) => 3,
            MoEExpertsBackendImpl::Slow(weights) => weights.experts.gate_proj.len() * 3,
        }
    }
}