1use std::sync::Arc;
2
3use hanzo_ml::{Context, Device, IndexOp, Result, Tensor, D};
4use hanzo_nn::Linear;
5
6use crate::{
7 blockwise_fp8::{blockwise_fp8_linear_b, blockwise_fp8_moe},
8 distributed,
9 gptq::gptq_linear,
10 lora::merge_lora_weights,
11 make_dummy_or_error,
12 pertensor_fp8::pertensor_fp8_linear_b,
13 should_apply_immediate_isq,
14 utils::isq::apply_immediate_isq,
15 AfqLayer, BnbLinear, DistributedKind, F8Q8Linear, FP8Linear, GgufMatMul, HqqLayer, MXFP4Layer,
16 QuantMethod, QuantMethodConfig, QuantizeOntoGuard, QuantizedConfig, QuantizedSerde,
17 QuantizedSerdeType, Shard, ShardedVarBuilder, UnquantLinear,
18};
19
20use super::{Comm, SumAllReduce};
21
22fn shard(dim: usize, rank: usize, world_size: usize) -> Shard {
23 Shard::Simple {
24 dim,
25 rank,
26 world_size,
27 }
28}
29
30#[derive(Debug)]
33pub struct RowParallelLayer {
34 weight: Arc<dyn QuantMethod>,
35 bias: Option<Tensor>,
36 all_reduce: distributed::SumAllReduce,
37}
38
39impl RowParallelLayer {
40 #[allow(clippy::new_ret_no_self)]
41 pub fn new(
42 in_dim: usize,
43 out_dim: usize,
44 config: &Option<QuantizedConfig>,
45 bias: bool,
46 comm: &Arc<crate::Comm>,
47 vb: ShardedVarBuilder,
48 ) -> Result<Arc<dyn QuantMethod>> {
49 let rank = comm.rank();
50 let world_size = comm.world_size();
51 let shard = shard(1, rank, world_size);
52
53 let base_vb = vb.clone();
54 let vb = if should_apply_immediate_isq(&vb) {
55 vb.set_device(Device::Cpu)
56 } else {
57 vb
58 };
59
60 let weight = if let Some(quant_conf) = &config {
61 if matches!(
63 quant_conf,
64 QuantizedConfig::GptqAwq { .. }
65 | QuantizedConfig::Bitsandbytes { .. }
66 | QuantizedConfig::Afq { .. }
67 ) && comm.world_size() != 1
68 {
69 hanzo_ml::bail!(
70 "GPTQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
71 comm.world_size()
72 );
73 }
74
75 match quant_conf {
76 QuantizedConfig::GptqAwq { .. } => {
77 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
78 }
79 QuantizedConfig::Fp8 { weight_block_size } => {
80 if weight_block_size.is_some() {
82 blockwise_fp8_linear_b(
83 in_dim,
84 out_dim,
85 quant_conf,
86 false,
87 shard,
88 vb.clone(),
89 )?
90 } else {
91 pertensor_fp8_linear_b(
92 in_dim,
93 out_dim,
94 quant_conf,
95 false,
96 shard,
97 vb.clone(),
98 )?
99 }
100 }
101 QuantizedConfig::Bitsandbytes { .. } => {
102 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
103 }
104 QuantizedConfig::Afq { .. } => {
105 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
106 }
107 QuantizedConfig::MXFP4 {} => {
108 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
109 }
110 }
111 } else {
112 if !vb.contains_tensor("weight") {
113 make_dummy_or_error("row_parallel_linear", &vb, &["weight"])?
114 } else {
115 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
116 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
117
118 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
119 Linear::new(weight, None),
120 ))?;
121 Arc::new(layer) as Arc<dyn QuantMethod>
122 }
123 };
124
125 let bias = if bias && vb.contains_tensor("bias") {
127 Some(vb.get((out_dim,), "bias")?)
128 } else {
129 None
130 };
131
132 let this_unquant = Arc::new(Self {
133 weight,
134 bias,
135 all_reduce: distributed::SumAllReduce::new(comm),
136 });
137 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
138 Ok(this)
139 }
140
141 #[allow(clippy::new_ret_no_self)]
142 pub fn new_matformer(
143 in_dim: usize,
144 out_dim: usize,
145 orig_intermediate_size: usize,
146 config: &Option<QuantizedConfig>,
147 bias: bool,
148 comm: &Arc<crate::Comm>,
149 vb: ShardedVarBuilder,
150 ) -> Result<Arc<dyn QuantMethod>> {
151 let rank = comm.rank();
152 let world_size = comm.world_size();
153 let shard = shard(1, rank, world_size);
154
155 let base_vb = vb.clone();
156 let vb = if should_apply_immediate_isq(&vb) {
157 vb.set_device(Device::Cpu)
158 } else {
159 vb
160 };
161
162 if config.is_some() {
163 hanzo_ml::bail!("Cannot load a matformer layer with a pre-quantized model.");
164 }
165
166 let weight = if !vb.contains_tensor("weight") {
167 make_dummy_or_error("row_parallel_matformer_linear", &vb, &["weight"])?
168 } else {
169 let weight = vb
170 .get_with_hints(
171 (out_dim, orig_intermediate_size),
172 "weight",
173 Default::default(),
174 )?
175 .i((.., ..in_dim))?
176 .contiguous()?;
177
178 let weight = shard.apply_to(&weight)?;
179 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
180
181 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
182 Linear::new(weight, None),
183 ))?;
184 Arc::new(layer) as Arc<dyn QuantMethod>
185 };
186
187 let bias = if bias && vb.contains_tensor("bias") {
189 Some(vb.get((out_dim,), "bias")?)
190 } else {
191 None
192 };
193
194 let this_unquant = Arc::new(Self {
195 weight,
196 bias,
197 all_reduce: distributed::SumAllReduce::new(comm),
198 });
199 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
200 Ok(this)
201 }
202}
203
204impl QuantMethod for RowParallelLayer {
205 fn new(_method: QuantMethodConfig) -> Result<Self>
206 where
207 Self: Sized,
208 {
209 hanzo_ml::bail!("RowParallelLayer should not be constructed with `QuantMethod::new`")
210 }
211
212 fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
213 let mut xs = self.weight.forward_raw(a)?;
214 xs = self.all_reduce.sum_all_reduce(&xs.contiguous()?)?;
215 if let Some(bias) = &self.bias {
216 xs = xs.broadcast_add(bias)?;
217 }
218 Ok(xs)
219 }
220
221 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
222 let weight = self.weight.add_delta_w(delta)?;
223 Ok(Arc::new(Self {
224 weight,
225 bias: self.bias.clone(),
226 all_reduce: self.all_reduce.clone(),
227 }))
228 }
229
230 fn dequantize_w(&self) -> Result<Tensor> {
231 self.weight.dequantize_w()
232 }
233
234 fn dtype_and_device(&self) -> (hanzo_ml::DType, hanzo_ml::Device) {
235 self.weight.dtype_and_device()
236 }
237
238 fn begin_track_stats(&mut self) -> Result<()> {
239 Arc::get_mut(&mut self.weight)
240 .context("Failed to get &mut to weight")?
241 .begin_track_stats()
242 }
243
244 fn end_track_stats(&self) -> Result<Tensor> {
245 self.weight.end_track_stats()
246 }
247
248 fn quantized_act_type(&self) -> Option<hanzo_ml::DType> {
249 self.weight.quantized_act_type()
250 }
251
252 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
253 self.weight.unquant_weight_bias()
254 }
255
256 fn has_bias(&self) -> bool {
257 self.bias.is_some() || self.weight.has_bias()
258 }
259
260 #[cfg(feature = "cuda")]
261 fn get_qtensor(&self) -> Option<&hanzo_ml::quantized::QTensor> {
262 self.weight.get_qtensor()
263 }
264
265 fn apply_isq(
266 self: Arc<Self>,
267 dtype: Option<crate::IsqType>,
268 device: hanzo_ml::Device,
269 n_quantized: &std::sync::atomic::AtomicUsize,
270 imatrix_weight: Option<Vec<f32>>,
271 guard: QuantizeOntoGuard,
272 ) -> Result<Arc<dyn QuantMethod>> {
273 let weight =
274 self.weight
275 .clone()
276 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
277 let bias = match &self.bias {
278 Some(b) => {
279 let (dtype, device) = weight.dtype_and_device();
280 Some(b.to_device(&device)?.to_dtype(dtype)?)
281 }
282 None => None,
283 };
284 Ok(Arc::new(Self {
285 weight,
286 bias,
287 all_reduce: self.all_reduce.clone(),
288 }))
289 }
290
291 fn is_distributed(&self) -> Option<DistributedKind> {
292 Some(DistributedKind::RowParallel)
293 }
294}
295
296impl QuantizedSerde for RowParallelLayer {
297 fn isq_serde_supported(&self) -> bool {
298 self.weight.isq_serde_supported()
299 }
300 fn name(&self) -> &'static str {
301 self.weight.name()
302 }
303 fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
304 self.weight.serialize_with_bias(self.bias.clone())
305 }
306 fn deserialize(
307 data: std::borrow::Cow<[u8]>,
308 device: &hanzo_ml::Device,
309 comm: &Arc<crate::Comm>,
310 guard: QuantizeOntoGuard,
311 ) -> Result<Arc<dyn QuantMethod>>
312 where
313 Self: Sized,
314 {
315 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
317 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
318 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
319 QuantizedSerdeType::Unquant => {
320 UnquantLinear::deserialize_ext_bias(data, device, guard)?
321 }
322 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
323 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
324 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
325 QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize_ext_bias(data, device, guard)?,
326 QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize_ext_bias(data, device, guard)?,
327 };
328 Ok(Arc::new(Self {
329 weight,
330 bias,
331 all_reduce: SumAllReduce::new(comm),
332 }))
333 }
334}
335
336#[derive(Debug)]
337pub struct ColumnParallelLayer {
340 weight: Arc<dyn QuantMethod>,
341 bias: Option<Tensor>,
342}
343
344impl ColumnParallelLayer {
345 #[allow(clippy::new_ret_no_self)]
346 pub fn new_with_shard(
347 in_dim: usize,
348 out_dim: usize,
349 config: &Option<QuantizedConfig>,
350 bias: bool,
351 comm: &Arc<crate::Comm>,
352 shard: Shard,
353 vb: ShardedVarBuilder,
354 ) -> Result<Arc<dyn QuantMethod>> {
355 let base_vb = vb.clone();
356 let vb = if should_apply_immediate_isq(&vb) {
357 vb.set_device(Device::Cpu)
358 } else {
359 vb
360 };
361
362 let weight = if let Some(quant_conf) = &config {
363 if matches!(
365 quant_conf,
366 QuantizedConfig::GptqAwq { .. }
367 | QuantizedConfig::Bitsandbytes { .. }
368 | QuantizedConfig::Afq { .. }
369 ) && comm.world_size() != 1
370 {
371 hanzo_ml::bail!(
372 "GPTQ/AWQ and BNB and AFQ quantization types to not support tensor parallelism, but got a world size of {}",
373 comm.world_size()
374 );
375 }
376
377 match quant_conf {
378 QuantizedConfig::GptqAwq { .. } => {
379 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
380 }
381 QuantizedConfig::Fp8 { weight_block_size } => {
382 if weight_block_size.is_some() {
384 blockwise_fp8_linear_b(
385 in_dim,
386 out_dim,
387 quant_conf,
388 false,
389 shard,
390 vb.clone(),
391 )?
392 } else {
393 pertensor_fp8_linear_b(
394 in_dim,
395 out_dim,
396 quant_conf,
397 false,
398 shard,
399 vb.clone(),
400 )?
401 }
402 }
403 QuantizedConfig::Bitsandbytes { .. } => {
404 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
405 }
406 QuantizedConfig::Afq { .. } => {
407 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
408 }
409 QuantizedConfig::MXFP4 {} => {
410 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
411 }
412 }
413 } else {
414 if !vb.contains_tensor("weight") {
415 make_dummy_or_error("column_parallel_linear", &vb, &["weight"])?
416 } else {
417 let weight = vb.get_with_hints((out_dim, in_dim), "weight", shard)?;
418 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
419
420 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
421 Linear::new(weight, None),
422 ))?;
423 Arc::new(layer) as Arc<dyn QuantMethod>
424 }
425 };
426
427 let bias = if bias && vb.contains_tensor("bias") {
429 Some(vb.get_with_hints((out_dim,), "bias", shard)?)
430 } else {
431 None
432 };
433
434 let this_unquant = Arc::new(Self { weight, bias });
435 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
436 Ok(this)
437 }
438
439 #[allow(clippy::new_ret_no_self)]
440 pub fn new(
441 in_dim: usize,
442 out_dim: usize,
443 config: &Option<QuantizedConfig>,
444 bias: bool,
445 comm: &Arc<crate::Comm>,
446 vb: ShardedVarBuilder,
447 ) -> Result<Arc<dyn QuantMethod>> {
448 let rank = comm.rank();
449 let world_size = comm.world_size();
450 let shard = shard(0, rank, world_size);
451
452 Self::new_with_shard(in_dim, out_dim, config, bias, comm, shard, vb)
453 }
454
455 #[allow(clippy::new_ret_no_self)]
456 pub fn new_matformer(
457 in_dim: usize,
458 out_dim: usize,
459 orig_intermediate_size: usize,
460 config: &Option<QuantizedConfig>,
461 bias: bool,
462 comm: &Arc<crate::Comm>,
463 vb: ShardedVarBuilder,
464 ) -> Result<Arc<dyn QuantMethod>> {
465 let rank = comm.rank();
466 let world_size = comm.world_size();
467 let shard = shard(0, rank, world_size);
468
469 let base_vb = vb.clone();
470 let vb = if should_apply_immediate_isq(&vb) {
471 vb.set_device(Device::Cpu)
472 } else {
473 vb
474 };
475
476 if config.is_some() {
477 hanzo_ml::bail!("Cannot load a matformer layer with a pre-quantized model.");
478 }
479
480 let weight = if !vb.contains_tensor("weight") {
481 make_dummy_or_error("column_parallel_matformer_linear", &vb, &["weight"])?
482 } else {
483 let weight = vb
484 .get_with_hints(
485 (orig_intermediate_size, in_dim),
486 "weight",
487 Default::default(),
488 )?
489 .i((..out_dim, ..))?
490 .contiguous()?;
491
492 let weight = shard.apply_to(&weight)?;
493 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, shard)?;
494
495 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
496 Linear::new(weight, None),
497 ))?;
498 Arc::new(layer) as Arc<dyn QuantMethod>
499 };
500
501 let bias = if bias && vb.contains_tensor("bias") {
503 Some(vb.get_with_hints((out_dim,), "bias", shard)?)
504 } else {
505 None
506 };
507
508 let this_unquant = Arc::new(Self { weight, bias });
509 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
510 Ok(this)
511 }
512
513 pub fn new_merged(
514 in_dim: usize,
515 out_dim: usize,
516 chunks: usize,
517 config: &Option<QuantizedConfig>,
518 bias: bool,
519 comm: &Arc<crate::Comm>,
520 vb: ShardedVarBuilder,
521 ) -> Result<Vec<Arc<dyn QuantMethod>>> {
522 let mut vec_layers = Vec::<Arc<dyn QuantMethod>>::new();
523 for chunk_idx in 0..chunks {
524 let layer = ColumnParallelLayer::new_with_shard(
525 in_dim,
526 out_dim,
527 config,
528 bias,
529 comm,
530 shard(
531 0,
532 chunk_idx * comm.world_size() + comm.rank(),
533 chunks * comm.world_size(),
534 ),
535 vb.clone(),
536 )?;
537 vec_layers.push(layer);
538 }
539 Ok(vec_layers)
540 }
541}
542
543impl QuantMethod for ColumnParallelLayer {
544 fn new(_method: QuantMethodConfig) -> Result<Self>
545 where
546 Self: Sized,
547 {
548 hanzo_ml::bail!("ColumnParallelLayer should not be constructed with `QuantMethod::new`")
549 }
550
551 fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
552 let mut xs = self.weight.forward_raw(a)?;
553 if let Some(bias) = &self.bias {
554 xs = xs.broadcast_add(bias)?;
555 }
556 Ok(xs)
557 }
558
559 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
560 let weight = self.weight.add_delta_w(delta)?;
561 Ok(Arc::new(Self {
562 weight,
563 bias: self.bias.clone(),
564 }))
565 }
566
567 fn dequantize_w(&self) -> Result<Tensor> {
568 self.weight.dequantize_w()
569 }
570
571 fn dtype_and_device(&self) -> (hanzo_ml::DType, hanzo_ml::Device) {
572 self.weight.dtype_and_device()
573 }
574
575 fn begin_track_stats(&mut self) -> Result<()> {
576 Arc::get_mut(&mut self.weight)
577 .context("Failed to get &mut to weight")?
578 .begin_track_stats()
579 }
580
581 fn end_track_stats(&self) -> Result<Tensor> {
582 self.weight.end_track_stats()
583 }
584
585 fn quantized_act_type(&self) -> Option<hanzo_ml::DType> {
586 self.weight.quantized_act_type()
587 }
588
589 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
590 self.weight.unquant_weight_bias()
591 }
592
593 fn has_bias(&self) -> bool {
594 self.bias.is_some() || self.weight.has_bias()
595 }
596
597 #[cfg(feature = "cuda")]
598 fn get_qtensor(&self) -> Option<&hanzo_ml::quantized::QTensor> {
599 self.weight.get_qtensor()
600 }
601
602 fn apply_isq(
603 self: Arc<Self>,
604 dtype: Option<crate::IsqType>,
605 device: hanzo_ml::Device,
606 n_quantized: &std::sync::atomic::AtomicUsize,
607 imatrix_weight: Option<Vec<f32>>,
608 guard: QuantizeOntoGuard,
609 ) -> Result<Arc<dyn QuantMethod>> {
610 let weight =
611 self.weight
612 .clone()
613 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)?;
614 let bias = match &self.bias {
615 Some(b) => {
616 let (dtype, device) = weight.dtype_and_device();
617 Some(b.to_device(&device)?.to_dtype(dtype)?)
618 }
619 None => None,
620 };
621 Ok(Arc::new(Self { weight, bias }))
622 }
623
624 fn is_distributed(&self) -> Option<DistributedKind> {
625 Some(DistributedKind::ColumnParallel)
626 }
627}
628
629impl QuantizedSerde for ColumnParallelLayer {
630 fn isq_serde_supported(&self) -> bool {
631 self.weight.isq_serde_supported()
632 }
633 fn name(&self) -> &'static str {
634 self.weight.name()
635 }
636 fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
637 self.weight.serialize_with_bias(self.bias.clone())
638 }
639 fn deserialize(
640 data: std::borrow::Cow<[u8]>,
641 device: &hanzo_ml::Device,
642 _comm: &Arc<crate::Comm>,
643 guard: QuantizeOntoGuard,
644 ) -> Result<Arc<dyn QuantMethod>>
645 where
646 Self: Sized,
647 {
648 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
650 let (weight, bias) = match QuantizedSerdeType::try_from(isq_type as usize)? {
651 QuantizedSerdeType::Gguf => GgufMatMul::deserialize_ext_bias(data, device, guard)?,
652 QuantizedSerdeType::Unquant => {
653 UnquantLinear::deserialize_ext_bias(data, device, guard)?
654 }
655 QuantizedSerdeType::Hqq => HqqLayer::deserialize_ext_bias(data, device, guard)?,
656 QuantizedSerdeType::Fp8 => FP8Linear::deserialize_ext_bias(data, device, guard)?,
657 QuantizedSerdeType::Afq => AfqLayer::deserialize_ext_bias(data, device, guard)?,
658 QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize_ext_bias(data, device, guard)?,
659 QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize_ext_bias(data, device, guard)?,
660 };
661 Ok(Arc::new(Self { weight, bias }))
662 }
663}
664
665#[derive(Debug)]
666pub struct ReplicatedLayer(Arc<dyn QuantMethod>);
668
669impl ReplicatedLayer {
670 pub fn from_linear(lin: Linear) -> Result<Arc<dyn QuantMethod>> {
671 let dev = lin.weight().device().clone();
672 if let Some(crate::ImmediateIsqParams {
673 guard,
674 ty: Some(immediate_isq),
675 pool,
676 backpressure,
677 ..
678 }) = crate::get_immediate_isq()
679 {
680 let lin = if !dev.is_cpu() {
683 Linear::new(lin.weight().to_device(&Device::Cpu)?, lin.bias().cloned())
684 } else {
685 lin
686 };
687 let layer: Arc<dyn QuantMethod> =
688 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lin))?);
689 if let Some(pool) = &pool {
690 backpressure.acquire();
691 let backpressure = backpressure.clone();
692 let dev = dev.clone();
693 let (tx, rx) = crate::pending_layer::pending_isq_channel();
694 pool.spawn(move || {
695 let result = layer.clone().apply_isq(
696 Some(immediate_isq),
697 dev,
698 &std::sync::atomic::AtomicUsize::new(0),
699 None,
700 guard,
701 );
702 let _ = tx.send(result);
703 backpressure.release();
704 });
705 Ok(Arc::new(crate::PendingIsqLayer::new(rx)))
706 } else {
707 layer.clone().apply_isq(
708 Some(immediate_isq),
709 dev,
710 &std::sync::atomic::AtomicUsize::new(0),
711 None,
712 guard,
713 )
714 }
715 } else {
716 Ok(Arc::new(UnquantLinear::new(
718 QuantMethodConfig::Unquantized(lin),
719 )?))
720 }
721 }
722
723 #[allow(clippy::new_ret_no_self)]
724 pub fn new(
725 in_dim: usize,
726 out_dim: usize,
727 config: &Option<QuantizedConfig>,
728 bias: bool,
729 vb: ShardedVarBuilder,
730 ) -> Result<Arc<dyn QuantMethod>> {
731 let base_vb = vb.clone();
732 let vb = if should_apply_immediate_isq(&vb) {
733 vb.set_device(Device::Cpu)
734 } else {
735 vb
736 };
737
738 let layer = if let Some(quant_conf) = &config {
739 match quant_conf {
740 QuantizedConfig::GptqAwq { .. } => {
741 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
742 }
743 QuantizedConfig::Fp8 { weight_block_size } => {
744 if weight_block_size.is_some() {
745 blockwise_fp8_linear_b(
746 in_dim,
747 out_dim,
748 quant_conf,
749 bias,
750 Default::default(),
751 vb.clone(),
752 )?
753 } else {
754 pertensor_fp8_linear_b(
755 in_dim,
756 out_dim,
757 quant_conf,
758 bias,
759 Default::default(),
760 vb.clone(),
761 )?
762 }
763 }
764 QuantizedConfig::Bitsandbytes { .. } => {
765 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
766 }
767 QuantizedConfig::Afq { .. } => {
768 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
769 }
770 QuantizedConfig::MXFP4 {} => {
771 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
772 }
773 }
774 } else {
775 if !vb.contains_tensor("weight") {
776 make_dummy_or_error("replicated_linear", &vb, &["weight"])?
777 } else {
778 let weight = vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
779 let weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
780
781 let bias = if bias {
782 Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
783 } else {
784 None
785 };
786 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
787 Linear::new(weight, bias),
788 ))?;
789 Arc::new(layer) as Arc<dyn QuantMethod>
790 }
791 };
792
793 let this_unquant = Arc::new(Self(layer));
794 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
795 Ok(this)
796 }
797
798 #[allow(clippy::new_ret_no_self)]
799 pub fn new_layers_matformer_indices(
800 in_dim: usize,
801 out_dim: usize,
802 kept_layers_indices: Option<&Tensor>,
803 orig_num_hidden_layers: usize,
804 config: &Option<QuantizedConfig>,
805 bias: bool,
806 vb: ShardedVarBuilder,
807 ) -> Result<Arc<dyn QuantMethod>> {
808 let base_vb = vb.clone();
809 let vb = if should_apply_immediate_isq(&vb) {
810 vb.set_device(Device::Cpu)
811 } else {
812 vb
813 };
814
815 let layer = if let Some(quant_conf) = &config {
816 if kept_layers_indices.is_some() {
817 hanzo_ml::bail!("Cannot load a matformer layer with a pre-quantized model.");
818 }
819
820 match quant_conf {
821 QuantizedConfig::GptqAwq { .. } => {
822 gptq_linear(in_dim, out_dim, quant_conf, vb.clone())?
823 }
824 QuantizedConfig::Fp8 { weight_block_size } => {
825 if weight_block_size.is_some() {
826 blockwise_fp8_linear_b(
827 in_dim,
828 out_dim,
829 quant_conf,
830 bias,
831 Default::default(),
832 vb.clone(),
833 )?
834 } else {
835 pertensor_fp8_linear_b(
836 in_dim,
837 out_dim,
838 quant_conf,
839 bias,
840 Default::default(),
841 vb.clone(),
842 )?
843 }
844 }
845 QuantizedConfig::Bitsandbytes { .. } => {
846 Arc::new(BnbLinear::linear_b(in_dim, out_dim, bias, vb.clone())?) as Arc<_>
847 }
848 QuantizedConfig::Afq { .. } => {
849 AfqLayer::afq_linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
850 }
851 QuantizedConfig::MXFP4 {} => {
852 MXFP4Layer::linear_b(in_dim, out_dim, quant_conf, bias, vb.clone())?
853 }
854 }
855 } else {
856 if !vb.contains_tensor("weight") {
857 make_dummy_or_error("replicated_matformer_linear", &vb, &["weight"])?
858 } else {
859 let mut weight =
860 vb.get_with_hints((out_dim, in_dim), "weight", Default::default())?;
861
862 if let Some(kept_layers_indices) = &kept_layers_indices {
863 let weight_reshaped = weight.reshape((
864 orig_num_hidden_layers,
865 weight.dim(0)? / orig_num_hidden_layers,
866 weight.dim(1)?,
867 ))?;
868
869 weight = weight_reshaped
870 .index_select(&kept_layers_indices.to_device(weight.device())?, 0)?
871 .reshape(((), weight_reshaped.dim(D::Minus1)?))?
872 .contiguous()?;
873 }
874
875 weight = merge_lora_weights(&vb, weight, in_dim, out_dim, Default::default())?;
876
877 let bias = if bias {
878 Some(vb.get_with_hints((out_dim,), "bias", Default::default())?)
879 } else {
880 None
881 };
882 let layer = <UnquantLinear as QuantMethod>::new(QuantMethodConfig::Unquantized(
883 Linear::new(weight, bias),
884 ))?;
885 Arc::new(layer) as Arc<dyn QuantMethod>
886 }
887 };
888
889 let this_unquant = Arc::new(Self(layer));
890 let this: Arc<dyn QuantMethod> = apply_immediate_isq(this_unquant, base_vb)?;
891 Ok(this)
892 }
893}
894
895impl QuantMethod for ReplicatedLayer {
896 fn new(_method: QuantMethodConfig) -> Result<Self>
897 where
898 Self: Sized,
899 {
900 hanzo_ml::bail!("ReplicatedLayer should not be constructed with `QuantMethod::new`")
901 }
902
903 fn forward_raw(&self, a: &Tensor) -> Result<Tensor> {
904 self.0.forward_raw(a)
905 }
906
907 fn add_delta_w(&self, delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
908 self.0.add_delta_w(delta)
909 }
910
911 fn dequantize_w(&self) -> Result<Tensor> {
912 self.0.dequantize_w()
913 }
914
915 fn dtype_and_device(&self) -> (hanzo_ml::DType, hanzo_ml::Device) {
916 self.0.dtype_and_device()
917 }
918
919 fn begin_track_stats(&mut self) -> Result<()> {
920 Arc::get_mut(&mut self.0)
921 .context("Failed to get &mut to weight")?
922 .begin_track_stats()
923 }
924
925 fn end_track_stats(&self) -> Result<Tensor> {
926 self.0.end_track_stats()
927 }
928
929 fn quantized_act_type(&self) -> Option<hanzo_ml::DType> {
930 self.0.quantized_act_type()
931 }
932
933 fn unquant_weight_bias(&self) -> Option<(Tensor, Option<Tensor>)> {
934 self.0.unquant_weight_bias()
935 }
936
937 fn has_bias(&self) -> bool {
938 self.0.has_bias()
939 }
940
941 #[cfg(feature = "cuda")]
942 fn get_qtensor(&self) -> Option<&hanzo_ml::quantized::QTensor> {
943 self.0.get_qtensor()
944 }
945
946 fn apply_isq(
947 self: Arc<Self>,
948 dtype: Option<crate::IsqType>,
949 device: hanzo_ml::Device,
950 n_quantized: &std::sync::atomic::AtomicUsize,
951 imatrix_weight: Option<Vec<f32>>,
952 guard: QuantizeOntoGuard,
953 ) -> Result<Arc<dyn QuantMethod>> {
954 self.0
955 .clone()
956 .apply_isq(dtype, device, n_quantized, imatrix_weight, guard)
957 }
958
959 fn is_distributed(&self) -> Option<DistributedKind> {
960 Some(DistributedKind::Replicated)
961 }
962}
963
964impl QuantizedSerde for ReplicatedLayer {
965 fn isq_serde_supported(&self) -> bool {
966 self.0.isq_serde_supported()
967 }
968 fn name(&self) -> &'static str {
969 self.0.name()
970 }
971 fn serialize(&self) -> Result<std::borrow::Cow<'_, [u8]>> {
972 self.0.serialize()
973 }
974 fn deserialize(
975 data: std::borrow::Cow<[u8]>,
976 device: &hanzo_ml::Device,
977 comm: &Arc<crate::Comm>,
978 guard: QuantizeOntoGuard,
979 ) -> Result<Arc<dyn QuantMethod>>
980 where
981 Self: Sized,
982 {
983 let isq_type = data[crate::UQFF_QUANT_TYPE_OFFSET];
985 let deserialized = match QuantizedSerdeType::try_from(isq_type as usize)? {
986 QuantizedSerdeType::Gguf => GgufMatMul::deserialize(data, device, comm, guard)?,
987 QuantizedSerdeType::Unquant => UnquantLinear::deserialize(data, device, comm, guard)?,
988 QuantizedSerdeType::Hqq => HqqLayer::deserialize(data, device, comm, guard)?,
989 QuantizedSerdeType::Fp8 => FP8Linear::deserialize(data, device, comm, guard)?,
990 QuantizedSerdeType::Afq => AfqLayer::deserialize(data, device, comm, guard)?,
991 QuantizedSerdeType::F8Q8 => F8Q8Linear::deserialize(data, device, comm, guard)?,
992 QuantizedSerdeType::Mxfp4 => MXFP4Layer::deserialize(data, device, comm, guard)?,
993 };
994 Ok(Arc::new(Self(deserialized)))
995 }
996}
997
998#[derive(Debug)]
999pub struct PackedExperts {
1000 pub gate_proj: Vec<Arc<dyn QuantMethod>>,
1001 pub up_proj: Vec<Arc<dyn QuantMethod>>,
1002 pub down_proj: Vec<Arc<dyn QuantMethod>>,
1003}
1004
1005impl PackedExperts {
1006 #[allow(clippy::too_many_arguments)]
1008 pub fn new(
1009 num_local_experts: usize,
1010 hidden_size: usize,
1011 intermediate_size: usize,
1012 config: &Option<QuantizedConfig>,
1013 bias: bool,
1014 comm: &Arc<crate::Comm>,
1015 vb: ShardedVarBuilder,
1016 ) -> Result<Self> {
1017 if bias {
1018 hanzo_ml::bail!("PackedExperts does not support bias.");
1019 }
1020
1021 let (gate_proj, up_proj, down_proj) = if let Some(quant_conf) = &config {
1022 if comm.world_size() != 1 {
1024 hanzo_ml::bail!(
1025 "PackedExperts with quantization config does not support distributed (world size {}). Use ISQ.",
1026 comm.world_size()
1027 );
1028 }
1029
1030 match quant_conf {
1031 QuantizedConfig::Afq { .. } => {
1032 if !vb.contains_tensor("gate_up_proj")
1033 || !vb.contains_tensor("gate_up_proj.weight")
1034 {
1035 hanzo_ml::bail!("PackedExperts with AFQ quantization config does not support `gate_up_proj` format.");
1036 }
1037
1038 let base_vb = vb.clone();
1039
1040 let vb_gate_proj = if should_apply_immediate_isq(&vb) {
1041 vb.pp("gate_proj").set_device(Device::Cpu)
1042 } else {
1043 vb.pp("gate_proj")
1044 };
1045 let vb_up_proj = if should_apply_immediate_isq(&vb) {
1046 vb.pp("up_proj").set_device(Device::Cpu)
1047 } else {
1048 vb.pp("up_proj")
1049 };
1050 let vb_down_proj = if should_apply_immediate_isq(&vb) {
1051 vb.pp("down_proj").set_device(Device::Cpu)
1052 } else {
1053 vb.pp("down_proj")
1054 };
1055 let mut gate_proj = AfqLayer::afq_packed_linear_b(
1056 num_local_experts,
1057 hidden_size,
1058 intermediate_size,
1059 quant_conf,
1060 bias,
1061 vb_gate_proj,
1062 )?;
1063 let mut up_proj = AfqLayer::afq_packed_linear_b(
1064 num_local_experts,
1065 hidden_size,
1066 intermediate_size,
1067 quant_conf,
1068 bias,
1069 vb_up_proj,
1070 )?;
1071 let mut down_proj = AfqLayer::afq_packed_linear_b(
1072 num_local_experts,
1073 intermediate_size,
1074 hidden_size,
1075 quant_conf,
1076 bias,
1077 vb_down_proj,
1078 )?;
1079
1080 gate_proj = apply_immediate_isq(gate_proj, base_vb.pp("gate_proj"))?;
1081 up_proj = apply_immediate_isq(up_proj, base_vb.pp("up_proj"))?;
1082 down_proj = apply_immediate_isq(down_proj, base_vb.pp("down_proj"))?;
1083
1084 (vec![gate_proj], vec![up_proj], vec![down_proj])
1085 }
1086 QuantizedConfig::Fp8 { weight_block_size } => {
1087 let Some(weight_block_size) = weight_block_size else {
1090 hanzo_ml::bail!("Blockwise FP8 for PackedExperts requires weight_block_size to be set.")
1091 };
1092 if weight_block_size.len() != 2 {
1093 hanzo_ml::bail!(
1094 "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1095 );
1096 }
1097
1098 let is_stacked_format = vb.contains_tensor("gate_up_proj");
1101
1102 if is_stacked_format {
1103 let has_fp8_scales = vb.contains_tensor("gate_up_proj.weight_scale_inv");
1105
1106 if has_fp8_scales {
1107 let gate_up_fp8 = vb.get_with_hints_dtype(
1109 (num_local_experts, hidden_size, intermediate_size * 2),
1110 "gate_up_proj",
1111 Default::default(),
1112 hanzo_ml::DType::F8E4M3,
1113 )?;
1114 let gate_up_scale = vb.get_with_hints_dtype(
1115 (
1116 num_local_experts,
1117 hidden_size.div_ceil(weight_block_size[0]),
1118 (intermediate_size * 2).div_ceil(weight_block_size[1]),
1119 ),
1120 "gate_up_proj.weight_scale_inv",
1121 Default::default(),
1122 hanzo_ml::DType::F32,
1123 )?;
1124
1125 let down_fp8 = vb.get_with_hints_dtype(
1127 (num_local_experts, intermediate_size, hidden_size),
1128 "down_proj",
1129 Default::default(),
1130 hanzo_ml::DType::F8E4M3,
1131 )?;
1132 let down_scale = vb.get_with_hints_dtype(
1133 (
1134 num_local_experts,
1135 intermediate_size.div_ceil(weight_block_size[0]),
1136 hidden_size.div_ceil(weight_block_size[1]),
1137 ),
1138 "down_proj.weight_scale_inv",
1139 Default::default(),
1140 hanzo_ml::DType::F32,
1141 )?;
1142
1143 let mut gs = Vec::new();
1145 let mut us = Vec::new();
1146 let mut ds = Vec::new();
1147
1148 for i in 0..num_local_experts {
1149 let gate_up_expert =
1151 gate_up_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1152 let gate_up_scale_expert = gate_up_scale.i(i)?.contiguous()?;
1153 let down_expert = down_fp8.i(i)?.transpose(0, 1)?.contiguous()?;
1154 let down_scale_expert = down_scale.i(i)?.contiguous()?;
1155
1156 let gate_expert = gate_up_expert.narrow(0, 0, intermediate_size)?;
1158 let up_expert = gate_up_expert.narrow(
1159 0,
1160 intermediate_size,
1161 intermediate_size,
1162 )?;
1163
1164 let gate_scale_expert = gate_up_scale_expert.narrow(
1166 1,
1167 0,
1168 intermediate_size.div_ceil(weight_block_size[1]),
1169 )?;
1170 let up_scale_expert = gate_up_scale_expert.narrow(
1171 1,
1172 intermediate_size.div_ceil(weight_block_size[1]),
1173 intermediate_size.div_ceil(weight_block_size[1]),
1174 )?;
1175
1176 use crate::blockwise_fp8::BlockwiseFP8Linear;
1178 use crate::QuantMethodConfig;
1179
1180 let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1181 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1182 weight: gate_expert,
1183 weight_scale_inv: gate_scale_expert.transpose(0, 1)?,
1184 bias: None,
1185 dequant_dtype: vb.dtype(),
1186 weight_block_size: weight_block_size.clone(),
1187 })?,
1188 );
1189 let up_layer: Arc<dyn QuantMethod> = Arc::new(
1190 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1191 weight: up_expert,
1192 weight_scale_inv: up_scale_expert.transpose(0, 1)?,
1193 bias: None,
1194 dequant_dtype: vb.dtype(),
1195 weight_block_size: weight_block_size.clone(),
1196 })?,
1197 );
1198 let down_layer: Arc<dyn QuantMethod> = Arc::new(
1199 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1200 weight: down_expert,
1201 weight_scale_inv: down_scale_expert.transpose(0, 1)?,
1202 bias: None,
1203 dequant_dtype: vb.dtype(),
1204 weight_block_size: weight_block_size.clone(),
1205 })?,
1206 );
1207
1208 gs.push(gate_layer);
1209 us.push(up_layer);
1210 ds.push(down_layer);
1211 }
1212
1213 (gs, us, ds)
1214 } else {
1215 hanzo_ml::bail!(
1216 "PackedExperts with FP8 requires weight_scale_inv tensors"
1217 );
1218 }
1219 } else {
1220 let mut gs = Vec::new();
1222 let mut us = Vec::new();
1223 let mut ds = Vec::new();
1224
1225 for i in 0..num_local_experts {
1226 let expert_vb = vb.pp(i);
1227
1228 let gate_fp8 = expert_vb.get_with_hints_dtype(
1230 (intermediate_size, hidden_size),
1231 "gate_proj.weight",
1232 Default::default(),
1233 hanzo_ml::DType::F8E4M3,
1234 )?;
1235 let gate_scale = expert_vb.get_with_hints_dtype(
1236 (
1237 intermediate_size.div_ceil(weight_block_size[0]),
1238 hidden_size.div_ceil(weight_block_size[1]),
1239 ),
1240 "gate_proj.weight_scale_inv",
1241 Default::default(),
1242 hanzo_ml::DType::F32,
1243 )?;
1244
1245 let up_fp8 = expert_vb.get_with_hints_dtype(
1246 (intermediate_size, hidden_size),
1247 "up_proj.weight",
1248 Default::default(),
1249 hanzo_ml::DType::F8E4M3,
1250 )?;
1251 let up_scale = expert_vb.get_with_hints_dtype(
1252 (
1253 intermediate_size.div_ceil(weight_block_size[0]),
1254 hidden_size.div_ceil(weight_block_size[1]),
1255 ),
1256 "up_proj.weight_scale_inv",
1257 Default::default(),
1258 hanzo_ml::DType::F32,
1259 )?;
1260
1261 let down_fp8 = expert_vb.get_with_hints_dtype(
1262 (hidden_size, intermediate_size),
1263 "down_proj.weight",
1264 Default::default(),
1265 hanzo_ml::DType::F8E4M3,
1266 )?;
1267 let down_scale = expert_vb.get_with_hints_dtype(
1268 (
1269 hidden_size.div_ceil(weight_block_size[0]),
1270 intermediate_size.div_ceil(weight_block_size[1]),
1271 ),
1272 "down_proj.weight_scale_inv",
1273 Default::default(),
1274 hanzo_ml::DType::F32,
1275 )?;
1276
1277 use crate::blockwise_fp8::BlockwiseFP8Linear;
1279 use crate::QuantMethodConfig;
1280
1281 let gate_layer: Arc<dyn QuantMethod> = Arc::new(
1282 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1283 weight: gate_fp8,
1284 weight_scale_inv: gate_scale,
1285 bias: None,
1286 dequant_dtype: vb.dtype(),
1287 weight_block_size: weight_block_size.clone(),
1288 })?,
1289 );
1290 let up_layer: Arc<dyn QuantMethod> = Arc::new(BlockwiseFP8Linear::new(
1291 QuantMethodConfig::BlockwiseFP8 {
1292 weight: up_fp8,
1293 weight_scale_inv: up_scale,
1294 bias: None,
1295 dequant_dtype: vb.dtype(),
1296 weight_block_size: weight_block_size.clone(),
1297 },
1298 )?);
1299 let down_layer: Arc<dyn QuantMethod> = Arc::new(
1300 BlockwiseFP8Linear::new(QuantMethodConfig::BlockwiseFP8 {
1301 weight: down_fp8,
1302 weight_scale_inv: down_scale,
1303 bias: None,
1304 dequant_dtype: vb.dtype(),
1305 weight_block_size: weight_block_size.clone(),
1306 })?,
1307 );
1308
1309 gs.push(gate_layer);
1310 us.push(up_layer);
1311 ds.push(down_layer);
1312 }
1313
1314 (gs, us, ds)
1315 }
1316 }
1317 QuantizedConfig::MXFP4 {} => {
1318 let gate_proj = MXFP4Layer::packed_linear_b(
1322 num_local_experts,
1323 hidden_size,
1324 intermediate_size,
1325 quant_conf,
1326 bias,
1327 vb.pp("gate_proj"),
1328 )?;
1329 let up_proj = MXFP4Layer::packed_linear_b(
1330 num_local_experts,
1331 hidden_size,
1332 intermediate_size,
1333 quant_conf,
1334 bias,
1335 vb.pp("up_proj"),
1336 )?;
1337 let down_proj = MXFP4Layer::packed_linear_b(
1338 num_local_experts,
1339 intermediate_size,
1340 hidden_size,
1341 quant_conf,
1342 bias,
1343 vb.pp("down_proj"),
1344 )?;
1345
1346 (vec![gate_proj], vec![up_proj], vec![down_proj])
1347 }
1348 _ => hanzo_ml::bail!(
1349 "PackedExperts with quantization config only allows AFQ, FP8, or MXFP4 quantization"
1350 ),
1351 }
1352 } else if !vb.contains_tensor("gate_up_proj") {
1353 let mut gs: Vec<Arc<dyn QuantMethod>> = Vec::new();
1355 let mut us: Vec<Arc<dyn QuantMethod>> = Vec::new();
1356 let mut ds: Vec<Arc<dyn QuantMethod>> = Vec::new();
1357 for _ in 0..num_local_experts {
1358 gs.push(make_dummy_or_error(
1359 "packed_experts_gate_proj",
1360 &vb,
1361 &["gate_up_proj"],
1362 )?);
1363 us.push(make_dummy_or_error(
1364 "packed_experts_up_proj",
1365 &vb,
1366 &["gate_up_proj"],
1367 )?);
1368 ds.push(make_dummy_or_error(
1369 "packed_experts_down_proj",
1370 &vb,
1371 &["gate_up_proj"],
1372 )?);
1373 }
1374 (gs, us, ds)
1375 } else {
1376 let gate_up_block_size = intermediate_size / comm.world_size();
1384 let gate_up_start = gate_up_block_size * comm.rank();
1385
1386 let shard_gate = Shard::Offset {
1388 dim: 2,
1389 offset: gate_up_start,
1390 len: gate_up_block_size,
1391 };
1392 let shard_up = Shard::Offset {
1393 dim: 2,
1394 offset: intermediate_size + gate_up_start,
1395 len: gate_up_block_size,
1396 };
1397 let shard_down = Shard::Simple {
1398 dim: 1,
1399 rank: comm.rank(),
1400 world_size: comm.world_size(),
1401 };
1402
1403 let vb_gate_up_proj = if should_apply_immediate_isq(&vb) {
1404 vb.pp("gate_up_proj").set_device(Device::Cpu)
1405 } else {
1406 vb.pp("gate_up_proj")
1407 };
1408 let vb_down_proj = if should_apply_immediate_isq(&vb) {
1409 vb.pp("down_proj").set_device(Device::Cpu)
1410 } else {
1411 vb.pp("down_proj")
1412 };
1413
1414 let gate_proj = vb
1415 .get_with_hints(
1416 (num_local_experts, hidden_size, intermediate_size * 2),
1417 "gate_up_proj",
1418 shard_gate,
1419 )?
1420 .t()?
1421 .contiguous()?;
1422 let up_proj = vb
1423 .get_with_hints(
1424 (num_local_experts, hidden_size, intermediate_size * 2),
1425 "gate_up_proj",
1426 shard_up,
1427 )?
1428 .t()?
1429 .contiguous()?;
1430 let down_proj = vb
1431 .get_with_hints(
1432 (num_local_experts, intermediate_size, hidden_size),
1433 "down_proj",
1434 shard_down,
1435 )?
1436 .t()?
1437 .contiguous()?;
1438
1439 let gc = gate_proj.chunk(num_local_experts, 0)?;
1440 let uc = up_proj.chunk(num_local_experts, 0)?;
1441 let dc = down_proj.chunk(num_local_experts, 0)?;
1442 drop((gate_proj, up_proj, down_proj));
1443
1444 let mut gs = Vec::new();
1445 let mut us = Vec::new();
1446 let mut ds = Vec::new();
1447 for ((mut gate_proj, mut up_proj), mut down_proj) in gc.into_iter().zip(uc).zip(dc) {
1448 gate_proj = gate_proj.squeeze(0)?;
1449 up_proj = up_proj.squeeze(0)?;
1450 down_proj = down_proj.squeeze(0)?;
1451 let gate_proj = merge_lora_weights(
1452 &vb,
1453 gate_proj,
1454 hidden_size,
1455 intermediate_size * 2,
1456 shard_gate,
1457 )?;
1458 let up_proj =
1459 merge_lora_weights(&vb, up_proj, hidden_size, intermediate_size * 2, shard_up)?;
1460 let down_proj =
1461 merge_lora_weights(&vb, down_proj, intermediate_size, hidden_size, shard_down)?;
1462
1463 let mut gate_proj: Arc<dyn QuantMethod> =
1464 Arc::new(<UnquantLinear as QuantMethod>::new(
1465 QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1466 )?);
1467 gate_proj = apply_immediate_isq(gate_proj, vb_gate_up_proj.clone())?;
1468 let mut up_proj: Arc<dyn QuantMethod> =
1469 Arc::new(<UnquantLinear as QuantMethod>::new(
1470 QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1471 )?);
1472 up_proj = apply_immediate_isq(up_proj, vb_gate_up_proj.clone())?;
1473 let mut down_proj: Arc<dyn QuantMethod> =
1474 Arc::new(<UnquantLinear as QuantMethod>::new(
1475 QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1476 )?);
1477 down_proj = apply_immediate_isq(down_proj, vb_down_proj.clone())?;
1478 gs.push(gate_proj);
1479 us.push(up_proj);
1480 ds.push(down_proj);
1481 }
1482 (gs, us, ds)
1483 };
1484
1485 Ok(Self {
1486 gate_proj,
1487 up_proj,
1488 down_proj,
1489 })
1490 }
1491}
1492
1493pub struct FusedExperts {
1494 pub fused_gate_proj: Arc<dyn QuantMethod>,
1495 pub fused_up_proj: Arc<dyn QuantMethod>,
1496 pub fused_down_proj: Arc<dyn QuantMethod>,
1497}
1498
1499impl FusedExperts {
1500 pub fn new(
1501 hidden_size: usize,
1502 moe_intermediate_size: usize,
1503 num_experts: usize,
1504 quantization_config: &Option<QuantizedConfig>,
1505 vb: ShardedVarBuilder,
1506 ) -> Result<Self> {
1507 let experts_vb = vb.pp("experts");
1513 let is_stacked_format = experts_vb.contains_tensor("gate_up_proj");
1514
1515 let (fused_gate_proj, fused_up_proj, fused_down_proj) = if matches!(
1516 &quantization_config,
1517 Some(QuantizedConfig::Afq { .. })
1518 ) {
1519 let quantization_config = quantization_config.as_ref().unwrap();
1520
1521 let fused_gate_proj = AfqLayer::afq_packed_linear_b(
1522 num_experts,
1523 hidden_size,
1524 moe_intermediate_size,
1525 quantization_config,
1526 false,
1527 vb.pp("switch_mlp.gate_proj"),
1528 )?;
1529 let fused_up_proj = AfqLayer::afq_packed_linear_b(
1530 num_experts,
1531 hidden_size,
1532 moe_intermediate_size,
1533 quantization_config,
1534 false,
1535 vb.pp("switch_mlp.up_proj"),
1536 )?;
1537 let fused_down_proj = AfqLayer::afq_packed_linear_b(
1538 num_experts,
1539 moe_intermediate_size,
1540 hidden_size,
1541 quantization_config,
1542 false,
1543 vb.pp("switch_mlp.down_proj"),
1544 )?;
1545
1546 (fused_gate_proj, fused_up_proj, fused_down_proj)
1547 } else if is_stacked_format
1548 && matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. }))
1549 {
1550 let has_fp8_scales = experts_vb.contains_tensor("gate_up_proj.weight_scale_inv");
1553
1554 if has_fp8_scales {
1555 let weight_block_size = match quantization_config {
1556 Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1557 _ => unreachable!(),
1558 };
1559
1560 let Some(weight_block_size) = weight_block_size else {
1561 hanzo_ml::bail!(
1562 "Blockwise FP8 for stacked experts requires weight_block_size to be set."
1563 )
1564 };
1565 if weight_block_size.len() != 2 {
1566 hanzo_ml::bail!(
1567 "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1568 );
1569 }
1570
1571 let gate_up_fp8 = experts_vb.get_with_hints_dtype(
1574 (num_experts, hidden_size, moe_intermediate_size * 2),
1575 "gate_up_proj",
1576 Default::default(),
1577 hanzo_ml::DType::F8E4M3,
1578 )?;
1579 let gate_up_scale = experts_vb.get_with_hints_dtype(
1580 (
1581 num_experts,
1582 hidden_size.div_ceil(weight_block_size[0]),
1583 (moe_intermediate_size * 2).div_ceil(weight_block_size[1]),
1584 ),
1585 "gate_up_proj.weight_scale_inv",
1586 Default::default(),
1587 hanzo_ml::DType::F32,
1588 )?;
1589
1590 let down_fp8 = experts_vb.get_with_hints_dtype(
1593 (num_experts, moe_intermediate_size, hidden_size),
1594 "down_proj",
1595 Default::default(),
1596 hanzo_ml::DType::F8E4M3,
1597 )?;
1598 let down_scale = experts_vb.get_with_hints_dtype(
1599 (
1600 num_experts,
1601 moe_intermediate_size.div_ceil(weight_block_size[0]),
1602 hidden_size.div_ceil(weight_block_size[1]),
1603 ),
1604 "down_proj.weight_scale_inv",
1605 Default::default(),
1606 hanzo_ml::DType::F32,
1607 )?;
1608
1609 let gate_fp8 = gate_up_fp8.narrow(2, 0, moe_intermediate_size)?;
1611 let up_fp8 = gate_up_fp8.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1612
1613 let gate_scale = gate_up_scale.narrow(
1615 2,
1616 0,
1617 moe_intermediate_size.div_ceil(weight_block_size[1]),
1618 )?;
1619 let up_scale = gate_up_scale.narrow(
1620 2,
1621 moe_intermediate_size.div_ceil(weight_block_size[1]),
1622 moe_intermediate_size.div_ceil(weight_block_size[1]),
1623 )?;
1624
1625 let gate_fp8 = gate_fp8.transpose(1, 2)?.contiguous()?;
1628 let up_fp8 = up_fp8.transpose(1, 2)?.contiguous()?;
1629 let down_fp8 = down_fp8.transpose(1, 2)?.contiguous()?;
1631
1632 let gate_scale = gate_scale.transpose(1, 2)?.contiguous()?;
1634 let up_scale = up_scale.transpose(1, 2)?.contiguous()?;
1635 let down_scale = down_scale.transpose(1, 2)?.contiguous()?;
1636
1637 let fused_gate_proj =
1639 blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1640 let fused_up_proj =
1641 blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1642 let fused_down_proj =
1643 blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1644
1645 (fused_gate_proj, fused_up_proj, fused_down_proj)
1646 } else {
1647 tracing::warn!(
1649 "FP8 quantization config specified but no scale tensors found for stacked MoE experts. \
1650 Loading as unquantized."
1651 );
1652 let gate_up_proj = experts_vb
1653 .get(
1654 (num_experts, hidden_size, moe_intermediate_size * 2),
1655 "gate_up_proj",
1656 )
1657 .or_else(|_| {
1658 experts_vb
1659 .get(
1660 (num_experts, moe_intermediate_size * 2, hidden_size),
1661 "gate_up_proj",
1662 )
1663 .and_then(|t| t.transpose(1, 2)?.contiguous())
1664 })?;
1665 let down_proj_packed = experts_vb
1666 .get(
1667 (num_experts, moe_intermediate_size, hidden_size),
1668 "down_proj",
1669 )
1670 .or_else(|_| {
1671 experts_vb
1672 .get(
1673 (num_experts, hidden_size, moe_intermediate_size),
1674 "down_proj",
1675 )
1676 .and_then(|t| t.transpose(1, 2)?.contiguous())
1677 })?;
1678
1679 let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1681 let up_proj =
1682 gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1683
1684 let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1686 let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1687 let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1688
1689 let isq_gate_up = should_apply_immediate_isq(&experts_vb.pp("gate_up_proj"));
1691 let isq_down = should_apply_immediate_isq(&experts_vb.pp("down_proj"));
1692 let target_device = gate_proj.device().clone();
1693 let (gate_proj, up_proj, down_proj) =
1694 if (isq_gate_up || isq_down) && !target_device.is_cpu() {
1695 (
1696 gate_proj.to_device(&Device::Cpu)?,
1697 up_proj.to_device(&Device::Cpu)?,
1698 down_proj.to_device(&Device::Cpu)?,
1699 )
1700 } else {
1701 (gate_proj, up_proj, down_proj)
1702 };
1703
1704 let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1705 QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1706 )?);
1707 let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1708 QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1709 )?);
1710 let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1711 QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1712 )?);
1713 let vb_gate_up = experts_vb.pp("gate_up_proj");
1716 let vb_down = experts_vb.pp("down_proj");
1717 fused_gate_proj = apply_immediate_isq(fused_gate_proj, vb_gate_up.clone())?;
1718 fused_up_proj = apply_immediate_isq(fused_up_proj, vb_gate_up)?;
1719 fused_down_proj = apply_immediate_isq(fused_down_proj, vb_down)?;
1720
1721 (fused_gate_proj, fused_up_proj, fused_down_proj)
1722 }
1723 } else if is_stacked_format
1724 && matches!(&quantization_config, Some(QuantizedConfig::MXFP4 {}))
1725 {
1726 let quantization_config = quantization_config.as_ref().unwrap();
1730
1731 let fused_gate_proj = MXFP4Layer::packed_linear_b(
1736 num_experts,
1737 hidden_size,
1738 moe_intermediate_size,
1739 quantization_config,
1740 false,
1741 experts_vb.pp("gate_proj"),
1742 )?;
1743 let fused_up_proj = MXFP4Layer::packed_linear_b(
1744 num_experts,
1745 hidden_size,
1746 moe_intermediate_size,
1747 quantization_config,
1748 false,
1749 experts_vb.pp("up_proj"),
1750 )?;
1751 let fused_down_proj = MXFP4Layer::packed_linear_b(
1752 num_experts,
1753 moe_intermediate_size,
1754 hidden_size,
1755 quantization_config,
1756 false,
1757 experts_vb.pp("down_proj"),
1758 )?;
1759
1760 (fused_gate_proj, fused_up_proj, fused_down_proj)
1761 } else if is_stacked_format {
1762 let gate_up_proj = experts_vb
1772 .get(
1773 (num_experts, hidden_size, moe_intermediate_size * 2),
1774 "gate_up_proj",
1775 )
1776 .or_else(|_| {
1777 experts_vb
1778 .get(
1779 (num_experts, moe_intermediate_size * 2, hidden_size),
1780 "gate_up_proj",
1781 )
1782 .and_then(|t| t.transpose(1, 2)?.contiguous())
1783 })?;
1784 let down_proj_packed = experts_vb
1785 .get(
1786 (num_experts, moe_intermediate_size, hidden_size),
1787 "down_proj",
1788 )
1789 .or_else(|_| {
1790 experts_vb
1791 .get(
1792 (num_experts, hidden_size, moe_intermediate_size),
1793 "down_proj",
1794 )
1795 .and_then(|t| t.transpose(1, 2)?.contiguous())
1796 })?;
1797
1798 let gate_proj = gate_up_proj.narrow(2, 0, moe_intermediate_size)?;
1802 let up_proj = gate_up_proj.narrow(2, moe_intermediate_size, moe_intermediate_size)?;
1803
1804 let gate_proj = gate_proj.transpose(1, 2)?.contiguous()?;
1807 let up_proj = up_proj.transpose(1, 2)?.contiguous()?;
1808 let down_proj = down_proj_packed.transpose(1, 2)?.contiguous()?;
1810
1811 let isq_gate_up = should_apply_immediate_isq(&experts_vb.pp("gate_up_proj"));
1813 let isq_down = should_apply_immediate_isq(&experts_vb.pp("down_proj"));
1814 let target_device = gate_proj.device().clone();
1815 let (gate_proj, up_proj, down_proj) =
1816 if (isq_gate_up || isq_down) && !target_device.is_cpu() {
1817 (
1818 gate_proj.to_device(&Device::Cpu)?,
1819 up_proj.to_device(&Device::Cpu)?,
1820 down_proj.to_device(&Device::Cpu)?,
1821 )
1822 } else {
1823 (gate_proj, up_proj, down_proj)
1824 };
1825
1826 let mut fused_gate_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1827 QuantMethodConfig::Unquantized(Linear::new(gate_proj, None)),
1828 )?);
1829 let mut fused_up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1830 QuantMethodConfig::Unquantized(Linear::new(up_proj, None)),
1831 )?);
1832 let mut fused_down_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1833 QuantMethodConfig::Unquantized(Linear::new(down_proj, None)),
1834 )?);
1835 let vb_gate_up = experts_vb.pp("gate_up_proj");
1838 let vb_down = experts_vb.pp("down_proj");
1839 fused_gate_proj = apply_immediate_isq(fused_gate_proj, vb_gate_up.clone())?;
1840 fused_up_proj = apply_immediate_isq(fused_up_proj, vb_gate_up)?;
1841 fused_down_proj = apply_immediate_isq(fused_down_proj, vb_down)?;
1842
1843 (fused_gate_proj, fused_up_proj, fused_down_proj)
1844 } else if matches!(&quantization_config, Some(QuantizedConfig::Fp8 { .. })) {
1845 let weight_block_size = match quantization_config {
1848 Some(QuantizedConfig::Fp8 { weight_block_size }) => weight_block_size.clone(),
1849 _ => unreachable!(),
1850 };
1851
1852 let Some(weight_block_size) = weight_block_size else {
1853 hanzo_ml::bail!(
1854 "Blockwise FP8 for per-expert format requires weight_block_size to be set."
1855 )
1856 };
1857 if weight_block_size.len() != 2 {
1858 hanzo_ml::bail!(
1859 "Expected weight_block_size to have length 2, got {weight_block_size:?}"
1860 );
1861 }
1862
1863 let mut gate_fp8_vec = Vec::new();
1864 let mut gate_scale_vec = Vec::new();
1865 let mut up_fp8_vec = Vec::new();
1866 let mut up_scale_vec = Vec::new();
1867 let mut down_fp8_vec = Vec::new();
1868 let mut down_scale_vec = Vec::new();
1869
1870 for i in 0..num_experts {
1871 let expert_vb = experts_vb.pp(i);
1872
1873 let gate_fp8 = expert_vb.get_with_hints_dtype(
1875 (moe_intermediate_size, hidden_size),
1876 "gate_proj.weight",
1877 Default::default(),
1878 hanzo_ml::DType::F8E4M3,
1879 )?;
1880 let gate_scale = expert_vb.get_with_hints_dtype(
1881 (
1882 moe_intermediate_size.div_ceil(weight_block_size[0]),
1883 hidden_size.div_ceil(weight_block_size[1]),
1884 ),
1885 "gate_proj.weight_scale_inv",
1886 Default::default(),
1887 hanzo_ml::DType::F32,
1888 )?;
1889
1890 let up_fp8 = expert_vb.get_with_hints_dtype(
1891 (moe_intermediate_size, hidden_size),
1892 "up_proj.weight",
1893 Default::default(),
1894 hanzo_ml::DType::F8E4M3,
1895 )?;
1896 let up_scale = expert_vb.get_with_hints_dtype(
1897 (
1898 moe_intermediate_size.div_ceil(weight_block_size[0]),
1899 hidden_size.div_ceil(weight_block_size[1]),
1900 ),
1901 "up_proj.weight_scale_inv",
1902 Default::default(),
1903 hanzo_ml::DType::F32,
1904 )?;
1905
1906 let down_fp8 = expert_vb.get_with_hints_dtype(
1907 (hidden_size, moe_intermediate_size),
1908 "down_proj.weight",
1909 Default::default(),
1910 hanzo_ml::DType::F8E4M3,
1911 )?;
1912 let down_scale = expert_vb.get_with_hints_dtype(
1913 (
1914 hidden_size.div_ceil(weight_block_size[0]),
1915 moe_intermediate_size.div_ceil(weight_block_size[1]),
1916 ),
1917 "down_proj.weight_scale_inv",
1918 Default::default(),
1919 hanzo_ml::DType::F32,
1920 )?;
1921
1922 gate_fp8_vec.push(gate_fp8);
1923 gate_scale_vec.push(gate_scale);
1924 up_fp8_vec.push(up_fp8);
1925 up_scale_vec.push(up_scale);
1926 down_fp8_vec.push(down_fp8);
1927 down_scale_vec.push(down_scale);
1928 }
1929
1930 let gate_fp8 = Tensor::stack(&gate_fp8_vec, 0)?;
1932 let gate_scale = Tensor::stack(&gate_scale_vec, 0)?;
1933 let up_fp8 = Tensor::stack(&up_fp8_vec, 0)?;
1934 let up_scale = Tensor::stack(&up_scale_vec, 0)?;
1935 let down_fp8 = Tensor::stack(&down_fp8_vec, 0)?;
1936 let down_scale = Tensor::stack(&down_scale_vec, 0)?;
1937
1938 let fused_gate_proj =
1940 blockwise_fp8_moe(gate_fp8, gate_scale, weight_block_size.clone(), vb.dtype())?;
1941 let fused_up_proj =
1942 blockwise_fp8_moe(up_fp8, up_scale, weight_block_size.clone(), vb.dtype())?;
1943 let fused_down_proj =
1944 blockwise_fp8_moe(down_fp8, down_scale, weight_block_size, vb.dtype())?;
1945
1946 (fused_gate_proj, fused_up_proj, fused_down_proj)
1947 } else if !experts_vb.pp("0").contains_tensor("gate_proj.weight") {
1948 let expert_vb = experts_vb.pp("0");
1951 let fused_gate_proj =
1952 make_dummy_or_error("fused_experts_gate_proj", &expert_vb, &["gate_proj.weight"])?;
1953 let fused_up_proj =
1954 make_dummy_or_error("fused_experts_up_proj", &expert_vb, &["gate_proj.weight"])?;
1955 let fused_down_proj =
1956 make_dummy_or_error("fused_experts_down_proj", &expert_vb, &["gate_proj.weight"])?;
1957 (fused_gate_proj, fused_up_proj, fused_down_proj)
1958 } else {
1959 let load_experts_vb =
1962 if crate::get_immediate_isq().is_some() && !experts_vb.device().is_cpu() {
1963 experts_vb.clone().set_device(Device::Cpu)
1964 } else {
1965 experts_vb.clone()
1966 };
1967 let mut gate_proj_vec = Vec::new();
1968 let mut up_proj_vec = Vec::new();
1969 let mut down_proj_vec = Vec::new();
1970 for i in 0..num_experts {
1971 let expert_vb = load_experts_vb.pp(i);
1972 let gate_proj =
1973 expert_vb.get((moe_intermediate_size, hidden_size), "gate_proj.weight")?;
1974 let up_proj =
1975 expert_vb.get((moe_intermediate_size, hidden_size), "up_proj.weight")?;
1976 let down_proj =
1977 expert_vb.get((hidden_size, moe_intermediate_size), "down_proj.weight")?;
1978
1979 gate_proj_vec.push(gate_proj);
1980 up_proj_vec.push(up_proj);
1981 down_proj_vec.push(down_proj);
1982 }
1983
1984 let mut gate_proj: Arc<dyn QuantMethod> =
1985 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1986 Linear::new(Tensor::stack(&gate_proj_vec, 0)?, None),
1987 ))?);
1988 let mut up_proj: Arc<dyn QuantMethod> = Arc::new(UnquantLinear::new(
1989 QuantMethodConfig::Unquantized(Linear::new(Tensor::stack(&up_proj_vec, 0)?, None)),
1990 )?);
1991 let mut down_proj: Arc<dyn QuantMethod> =
1992 Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
1993 Linear::new(Tensor::stack(&down_proj_vec, 0)?, None),
1994 ))?);
1995 let expert0_vb = experts_vb.pp("0");
1997 gate_proj = apply_immediate_isq(gate_proj, expert0_vb.pp("gate_proj"))?;
1998 up_proj = apply_immediate_isq(up_proj, expert0_vb.pp("up_proj"))?;
1999 down_proj = apply_immediate_isq(down_proj, expert0_vb.pp("down_proj"))?;
2000
2001 (gate_proj, up_proj, down_proj)
2002 };
2003
2004 Ok(Self {
2005 fused_gate_proj,
2006 fused_up_proj,
2007 fused_down_proj,
2008 })
2009 }
2010}
2011
2012pub fn compute_kv_shard(total_num_kv_heads: usize, head_dim: usize, comm: &Comm) -> Shard {
2014 if comm.world_size() == 1 {
2015 return Shard::default();
2016 }
2017
2018 let kv_replicate = if comm.world_size() > total_num_kv_heads {
2022 comm.world_size() / total_num_kv_heads
2023 } else {
2024 return Shard::Simple {
2025 dim: 0,
2026 rank: comm.rank(),
2027 world_size: comm.world_size(),
2028 };
2029 };
2030
2031 let num_kv_heads = (total_num_kv_heads / comm.world_size()).max(1);
2032 let kv_shard_id = (comm.rank() / kv_replicate) * num_kv_heads;
2033 Shard::Offset {
2034 dim: 0,
2035 offset: kv_shard_id * head_dim,
2036 len: head_dim,
2037 }
2038}
2039
2040pub fn compute_n_kv_groups(
2042 total_num_kv_heads: usize,
2043 num_attention_heads: usize,
2044 comm: &Comm,
2045) -> usize {
2046 let kv_replicate = if comm.world_size() > total_num_kv_heads {
2047 comm.world_size() / total_num_kv_heads
2048 } else {
2049 1
2050 };
2051 (num_attention_heads / total_num_kv_heads)
2052 .checked_div(kv_replicate)
2053 .unwrap_or(num_attention_heads / total_num_kv_heads)
2054}