1use std::cell::RefCell;
23use std::sync::atomic::{AtomicUsize, Ordering};
24use std::sync::{Arc, RwLock};
25
26#[derive(Debug, Clone)]
28pub struct LayerHostBind {
29 pub gate: Vec<*const f32>,
30 pub up: Vec<*const f32>,
31 pub down: Vec<*const f32>,
32 pub stride: usize,
33}
34
35#[derive(Debug, Clone)]
37pub struct MoeHostBind {
38 pub layers: Vec<LayerHostBind>,
39}
40
41unsafe impl Send for LayerHostBind {}
44unsafe impl Sync for LayerHostBind {}
45unsafe impl Send for MoeHostBind {}
46unsafe impl Sync for MoeHostBind {}
47
48static HOST_BIND: RwLock<Option<MoeHostBind>> = RwLock::new(None);
49static LAST_STATS: RwLock<Option<MoeResidencyStats>> = RwLock::new(None);
50static GMM_ORD: AtomicUsize = AtomicUsize::new(0);
52
53#[derive(Debug, Default, Clone)]
54pub struct MoeResidencyStats {
55 pub gpu_expert_calls: u64,
56 pub cpu_expert_calls: u64,
57 pub gpu_tokens: u64,
58 pub cpu_tokens: u64,
59}
60
61struct MoeResidencyCtx {
62 merged: Option<Arc<[bool]>>,
64 per_layer: Option<Arc<Vec<Arc<[bool]>>>>,
66 stats: MoeResidencyStats,
67}
68
69thread_local! {
70 static CTX: RefCell<Option<MoeResidencyCtx>> = const { RefCell::new(None) };
71}
72
73pub fn set_mask(mask: Option<Arc<[bool]>>) {
75 CTX.with(|c| {
76 *c.borrow_mut() = Some(MoeResidencyCtx {
77 merged: mask,
78 per_layer: None,
79 stats: MoeResidencyStats::default(),
80 });
81 });
82}
83
84pub fn set_per_layer_masks(layers: Option<Arc<Vec<Arc<[bool]>>>>) {
86 CTX.with(|c| {
87 *c.borrow_mut() = Some(MoeResidencyCtx {
88 merged: None,
89 per_layer: layers,
90 stats: MoeResidencyStats::default(),
91 });
92 });
93}
94
95pub fn clear_mask() {
96 CTX.with(|c| *c.borrow_mut() = None);
97}
98
99pub fn bind_host_weights(bind: Option<MoeHostBind>) {
100 *HOST_BIND.write().unwrap() = bind;
101}
102
103pub fn reset_gmm_counters() {
104 GMM_ORD.store(0, Ordering::Relaxed);
105}
106
107pub fn next_gmm_ord() -> usize {
109 GMM_ORD.fetch_add(1, Ordering::Relaxed)
110}
111
112pub fn host_expert_weight_ptr(ord: usize, expert: usize) -> Option<*const f32> {
114 let bind = HOST_BIND.read().unwrap();
115 let bind = bind.as_ref()?;
116 let layer = bind.layers.get(ord / 3)?;
117 let ptrs = match ord % 3 {
118 0 => &layer.gate,
119 1 => &layer.up,
120 _ => &layer.down,
121 };
122 ptrs.get(expert).copied()
123}
124
125pub fn peek_stats() -> Option<MoeResidencyStats> {
126 CTX.with(|c| c.borrow().as_ref().map(|ctx| ctx.stats.clone()))
127}
128
129pub fn take_last_forward_stats() -> Option<MoeResidencyStats> {
131 LAST_STATS.write().unwrap().take()
132}
133
134pub(crate) fn stash_last_forward_stats(stats: MoeResidencyStats) {
135 *LAST_STATS.write().unwrap() = Some(stats);
136}
137
138fn expert_on_device_inner(ctx: &MoeResidencyCtx, layer: Option<usize>, e: usize) -> bool {
139 if let Some(layers) = ctx.per_layer.as_ref() {
140 if let Some(li) = layer {
141 return layers
142 .get(li)
143 .and_then(|m| m.get(e).copied())
144 .unwrap_or(true);
145 }
146 }
147 ctx.merged
148 .as_ref()
149 .and_then(|m| m.get(e).copied())
150 .unwrap_or(true)
151}
152
153pub fn expert_on_device_for_layer(layer: usize, e: usize) -> bool {
155 CTX.with(|c| {
156 let borrow = c.borrow();
157 let Some(ctx) = borrow.as_ref() else {
158 return true;
159 };
160 expert_on_device_inner(ctx, Some(layer), e)
161 })
162}
163
164pub fn expert_on_device(e: usize) -> bool {
166 expert_on_device_for_layer(0, e)
167}
168
169pub fn record_expert_tokens(layer: usize, e: usize, num_tokens: usize) {
170 if num_tokens == 0 {
171 return;
172 }
173 CTX.with(|c| {
174 let mut borrow = c.borrow_mut();
175 let Some(ctx) = borrow.as_mut() else {
176 return;
177 };
178 let on_device = expert_on_device_inner(ctx, Some(layer), e);
179 if on_device {
180 ctx.stats.gpu_expert_calls += 1;
181 ctx.stats.gpu_tokens += num_tokens as u64;
182 } else {
183 ctx.stats.cpu_expert_calls += 1;
184 ctx.stats.cpu_tokens += num_tokens as u64;
185 }
186 });
187}
188
189pub fn take_stats() -> Option<MoeResidencyStats> {
191 CTX.with(|c| c.borrow_mut().take().map(|ctx| ctx.stats))
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use std::sync::Arc;
198
199 #[test]
200 fn per_layer_masks_are_layer_local() {
201 let per = Arc::new(vec![
202 Arc::from([false, true, true, true]),
203 Arc::from([true, false, true, true]),
204 ]);
205 set_per_layer_masks(Some(per));
206 assert!(!expert_on_device_for_layer(0, 0));
207 assert!(expert_on_device_for_layer(0, 1));
208 assert!(expert_on_device_for_layer(1, 0));
209 assert!(!expert_on_device_for_layer(1, 1));
210 clear_mask();
211 }
212}