rlx_cpu/moe_topk_capture.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Capture MoE router [`Op::TopK`] outputs during CPU forward (TIDE refresh input).
17
18use std::sync::{Arc, Mutex};
19
20/// Shared capture buffer — one entry per MoE router TopK in schedule order.
21#[derive(Debug)]
22pub struct MoeTopkCapture {
23 pub num_experts: usize,
24 layers: Mutex<Vec<Vec<u32>>>,
25}
26
27impl MoeTopkCapture {
28 pub fn new(num_experts: usize) -> Arc<Self> {
29 Arc::new(Self {
30 num_experts,
31 layers: Mutex::new(Vec::new()),
32 })
33 }
34
35 pub fn clear(&self) {
36 self.layers.lock().unwrap().clear();
37 }
38
39 /// Record one router TopK output (`outer * k` f32-encoded expert ids).
40 pub fn push_topk_f32(&self, data: &[f32], axis_dim: usize) {
41 if axis_dim != self.num_experts {
42 return;
43 }
44 let flat: Vec<u32> = data.iter().map(|&v| v as u32).collect();
45 self.layers.lock().unwrap().push(flat);
46 }
47
48 pub fn take_layers(&self) -> Vec<Vec<u32>> {
49 std::mem::take(&mut *self.layers.lock().unwrap())
50 }
51}