Skip to main content

rlx_runtime/
moe_expert_store.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Per-expert F32 weight slabs for MoE offload (TIDE-style migration source).
17
18use crate::ExpertPool;
19use std::sync::Arc;
20
21/// One expert projection stack in GroupedMatMul layout `[num_experts, k, n]`.
22#[derive(Debug, Clone)]
23pub struct ExpertStackF32 {
24    pub num_experts: usize,
25    pub k: usize,
26    pub n: usize,
27    pub data: Arc<[f32]>,
28}
29
30impl ExpertStackF32 {
31    pub fn new(data: Vec<f32>, num_experts: usize, k: usize, n: usize) -> Self {
32        assert_eq!(data.len(), num_experts * k * n);
33        Self {
34            num_experts,
35            k,
36            n,
37            data: Arc::from(data),
38        }
39    }
40
41    pub fn expert_stride(&self) -> usize {
42        self.k * self.n
43    }
44
45    pub fn expert_slice(&self, e: usize) -> &[f32] {
46        let stride = self.expert_stride();
47        let start = e * stride;
48        &self.data[start..start + stride]
49    }
50
51    pub fn as_slice(&self) -> &[f32] {
52        &self.data
53    }
54}
55
56/// Gate / up / down expert stacks for one decoder layer.
57#[derive(Debug, Clone)]
58pub struct LayerMoeWeights {
59    pub layer_index: usize,
60    pub gate: ExpertStackF32,
61    pub up: ExpertStackF32,
62    pub down: ExpertStackF32,
63}
64
65/// Host-side expert weights for all MoE layers (migration source of truth).
66#[derive(Debug, Clone)]
67pub struct MoeExpertStore {
68    pub layers: Vec<LayerMoeWeights>,
69}
70
71impl MoeExpertStore {
72    pub fn num_layers(&self) -> usize {
73        self.layers.len()
74    }
75
76    /// Apply captured TopK indices to per-layer pools (TIDE refresh).
77    pub fn refresh_pools(
78        &self,
79        pools: &mut [ExpertPool],
80        captured: &[Vec<u32>],
81        decode_step: usize,
82        is_prefill_block: bool,
83    ) -> bool {
84        let n = self.layers.len().min(pools.len()).min(captured.len());
85        if n == 0 {
86            return false;
87        }
88        let refresh =
89            pools[0].should_refresh(crate::MoEExecMode::Reuse, decode_step, is_prefill_block);
90        if !refresh {
91            return false;
92        }
93        for i in 0..n {
94            pools[i].refresh_from_indices(&captured[i]);
95        }
96        true
97    }
98
99    /// Push full host stacks into compiled params (lossless; refreshes arena bytes).
100    pub fn apply_to_compiled(&self, compiled: &mut crate::CompiledGraph) {
101        for layer in &self.layers {
102            let il = layer.layer_index;
103            compiled.set_param(
104                &format!("blk.{il}.ffn_gate_exps.weight"),
105                layer.gate.as_slice(),
106            );
107            compiled.set_param(&format!("blk.{il}.ffn_up_exps.weight"), layer.up.as_slice());
108            compiled.set_param(
109                &format!("blk.{il}.ffn_down_exps.weight"),
110                layer.down.as_slice(),
111            );
112        }
113    }
114}