Skip to main content

rlx_runtime/
expert_pool.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! MoE expert residency pool (TIDE-style predictive offload).
17//!
18//! Mirrors the policy in [ims-kdks/TIDE](https://github.com/ims-kdks/TIDE)
19//! `LLaDA2MoeSparseMoeBlock`: rank experts by token hits, refresh placement
20//! every τ steps, paired promote/demote to limit PCIe churn.
21//!
22//! Router logits and expert indices are unchanged — placement only.
23
24use std::collections::HashSet;
25
26/// When to re-run hit counting and expert placement.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum ExpertRefreshPolicy {
29    /// Refresh on every forward (τ = 1; Mixtral-Offload-style).
30    EveryForward,
31    /// Autoregressive decode: refresh every N generated tokens / steps.
32    EveryDecodeSteps(usize),
33    /// Diffusion block decode: refresh every N denoise steps within a block
34    /// (`jump_steps` in the TIDE reference repo).
35    EveryDenoiseSteps(usize),
36}
37
38impl Default for ExpertRefreshPolicy {
39    fn default() -> Self {
40        Self::EveryDenoiseSteps(1)
41    }
42}
43
44/// Per-forward hint from the runner (maps to TIDE `refresh_experts`).
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum MoEExecMode {
47    /// Reuse current GPU/CPU placement (`moe_infer`).
48    Reuse,
49    /// Recompute placement from this step's routing (`moe_infer_with_expert_refresh`).
50    Refresh,
51}
52
53/// Configuration for [`ExpertPool`].
54#[derive(Debug, Clone)]
55pub struct ExpertPoolConfig {
56    pub num_experts: usize,
57    /// Max experts resident on the accelerator per MoE layer.
58    pub gpu_budget: usize,
59    pub refresh: ExpertRefreshPolicy,
60}
61
62impl ExpertPoolConfig {
63    pub fn new(num_experts: usize, gpu_budget: usize, refresh: ExpertRefreshPolicy) -> Self {
64        Self {
65            num_experts,
66            gpu_budget: gpu_budget.min(num_experts),
67            refresh,
68        }
69    }
70
71    /// All experts pinned on device (offload disabled).
72    pub fn all_resident(num_experts: usize) -> Self {
73        Self::new(num_experts, num_experts, ExpertRefreshPolicy::EveryForward)
74    }
75}
76
77/// Result of one placement refresh.
78#[derive(Debug, Clone, PartialEq, Eq)]
79pub struct ExpertRefreshResult {
80    pub target_gpu: Vec<usize>,
81    pub promotions: usize,
82    pub demotions: usize,
83}
84
85/// Cumulative counters (TIDE `offload_stats`).
86#[derive(Debug, Clone, Default, PartialEq, Eq)]
87pub struct ExpertPoolStats {
88    pub refreshes: u64,
89    pub promotions: u64,
90    pub demotions: u64,
91}
92
93/// Tracks which logical experts are GPU-resident and applies TIDE placement updates.
94#[derive(Debug, Clone)]
95pub struct ExpertPool {
96    num_experts: usize,
97    gpu_budget: usize,
98    refresh: ExpertRefreshPolicy,
99    resident: HashSet<usize>,
100    /// Steps since last refresh (decode / denoise counter).
101    steps_since_refresh: usize,
102    stats: ExpertPoolStats,
103}
104
105impl ExpertPool {
106    pub fn new(config: ExpertPoolConfig) -> Self {
107        let gpu_budget = config.gpu_budget.min(config.num_experts);
108        let mut resident = HashSet::new();
109        for e in 0..gpu_budget {
110            resident.insert(e);
111        }
112        Self {
113            num_experts: config.num_experts,
114            gpu_budget,
115            refresh: config.refresh,
116            resident,
117            steps_since_refresh: 0,
118            stats: ExpertPoolStats::default(),
119        }
120    }
121
122    pub fn num_experts(&self) -> usize {
123        self.num_experts
124    }
125
126    pub fn gpu_budget(&self) -> usize {
127        self.gpu_budget
128    }
129
130    pub fn refresh_policy(&self) -> ExpertRefreshPolicy {
131        self.refresh
132    }
133
134    pub fn stats(&self) -> &ExpertPoolStats {
135        &self.stats
136    }
137
138    /// TIDE `LLaDA2MoeSparseMoeBlock.reset_stats()` — clear per-step counters before next forward.
139    pub fn reset_step_stats(&mut self) {
140        self.stats = ExpertPoolStats::default();
141    }
142
143    pub fn resident_gpu_experts(&self) -> impl Iterator<Item = usize> + '_ {
144        self.resident.iter().copied()
145    }
146
147    /// Bitmask for [`crate::CompiledGraph::set_moe_resident_experts`].
148    pub fn resident_mask(&self) -> Vec<bool> {
149        (0..self.num_experts)
150            .map(|e| self.resident.contains(&e))
151            .collect()
152    }
153
154    pub fn is_gpu_resident(&self, expert: usize) -> bool {
155        self.resident.contains(&expert)
156    }
157
158    /// Whether offload is active (budget < total experts).
159    pub fn offload_enabled(&self) -> bool {
160        self.gpu_budget < self.num_experts
161    }
162
163    /// TIDE `generate`: `refresh_experts = prefill_block || (offload && step % τ == 0)`.
164    pub fn should_refresh(
165        &self,
166        mode: MoEExecMode,
167        denoise_step: usize,
168        is_prefill_block: bool,
169    ) -> bool {
170        if !self.offload_enabled() {
171            return false;
172        }
173        match mode {
174            MoEExecMode::Refresh => true,
175            MoEExecMode::Reuse => {
176                if is_prefill_block {
177                    return true;
178                }
179                match self.refresh {
180                    ExpertRefreshPolicy::EveryForward => true,
181                    ExpertRefreshPolicy::EveryDecodeSteps(n)
182                    | ExpertRefreshPolicy::EveryDenoiseSteps(n) => {
183                        let interval = n.max(1);
184                        denoise_step.is_multiple_of(interval)
185                    }
186                }
187            }
188        }
189    }
190
191    /// Advance the step counter; returns whether this forward should refresh.
192    pub fn on_forward_step(
193        &mut self,
194        mode: MoEExecMode,
195        denoise_step: usize,
196        is_prefill_block: bool,
197    ) -> bool {
198        let refresh = self.should_refresh(mode, denoise_step, is_prefill_block);
199        if refresh {
200            self.steps_since_refresh = 0;
201        } else {
202            self.steps_since_refresh = self.steps_since_refresh.saturating_add(1);
203        }
204        refresh
205    }
206
207    /// Count token hits per expert from flat or per-token indices (TIDE `bincount`).
208    pub fn count_hits(expert_idx: &[u32], num_experts: usize) -> Vec<u64> {
209        let mut counts = vec![0u64; num_experts];
210        for &e in expert_idx {
211            let e = e as usize;
212            if e < num_experts {
213                counts[e] += 1;
214            }
215        }
216        counts
217    }
218
219    /// Top-`gpu_budget` experts by hit count (TIDE `torch.topk` on bincount).
220    pub fn target_gpu_from_counts(counts: &[u64], gpu_budget: usize) -> Vec<usize> {
221        let mut ranked: Vec<(u64, usize)> = counts
222            .iter()
223            .enumerate()
224            .filter(|&(_, c)| *c > 0)
225            .map(|(e, &c)| (c, e))
226            .collect();
227        ranked.sort_by(|a, b| b.0.cmp(&a.0).then_with(|| a.1.cmp(&b.1)));
228        ranked
229            .into_iter()
230            .take(gpu_budget)
231            .map(|(_, e)| e)
232            .collect()
233    }
234
235    /// TIDE `update_expert_placement` + hit-based target selection.
236    pub fn refresh_from_indices(&mut self, expert_idx: &[u32]) -> ExpertRefreshResult {
237        let counts = Self::count_hits(expert_idx, self.num_experts);
238        let target_order = Self::target_gpu_from_counts(&counts, self.gpu_budget);
239        self.apply_target_placement(&target_order)
240    }
241
242    /// Apply a precomputed target GPU set (paired promote/demote).
243    pub fn apply_target_placement(&mut self, target_order: &[usize]) -> ExpertRefreshResult {
244        let target_set: HashSet<usize> = target_order.iter().copied().collect();
245
246        let to_promote: Vec<usize> = target_order
247            .iter()
248            .copied()
249            .filter(|e| !self.resident.contains(e))
250            .collect();
251        let can_demote: Vec<usize> = self
252            .resident
253            .iter()
254            .copied()
255            .filter(|e| !target_set.contains(e))
256            .collect();
257        let to_demote: Vec<usize> = can_demote.iter().copied().take(to_promote.len()).collect();
258
259        let mut new_resident = target_set;
260        for e in can_demote.iter().skip(to_promote.len()) {
261            new_resident.insert(*e);
262        }
263
264        let promotions = to_promote.len();
265        let demotions = to_demote.len();
266        self.resident = new_resident;
267        self.stats.refreshes += 1;
268        self.stats.promotions += promotions as u64;
269        self.stats.demotions += demotions as u64;
270
271        ExpertRefreshResult {
272            target_gpu: target_order.to_vec(),
273            promotions,
274            demotions,
275        }
276    }
277}
278
279/// Per-layer resident bitmasks (TIDE placement; one row per MoE FFN in forward order).
280pub fn per_layer_resident_masks(pools: &[ExpertPool]) -> Vec<Vec<bool>> {
281    pools.iter().map(|p| p.resident_mask()).collect()
282}
283
284/// Union of GPU-resident experts across per-layer pools (legacy single graph mask).
285pub fn merged_resident_mask(pools: &[ExpertPool]) -> Vec<bool> {
286    let Some(first) = pools.first() else {
287        return Vec::new();
288    };
289    let n = first.num_experts();
290    (0..n)
291        .map(|e| pools.iter().any(|p| p.is_gpu_resident(e)))
292        .collect()
293}
294
295pub fn gpu_expert_budget_from_vram(
296    free_bytes: usize,
297    reserve_bytes: usize,
298    expert_param_bytes: usize,
299    num_moe_layers: usize,
300    max_gpu_experts_per_layer: usize,
301    num_experts: usize,
302) -> usize {
303    if expert_param_bytes == 0 || num_moe_layers == 0 {
304        return max_gpu_experts_per_layer.min(num_experts);
305    }
306    let usable = free_bytes.saturating_sub(reserve_bytes);
307    let per_layer = usable / (expert_param_bytes.saturating_mul(num_moe_layers));
308    per_layer.min(max_gpu_experts_per_layer).min(num_experts)
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[test]
316    fn per_layer_masks_differ_from_merged_union() {
317        let mut p0 = ExpertPool::new(ExpertPoolConfig::new(
318            4,
319            2,
320            ExpertRefreshPolicy::EveryForward,
321        ));
322        let mut p1 = ExpertPool::new(ExpertPoolConfig::new(
323            4,
324            2,
325            ExpertRefreshPolicy::EveryForward,
326        ));
327        p0.refresh_from_indices(&[0, 1]);
328        p1.refresh_from_indices(&[2, 3]);
329        let pools = [p0, p1];
330        let merged = merged_resident_mask(&pools);
331        assert_eq!(merged, vec![true, true, true, true]);
332        let per = per_layer_resident_masks(&pools);
333        assert_eq!(per[0], vec![true, true, false, false]);
334        assert_eq!(per[1], vec![false, false, true, true]);
335    }
336
337    #[test]
338    fn count_hits_matches_bincount() {
339        let idx = [1u32, 0, 1, 2, 1];
340        let c = ExpertPool::count_hits(&idx, 4);
341        assert_eq!(c, [1, 3, 1, 0]);
342    }
343
344    #[test]
345    fn target_gpu_picks_top_by_count() {
346        let counts = [10, 50, 30, 0, 50];
347        let t = ExpertPool::target_gpu_from_counts(&counts, 3);
348        assert_eq!(t, vec![1, 4, 2]); // tie-break: lower expert id first
349    }
350
351    #[test]
352    fn paired_swap_limits_demotions() {
353        let mut pool = ExpertPool::new(ExpertPoolConfig::new(
354            8,
355            2,
356            ExpertRefreshPolicy::EveryForward,
357        ));
358        pool.resident = [0, 1].into_iter().collect();
359        let r = pool.apply_target_placement(&[6, 7]);
360        assert_eq!(r.promotions, 2);
361        assert_eq!(r.demotions, 2);
362        assert_eq!(pool.resident, [6, 7].into_iter().collect::<HashSet<_>>());
363    }
364
365    #[test]
366    fn paired_swap_keeps_extra_residents() {
367        let mut pool = ExpertPool::new(ExpertPoolConfig::new(
368            8,
369            4,
370            ExpertRefreshPolicy::EveryForward,
371        ));
372        pool.resident = [0, 1, 2, 3].into_iter().collect();
373        // Target overlaps heavily — paired demotion leaves one former GPU expert
374        // on device (matches TIDE `can_demote[len(to_promote):]`).
375        let r = pool.apply_target_placement(&[2, 3, 4, 5]);
376        assert_eq!(r.promotions, 2);
377        assert_eq!(r.demotions, 2);
378        assert_eq!(pool.resident.len(), 4);
379        for e in [2, 3, 4, 5] {
380            assert!(pool.is_gpu_resident(e));
381        }
382        assert!(!pool.is_gpu_resident(0));
383    }
384
385    #[test]
386    fn jump_steps_refresh_schedule() {
387        let pool = ExpertPool::new(ExpertPoolConfig::new(
388            256,
389            64,
390            ExpertRefreshPolicy::EveryDenoiseSteps(3),
391        ));
392        assert!(pool.should_refresh(MoEExecMode::Reuse, 0, false));
393        assert!(!pool.should_refresh(MoEExecMode::Reuse, 1, false));
394        assert!(!pool.should_refresh(MoEExecMode::Reuse, 2, false));
395        assert!(pool.should_refresh(MoEExecMode::Reuse, 3, false));
396        assert!(pool.should_refresh(MoEExecMode::Reuse, 0, true)); // prefill block
397    }
398
399    #[test]
400    fn vram_budget_formula() {
401        let b = gpu_expert_budget_from_vram(
402            40 * 1024 * 1024 * 1024,
403            2 * 1024 * 1024 * 1024,
404            50 * 1024 * 1024,
405            20,
406            128,
407            256,
408        );
409        assert!(b > 0 && b <= 128);
410    }
411}