1use std::sync::Arc;
2
3use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
4use candle_nn::Linear;
5
6use crate::{
7 blockwise_fp8::{blockwise_fp8_linear_b, blockwise_fp8_moe},
8 distributed,
9 gptq::gptq_linear,
10 lora::merge_lora_weights,
11 pertensor_fp8::pertensor_fp8_linear_b,
12 should_apply_immediate_isq,
13 utils::isq::apply_immediate_isq,
14 AfqLayer, BnbLinear, DistributedKind, DummyLayer, F8Q8Linear, FP8Linear, GgufMatMul, HqqLayer,
15 MXFP4Layer, QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
16 QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
17};
18
19use super::{Comm, SumAllReduce};
20
21fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
22 Shard::Simple {
23 dim,
24 rank,
25 world_size,
26 }
27}
28
29#[derive(Debug)]
32pub struct RowParallelLayer {
33 weight: Arc<dyn QuantMethod>,
34 bias: Option<Tensor>,
35 all_reduce: distributed::SumAllReduce,
36}
37
38impl RowParallelLayer {
39 #[allow(clippy::new_ret_no_self)]
40 pub fn new(
41 in_dim: usize,
42 out_dim: usize,
43 config: &Option<QuantizedConfig>,
44 bias: bool,
45 comm: &Arc<crate::Comm>,
46 vb: ShardedVarBuilder,
47 ) -> Result<Arc<dyn QuantMethod>> {
48 let rank = comm.rank();
49 let world_size = comm.world_size();
50 let shard = shard(1, rank, world_size);
51
52 let base_vb = vb.clone();
53 let vb = if should_apply_immediate_isq(&vb) {
54 vb.set_device(Device::Cpu)
55 } else {
56 vb
57 };
58
59 let weight = if let Some(quant_conf) = &config {
60 if matches!(
62 quant_conf,
63 QuantizedConfig::GptqAwq { .. }
64 | QuantizedConfig::Bitsandbytes { .. }
65 | QuantizedConfig::Afq { .. }
66 ) && comm.world_size() != 1
67 {
68 candle_core::bail!(
69 "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
70 comm.world_size()
71 );
72 }
73
74 match quant_conf {
75 QuantizedConfig::GptqAwq { .. } => {
76 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
77 }
78 QuantizedConfig::Fp8 { weight_block_size } => {
79 if weight_block_size.is_some() {
81 blockwise_fp8_linear_b(
82 in_dim,
83 out_dim,
84 quant_conf,
85 false,
86 shard,
87 vb.clone(),
88 )?
89 } else {
90 pertensor_fp8_linear_b(
91 in_dim,
92 out_dim,
93 quant_conf,
94 false,
95 shard,
96 vb.clone(),
97 )?
98 }
99 }
100 QuantizedConfig::Bitsandbytes { .. } => {
101 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
102 }
103 QuantizedConfig::Afq { .. } => {
104 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
105 }
106 QuantizedConfig::MXFP4 {} => {
107 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
108 }
109 }
110 } else {
111 if !vb.contains_tensor("weight") {
113 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
114 Arc::new(layer) as Arc<dyn QuantMethod>
115 } else {
116 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
117 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
118
119 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
120 Linear::new(weight, None),
121 ))?;
122 Arc::new(layer) as Arc<dyn QuantMethod>
123 }
124 };
125
126 let bias = if bias && vb.contains_tensor("bias") {
128 Some(vb.get((out_dim,), "bias")?)
129 } else {
130 None
131 };
132
133 let this_unquant = Arc::new(Self {
134 weight,
135 bias,
136 all_reduce: distributed::SumAllReduce::new(comm),
137 });
138 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
139 Ok(this)
140 }
141
142 #[allow(clippy::new_ret_no_self)]
143 pub fn new_matformer(
144 in_dim: usize,
145 out_dim: usize,
146 orig_intermediate_size: usize,
147 config: &Option<QuantizedConfig>,
148 bias: bool,
149 comm: &Arc<crate::Comm>,
150 vb: ShardedVarBuilder,
151 ) -> Result<Arc<dyn QuantMethod>> {
152 let rank = comm.rank();
153 let world_size = comm.world_size();
154 let shard = shard(1, rank, world_size);
155
156 let base_vb = vb.clone();
157 let vb = if should_apply_immediate_isq(&vb) {
158 vb.set_device(Device::Cpu)
159 } else {
160 vb
161 };
162
163 if config.is_some() {
164 candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
165 }
166
167 let weight = if !vb.contains_tensor("weight") {
169 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
170 Arc::new(layer) as Arc<dyn QuantMethod>
171 } else {
172 let weight = vb
173 .get_with_hints(
174 (out_dim, orig_intermediate_size),
175 "weight",
176 Default::default(),
177 )?
178 .i((.., ..in_dim))?
179 .contiguous()?;
180
181 let weight = shard.apply_to(&weight)?;
182 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
183
184 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
185 Linear::new(weight, None),
186 ))?;
187 Arc::new(layer) as Arc<dyn QuantMethod>
188 };
189
190 let bias = if bias && vb.contains_tensor("bias") {
192 Some(vb.get((out_dim,), "bias")?)
193 } else {
194 None
195 };
196
197 let this_unquant = Arc::new(Self {
198 weight,
199 bias,
200 all_reduce: distributed::SumAllReduce::new(comm),
201 });
202 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
203 Ok(this)
204 }
205}
206
207impl QuantMethod for RowParallelLayer {
208 fn new(_method: QuantMethodConfig) -> Result<Self>
209 where
210 Self: Sized,
211 {
212 candle_core::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
213 }
214
215 fn forward(&self, a: &Tensor) -> Result<Tensor> {
216 let mut xs = self.weight.forward(a)?;
217 xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
218 if let Some(bias) = &self.bias {
219 xs = xs.broadcast_add(bias)?;
220 }
221 Ok(xs)
222 }
223
224 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
225 let weight = self.weight.add_delta_w(delta)?;
226 Ok(Arc::new(Self {
227 weight,
228 bias: self.bias.clone(),
229 all_reduce: self.all_reduce.clone(),
230 }))
231 }
232
233 fn dequantize_w(&self) -> Result<Tensor> {
234 self.weight.dequantize_w()
235 }
236
237 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
238 self.weight.dtype_and_device()
239 }
240
241 fn begin_track_stats(&mut self) -> Result<()> {
242 Arc::get_mut(&mut self.weight)
243 .context("Failed to get &mut to weight")?
244 .begin_track_stats()
245 }
246
247 fn end_track_stats(&self) -> Result<Tensor> {
248 self.weight.end_track_stats()
249 }
250
251 fn quantized_act_type(&self) -> Option<candle_core::DType> {
252 self.weight.quantized_act_type()
253 }
254
255 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
256 self.weight.unquant_weight_bias()
257 }
258
259 fn apply_isq(
260 self: Arc<Self>,
261 dtype: Option<crate::IsqType>,
262 device: candle_core::Device,
263 n_quantized: &std::sync::atomic::AtomicUsize,
264 imatrix_weight: Option<Vec<f32>>,
265 guard: QuantizeOntoGuard,
266 ) -> Result<Arc<dyn QuantMethod>> {
267 let weight =
268 self.weight
269 .clone()
270 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
271 let bias = match &self.bias {
272 Some(b) => {
273 let (dtype, device) = weight.dtype_and_device();
274 Some(b.to_device(&device)?.to_dtype(dtype)?)
275 }
276 None => None,
277 };
278 Ok(Arc::new(Self {
279 weight,
280 bias,
281 all_reduce: self.all_reduce.clone(),
282 }))
283 }
284
285 fn is_distributed(&self) -> Option<DistributedKind> {
286 Some(DistributedKind::RowParallel)
287 }
288}
289
290impl QuantizedSerde for RowParallelLayer {
291 fn isq_serde_supported(&self) -> bool {
292 self.weight.isq_serde_supported()
293 }
294 fn name(&self) -> &'static str {
295 self.weight.name()
296 }
297 fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
298 self.weight.serialize_with_bias(self.bias.clone())
299 }
300 fn deserialize(
301 data: std::borrow::Cow<[u8]>,
302 device: &candle_core::Device,
303 comm: &Arc<crate::Comm>,
304 guard: QuantizeOntoGuard,
305 ) -> Result<Arc<dyn QuantMethod>>
306 where
307 Self: Sized,
308 {
309 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
311 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
312 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
313 QuantizedSerdeType::Unquant => {
314 UnquantLinear::deserialize_ext_bias(data, device, guard)?
315 }
316 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
317 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
318 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
319 QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize_ext_bias(data, device, guard)?,
320 QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize_ext_bias(data, device, guard)?,
321 };
322 Ok(Arc::new(Self {
323 weight,
324 bias,
325 all_reduce: SumAllReduce::new(comm),
326 }))
327 }
328}
329
330#[derive(Debug)]
331pub struct ColumnParallelLayer {
334 weight: Arc<dyn QuantMethod>,
335 bias: Option<Tensor>,
336}
337
338impl ColumnParallelLayer {
339 #[allow(clippy::new_ret_no_self)]
340 pub fn new_with_shard(
341 in_dim: usize,
342 out_dim: usize,
343 config: &Option<QuantizedConfig>,
344 bias: bool,
345 comm: &Arc<crate::Comm>,
346 shard: Shard,
347 vb: ShardedVarBuilder,
348 ) -> Result<Arc<dyn QuantMethod>> {
349 let base_vb = vb.clone();
350 let vb = if should_apply_immediate_isq(&vb) {
351 vb.set_device(Device::Cpu)
352 } else {
353 vb
354 };
355
356 let weight = if let Some(quant_conf) = &config {
357 if matches!(
359 quant_conf,
360 QuantizedConfig::GptqAwq { .. }
361 | QuantizedConfig::Bitsandbytes { .. }
362 | QuantizedConfig::Afq { .. }
363 ) && comm.world_size() != 1
364 {
365 candle_core::bail!(
366 "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
367 comm.world_size()
368 );
369 }
370
371 match quant_conf {
372 QuantizedConfig::GptqAwq { .. } => {
373 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
374 }
375 QuantizedConfig::Fp8 { weight_block_size } => {
376 if weight_block_size.is_some() {
378 blockwise_fp8_linear_b(
379 in_dim,
380 out_dim,
381 quant_conf,
382 false,
383 shard,
384 vb.clone(),
385 )?
386 } else {
387 pertensor_fp8_linear_b(
388 in_dim,
389 out_dim,
390 quant_conf,
391 false,
392 shard,
393 vb.clone(),
394 )?
395 }
396 }
397 QuantizedConfig::Bitsandbytes { .. } => {
398 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
399 }
400 QuantizedConfig::Afq { .. } => {
401 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
402 }
403 QuantizedConfig::MXFP4 {} => {
404 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
405 }
406 }
407 } else {
408 if !vb.contains_tensor("weight") {
410 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
411 Arc::new(layer) as Arc<dyn QuantMethod>
412 } else {
413 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
414 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
415
416 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
417 Linear::new(weight, None),
418 ))?;
419 Arc::new(layer) as Arc<dyn QuantMethod>
420 }
421 };
422
423 let bias = if bias && vb.contains_tensor("bias") {
425 Some(vb.get_with_hints((out_dim,), "bias", shard)?)
426 } else {
427 None
428 };
429
430 let this_unquant = Arc::new(Self { weight, bias });
431 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
432 Ok(this)
433 }
434
435 #[allow(clippy::new_ret_no_self)]
436 pub fn new(
437 in_dim: usize,
438 out_dim: usize,
439 config: &Option<QuantizedConfig>,
440 bias: bool,
441 comm: &Arc<crate::Comm>,
442 vb: ShardedVarBuilder,
443 ) -> Result<Arc<dyn QuantMethod>> {
444 let rank = comm.rank();
445 let world_size = comm.world_size();
446 let shard = shard(0, rank, world_size);
447
448 Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
449 }
450
451 #[allow(clippy::new_ret_no_self)]
452 pub fn new_matformer(
453 in_dim: usize,
454 out_dim: usize,
455 orig_intermediate_size: usize,
456 config: &Option<QuantizedConfig>,
457 bias: bool,
458 comm: &Arc<crate::Comm>,
459 vb: ShardedVarBuilder,
460 ) -> Result<Arc<dyn QuantMethod>> {
461 let rank = comm.rank();
462 let world_size = comm.world_size();
463 let shard = shard(0, rank, world_size);
464
465 let base_vb = vb.clone();
466 let vb = if should_apply_immediate_isq(&vb) {
467 vb.set_device(Device::Cpu)
468 } else {
469 vb
470 };
471
472 if config.is_some() {
473 candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
474 }
475
476 let weight = if !vb.contains_tensor("weight") {
478 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
479 Arc::new(layer) as Arc<dyn QuantMethod>
480 } else {
481 let weight = vb
482 .get_with_hints(
483 (orig_intermediate_size, in_dim),
484 "weight",
485 Default::default(),
486 )?
487 .i((..out_dim, ..))?
488 .contiguous()?;
489
490 let weight = shard.apply_to(&weight)?;
491 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
492
493 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
494 Linear::new(weight, None),
495 ))?;
496 Arc::new(layer) as Arc<dyn QuantMethod>
497 };
498
499 let bias = if bias && vb.contains_tensor("bias") {
501 Some(vb.get_with_hints((out_dim,), "bias", shard)?)
502 } else {
503 None
504 };
505
506 let this_unquant = Arc::new(Self { weight, bias });
507 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
508 Ok(this)
509 }
510
511 pub fn new_merged(
512 in_dim: usize,
513 out_dim: usize,
514 chunks: usize,
515 config: &Option<QuantizedConfig>,
516 bias: bool,
517 comm: &Arc<crate::Comm>,
518 vb: ShardedVarBuilder,
519 ) -> Result<Vec<Arc<dyn QuantMethod>>> {
520 let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
521 for chunk_idx in 0..chunks {
522 let layer = ColumnParallelLayer::new_with_shard(
523 in_dim,
524 out_dim,
525 config,
526 bias,
527 comm,
528 shard(
529 0,
530 chunk_idx * comm.world_size() + comm.rank(),
531 chunks * comm.world_size(),
532 ),
533 vb.clone(),
534 )?;
535 vec_layers.push(layer);
536 }
537 Ok(vec_layers)
538 }
539}
540
541impl QuantMethod for ColumnParallelLayer {
542 fn new(_method: QuantMethodConfig) -> Result<Self>
543 where
544 Self: Sized,
545 {
546 candle_core::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
547 }
548
549 fn forward(&self, a: &Tensor) -> Result<Tensor> {
550 let mut xs = self.weight.forward(a)?;
551 if let Some(bias) = &self.bias {
552 xs = xs.broadcast_add(bias)?;
553 }
554 Ok(xs)
555 }
556
557 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
558 let weight = self.weight.add_delta_w(delta)?;
559 Ok(Arc::new(Self {
560 weight,
561 bias: self.bias.clone(),
562 }))
563 }
564
565 fn dequantize_w(&self) -> Result<Tensor> {
566 self.weight.dequantize_w()
567 }
568
569 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
570 self.weight.dtype_and_device()
571 }
572
573 fn begin_track_stats(&mut self) -> Result<()> {
574 Arc::get_mut(&mut self.weight)
575 .context("Failed to get &mut to weight")?
576 .begin_track_stats()
577 }
578
579 fn end_track_stats(&self) -> Result<Tensor> {
580 self.weight.end_track_stats()
581 }
582
583 fn quantized_act_type(&self) -> Option<candle_core::DType> {
584 self.weight.quantized_act_type()
585 }
586
587 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
588 self.weight.unquant_weight_bias()
589 }
590
591 fn apply_isq(
592 self: Arc<Self>,
593 dtype: Option<crate::IsqType>,
594 device: candle_core::Device,
595 n_quantized: &std::sync::atomic::AtomicUsize,
596 imatrix_weight: Option<Vec<f32>>,
597 guard: QuantizeOntoGuard,
598 ) -> Result<Arc<dyn QuantMethod>> {
599 let weight =
600 self.weight
601 .clone()
602 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
603 let bias = match &self.bias {
604 Some(b) => {
605 let (dtype, device) = weight.dtype_and_device();
606 Some(b.to_device(&device)?.to_dtype(dtype)?)
607 }
608 None => None,
609 };
610 Ok(Arc::new(Self { weight, bias }))
611 }
612
613 fn is_distributed(&self) -> Option<DistributedKind> {
614 Some(DistributedKind::ColumnParallel)
615 }
616}
617
618impl QuantizedSerde for ColumnParallelLayer {
619 fn isq_serde_supported(&self) -> bool {
620 self.weight.isq_serde_supported()
621 }
622 fn name(&self) -> &'static str {
623 self.weight.name()
624 }
625 fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
626 self.weight.serialize_with_bias(self.bias.clone())
627 }
628 fn deserialize(
629 data: std::borrow::Cow<[u8]>,
630 device: &candle_core::Device,
631 _comm: &Arc<crate::Comm>,
632 guard: QuantizeOntoGuard,
633 ) -> Result<Arc<dyn QuantMethod>>
634 where
635 Self: Sized,
636 {
637 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
639 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
640 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
641 QuantizedSerdeType::Unquant => {
642 UnquantLinear::deserialize_ext_bias(data, device, guard)?
643 }
644 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
645 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
646 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
647 QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize_ext_bias(data, device, guard)?,
648 QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize_ext_bias(data, device, guard)?,
649 };
650 Ok(Arc::new(Self { weight, bias }))
651 }
652}
653
654#[derive(Debug)]
655pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
657
658impl ReplicatedLayer {
659 pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
660 let dev = lin.weight().device().clone();
661 if let Some(crate::ImmediateIsqParams {
662 guard,
663 ty: Some(immediate_isq),
664 pool,
665 backpressure,
666 ..
667 }) = crate::get_immediate_isq()
668 {
669 let lin = if !dev.is_cpu() {
672 Linear::new(lin.weight().to_device(&Device::Cpu)?, lin.bias().cloned())
673 } else {
674 lin
675 };
676 let layer: Arc<dyn QuantMethod> =
677 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
678 if let Some(pool) = &pool {
679 backpressure.acquire();
680 let backpressure = backpressure.clone();
681 let dev = dev.clone();
682 let (tx, rx) = crate::pending_layer::pending_isq_channel();
683 pool.spawn(move || {
684 let result = layer.clone().apply_isq(
685 Some(immediate_isq),
686 dev,
687 &std::sync::atomic::AtomicUsize::new(0),
688 None,
689 guard,
690 );
691 let _ = tx.send(result);
692 backpressure.release();
693 });
694 Ok(Arc::new(crate::PendingIsqLayer::new(rx)))
695 } else {
696 layer.clone().apply_isq(
697 Some(immediate_isq),
698 dev,
699 &std::sync::atomic::AtomicUsize::new(0),
700 None,
701 guard,
702 )
703 }
704 } else {
705 Ok(Arc::new(UnquantLinear::new(
707 QuantMethodConfig::Unquantized(lin),
708 )?))
709 }
710 }
711
712 #[allow(clippy::new_ret_no_self)]
713 pub fn new(
714 in_dim: usize,
715 out_dim: usize,
716 config: &Option<QuantizedConfig>,
717 bias: bool,
718 vb: ShardedVarBuilder,
719 ) -> Result<Arc<dyn QuantMethod>> {
720 let base_vb = vb.clone();
721 let vb = if should_apply_immediate_isq(&vb) {
722 vb.set_device(Device::Cpu)
723 } else {
724 vb
725 };
726
727 let layer = if let Some(quant_conf) = &config {
728 match quant_conf {
729 QuantizedConfig::GptqAwq { .. } => {
730 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
731 }
732 QuantizedConfig::Fp8 { weight_block_size } => {
733 if weight_block_size.is_some() {
734 blockwise_fp8_linear_b(
735 in_dim,
736 out_dim,
737 quant_conf,
738 bias,
739 Default::default(),
740 vb.clone(),
741 )?
742 } else {
743 pertensor_fp8_linear_b(
744 in_dim,
745 out_dim,
746 quant_conf,
747 bias,
748 Default::default(),
749 vb.clone(),
750 )?
751 }
752 }
753 QuantizedConfig::Bitsandbytes { .. } => {
754 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
755 }
756 QuantizedConfig::Afq { .. } => {
757 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
758 }
759 QuantizedConfig::MXFP4 {} => {
760 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
761 }
762 }
763 } else {
764 if !vb.contains_tensor("weight") {
766 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
767 Arc::new(layer) as Arc<dyn QuantMethod>
768 } else {
769 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
770 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
771
772 let bias = if bias {
773 Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
774 } else {
775 None
776 };
777 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
778 Linear::new(weight, bias),
779 ))?;
780 Arc::new(layer) as Arc<dyn QuantMethod>
781 }
782 };
783
784 let this_unquant = Arc::new(Self(layer));
785 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
786 Ok(this)
787 }
788
789 #[allow(clippy::new_ret_no_self)]
790 pub fn new_layers_matformer_indices(
791 in_dim: usize,
792 out_dim: usize,
793 kept_layers_indices: Option<&Tensor>,
794 orig_num_hidden_layers: usize,
795 config: &Option<QuantizedConfig>,
796 bias: bool,
797 vb: ShardedVarBuilder,
798 ) -> Result<Arc<dyn QuantMethod>> {
799 let base_vb = vb.clone();
800 let vb = if should_apply_immediate_isq(&vb) {
801 vb.set_device(Device::Cpu)
802 } else {
803 vb
804 };
805
806 let layer = if let Some(quant_conf) = &config {
807 if kept_layers_indices.is_some() {
808 candle_core::bail!("Cannot load a matformer layer with a pre-quantized model.");
809 }
810
811 match quant_conf {
812 QuantizedConfig::GptqAwq { .. } => {
813 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
814 }
815 QuantizedConfig::Fp8 { weight_block_size } => {
816 if weight_block_size.is_some() {
817 blockwise_fp8_linear_b(
818 in_dim,
819 out_dim,
820 quant_conf,
821 bias,
822 Default::default(),
823 vb.clone(),
824 )?
825 } else {
826 pertensor_fp8_linear_b(
827 in_dim,
828 out_dim,
829 quant_conf,
830 bias,
831 Default::default(),
832 vb.clone(),
833 )?
834 }
835 }
836 QuantizedConfig::Bitsandbytes { .. } => {
837 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
838 }
839 QuantizedConfig::Afq { .. } => {
840 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
841 }
842 QuantizedConfig::MXFP4 {} => {
843 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
844 }
845 }
846 } else {
847 if !vb.contains_tensor("weight") {
849 let layer = <DummyLayer as QuantMethod>::new(QuantMethodConfig::Dummy)?;
850 Arc::new(layer) as Arc<dyn QuantMethod>
851 } else {
852 let mut weight =
853 vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
854
855 if let Some(kept_layers_indices) = &kept_layers_indices {
856 let weight_reshaped = weight.reshape((
857 orig_num_hidden_layers,
858 weight.dim(0)? / orig_num_hidden_layers,
859 weight.dim(1)?,
860 ))?;
861
862 weight = weight_reshaped
863 .index_select(&kept_layers_indices.to_device(weight.device())?, 0)?
864 .reshape(((), weight_reshaped.dim(D::Minus1)?))?
865 .contiguous()?;
866 }
867
868 weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
869
870 let bias = if bias {
871 Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
872 } else {
873 None
874 };
875 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
876 Linear::new(weight, bias),
877 ))?;
878 Arc::new(layer) as Arc<dyn QuantMethod>
879 }
880 };
881
882 let this_unquant = Arc::new(Self(layer));
883 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
884 Ok(this)
885 }
886}
887
888impl QuantMethod for ReplicatedLayer {
889 fn new(_method: QuantMethodConfig) -> Result<Self>
890 where
891 Self: Sized,
892 {
893 candle_core::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
894 }
895
896 fn forward(&self, a: &Tensor) -> Result<Tensor> {
897 self.0.forward(a)
898 }
899
900 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
901 self.0.add_delta_w(delta)
902 }
903
904 fn dequantize_w(&self) -> Result<Tensor> {
905 self.0.dequantize_w()
906 }
907
908 fn dtype_and_device(&self) -> (candle_core::DType, candle_core::Device) {
909 self.0.dtype_and_device()
910 }
911
912 fn begin_track_stats(&mut self) -> Result<()> {
913 Arc::get_mut(&mut self.0)
914 .context("Failed to get &mut to weight")?
915 .begin_track_stats()
916 }
917
918 fn end_track_stats(&self) -> Result<Tensor> {
919 self.0.end_track_stats()
920 }
921
922 fn quantized_act_type(&self) -> Option<candle_core::DType> {
923 self.0.quantized_act_type()
924 }
925
926 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
927 self.0.unquant_weight_bias()
928 }
929
930 fn apply_isq(
931 self: Arc<Self>,
932 dtype: Option<crate::IsqType>,
933 device: candle_core::Device,
934 n_quantized: &std::sync::atomic::AtomicUsize,
935 imatrix_weight: Option<Vec<f32>>,
936 guard: QuantizeOntoGuard,
937 ) -> Result<Arc<dyn QuantMethod>> {
938 self.0
939 .clone()
940 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
941 }
942
943 fn is_distributed(&self) -> Option<DistributedKind> {
944 Some(DistributedKind::Replicated)
945 }
946}
947
948impl QuantizedSerde for ReplicatedLayer {
949 fn isq_serde_supported(&self) -> bool {
950 self.0.isq_serde_supported()
951 }
952 fn name(&self) -> &'static str {
953 self.0.name()
954 }
955 fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
956 self.0.serialize()
957 }
958 fn deserialize(
959 data: std::borrow::Cow<[u8]>,
960 device: &candle_core::Device,
961 comm: &Arc<crate::Comm>,
962 guard: QuantizeOntoGuard,
963 ) -> Result<Arc<dyn QuantMethod>>
964 where
965 Self: Sized,
966 {
967 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
969 let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
970 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
971 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
972 QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
973 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
974 QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
975 QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize(data, device, comm, guard)?,
976 QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize(data, device, comm, guard)?,
977 };
978 Ok(Arc::new(Self(deserialized)))
979 }
980}
981
982#[derive(Debug)]
983pub struct PackedExperts {
984 pub gate_proj: Vec<Arc<dyn QuantMethod>>,
985 pub up_proj: Vec<Arc<dyn QuantMethod>>,
986 pub down_proj: Vec<Arc<dyn QuantMethod>>,
987}
988
989impl PackedExperts {
990 #[allow(clippy::too_many_arguments)]
992 pub fn new(
993 num_local_experts: usize,
994 hidden_size: usize,
995 intermediate_size: usize,
996 config: &Option<QuantizedConfig>,
997 bias: bool,
998 comm: &Arc<crate::Comm>,
999 vb: ShardedVarBuilder,
1000 ) -> Result<Self> {
1001 if bias {
1002 candle_core::bail!("PackedExperts does not support bias.");
1003 }
1004
1005 let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
1006 if comm.world_size() != 1 {
1008 candle_core::bail!(
1009 "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
1010 comm.world_size()
1011 );
1012 }
1013
1014 match quant_conf {
1015 QuantizedConfig::Afq { .. } => {
1016 if !vb.contains_tensor("gate_up_proj")
1017 || !vb.contains_tensor("gate_up_proj.weight")
1018 {
1019 candle_core::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
1020 }
1021
1022 let base_vb = vb.clone();
1023
1024 let vb_gate_proj = if should_apply_immediate_isq(&vb) {
1025 vb.pp("gate_proj").set_device(Device::Cpu)
1026 } else {
1027 vb.pp("gate_proj")
1028 };
1029 let vb_up_proj = if should_apply_immediate_isq(&vb) {
1030 vb.pp("up_proj").set_device(Device::Cpu)
1031 } else {
1032 vb.pp("up_proj")
1033 };
1034 let vb_down_proj = if should_apply_immediate_isq(&vb) {
1035 vb.pp("down_proj").set_device(Device::Cpu)
1036 } else {
1037 vb.pp("down_proj")
1038 };
1039 let mut gate_proj = AfqLayer::afq_packed_linear_b(
1040 num_local_experts,
1041 hidden_size,
1042 intermediate_size,
1043 quant_conf,
1044 bias,
1045 vb_gate_proj,
1046 )?;
1047 let mut up_proj = AfqLayer::afq_packed_linear_b(
1048 num_local_experts,
1049 hidden_size,
1050 intermediate_size,
1051 quant_conf,
1052 bias,
1053 vb_up_proj,
1054 )?;
1055 let mut down_proj = AfqLayer::afq_packed_linear_b(
1056 num_local_experts,
1057 intermediate_size,
1058 hidden_size,
1059 quant_conf,
1060 bias,
1061 vb_down_proj,
1062 )?;
1063
1064 gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
1065 up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
1066 down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
1067
1068 (vec![gate_proj], vec![up_proj], vec![down_proj])
1069 }
1070 QuantizedConfig::Fp8 { weight_block_size } => {
1071 let Some(weight_block_size) = weight_block_size else {
1074 candle_core::bail!("Blockwise FP8 for PackedExperts requires weight_block_size to be set.")
1075 };
1076 if weight_block_size.len() != 2 {
1077 candle_core::bail!(
1078 "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1079 );
1080 }
1081
1082 let is_stacked_format = vb.contains_tensor("gate_up_proj");
1085
1086 if is_stacked_format {
1087 let has_fp8_scales = vb.contains_tensor("gate_up_proj.weight_scale_inv");
1089
1090 if has_fp8_scales {
1091 let gate_up_fp8 = vb.get_with_hints_dtype(
1093 (num_local_experts, hidden_size, intermediate_size * 2),
1094 "gate_up_proj",
1095 Default::default(),
1096 candle_core::DType::F8E4M3,
1097 )?;
1098 let gate_up_scale = vb.get_with_hints_dtype(
1099 (
1100 num_local_experts,
1101 hidden_size.div_ceil(weight_block_size[0]),
1102 (intermediate_size * 2).div_ceil(weight_block_size[1]),
1103 ),
1104 "gate_up_proj.weight_scale_inv",
1105 Default::default(),
1106 candle_core::DType::F32,
1107 )?;
1108
1109 let down_fp8 = vb.get_with_hints_dtype(
1111 (num_local_experts, intermediate_size, hidden_size),
1112 "down_proj",
1113 Default::default(),
1114 candle_core::DType::F8E4M3,
1115 )?;
1116 let down_scale = vb.get_with_hints_dtype(
1117 (
1118 num_local_experts,
1119 intermediate_size.div_ceil(weight_block_size[0]),
1120 hidden_size.div_ceil(weight_block_size[1]),
1121 ),
1122 "down_proj.weight_scale_inv",
1123 Default::default(),
1124 candle_core::DType::F32,
1125 )?;
1126
1127 let mut gs = Vec::new();
1129 let mut us = Vec::new();
1130 let mut ds = Vec::new();
1131
1132 for i in 0..num_local_experts {
1133 let gate_up_expert =
1135 gate_up_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1136 let gate_up_scale_expert = gate_up_scale.i(i)?.contiguous()?;
1137 let down_expert = down_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1138 let down_scale_expert = down_scale.i(i)?.contiguous()?;
1139
1140 let gate_expert = gate_up_expert.narrow(0, 0, intermediate_size)?;
1142 let up_expert = gate_up_expert.narrow(
1143 0,
1144 intermediate_size,
1145 intermediate_size,
1146 )?;
1147
1148 let gate_scale_expert = gate_up_scale_expert.narrow(
1150 1,
1151 0,
1152 intermediate_size.div_ceil(weight_block_size[1]),
1153 )?;
1154 let up_scale_expert = gate_up_scale_expert.narrow(
1155 1,
1156 intermediate_size.div_ceil(weight_block_size[1]),
1157 intermediate_size.div_ceil(weight_block_size[1]),
1158 )?;
1159
1160 use crate::blockwise_fp8::BlockwiseFP8Linear;
1162 use crate::QuantMethodConfig;
1163
1164 let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1165 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1166 weight: gate_expert,
1167 weight_scale_inv: gate_scale_expert.transpose(0, 1)?,
1168 bias: None,
1169 dequant_dtype: vb.dtype(),
1170 weight_block_size: weight_block_size.clone(),
1171 })?,
1172 );
1173 let up_layer: Arc<dyn QuantMethod> = Arc::new(
1174 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1175 weight: up_expert,
1176 weight_scale_inv: up_scale_expert.transpose(0, 1)?,
1177 bias: None,
1178 dequant_dtype: vb.dtype(),
1179 weight_block_size: weight_block_size.clone(),
1180 })?,
1181 );
1182 let down_layer: Arc<dyn QuantMethod> = Arc::new(
1183 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1184 weight: down_expert,
1185 weight_scale_inv: down_scale_expert.transpose(0, 1)?,
1186 bias: None,
1187 dequant_dtype: vb.dtype(),
1188 weight_block_size: weight_block_size.clone(),
1189 })?,
1190 );
1191
1192 gs.push(gate_layer);
1193 us.push(up_layer);
1194 ds.push(down_layer);
1195 }
1196
1197 (gs, us, ds)
1198 } else {
1199 candle_core::bail!(
1200 "PackedExperts with FP8 requires weight_scale_inv tensors"
1201 );
1202 }
1203 } else {
1204 let mut gs = Vec::new();
1206 let mut us = Vec::new();
1207 let mut ds = Vec::new();
1208
1209 for i in 0..num_local_experts {
1210 let expert_vb = vb.pp(i);
1211
1212 let gate_fp8 = expert_vb.get_with_hints_dtype(
1214 (intermediate_size, hidden_size),
1215 "gate_proj.weight",
1216 Default::default(),
1217 candle_core::DType::F8E4M3,
1218 )?;
1219 let gate_scale = expert_vb.get_with_hints_dtype(
1220 (
1221 intermediate_size.div_ceil(weight_block_size[0]),
1222 hidden_size.div_ceil(weight_block_size[1]),
1223 ),
1224 "gate_proj.weight_scale_inv",
1225 Default::default(),
1226 candle_core::DType::F32,
1227 )?;
1228
1229 let up_fp8 = expert_vb.get_with_hints_dtype(
1230 (intermediate_size, hidden_size),
1231 "up_proj.weight",
1232 Default::default(),
1233 candle_core::DType::F8E4M3,
1234 )?;
1235 let up_scale = expert_vb.get_with_hints_dtype(
1236 (
1237 intermediate_size.div_ceil(weight_block_size[0]),
1238 hidden_size.div_ceil(weight_block_size[1]),
1239 ),
1240 "up_proj.weight_scale_inv",
1241 Default::default(),
1242 candle_core::DType::F32,
1243 )?;
1244
1245 let down_fp8 = expert_vb.get_with_hints_dtype(
1246 (hidden_size, intermediate_size),
1247 "down_proj.weight",
1248 Default::default(),
1249 candle_core::DType::F8E4M3,
1250 )?;
1251 let down_scale = expert_vb.get_with_hints_dtype(
1252 (
1253 hidden_size.div_ceil(weight_block_size[0]),
1254 intermediate_size.div_ceil(weight_block_size[1]),
1255 ),
1256 "down_proj.weight_scale_inv",
1257 Default::default(),
1258 candle_core::DType::F32,
1259 )?;
1260
1261 use crate::blockwise_fp8::BlockwiseFP8Linear;
1263 use crate::QuantMethodConfig;
1264
1265 let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1266 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1267 weight: gate_fp8,
1268 weight_scale_inv: gate_scale,
1269 bias: None,
1270 dequant_dtype: vb.dtype(),
1271 weight_block_size: weight_block_size.clone(),
1272 })?,
1273 );
1274 let up_layer: Arc<dyn QuantMethod> = Arc::new(BlockwiseFP8Linear::new(
1275 QuantMethodConfig::BlockwiseFP8 {
1276 weight: up_fp8,
1277 weight_scale_inv: up_scale,
1278 bias: None,
1279 dequant_dtype: vb.dtype(),
1280 weight_block_size: weight_block_size.clone(),
1281 },
1282 )?);
1283 let down_layer: Arc<dyn QuantMethod> = Arc::new(
1284 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1285 weight: down_fp8,
1286 weight_scale_inv: down_scale,
1287 bias: None,
1288 dequant_dtype: vb.dtype(),
1289 weight_block_size: weight_block_size.clone(),
1290 })?,
1291 );
1292
1293 gs.push(gate_layer);
1294 us.push(up_layer);
1295 ds.push(down_layer);
1296 }
1297
1298 (gs, us, ds)
1299 }
1300 }
1301 QuantizedConfig::MXFP4 {} => {
1302 let gate_proj = MXFP4Layer::packed_linear_b(
1306 num_local_experts,
1307 hidden_size,
1308 intermediate_size,
1309 quant_conf,
1310 bias,
1311 vb.pp("gate_proj"),
1312 )?;
1313 let up_proj = MXFP4Layer::packed_linear_b(
1314 num_local_experts,
1315 hidden_size,
1316 intermediate_size,
1317 quant_conf,
1318 bias,
1319 vb.pp("up_proj"),
1320 )?;
1321 let down_proj = MXFP4Layer::packed_linear_b(
1322 num_local_experts,
1323 intermediate_size,
1324 hidden_size,
1325 quant_conf,
1326 bias,
1327 vb.pp("down_proj"),
1328 )?;
1329
1330 (vec![gate_proj], vec![up_proj], vec![down_proj])
1331 }
1332 _ => candle_core::bail!(
1333 "PackedExperts with quantization config only allows AFQ, FP8, or MXFP4 quantization"
1334 ),
1335 }
1336 } else if !vb.contains_tensor("gate_up_proj") {
1337 let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
1339 let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
1340 let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
1341 for _ in 0..num_local_experts {
1342 gs.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1343 us.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1344 ds.push(Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?));
1345 }
1346 (gs, us, ds)
1347 } else {
1348 let gate_up_block_size = intermediate_size / comm.world_size();
1356 let gate_up_start = gate_up_block_size * comm.rank();
1357
1358 let shard_gate = Shard::Offset {
1360 dim: 2,
1361 offset: gate_up_start,
1362 len: gate_up_block_size,
1363 };
1364 let shard_up = Shard::Offset {
1365 dim: 2,
1366 offset: intermediate_size + gate_up_start,
1367 len: gate_up_block_size,
1368 };
1369 let shard_down = Shard::Simple {
1370 dim: 1,
1371 rank: comm.rank(),
1372 world_size: comm.world_size(),
1373 };
1374
1375 let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
1376 vb.pp("gate_up_proj").set_device(Device::Cpu)
1377 } else {
1378 vb.pp("gate_up_proj")
1379 };
1380 let vb_down_proj = if should_apply_immediate_isq(&vb) {
1381 vb.pp("down_proj").set_device(Device::Cpu)
1382 } else {
1383 vb.pp("down_proj")
1384 };
1385
1386 let gate_proj = vb
1387 .get_with_hints(
1388 (num_local_experts, hidden_size, intermediate_size * 2),
1389 "gate_up_proj",
1390 shard_gate,
1391 )?
1392 .t()?
1393 .contiguous()?;
1394 let up_proj = vb
1395 .get_with_hints(
1396 (num_local_experts, hidden_size, intermediate_size * 2),
1397 "gate_up_proj",
1398 shard_up,
1399 )?
1400 .t()?
1401 .contiguous()?;
1402 let down_proj = vb
1403 .get_with_hints(
1404 (num_local_experts, intermediate_size, hidden_size),
1405 "down_proj",
1406 shard_down,
1407 )?
1408 .t()?
1409 .contiguous()?;
1410
1411 let gc = gate_proj.chunk(num_local_experts, 0)?;
1412 let uc = up_proj.chunk(num_local_experts, 0)?;
1413 let dc = down_proj.chunk(num_local_experts, 0)?;
1414 drop((gate_proj, up_proj, down_proj));
1415
1416 let mut gs = Vec::new();
1417 let mut us = Vec::new();
1418 let mut ds = Vec::new();
1419 for ((mut gate_proj, mut up_proj), mut down_proj) in
1420 gc.into_iter().zip(uc.into_iter()).zip(dc.into_iter())
1421 {
1422 gate_proj = gate_proj.squeeze(0)?;
1423 up_proj = up_proj.squeeze(0)?;
1424 down_proj = down_proj.squeeze(0)?;
1425 let gate_proj = merge_lora_weights(
1426 &vb,
1427 gate_proj,
1428 hidden_size,
1429 intermediate_size * 2,
1430 shard_gate,
1431 )?;
1432 let up_proj =
1433 merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
1434 let down_proj =
1435 merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
1436
1437 let mut gate_proj: Arc<dyn QuantMethod> =
1438 Arc::new(<UnquantLinear as QuantMethod>::new(
1439 QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1440 )?);
1441 gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
1442 let mut up_proj: Arc<dyn QuantMethod> =
1443 Arc::new(<UnquantLinear as QuantMethod>::new(
1444 QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1445 )?);
1446 up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
1447 let mut down_proj: Arc<dyn QuantMethod> =
1448 Arc::new(<UnquantLinear as QuantMethod>::new(
1449 QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1450 )?);
1451 down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
1452 gs.push(gate_proj);
1453 us.push(up_proj);
1454 ds.push(down_proj);
1455 }
1456 (gs, us, ds)
1457 };
1458
1459 Ok(Self {
1460 gate_proj,
1461 up_proj,
1462 down_proj,
1463 })
1464 }
1465}
1466
1467pub struct FusedExperts {
1468 pub fused_gate_proj: Arc<dyn QuantMethod>,
1469 pub fused_up_proj: Arc<dyn QuantMethod>,
1470 pub fused_down_proj: Arc<dyn QuantMethod>,
1471}
1472
1473impl FusedExperts {
1474 pub fn new(
1475 hidden_size: usize,
1476 moe_intermediate_size: usize,
1477 num_experts: usize,
1478 quantization_config: &Option<QuantizedConfig>,
1479 vb: ShardedVarBuilder,
1480 ) -> Result<Self> {
1481 let experts_vb = vb.pp("experts");
1487 let is_stacked_format = experts_vb.contains_tensor("gate_up_proj");
1488
1489 let (fused_gate_proj, fused_up_proj, fused_down_proj) = if matches!(
1490 &quantization_config,
1491 Some(QuantizedConfig::Afq { .. })
1492 ) {
1493 let quantization_config = quantization_config.as_ref().unwrap();
1494
1495 let fused_gate_proj = AfqLayer::afq_packed_linear_b(
1496 num_experts,
1497 hidden_size,
1498 moe_intermediate_size,
1499 quantization_config,
1500 false,
1501 vb.pp("switch_mlp.gate_proj"),
1502 )?;
1503 let fused_up_proj = AfqLayer::afq_packed_linear_b(
1504 num_experts,
1505 hidden_size,
1506 moe_intermediate_size,
1507 quantization_config,
1508 false,
1509 vb.pp("switch_mlp.up_proj"),
1510 )?;
1511 let fused_down_proj = AfqLayer::afq_packed_linear_b(
1512 num_experts,
1513 moe_intermediate_size,
1514 hidden_size,
1515 quantization_config,
1516 false,
1517 vb.pp("switch_mlp.down_proj"),
1518 )?;
1519
1520 (fused_gate_proj, fused_up_proj, fused_down_proj)
1521 } else if is_stacked_format
1522 && matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. }))
1523 {
1524 let has_fp8_scales = experts_vb.contains_tensor("gate_up_proj.weight_scale_inv");
1527
1528 if has_fp8_scales {
1529 let weight_block_size = match quantization_config {
1530 Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1531 _ => unreachable!(),
1532 };
1533
1534 let Some(weight_block_size) = weight_block_size else {
1535 candle_core::bail!(
1536 "Blockwise FP8 for stacked experts requires weight_block_size to be set."
1537 )
1538 };
1539 if weight_block_size.len() != 2 {
1540 candle_core::bail!(
1541 "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1542 );
1543 }
1544
1545 let gate_up_fp8 = experts_vb.get_with_hints_dtype(
1548 (num_experts, hidden_size, moe_intermediate_size * 2),
1549 "gate_up_proj",
1550 Default::default(),
1551 candle_core::DType::F8E4M3,
1552 )?;
1553 let gate_up_scale = experts_vb.get_with_hints_dtype(
1554 (
1555 num_experts,
1556 hidden_size.div_ceil(weight_block_size[0]),
1557 (moe_intermediate_size * 2).div_ceil(weight_block_size[1]),
1558 ),
1559 "gate_up_proj.weight_scale_inv",
1560 Default::default(),
1561 candle_core::DType::F32,
1562 )?;
1563
1564 let down_fp8 = experts_vb.get_with_hints_dtype(
1567 (num_experts, moe_intermediate_size, hidden_size),
1568 "down_proj",
1569 Default::default(),
1570 candle_core::DType::F8E4M3,
1571 )?;
1572 let down_scale = experts_vb.get_with_hints_dtype(
1573 (
1574 num_experts,
1575 moe_intermediate_size.div_ceil(weight_block_size[0]),
1576 hidden_size.div_ceil(weight_block_size[1]),
1577 ),
1578 "down_proj.weight_scale_inv",
1579 Default::default(),
1580 candle_core::DType::F32,
1581 )?;
1582
1583 let gate_fp8 = gate_up_fp8.narrow(2, 0, moe_intermediate_size)?;
1585 let up_fp8 = gate_up_fp8.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1586
1587 let gate_scale = gate_up_scale.narrow(
1589 2,
1590 0,
1591 moe_intermediate_size.div_ceil(weight_block_size[1]),
1592 )?;
1593 let up_scale = gate_up_scale.narrow(
1594 2,
1595 moe_intermediate_size.div_ceil(weight_block_size[1]),
1596 moe_intermediate_size.div_ceil(weight_block_size[1]),
1597 )?;
1598
1599 let gate_fp8 = gate_fp8.transpose(1, 2)?.contiguous()?;
1602 let up_fp8 = up_fp8.transpose(1, 2)?.contiguous()?;
1603 let down_fp8 = down_fp8.transpose(1, 2)?.contiguous()?;
1605
1606 let gate_scale = gate_scale.transpose(1, 2)?.contiguous()?;
1608 let up_scale = up_scale.transpose(1, 2)?.contiguous()?;
1609 let down_scale = down_scale.transpose(1, 2)?.contiguous()?;
1610
1611 let fused_gate_proj =
1613 blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1614 let fused_up_proj =
1615 blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1616 let fused_down_proj =
1617 blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1618
1619 (fused_gate_proj, fused_up_proj, fused_down_proj)
1620 } else {
1621 tracing::warn!(
1623 "FP8 quantization config specified but no scale tensors found for stacked MoE experts. \
1624 Loading as unquantized."
1625 );
1626 let gate_up_proj = experts_vb
1627 .get(
1628 (num_experts, hidden_size, moe_intermediate_size * 2),
1629 "gate_up_proj",
1630 )
1631 .or_else(|_| {
1632 experts_vb
1633 .get(
1634 (num_experts, moe_intermediate_size * 2, hidden_size),
1635 "gate_up_proj",
1636 )
1637 .and_then(|t| t.transpose(1, 2)?.contiguous())
1638 })?;
1639 let down_proj_packed = experts_vb
1640 .get(
1641 (num_experts, moe_intermediate_size, hidden_size),
1642 "down_proj",
1643 )
1644 .or_else(|_| {
1645 experts_vb
1646 .get(
1647 (num_experts, hidden_size, moe_intermediate_size),
1648 "down_proj",
1649 )
1650 .and_then(|t| t.transpose(1, 2)?.contiguous())
1651 })?;
1652
1653 let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1655 let up_proj =
1656 gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1657
1658 let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1660 let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1661 let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1662
1663 let isq_gate_up = should_apply_immediate_isq(&experts_vb.pp("gate_up_proj"));
1665 let isq_down = should_apply_immediate_isq(&experts_vb.pp("down_proj"));
1666 let target_device = gate_proj.device().clone();
1667 let (gate_proj, up_proj, down_proj) =
1668 if (isq_gate_up || isq_down) && !target_device.is_cpu() {
1669 (
1670 gate_proj.to_device(&Device::Cpu)?,
1671 up_proj.to_device(&Device::Cpu)?,
1672 down_proj.to_device(&Device::Cpu)?,
1673 )
1674 } else {
1675 (gate_proj, up_proj, down_proj)
1676 };
1677
1678 let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1679 QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1680 )?);
1681 let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1682 QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1683 )?);
1684 let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1685 QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1686 )?);
1687 let vb_gate_up = experts_vb.pp("gate_up_proj");
1690 let vb_down = experts_vb.pp("down_proj");
1691 fused_gate_proj = apply_immediate_isq(fused_gate_proj, vb_gate_up.clone())?;
1692 fused_up_proj = apply_immediate_isq(fused_up_proj, vb_gate_up)?;
1693 fused_down_proj = apply_immediate_isq(fused_down_proj, vb_down)?;
1694
1695 (fused_gate_proj, fused_up_proj, fused_down_proj)
1696 }
1697 } else if is_stacked_format
1698 && matches!(&quantization_config, Some(QuantizedConfig::MXFP4 {}))
1699 {
1700 let quantization_config = quantization_config.as_ref().unwrap();
1704
1705 let fused_gate_proj = MXFP4Layer::packed_linear_b(
1710 num_experts,
1711 hidden_size,
1712 moe_intermediate_size,
1713 quantization_config,
1714 false,
1715 experts_vb.pp("gate_proj"),
1716 )?;
1717 let fused_up_proj = MXFP4Layer::packed_linear_b(
1718 num_experts,
1719 hidden_size,
1720 moe_intermediate_size,
1721 quantization_config,
1722 false,
1723 experts_vb.pp("up_proj"),
1724 )?;
1725 let fused_down_proj = MXFP4Layer::packed_linear_b(
1726 num_experts,
1727 moe_intermediate_size,
1728 hidden_size,
1729 quantization_config,
1730 false,
1731 experts_vb.pp("down_proj"),
1732 )?;
1733
1734 (fused_gate_proj, fused_up_proj, fused_down_proj)
1735 } else if is_stacked_format {
1736 let gate_up_proj = experts_vb
1746 .get(
1747 (num_experts, hidden_size, moe_intermediate_size * 2),
1748 "gate_up_proj",
1749 )
1750 .or_else(|_| {
1751 experts_vb
1752 .get(
1753 (num_experts, moe_intermediate_size * 2, hidden_size),
1754 "gate_up_proj",
1755 )
1756 .and_then(|t| t.transpose(1, 2)?.contiguous())
1757 })?;
1758 let down_proj_packed = experts_vb
1759 .get(
1760 (num_experts, moe_intermediate_size, hidden_size),
1761 "down_proj",
1762 )
1763 .or_else(|_| {
1764 experts_vb
1765 .get(
1766 (num_experts, hidden_size, moe_intermediate_size),
1767 "down_proj",
1768 )
1769 .and_then(|t| t.transpose(1, 2)?.contiguous())
1770 })?;
1771
1772 let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1776 let up_proj = gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1777
1778 let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1781 let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1782 let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1784
1785 let isq_gate_up = should_apply_immediate_isq(&experts_vb.pp("gate_up_proj"));
1787 let isq_down = should_apply_immediate_isq(&experts_vb.pp("down_proj"));
1788 let target_device = gate_proj.device().clone();
1789 let (gate_proj, up_proj, down_proj) =
1790 if (isq_gate_up || isq_down) && !target_device.is_cpu() {
1791 (
1792 gate_proj.to_device(&Device::Cpu)?,
1793 up_proj.to_device(&Device::Cpu)?,
1794 down_proj.to_device(&Device::Cpu)?,
1795 )
1796 } else {
1797 (gate_proj, up_proj, down_proj)
1798 };
1799
1800 let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1801 QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1802 )?);
1803 let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1804 QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1805 )?);
1806 let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1807 QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1808 )?);
1809 let vb_gate_up = experts_vb.pp("gate_up_proj");
1812 let vb_down = experts_vb.pp("down_proj");
1813 fused_gate_proj = apply_immediate_isq(fused_gate_proj, vb_gate_up.clone())?;
1814 fused_up_proj = apply_immediate_isq(fused_up_proj, vb_gate_up)?;
1815 fused_down_proj = apply_immediate_isq(fused_down_proj, vb_down)?;
1816
1817 (fused_gate_proj, fused_up_proj, fused_down_proj)
1818 } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
1819 let weight_block_size = match quantization_config {
1822 Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1823 _ => unreachable!(),
1824 };
1825
1826 let Some(weight_block_size) = weight_block_size else {
1827 candle_core::bail!(
1828 "Blockwise FP8 for per-expert format requires weight_block_size to be set."
1829 )
1830 };
1831 if weight_block_size.len() != 2 {
1832 candle_core::bail!(
1833 "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1834 );
1835 }
1836
1837 let mut gate_fp8_vec = Vec::new();
1838 let mut gate_scale_vec = Vec::new();
1839 let mut up_fp8_vec = Vec::new();
1840 let mut up_scale_vec = Vec::new();
1841 let mut down_fp8_vec = Vec::new();
1842 let mut down_scale_vec = Vec::new();
1843
1844 for i in 0..num_experts {
1845 let expert_vb = experts_vb.pp(i);
1846
1847 let gate_fp8 = expert_vb.get_with_hints_dtype(
1849 (moe_intermediate_size, hidden_size),
1850 "gate_proj.weight",
1851 Default::default(),
1852 candle_core::DType::F8E4M3,
1853 )?;
1854 let gate_scale = expert_vb.get_with_hints_dtype(
1855 (
1856 moe_intermediate_size.div_ceil(weight_block_size[0]),
1857 hidden_size.div_ceil(weight_block_size[1]),
1858 ),
1859 "gate_proj.weight_scale_inv",
1860 Default::default(),
1861 candle_core::DType::F32,
1862 )?;
1863
1864 let up_fp8 = expert_vb.get_with_hints_dtype(
1865 (moe_intermediate_size, hidden_size),
1866 "up_proj.weight",
1867 Default::default(),
1868 candle_core::DType::F8E4M3,
1869 )?;
1870 let up_scale = expert_vb.get_with_hints_dtype(
1871 (
1872 moe_intermediate_size.div_ceil(weight_block_size[0]),
1873 hidden_size.div_ceil(weight_block_size[1]),
1874 ),
1875 "up_proj.weight_scale_inv",
1876 Default::default(),
1877 candle_core::DType::F32,
1878 )?;
1879
1880 let down_fp8 = expert_vb.get_with_hints_dtype(
1881 (hidden_size, moe_intermediate_size),
1882 "down_proj.weight",
1883 Default::default(),
1884 candle_core::DType::F8E4M3,
1885 )?;
1886 let down_scale = expert_vb.get_with_hints_dtype(
1887 (
1888 hidden_size.div_ceil(weight_block_size[0]),
1889 moe_intermediate_size.div_ceil(weight_block_size[1]),
1890 ),
1891 "down_proj.weight_scale_inv",
1892 Default::default(),
1893 candle_core::DType::F32,
1894 )?;
1895
1896 gate_fp8_vec.push(gate_fp8);
1897 gate_scale_vec.push(gate_scale);
1898 up_fp8_vec.push(up_fp8);
1899 up_scale_vec.push(up_scale);
1900 down_fp8_vec.push(down_fp8);
1901 down_scale_vec.push(down_scale);
1902 }
1903
1904 let gate_fp8 = Tensor::stack(&gate_fp8_vec, 0)?;
1906 let gate_scale = Tensor::stack(&gate_scale_vec, 0)?;
1907 let up_fp8 = Tensor::stack(&up_fp8_vec, 0)?;
1908 let up_scale = Tensor::stack(&up_scale_vec, 0)?;
1909 let down_fp8 = Tensor::stack(&down_fp8_vec, 0)?;
1910 let down_scale = Tensor::stack(&down_scale_vec, 0)?;
1911
1912 let fused_gate_proj =
1914 blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1915 let fused_up_proj =
1916 blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1917 let fused_down_proj =
1918 blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1919
1920 (fused_gate_proj, fused_up_proj, fused_down_proj)
1921 } else if !experts_vb.pp("0").contains_tensor("gate_proj.weight") {
1922 let fused_gate_proj: Arc<dyn QuantMethod> =
1925 Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?);
1926 let fused_up_proj: Arc<dyn QuantMethod> =
1927 Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?);
1928 let fused_down_proj: Arc<dyn QuantMethod> =
1929 Arc::new(DummyLayer::new(QuantMethodConfig::Dummy)?);
1930 (fused_gate_proj, fused_up_proj, fused_down_proj)
1931 } else {
1932 let load_experts_vb =
1935 if crate::get_immediate_isq().is_some() && !experts_vb.device().is_cpu() {
1936 experts_vb.clone().set_device(Device::Cpu)
1937 } else {
1938 experts_vb.clone()
1939 };
1940 let mut gate_proj_vec = Vec::new();
1941 let mut up_proj_vec = Vec::new();
1942 let mut down_proj_vec = Vec::new();
1943 for i in 0..num_experts {
1944 let expert_vb = load_experts_vb.pp(i);
1945 let gate_proj =
1946 expert_vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
1947 let up_proj =
1948 expert_vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
1949 let down_proj =
1950 expert_vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
1951
1952 gate_proj_vec.push(gate_proj);
1953 up_proj_vec.push(up_proj);
1954 down_proj_vec.push(down_proj);
1955 }
1956
1957 let mut gate_proj: Arc<dyn QuantMethod> =
1958 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1959 Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1960 ))?);
1961 let mut up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1962 QuantMethodConfig::Unquantized(Linear::new(Tensor::stack(&up_proj_vec, 0)?, None)),
1963 )?);
1964 let mut down_proj: Arc<dyn QuantMethod> =
1965 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1966 Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1967 ))?);
1968 let expert0_vb = experts_vb.pp("0");
1970 gate_proj = apply_immediate_isq(gate_proj, expert0_vb.pp("gate_proj"))?;
1971 up_proj = apply_immediate_isq(up_proj, expert0_vb.pp("up_proj"))?;
1972 down_proj = apply_immediate_isq(down_proj, expert0_vb.pp("down_proj"))?;
1973
1974 (gate_proj, up_proj, down_proj)
1975 };
1976
1977 Ok(Self {
1978 fused_gate_proj,
1979 fused_up_proj,
1980 fused_down_proj,
1981 })
1982 }
1983}
1984
1985pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
1987 if comm.world_size() == 1 {
1988 return Shard::default();
1989 }
1990
1991 let kv_replicate = if comm.world_size() > total_num_kv_heads {
1995 comm.world_size() / total_num_kv_heads
1996 } else {
1997 return Shard::Simple {
1998 dim: 0,
1999 rank: comm.rank(),
2000 world_size: comm.world_size(),
2001 };
2002 };
2003
2004 let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
2005 let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
2006 Shard::Offset {
2007 dim: 0,
2008 offset: kv_shard_id * head_dim,
2009 len: head_dim,
2010 }
2011}
2012
2013pub fn compute_n_kv_groups(
2015 total_num_kv_heads: usize,
2016 num_attention_heads: usize,
2017 comm: &Comm,
2018) -> usize {
2019 let kv_replicate = if comm.world_size() > total_num_kv_heads {
2020 comm.world_size() / total_num_kv_heads
2021 } else {
2022 1
2023 };
2024 if kv_replicate != 0 {
2025 (num_attention_heads / total_num_kv_heads) / kv_replicate
2026 } else {
2027 num_attention_heads / total_num_kv_heads
2028 }
2029}