use std::cell::RefCell;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone)]
pub struct LayerHostBind {
pub gate: Vec<*const f32>,
pub up: Vec<*const f32>,
pub down: Vec<*const f32>,
pub stride: usize,
}
#[derive(Debug, Clone)]
pub struct MoeHostBind {
pub layers: Vec<LayerHostBind>,
}
unsafe impl Send for LayerHostBind {}
unsafe impl Sync for LayerHostBind {}
unsafe impl Send for MoeHostBind {}
unsafe impl Sync for MoeHostBind {}
static HOST_BIND: RwLock<Option<MoeHostBind>> = RwLock::new(None);
static LAST_STATS: RwLock<Option<MoeResidencyStats>> = RwLock::new(None);
static GMM_ORD: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug, Default, Clone)]
pub struct MoeResidencyStats {
pub gpu_expert_calls: u64,
pub cpu_expert_calls: u64,
pub gpu_tokens: u64,
pub cpu_tokens: u64,
}
struct MoeResidencyCtx {
merged: Option<Arc<[bool]>>,
per_layer: Option<Arc<Vec<Arc<[bool]>>>>,
stats: MoeResidencyStats,
}
thread_local! {
static CTX: RefCell<Option<MoeResidencyCtx>> = const { RefCell::new(None) };
}
pub fn set_mask(mask: Option<Arc<[bool]>>) {
CTX.with(|c| {
*c.borrow_mut() = Some(MoeResidencyCtx {
merged: mask,
per_layer: None,
stats: MoeResidencyStats::default(),
});
});
}
pub fn set_per_layer_masks(layers: Option<Arc<Vec<Arc<[bool]>>>>) {
CTX.with(|c| {
*c.borrow_mut() = Some(MoeResidencyCtx {
merged: None,
per_layer: layers,
stats: MoeResidencyStats::default(),
});
});
}
pub fn clear_mask() {
CTX.with(|c| *c.borrow_mut() = None);
}
pub fn bind_host_weights(bind: Option<MoeHostBind>) {
*HOST_BIND.write().unwrap() = bind;
}
pub fn reset_gmm_counters() {
GMM_ORD.store(0, Ordering::Relaxed);
}
pub fn next_gmm_ord() -> usize {
GMM_ORD.fetch_add(1, Ordering::Relaxed)
}
pub fn host_expert_weight_ptr(ord: usize, expert: usize) -> Option<*const f32> {
let bind = HOST_BIND.read().unwrap();
let bind = bind.as_ref()?;
let layer = bind.layers.get(ord / 3)?;
let ptrs = match ord % 3 {
0 => &layer.gate,
1 => &layer.up,
_ => &layer.down,
};
ptrs.get(expert).copied()
}
pub fn peek_stats() -> Option<MoeResidencyStats> {
CTX.with(|c| c.borrow().as_ref().map(|ctx| ctx.stats.clone()))
}
pub fn take_last_forward_stats() -> Option<MoeResidencyStats> {
LAST_STATS.write().unwrap().take()
}
pub(crate) fn stash_last_forward_stats(stats: MoeResidencyStats) {
*LAST_STATS.write().unwrap() = Some(stats);
}
fn expert_on_device_inner(ctx: &MoeResidencyCtx, layer: Option<usize>, e: usize) -> bool {
if let Some(layers) = ctx.per_layer.as_ref() {
if let Some(li) = layer {
return layers
.get(li)
.and_then(|m| m.get(e).copied())
.unwrap_or(true);
}
}
ctx.merged
.as_ref()
.and_then(|m| m.get(e).copied())
.unwrap_or(true)
}
pub fn expert_on_device_for_layer(layer: usize, e: usize) -> bool {
CTX.with(|c| {
let borrow = c.borrow();
let Some(ctx) = borrow.as_ref() else {
return true;
};
expert_on_device_inner(ctx, Some(layer), e)
})
}
pub fn expert_on_device(e: usize) -> bool {
expert_on_device_for_layer(0, e)
}
pub fn record_expert_tokens(layer: usize, e: usize, num_tokens: usize) {
if num_tokens == 0 {
return;
}
CTX.with(|c| {
let mut borrow = c.borrow_mut();
let Some(ctx) = borrow.as_mut() else {
return;
};
let on_device = expert_on_device_inner(ctx, Some(layer), e);
if on_device {
ctx.stats.gpu_expert_calls += 1;
ctx.stats.gpu_tokens += num_tokens as u64;
} else {
ctx.stats.cpu_expert_calls += 1;
ctx.stats.cpu_tokens += num_tokens as u64;
}
});
}
pub fn take_stats() -> Option<MoeResidencyStats> {
CTX.with(|c| c.borrow_mut().take().map(|ctx| ctx.stats))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn per_layer_masks_are_layer_local() {
let per = Arc::new(vec![
Arc::from([false, true, true, true]),
Arc::from([true, false, true, true]),
]);
set_per_layer_masks(Some(per));
assert!(!expert_on_device_for_layer(0, 0));
assert!(expert_on_device_for_layer(0, 1));
assert!(expert_on_device_for_layer(1, 0));
assert!(!expert_on_device_for_layer(1, 1));
clear_mask();
}
}