1use std::collections::HashSet;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum ExpertRefreshPolicy {
29 EveryForward,
31 EveryDecodeSteps(usize),
33 EveryDenoiseSteps(usize),
36}
37
38impl Default for ExpertRefreshPolicy {
39 fn default() -> Self {
40 Self::EveryDenoiseSteps(1)
41 }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum MoEExecMode {
47 Reuse,
49 Refresh,
51}
52
53#[derive(Debug, Clone)]
55pub struct ExpertPoolConfig {
56 pub num_experts: usize,
57 pub gpu_budget: usize,
59 pub refresh: ExpertRefreshPolicy,
60}
61
62impl ExpertPoolConfig {
63 pub fn new(num_experts: usize, gpu_budget: usize, refresh: ExpertRefreshPolicy) -> Self {
64 Self {
65 num_experts,
66 gpu_budget: gpu_budget.min(num_experts),
67 refresh,
68 }
69 }
70
71 pub fn all_resident(num_experts: usize) -> Self {
73 Self::new(num_experts, num_experts, ExpertRefreshPolicy::EveryForward)
74 }
75}
76
77#[derive(Debug, Clone, PartialEq, Eq)]
79pub struct ExpertRefreshResult {
80 pub target_gpu: Vec<usize>,
81 pub promotions: usize,
82 pub demotions: usize,
83}
84
85#[derive(Debug, Clone, Default, PartialEq, Eq)]
87pub struct ExpertPoolStats {
88 pub refreshes: u64,
89 pub promotions: u64,
90 pub demotions: u64,
91}
92
93#[derive(Debug, Clone)]
95pub struct ExpertPool {
96 num_experts: usize,
97 gpu_budget: usize,
98 refresh: ExpertRefreshPolicy,
99 resident: HashSet<usize>,
100 steps_since_refresh: usize,
102 stats: ExpertPoolStats,
103}
104
105impl ExpertPool {
106 pub fn new(config: ExpertPoolConfig) -> Self {
107 let gpu_budget = config.gpu_budget.min(config.num_experts);
108 let mut resident = HashSet::new();
109 for e in 0..gpu_budget {
110 resident.insert(e);
111 }
112 Self {
113 num_experts: config.num_experts,
114 gpu_budget,
115 refresh: config.refresh,
116 resident,
117 steps_since_refresh: 0,
118 stats: ExpertPoolStats::default(),
119 }
120 }
121
122 pub fn num_experts(&self) -> usize {
123 self.num_experts
124 }
125
126 pub fn gpu_budget(&self) -> usize {
127 self.gpu_budget
128 }
129
130 pub fn refresh_policy(&self) -> ExpertRefreshPolicy {
131 self.refresh
132 }
133
134 pub fn stats(&self) -> &ExpertPoolStats {
135 &self.stats
136 }
137
138 pub fn reset_step_stats(&mut self) {
140 self.stats = ExpertPoolStats::default();
141 }
142
143 pub fn resident_gpu_experts(&self) -> impl Iterator<Item = usize> + '_ {
144 self.resident.iter().copied()
145 }
146
147 pub fn resident_mask(&self) -> Vec<bool> {
149 (0..self.num_experts)
150 .map(|e| self.resident.contains(&e))
151 .collect()
152 }
153
154 pub fn is_gpu_resident(&self, expert: usize) -> bool {
155 self.resident.contains(&expert)
156 }
157
158 pub fn offload_enabled(&self) -> bool {
160 self.gpu_budget < self.num_experts
161 }
162
163 pub fn should_refresh(
165 &self,
166 mode: MoEExecMode,
167 denoise_step: usize,
168 is_prefill_block: bool,
169 ) -> bool {
170 if !self.offload_enabled() {
171 return false;
172 }
173 match mode {
174 MoEExecMode::Refresh => true,
175 MoEExecMode::Reuse => {
176 if is_prefill_block {
177 return true;
178 }
179 match self.refresh {
180 ExpertRefreshPolicy::EveryForward => true,
181 ExpertRefreshPolicy::EveryDecodeSteps(n)
182 | ExpertRefreshPolicy::EveryDenoiseSteps(n) => {
183 let interval = n.max(1);
184 denoise_step.is_multiple_of(interval)
185 }
186 }
187 }
188 }
189 }
190
191 pub fn on_forward_step(
193 &mut self,
194 mode: MoEExecMode,
195 denoise_step: usize,
196 is_prefill_block: bool,
197 ) -> bool {
198 let refresh = self.should_refresh(mode, denoise_step, is_prefill_block);
199 if refresh {
200 self.steps_since_refresh = 0;
201 } else {
202 self.steps_since_refresh = self.steps_since_refresh.saturating_add(1);
203 }
204 refresh
205 }
206
207 pub fn count_hits(expert_idx: &[u32], num_experts: usize) -> Vec<u64> {
209 let mut counts = vec![0u64; num_experts];
210 for &e in expert_idx {
211 let e = e as usize;
212 if e < num_experts {
213 counts[e] += 1;
214 }
215 }
216 counts
217 }
218
219 pub fn target_gpu_from_counts(counts: &[u64], gpu_budget: usize) -> Vec<usize> {
221 let mut ranked: Vec<(u64, usize)> = counts
222 .iter()
223 .enumerate()
224 .filter(|&(_, c)| *c > 0)
225 .map(|(e, &c)| (c, e))
226 .collect();
227 ranked.sort_by(|a, b| b.0.cmp(&a.0).then_with(|| a.1.cmp(&b.1)));
228 ranked
229 .into_iter()
230 .take(gpu_budget)
231 .map(|(_, e)| e)
232 .collect()
233 }
234
235 pub fn refresh_from_indices(&mut self, expert_idx: &[u32]) -> ExpertRefreshResult {
237 let counts = Self::count_hits(expert_idx, self.num_experts);
238 let target_order = Self::target_gpu_from_counts(&counts, self.gpu_budget);
239 self.apply_target_placement(&target_order)
240 }
241
242 pub fn apply_target_placement(&mut self, target_order: &[usize]) -> ExpertRefreshResult {
244 let target_set: HashSet<usize> = target_order.iter().copied().collect();
245
246 let to_promote: Vec<usize> = target_order
247 .iter()
248 .copied()
249 .filter(|e| !self.resident.contains(e))
250 .collect();
251 let can_demote: Vec<usize> = self
252 .resident
253 .iter()
254 .copied()
255 .filter(|e| !target_set.contains(e))
256 .collect();
257 let to_demote: Vec<usize> = can_demote.iter().copied().take(to_promote.len()).collect();
258
259 let mut new_resident = target_set;
260 for e in can_demote.iter().skip(to_promote.len()) {
261 new_resident.insert(*e);
262 }
263
264 let promotions = to_promote.len();
265 let demotions = to_demote.len();
266 self.resident = new_resident;
267 self.stats.refreshes += 1;
268 self.stats.promotions += promotions as u64;
269 self.stats.demotions += demotions as u64;
270
271 ExpertRefreshResult {
272 target_gpu: target_order.to_vec(),
273 promotions,
274 demotions,
275 }
276 }
277}
278
279pub fn per_layer_resident_masks(pools: &[ExpertPool]) -> Vec<Vec<bool>> {
281 pools.iter().map(|p| p.resident_mask()).collect()
282}
283
284pub fn merged_resident_mask(pools: &[ExpertPool]) -> Vec<bool> {
286 let Some(first) = pools.first() else {
287 return Vec::new();
288 };
289 let n = first.num_experts();
290 (0..n)
291 .map(|e| pools.iter().any(|p| p.is_gpu_resident(e)))
292 .collect()
293}
294
295pub fn gpu_expert_budget_from_vram(
296 free_bytes: usize,
297 reserve_bytes: usize,
298 expert_param_bytes: usize,
299 num_moe_layers: usize,
300 max_gpu_experts_per_layer: usize,
301 num_experts: usize,
302) -> usize {
303 if expert_param_bytes == 0 || num_moe_layers == 0 {
304 return max_gpu_experts_per_layer.min(num_experts);
305 }
306 let usable = free_bytes.saturating_sub(reserve_bytes);
307 let per_layer = usable / (expert_param_bytes.saturating_mul(num_moe_layers));
308 per_layer.min(max_gpu_experts_per_layer).min(num_experts)
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 #[test]
316 fn per_layer_masks_differ_from_merged_union() {
317 let mut p0 = ExpertPool::new(ExpertPoolConfig::new(
318 4,
319 2,
320 ExpertRefreshPolicy::EveryForward,
321 ));
322 let mut p1 = ExpertPool::new(ExpertPoolConfig::new(
323 4,
324 2,
325 ExpertRefreshPolicy::EveryForward,
326 ));
327 p0.refresh_from_indices(&[0, 1]);
328 p1.refresh_from_indices(&[2, 3]);
329 let pools = [p0, p1];
330 let merged = merged_resident_mask(&pools);
331 assert_eq!(merged, vec![true, true, true, true]);
332 let per = per_layer_resident_masks(&pools);
333 assert_eq!(per[0], vec![true, true, false, false]);
334 assert_eq!(per[1], vec![false, false, true, true]);
335 }
336
337 #[test]
338 fn count_hits_matches_bincount() {
339 let idx = [1u32, 0, 1, 2, 1];
340 let c = ExpertPool::count_hits(&idx, 4);
341 assert_eq!(c, [1, 3, 1, 0]);
342 }
343
344 #[test]
345 fn target_gpu_picks_top_by_count() {
346 let counts = [10, 50, 30, 0, 50];
347 let t = ExpertPool::target_gpu_from_counts(&counts, 3);
348 assert_eq!(t, vec![1, 4, 2]); }
350
351 #[test]
352 fn paired_swap_limits_demotions() {
353 let mut pool = ExpertPool::new(ExpertPoolConfig::new(
354 8,
355 2,
356 ExpertRefreshPolicy::EveryForward,
357 ));
358 pool.resident = [0, 1].into_iter().collect();
359 let r = pool.apply_target_placement(&[6, 7]);
360 assert_eq!(r.promotions, 2);
361 assert_eq!(r.demotions, 2);
362 assert_eq!(pool.resident, [6, 7].into_iter().collect::<HashSet<_>>());
363 }
364
365 #[test]
366 fn paired_swap_keeps_extra_residents() {
367 let mut pool = ExpertPool::new(ExpertPoolConfig::new(
368 8,
369 4,
370 ExpertRefreshPolicy::EveryForward,
371 ));
372 pool.resident = [0, 1, 2, 3].into_iter().collect();
373 let r = pool.apply_target_placement(&[2, 3, 4, 5]);
376 assert_eq!(r.promotions, 2);
377 assert_eq!(r.demotions, 2);
378 assert_eq!(pool.resident.len(), 4);
379 for e in [2, 3, 4, 5] {
380 assert!(pool.is_gpu_resident(e));
381 }
382 assert!(!pool.is_gpu_resident(0));
383 }
384
385 #[test]
386 fn jump_steps_refresh_schedule() {
387 let pool = ExpertPool::new(ExpertPoolConfig::new(
388 256,
389 64,
390 ExpertRefreshPolicy::EveryDenoiseSteps(3),
391 ));
392 assert!(pool.should_refresh(MoEExecMode::Reuse, 0, false));
393 assert!(!pool.should_refresh(MoEExecMode::Reuse, 1, false));
394 assert!(!pool.should_refresh(MoEExecMode::Reuse, 2, false));
395 assert!(pool.should_refresh(MoEExecMode::Reuse, 3, false));
396 assert!(pool.should_refresh(MoEExecMode::Reuse, 0, true)); }
398
399 #[test]
400 fn vram_budget_formula() {
401 let b = gpu_expert_budget_from_vram(
402 40 * 1024 * 1024 * 1024,
403 2 * 1024 * 1024 * 1024,
404 50 * 1024 * 1024,
405 20,
406 128,
407 256,
408 );
409 assert!(b > 0 && b <= 128);
410 }
411}