rlx_runtime/
moe_expert_store.rs1use crate::ExpertPool;
19use std::sync::Arc;
20
21#[derive(Debug, Clone)]
23pub struct ExpertStackF32 {
24 pub num_experts: usize,
25 pub k: usize,
26 pub n: usize,
27 pub data: Arc<[f32]>,
28}
29
30impl ExpertStackF32 {
31 pub fn new(data: Vec<f32>, num_experts: usize, k: usize, n: usize) -> Self {
32 assert_eq!(data.len(), num_experts * k * n);
33 Self {
34 num_experts,
35 k,
36 n,
37 data: Arc::from(data),
38 }
39 }
40
41 pub fn expert_stride(&self) -> usize {
42 self.k * self.n
43 }
44
45 pub fn expert_slice(&self, e: usize) -> &[f32] {
46 let stride = self.expert_stride();
47 let start = e * stride;
48 &self.data[start..start + stride]
49 }
50
51 pub fn as_slice(&self) -> &[f32] {
52 &self.data
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct LayerMoeWeights {
59 pub layer_index: usize,
60 pub gate: ExpertStackF32,
61 pub up: ExpertStackF32,
62 pub down: ExpertStackF32,
63}
64
65#[derive(Debug, Clone)]
67pub struct MoeExpertStore {
68 pub layers: Vec<LayerMoeWeights>,
69}
70
71impl MoeExpertStore {
72 pub fn num_layers(&self) -> usize {
73 self.layers.len()
74 }
75
76 pub fn refresh_pools(
78 &self,
79 pools: &mut [ExpertPool],
80 captured: &[Vec<u32>],
81 decode_step: usize,
82 is_prefill_block: bool,
83 ) -> bool {
84 let n = self.layers.len().min(pools.len()).min(captured.len());
85 if n == 0 {
86 return false;
87 }
88 let refresh =
89 pools[0].should_refresh(crate::MoEExecMode::Reuse, decode_step, is_prefill_block);
90 if !refresh {
91 return false;
92 }
93 for i in 0..n {
94 pools[i].refresh_from_indices(&captured[i]);
95 }
96 true
97 }
98
99 pub fn apply_to_compiled(&self, compiled: &mut crate::CompiledGraph) {
101 for layer in &self.layers {
102 let il = layer.layer_index;
103 compiled.set_param(
104 &format!("blk.{il}.ffn_gate_exps.weight"),
105 layer.gate.as_slice(),
106 );
107 compiled.set_param(&format!("blk.{il}.ffn_up_exps.weight"), layer.up.as_slice());
108 compiled.set_param(
109 &format!("blk.{il}.ffn_down_exps.weight"),
110 layer.down.as_slice(),
111 );
112 }
113 }
114}