Skip to main content

rlx_cpu/
moe_topk_capture.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Capture MoE router [`Op::TopK`] outputs during CPU forward (TIDE refresh input).
17
18use std::sync::{Arc, Mutex};
19
20/// Shared capture buffer — one entry per MoE router TopK in schedule order.
21#[derive(Debug)]
22pub struct MoeTopkCapture {
23    pub num_experts: usize,
24    layers: Mutex<Vec<Vec<u32>>>,
25}
26
27impl MoeTopkCapture {
28    pub fn new(num_experts: usize) -> Arc<Self> {
29        Arc::new(Self {
30            num_experts,
31            layers: Mutex::new(Vec::new()),
32        })
33    }
34
35    pub fn clear(&self) {
36        self.layers.lock().unwrap().clear();
37    }
38
39    /// Record one router TopK output (`outer * k` f32-encoded expert ids).
40    pub fn push_topk_f32(&self, data: &[f32], axis_dim: usize) {
41        if axis_dim != self.num_experts {
42            return;
43        }
44        let flat: Vec<u32> = data.iter().map(|&v| v as u32).collect();
45        self.layers.lock().unwrap().push(flat);
46    }
47
48    pub fn take_layers(&self) -> Vec<Vec<u32>> {
49        std::mem::take(&mut *self.layers.lock().unwrap())
50    }
51}