Skip to main content

rlx_cpu/
moe_residency.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Per-forward MoE expert residency mask (TIDE placement) for CPU dispatch.
17//!
18//! Set by [`rlx_runtime::CompiledGraph::set_moe_resident_experts`] before
19//! `run`. [`crate::thunk::GroupedMatMul`] reads the mask for accounting;
20//! numerics still use the full expert stack in the arena (lossless on CPU).
21
22use std::cell::RefCell;
23use std::sync::atomic::{AtomicUsize, Ordering};
24use std::sync::{Arc, RwLock};
25
26/// Per-expert host pointers for one MoE layer (gate/up/down).
27#[derive(Debug, Clone)]
28pub struct LayerHostBind {
29    pub gate: Vec<*const f32>,
30    pub up: Vec<*const f32>,
31    pub down: Vec<*const f32>,
32    pub stride: usize,
33}
34
35/// Host weight lookup installed before forward (TIDE CPU/GPU fallback path).
36#[derive(Debug, Clone)]
37pub struct MoeHostBind {
38    pub layers: Vec<LayerHostBind>,
39}
40
41// Host weight pointers are installed on the calling thread; safe to share
42// across the RwLock used for TIDE residency bookkeeping.
43unsafe impl Send for LayerHostBind {}
44unsafe impl Sync for LayerHostBind {}
45unsafe impl Send for MoeHostBind {}
46unsafe impl Sync for MoeHostBind {}
47
48static HOST_BIND: RwLock<Option<MoeHostBind>> = RwLock::new(None);
49static LAST_STATS: RwLock<Option<MoeResidencyStats>> = RwLock::new(None);
50/// Monotonic GroupedMatMul ordinal within one forward (layer = ord/3, matrix = ord%3).
51static GMM_ORD: AtomicUsize = AtomicUsize::new(0);
52
53#[derive(Debug, Default, Clone)]
54pub struct MoeResidencyStats {
55    pub gpu_expert_calls: u64,
56    pub cpu_expert_calls: u64,
57    pub gpu_tokens: u64,
58    pub cpu_tokens: u64,
59}
60
61struct MoeResidencyCtx {
62    /// Union mask (legacy): expert on device if any layer has it resident.
63    merged: Option<Arc<[bool]>>,
64    /// TIDE per MoE layer (forward order); takes precedence over [`merged`].
65    per_layer: Option<Arc<Vec<Arc<[bool]>>>>,
66    stats: MoeResidencyStats,
67}
68
69thread_local! {
70    static CTX: RefCell<Option<MoeResidencyCtx>> = const { RefCell::new(None) };
71}
72
73/// Install merged (union) residency mask for the current thread until [`clear_mask`].
74pub fn set_mask(mask: Option<Arc<[bool]>>) {
75    CTX.with(|c| {
76        *c.borrow_mut() = Some(MoeResidencyCtx {
77            merged: mask,
78            per_layer: None,
79            stats: MoeResidencyStats::default(),
80        });
81    });
82}
83
84/// Install per-layer masks (one row per MoE FFN in forward order).
85pub fn set_per_layer_masks(layers: Option<Arc<Vec<Arc<[bool]>>>>) {
86    CTX.with(|c| {
87        *c.borrow_mut() = Some(MoeResidencyCtx {
88            merged: None,
89            per_layer: layers,
90            stats: MoeResidencyStats::default(),
91        });
92    });
93}
94
95pub fn clear_mask() {
96    CTX.with(|c| *c.borrow_mut() = None);
97}
98
99pub fn bind_host_weights(bind: Option<MoeHostBind>) {
100    *HOST_BIND.write().unwrap() = bind;
101}
102
103pub fn reset_gmm_counters() {
104    GMM_ORD.store(0, Ordering::Relaxed);
105}
106
107/// Next MoE GroupedMatMul ordinal for this forward (call once per kernel).
108pub fn next_gmm_ord() -> usize {
109    GMM_ORD.fetch_add(1, Ordering::Relaxed)
110}
111
112/// Host expert weight pointer for `ord` (gate/up/down = ord%3, layer = ord/3).
113pub fn host_expert_weight_ptr(ord: usize, expert: usize) -> Option<*const f32> {
114    let bind = HOST_BIND.read().unwrap();
115    let bind = bind.as_ref()?;
116    let layer = bind.layers.get(ord / 3)?;
117    let ptrs = match ord % 3 {
118        0 => &layer.gate,
119        1 => &layer.up,
120        _ => &layer.down,
121    };
122    ptrs.get(expert).copied()
123}
124
125pub fn peek_stats() -> Option<MoeResidencyStats> {
126    CTX.with(|c| c.borrow().as_ref().map(|ctx| ctx.stats.clone()))
127}
128
129/// Stats from the most recent forward on this thread (set when the residency guard drops).
130pub fn take_last_forward_stats() -> Option<MoeResidencyStats> {
131    LAST_STATS.write().unwrap().take()
132}
133
134pub(crate) fn stash_last_forward_stats(stats: MoeResidencyStats) {
135    *LAST_STATS.write().unwrap() = Some(stats);
136}
137
138fn expert_on_device_inner(ctx: &MoeResidencyCtx, layer: Option<usize>, e: usize) -> bool {
139    if let Some(layers) = ctx.per_layer.as_ref() {
140        if let Some(li) = layer {
141            return layers
142                .get(li)
143                .and_then(|m| m.get(e).copied())
144                .unwrap_or(true);
145        }
146    }
147    ctx.merged
148        .as_ref()
149        .and_then(|m| m.get(e).copied())
150        .unwrap_or(true)
151}
152
153/// True when expert `e` is GPU-resident for MoE layer `layer` (GMM ord / 3).
154pub fn expert_on_device_for_layer(layer: usize, e: usize) -> bool {
155    CTX.with(|c| {
156        let borrow = c.borrow();
157        let Some(ctx) = borrow.as_ref() else {
158            return true;
159        };
160        expert_on_device_inner(ctx, Some(layer), e)
161    })
162}
163
164/// True when expert `e` is marked GPU-resident (merged mask if no per-layer table).
165pub fn expert_on_device(e: usize) -> bool {
166    expert_on_device_for_layer(0, e)
167}
168
169pub fn record_expert_tokens(layer: usize, e: usize, num_tokens: usize) {
170    if num_tokens == 0 {
171        return;
172    }
173    CTX.with(|c| {
174        let mut borrow = c.borrow_mut();
175        let Some(ctx) = borrow.as_mut() else {
176            return;
177        };
178        let on_device = expert_on_device_inner(ctx, Some(layer), e);
179        if on_device {
180            ctx.stats.gpu_expert_calls += 1;
181            ctx.stats.gpu_tokens += num_tokens as u64;
182        } else {
183            ctx.stats.cpu_expert_calls += 1;
184            ctx.stats.cpu_tokens += num_tokens as u64;
185        }
186    });
187}
188
189/// Take stats and clear the thread-local context.
190pub fn take_stats() -> Option<MoeResidencyStats> {
191    CTX.with(|c| c.borrow_mut().take().map(|ctx| ctx.stats))
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use std::sync::Arc;
198
199    #[test]
200    fn per_layer_masks_are_layer_local() {
201        let per = Arc::new(vec![
202            Arc::from([false, true, true, true]),
203            Arc::from([true, false, true, true]),
204        ]);
205        set_per_layer_masks(Some(per));
206        assert!(!expert_on_device_for_layer(0, 0));
207        assert!(expert_on_device_for_layer(0, 1));
208        assert!(expert_on_device_for_layer(1, 0));
209        assert!(!expert_on_device_for_layer(1, 1));
210        clear_mask();
211    }
212}