rlx_llada2/tide/
moe_state.rs1use rlx_runtime::{
7 ExpertPool, ExpertRefreshPolicy, MoeExpertStore, MoeResidencyStats, merged_resident_mask,
8 per_layer_resident_masks,
9};
10
11use super::{PredictiveOffloadInfo, TideOffloadStats, aggregate_offload_stats, refresh_experts};
12
13#[derive(Debug)]
15pub struct MoeOffloadState {
16 pub pools: Vec<ExpertPool>,
17 pub refresh: ExpertRefreshPolicy,
18 pub info: PredictiveOffloadInfo,
19 pub predictive_enabled: bool,
20 pub jump_steps: usize,
21 pub collect_stats: bool,
22}
23
24impl MoeOffloadState {
25 pub fn num_layers(&self) -> usize {
26 self.pools.len()
27 }
28
29 pub fn merged_resident_mask(&self) -> Vec<bool> {
30 merged_resident_mask(&self.pools)
31 }
32
33 pub fn per_layer_resident_masks(&self) -> Vec<Vec<bool>> {
34 per_layer_resident_masks(&self.pools)
35 }
36
37 pub fn should_refresh_forward(&self, denoise_step: usize, is_prefill_block: bool) -> bool {
39 if !self.predictive_enabled {
40 return false;
41 }
42 if is_prefill_block {
43 return true;
44 }
45 self.pools
46 .first()
47 .is_some_and(|p| p.should_refresh(rlx_runtime::MoEExecMode::Reuse, denoise_step, false))
48 }
49
50 pub fn should_refresh_block(
52 &self,
53 num_block: usize,
54 prefill_blocks: usize,
55 denoise_step: usize,
56 ) -> bool {
57 if !self.predictive_enabled {
58 return false;
59 }
60 refresh_experts(
61 true,
62 self.jump_steps,
63 num_block,
64 prefill_blocks,
65 denoise_step,
66 )
67 }
68
69 pub fn refresh_from_capture(
71 &mut self,
72 layer_indices: &[Vec<u32>],
73 denoise_step: usize,
74 is_prefill_block: bool,
75 ) -> bool {
76 let n = self.pools.len().min(layer_indices.len());
77 if n == 0 {
78 return false;
79 }
80 if !self.should_refresh_forward(denoise_step, is_prefill_block) {
81 return false;
82 }
83 for (pool, idx) in self.pools.iter_mut().zip(&layer_indices[..n]) {
84 pool.refresh_from_indices(idx);
85 }
86 true
87 }
88
89 pub fn refresh_from_capture_with_store(
90 &mut self,
91 store: &MoeExpertStore,
92 captured: &[Vec<u32>],
93 denoise_step: usize,
94 is_prefill_block: bool,
95 ) -> bool {
96 if !self.should_refresh_forward(denoise_step, is_prefill_block) {
97 return false;
98 }
99 store.refresh_pools(&mut self.pools, captured, denoise_step, is_prefill_block)
100 }
101
102 pub fn tide_offload_stats(&self, residency: Option<&MoeResidencyStats>) -> TideOffloadStats {
103 aggregate_offload_stats(&self.pools, residency)
104 }
105}