1use std::path::Path;
23use std::sync::atomic::{AtomicU64, Ordering};
24
25use candle_core::quantized::GgmlDType;
26use candle_core::{Device, Result as CandleResult};
27use ferrum_kernels::backend::cpu::CpuBackend;
28use ferrum_kernels::backend::{Backend, GgufQuantType};
29use ferrum_kernels::Linear;
30use ferrum_quantization::gguf::GgufFile;
31use ferrum_quantization::{DenseLinear, QuantLinear};
32use ferrum_types::{FerrumError, Result};
33
34use crate::moe::router::RouterOutput;
35
36pub static MOE_SYNC_US: AtomicU64 = AtomicU64::new(0);
40pub static MOE_SYNC_CALLS: AtomicU64 = AtomicU64::new(0);
41pub static MOE_GEMV_GATE_UP_US: AtomicU64 = AtomicU64::new(0);
42pub static MOE_GEMV_GATE_UP_CALLS: AtomicU64 = AtomicU64::new(0);
43pub static MOE_SILU_US: AtomicU64 = AtomicU64::new(0);
44pub static MOE_SILU_CALLS: AtomicU64 = AtomicU64::new(0);
45pub static MOE_GEMV_DOWN_US: AtomicU64 = AtomicU64::new(0);
46pub static MOE_GEMV_DOWN_CALLS: AtomicU64 = AtomicU64::new(0);
47pub static MOE_SCALED_ADD_US: AtomicU64 = AtomicU64::new(0);
48pub static MOE_SCALED_ADD_CALLS: AtomicU64 = AtomicU64::new(0);
49pub static MOE_COPY_US: AtomicU64 = AtomicU64::new(0);
50pub static MOE_COPY_CALLS: AtomicU64 = AtomicU64::new(0);
51pub static MOE_HOST_TOPK_US: AtomicU64 = AtomicU64::new(0);
52pub static MOE_HOST_TOPK_CALLS: AtomicU64 = AtomicU64::new(0);
53
54fn moe_profile_enabled() -> bool {
55 std::env::var("FERRUM_MOE_PROFILE").is_ok()
56}
57
58pub struct ExpertStack<B: Backend> {
67 pub gate_up: Vec<Box<dyn Linear<B>>>,
70 pub down: Vec<Box<dyn Linear<B>>>,
72 pub gate_stacked: Option<B::QuantStore>,
84 pub up_stacked: Option<B::QuantStore>,
85 pub down_stacked: Option<B::QuantStore>,
86}
87
88impl<B: Backend> ExpertStack<B> {
89 pub fn from_dense_stacks(
96 gate_stack: &[f32],
97 up_stack: &[f32],
98 down_stack: &[f32],
99 num_experts: usize,
100 hidden_size: usize,
101 expert_intermediate: usize,
102 ) -> Result<Self> {
103 let gate_up_per_expert = expert_intermediate * hidden_size;
104 let down_per_expert = hidden_size * expert_intermediate;
105
106 check_size(
107 gate_stack.len(),
108 num_experts * gate_up_per_expert,
109 "gate_stack",
110 )?;
111 check_size(up_stack.len(), num_experts * gate_up_per_expert, "up_stack")?;
112 check_size(
113 down_stack.len(),
114 num_experts * down_per_expert,
115 "down_stack",
116 )?;
117
118 let mut gate_up = Vec::with_capacity(num_experts);
119 let mut down = Vec::with_capacity(num_experts);
120 for e in 0..num_experts {
121 let g_off = e * gate_up_per_expert;
122 let g_slice = &gate_stack[g_off..g_off + gate_up_per_expert];
123 let u_slice = &up_stack[g_off..g_off + gate_up_per_expert];
124
125 let mut fused = Vec::with_capacity(2 * gate_up_per_expert);
130 fused.extend_from_slice(g_slice);
131 fused.extend_from_slice(u_slice);
132 gate_up.push(Box::new(DenseLinear::<B>::from_rows(
133 &fused,
134 2 * expert_intermediate,
135 hidden_size,
136 )) as Box<dyn Linear<B>>);
137
138 let d_off = e * down_per_expert;
139 let d_slice = &down_stack[d_off..d_off + down_per_expert];
140 down.push(Box::new(DenseLinear::<B>::from_rows(
141 d_slice,
142 hidden_size,
143 expert_intermediate,
144 )) as Box<dyn Linear<B>>);
145 }
146 Ok(Self {
147 gate_up,
148 down,
149 gate_stacked: None,
150 up_stacked: None,
151 down_stacked: None,
152 })
153 }
154
155 pub fn load_from_gguf(
175 gguf: &GgufFile,
176 layer_idx: usize,
177 num_experts: usize,
178 hidden_size: usize,
179 expert_intermediate: usize,
180 ) -> Result<Self> {
181 if let Some(quant) = Self::try_load_quantised(
182 gguf,
183 layer_idx,
184 num_experts,
185 hidden_size,
186 expert_intermediate,
187 )? {
188 if std::env::var("FERRUM_MOE_LOAD_TRACE").is_ok() {
189 eprintln!("[moe-load] layer {layer_idx} → quantised expert path");
190 }
191 return Ok(quant);
192 }
193
194 if std::env::var("FERRUM_MOE_LOAD_TRACE").is_ok() {
195 eprintln!("[moe-load] layer {layer_idx} → eager fp32 dense fallback ⚠");
196 }
197
198 let device = Device::Cpu;
199 let gate = read_dequant_flat(
200 gguf,
201 &format!("blk.{layer_idx}.ffn_gate_exps.weight"),
202 &device,
203 )?;
204 let up = read_dequant_flat(
205 gguf,
206 &format!("blk.{layer_idx}.ffn_up_exps.weight"),
207 &device,
208 )?;
209 let down = read_dequant_flat(
210 gguf,
211 &format!("blk.{layer_idx}.ffn_down_exps.weight"),
212 &device,
213 )?;
214 Self::from_dense_stacks(
217 &gate,
218 &up,
219 &down,
220 num_experts,
221 hidden_size,
222 expert_intermediate,
223 )
224 }
225
226 fn try_load_quantised(
232 gguf: &GgufFile,
233 layer_idx: usize,
234 num_experts: usize,
235 hidden_size: usize,
236 expert_intermediate: usize,
237 ) -> Result<Option<Self>> {
238 let device = Device::Cpu;
239
240 let gate_name = format!("blk.{layer_idx}.ffn_gate_exps.weight");
241 let up_name = format!("blk.{layer_idx}.ffn_up_exps.weight");
242 let down_name = format!("blk.{layer_idx}.ffn_down_exps.weight");
243
244 let gate_kind = match quant_kind(gguf, &gate_name)? {
248 Some(k) => k,
249 None => return Ok(None),
250 };
251 let up_kind = match quant_kind(gguf, &up_name)? {
252 Some(k) => k,
253 None => return Ok(None),
254 };
255 let down_kind = match quant_kind(gguf, &down_name)? {
256 Some(k) => k,
257 None => return Ok(None),
258 };
259
260 let gate_bytes = gguf.tensor_byte_slice(&gate_name).ok_or_else(|| {
270 FerrumError::model(format!("MoE: tensor_byte_slice failed for '{gate_name}'"))
271 })?;
272 let up_bytes = gguf.tensor_byte_slice(&up_name).ok_or_else(|| {
273 FerrumError::model(format!("MoE: tensor_byte_slice failed for '{up_name}'"))
274 })?;
275 let down_bytes = gguf.tensor_byte_slice(&down_name).ok_or_else(|| {
276 FerrumError::model(format!("MoE: tensor_byte_slice failed for '{down_name}'"))
277 })?;
278 let _ = device; let gate_per = block_bytes_for(
284 gate_kind,
285 expert_intermediate * hidden_size,
286 "ffn_gate_exps",
287 )?;
288 let up_per = block_bytes_for(up_kind, expert_intermediate * hidden_size, "ffn_up_exps")?;
289 let down_per = block_bytes_for(
290 down_kind,
291 hidden_size * expert_intermediate,
292 "ffn_down_exps",
293 )?;
294
295 check_size(
296 gate_bytes.len(),
297 num_experts * gate_per,
298 "ffn_gate_exps bytes",
299 )?;
300 check_size(up_bytes.len(), num_experts * up_per, "ffn_up_exps bytes")?;
301 check_size(
302 down_bytes.len(),
303 num_experts * down_per,
304 "ffn_down_exps bytes",
305 )?;
306
307 let gate_stacked = B::load_quant_experts(
315 gate_kind,
316 gate_bytes,
317 num_experts,
318 expert_intermediate,
319 hidden_size,
320 )
321 .ok();
322 let up_stacked = B::load_quant_experts(
323 up_kind,
324 up_bytes,
325 num_experts,
326 expert_intermediate,
327 hidden_size,
328 )
329 .ok();
330 let down_stacked = B::load_quant_experts(
331 down_kind,
332 down_bytes,
333 num_experts,
334 hidden_size,
335 expert_intermediate,
336 )
337 .ok();
338
339 let stacked_complete =
347 gate_stacked.is_some() && up_stacked.is_some() && down_stacked.is_some();
348
349 let (gate_up, down) = if stacked_complete {
350 (Vec::new(), Vec::new())
353 } else {
354 let mut gate_up: Vec<Box<dyn Linear<B>>> = Vec::with_capacity(num_experts);
355 let mut down: Vec<Box<dyn Linear<B>>> = Vec::with_capacity(num_experts);
356 for e in 0..num_experts {
357 let g_slice = &gate_bytes[e * gate_per..(e + 1) * gate_per];
358 let u_slice = &up_bytes[e * up_per..(e + 1) * up_per];
359 let d_slice = &down_bytes[e * down_per..(e + 1) * down_per];
360
361 let parts: [(GgufQuantType, &[u8], usize); 2] = [
362 (gate_kind, g_slice, expert_intermediate),
363 (up_kind, u_slice, expert_intermediate),
364 ];
365 let gate_up_e = match QuantLinear::<B>::from_gguf_fused(&parts, hidden_size) {
366 Ok(q) => q,
367 Err(_) => return Ok(None),
368 };
369 gate_up.push(Box::new(gate_up_e) as Box<dyn Linear<B>>);
370
371 let down_e = match QuantLinear::<B>::from_gguf_bytes(
372 down_kind,
373 d_slice,
374 hidden_size,
375 expert_intermediate,
376 ) {
377 Ok(q) => q,
378 Err(_) => return Ok(None),
379 };
380 down.push(Box::new(down_e) as Box<dyn Linear<B>>);
381 }
382 (gate_up, down)
383 };
384
385 Ok(Some(Self {
386 gate_up,
387 down,
388 gate_stacked,
389 up_stacked,
390 down_stacked,
391 }))
392 }
393
394 pub fn open_and_load(
398 path: impl AsRef<Path>,
399 layer_idx: usize,
400 num_experts: usize,
401 hidden_size: usize,
402 expert_intermediate: usize,
403 ) -> Result<Self> {
404 let gguf = GgufFile::open(path).map_err(candle_to_ferrum)?;
405 Self::load_from_gguf(
406 &gguf,
407 layer_idx,
408 num_experts,
409 hidden_size,
410 expert_intermediate,
411 )
412 }
413
414 pub fn num_experts(&self) -> usize {
422 debug_assert_eq!(
423 self.gate_up.len(),
424 self.down.len(),
425 "ExpertStack: gate_up and down disagree on expert count"
426 );
427 self.gate_up.len()
428 }
429}
430
431#[allow(clippy::too_many_arguments)]
469pub fn moe_forward<B: Backend>(
470 ctx: &mut B::Context,
471 x: &B::Buffer,
472 router_logits: &B::Buffer,
473 out: &mut B::Buffer,
474 batch: usize,
475 hidden_size: usize,
476 expert_intermediate: usize,
477 num_experts: usize,
478 top_k: usize,
479 norm_topk_prob: bool,
480 experts: &ExpertStack<B>,
481 x_single: &mut B::Buffer,
482 acc_buf: &mut B::Buffer,
483 gate_up_buf: &mut B::Buffer,
484 silu_buf: &mut B::Buffer,
485 down_buf: &mut B::Buffer,
486 zero_hidden: &B::Buffer,
487) -> Result<()> {
488 let n_experts = experts.num_experts();
489 if n_experts != num_experts {
490 return Err(FerrumError::model(format!(
491 "moe_forward: experts.num_experts() = {n_experts} != cfg.num_experts = {num_experts}"
492 )));
493 }
494
495 let prof = moe_profile_enabled();
496
497 let t0 = if prof {
501 Some(std::time::Instant::now())
502 } else {
503 None
504 };
505 B::sync(ctx);
506 if let Some(t) = t0 {
507 MOE_SYNC_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
508 MOE_SYNC_CALLS.fetch_add(1, Ordering::Relaxed);
509 }
510
511 let t0 = if prof {
512 Some(std::time::Instant::now())
513 } else {
514 None
515 };
516 let logits_host = B::to_vec(router_logits, batch * num_experts);
517 let route_out =
518 crate::moe::router::route(&logits_host, batch, num_experts, top_k, norm_topk_prob);
519 if let Some(t) = t0 {
520 MOE_HOST_TOPK_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
521 MOE_HOST_TOPK_CALLS.fetch_add(1, Ordering::Relaxed);
522 }
523
524 for b in 0..batch {
525 let t0 = if prof {
527 Some(std::time::Instant::now())
528 } else {
529 None
530 };
531 B::copy_slice(ctx, x, b * hidden_size, x_single, 0, hidden_size);
532 B::copy_slice(ctx, zero_hidden, 0, acc_buf, 0, hidden_size);
533 if let Some(t) = t0 {
534 MOE_COPY_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
535 MOE_COPY_CALLS.fetch_add(2, Ordering::Relaxed);
536 }
537
538 for k in 0..top_k {
539 let pair = b * top_k + k;
540 let expert_id = route_out.expert_ids[pair] as usize;
541 let weight = route_out.expert_weights[pair];
542 if expert_id >= num_experts {
543 return Err(FerrumError::model(format!(
544 "moe_forward: routed expert {expert_id} >= num_experts {num_experts}"
545 )));
546 }
547
548 let t0 = if prof {
550 B::sync(ctx);
551 Some(std::time::Instant::now())
552 } else {
553 None
554 };
555 experts.gate_up[expert_id].forward(ctx, x_single, gate_up_buf, 1);
556 if let Some(t) = t0 {
557 B::sync(ctx);
558 MOE_GEMV_GATE_UP_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
559 MOE_GEMV_GATE_UP_CALLS.fetch_add(1, Ordering::Relaxed);
560 }
561
562 let t0 = if prof {
564 Some(std::time::Instant::now())
565 } else {
566 None
567 };
568 B::fused_silu_mul_split(ctx, gate_up_buf, silu_buf, 1, expert_intermediate);
569 if let Some(t) = t0 {
570 B::sync(ctx);
571 MOE_SILU_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
572 MOE_SILU_CALLS.fetch_add(1, Ordering::Relaxed);
573 }
574
575 let t0 = if prof {
577 Some(std::time::Instant::now())
578 } else {
579 None
580 };
581 experts.down[expert_id].forward(ctx, silu_buf, down_buf, 1);
582 if let Some(t) = t0 {
583 B::sync(ctx);
584 MOE_GEMV_DOWN_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
585 MOE_GEMV_DOWN_CALLS.fetch_add(1, Ordering::Relaxed);
586 }
587
588 let t0 = if prof {
590 Some(std::time::Instant::now())
591 } else {
592 None
593 };
594 B::scaled_add_inplace(ctx, acc_buf, down_buf, weight, hidden_size);
595 if let Some(t) = t0 {
596 B::sync(ctx);
597 MOE_SCALED_ADD_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
598 MOE_SCALED_ADD_CALLS.fetch_add(1, Ordering::Relaxed);
599 }
600 }
601
602 let t0 = if prof {
604 Some(std::time::Instant::now())
605 } else {
606 None
607 };
608 B::copy_slice(ctx, acc_buf, 0, out, b * hidden_size, hidden_size);
609 if let Some(t) = t0 {
610 MOE_COPY_US.fetch_add(t.elapsed().as_micros() as u64, Ordering::Relaxed);
611 MOE_COPY_CALLS.fetch_add(1, Ordering::Relaxed);
612 }
613 }
614
615 Ok(())
616}
617
618pub fn moe_forward_cpu(
633 x: &[f32],
634 batch: usize,
635 hidden_size: usize,
636 expert_intermediate: usize,
637 top_k: usize,
638 router: &RouterOutput,
639 experts: &ExpertStack<CpuBackend>,
640 out: &mut Vec<f32>,
641) -> Result<()> {
642 let n_experts = experts.num_experts();
643
644 if x.len() != batch * hidden_size {
645 return Err(FerrumError::model(format!(
646 "moe_forward_cpu: x len {} doesn't match batch*hidden = {}*{} = {}",
647 x.len(),
648 batch,
649 hidden_size,
650 batch * hidden_size
651 )));
652 }
653 if router.expert_ids.len() != batch * top_k {
654 return Err(FerrumError::model(format!(
655 "moe_forward_cpu: router has {} expert_ids but expected batch*top_k = {}*{} = {}",
656 router.expert_ids.len(),
657 batch,
658 top_k,
659 batch * top_k
660 )));
661 }
662
663 out.clear();
664 out.resize(batch * hidden_size, 0.0);
665
666 let mut ctx = <CpuBackend as Backend>::new_context();
667 let mut x_b: Vec<f32> = vec![0.0; hidden_size];
668 let mut gate_up_buf: Vec<f32> = vec![0.0; 2 * expert_intermediate];
669 let mut silu_mul_buf: Vec<f32> = vec![0.0; expert_intermediate];
670 let mut down_out: Vec<f32> = vec![0.0; hidden_size];
671
672 for b in 0..batch {
673 x_b.copy_from_slice(&x[b * hidden_size..(b + 1) * hidden_size]);
674
675 for k in 0..top_k {
676 let pair_idx = b * top_k + k;
677 let expert_id = router.expert_ids[pair_idx] as usize;
678 let weight = router.expert_weights[pair_idx];
679
680 if expert_id >= n_experts {
681 return Err(FerrumError::model(format!(
682 "moe_forward_cpu: router selected expert {expert_id} >= num_experts {n_experts}"
683 )));
684 }
685
686 experts.gate_up[expert_id].forward(&mut ctx, &x_b, &mut gate_up_buf, 1);
688
689 <CpuBackend as Backend>::fused_silu_mul_split(
691 &mut ctx,
692 &gate_up_buf,
693 &mut silu_mul_buf,
694 1,
695 expert_intermediate,
696 );
697
698 experts.down[expert_id].forward(&mut ctx, &silu_mul_buf, &mut down_out, 1);
700
701 let out_row = &mut out[b * hidden_size..(b + 1) * hidden_size];
705 for (o, d) in out_row.iter_mut().zip(down_out.iter()) {
706 *o += weight * *d;
707 }
708 }
709 }
710
711 Ok(())
712}
713
714fn check_size(actual: usize, expected: usize, label: &str) -> Result<()> {
715 if actual != expected {
716 return Err(FerrumError::model(format!(
717 "ExpertStack: {label} size mismatch (got {actual}, expected {expected})"
718 )));
719 }
720 Ok(())
721}
722
723fn quant_kind(gguf: &GgufFile, name: &str) -> Result<Option<GgufQuantType>> {
727 let info = gguf.tensor_info(name).ok_or_else(|| {
728 FerrumError::model(format!("ExpertStack: tensor info missing for '{name}'"))
729 })?;
730 Ok(match info.ggml_dtype {
731 GgmlDType::Q4K => Some(GgufQuantType::Q4K),
732 GgmlDType::Q6K => Some(GgufQuantType::Q6K),
733 _ => None,
734 })
735}
736
737fn block_bytes_for(kind: GgufQuantType, n_elems: usize, label: &str) -> Result<usize> {
742 const QK_K: usize = 256;
743 if n_elems % QK_K != 0 {
744 return Err(FerrumError::model(format!(
745 "ExpertStack {label}: per-expert element count {n_elems} not a multiple of {QK_K}"
746 )));
747 }
748 let block_bytes = match kind {
749 GgufQuantType::Q4K => 144,
750 GgufQuantType::Q6K => 210,
751 other => {
754 return Err(FerrumError::model(format!(
755 "ExpertStack {label}: unsupported k-quant flavour {other:?}"
756 )))
757 }
758 };
759 Ok((n_elems / QK_K) * block_bytes)
760}
761
762fn read_dequant_flat(gguf: &GgufFile, name: &str, device: &Device) -> Result<Vec<f32>> {
763 let qt = gguf.read_tensor(name, device).map_err(candle_to_ferrum)?;
764 let dense = qt.dequantize(device).map_err(candle_to_ferrum)?;
765 let flat = dense.flatten_all().map_err(candle_to_ferrum)?;
766 flat.to_vec1::<f32>().map_err(candle_to_ferrum)
767}
768
769fn candle_to_ferrum(e: candle_core::Error) -> FerrumError {
770 FerrumError::model(format!("candle: {e}"))
771}
772
773#[allow(dead_code)]
776type _CandleResult<T> = CandleResult<T>;