Skip to main content

rlx_llada2/tide/
moe_state.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// Shared TIDE MoE offload state (Qwen3.5 + LLaDA2).
5
6use rlx_runtime::{
7    ExpertPool, ExpertRefreshPolicy, MoeExpertStore, MoeResidencyStats, merged_resident_mask,
8    per_layer_resident_masks,
9};
10
11use super::{PredictiveOffloadInfo, TideOffloadStats, aggregate_offload_stats, refresh_experts};
12
13/// Per-layer TIDE expert pools (one per MoE FFN in trunk order).
14#[derive(Debug)]
15pub struct MoeOffloadState {
16    pub pools: Vec<ExpertPool>,
17    pub refresh: ExpertRefreshPolicy,
18    pub info: PredictiveOffloadInfo,
19    pub predictive_enabled: bool,
20    pub jump_steps: usize,
21    pub collect_stats: bool,
22}
23
24impl MoeOffloadState {
25    pub fn num_layers(&self) -> usize {
26        self.pools.len()
27    }
28
29    pub fn merged_resident_mask(&self) -> Vec<bool> {
30        merged_resident_mask(&self.pools)
31    }
32
33    pub fn per_layer_resident_masks(&self) -> Vec<Vec<bool>> {
34        per_layer_resident_masks(&self.pools)
35    }
36
37    /// AR prefill or block-diffusion prefill block → always refresh; else `step % jump_steps == 0`.
38    pub fn should_refresh_forward(&self, denoise_step: usize, is_prefill_block: bool) -> bool {
39        if !self.predictive_enabled {
40            return false;
41        }
42        if is_prefill_block {
43            return true;
44        }
45        self.pools
46            .first()
47            .is_some_and(|p| p.should_refresh(rlx_runtime::MoEExecMode::Reuse, denoise_step, false))
48    }
49
50    /// Block diffusion: TIDE `generate` refresh line.
51    pub fn should_refresh_block(
52        &self,
53        num_block: usize,
54        prefill_blocks: usize,
55        denoise_step: usize,
56    ) -> bool {
57        if !self.predictive_enabled {
58            return false;
59        }
60        refresh_experts(
61            true,
62            self.jump_steps,
63            num_block,
64            prefill_blocks,
65            denoise_step,
66        )
67    }
68
69    /// Apply captured TopK indices per layer; returns true if any layer refreshed.
70    pub fn refresh_from_capture(
71        &mut self,
72        layer_indices: &[Vec<u32>],
73        denoise_step: usize,
74        is_prefill_block: bool,
75    ) -> bool {
76        let n = self.pools.len().min(layer_indices.len());
77        if n == 0 {
78            return false;
79        }
80        if !self.should_refresh_forward(denoise_step, is_prefill_block) {
81            return false;
82        }
83        for (pool, idx) in self.pools.iter_mut().zip(&layer_indices[..n]) {
84            pool.refresh_from_indices(idx);
85        }
86        true
87    }
88
89    pub fn refresh_from_capture_with_store(
90        &mut self,
91        store: &MoeExpertStore,
92        captured: &[Vec<u32>],
93        denoise_step: usize,
94        is_prefill_block: bool,
95    ) -> bool {
96        if !self.should_refresh_forward(denoise_step, is_prefill_block) {
97            return false;
98        }
99        store.refresh_pools(&mut self.pools, captured, denoise_step, is_prefill_block)
100    }
101
102    pub fn tide_offload_stats(&self, residency: Option<&MoeResidencyStats>) -> TideOffloadStats {
103        aggregate_offload_stats(&self.pools, residency)
104    }
105}