Skip to main content

rlx_cpu/
llada2_gate.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15// RLX — LLaDA2 group-limited MoE gate (shared CPU reference for all backends).
16
17/// Group-limited TopK (TIDE `LLaDA2MoeGate.group_limited_topk`).
18pub fn group_limited_topk(
19    scores: &[f32],
20    num_tokens: usize,
21    num_experts: usize,
22    n_group: usize,
23    topk_group: usize,
24    top_k: usize,
25) -> (Vec<f32>, Vec<u32>) {
26    let epg = num_experts / n_group;
27    let mut probs = Vec::with_capacity(num_tokens * top_k);
28    let mut indices = Vec::with_capacity(num_tokens * top_k);
29    for t in 0..num_tokens {
30        let row = &scores[t * num_experts..(t + 1) * num_experts];
31        let mut group_scores = vec![0f32; n_group];
32        for g in 0..n_group {
33            let base = g * epg;
34            let slice = &row[base..base + epg];
35            let mut top2 = [f32::NEG_INFINITY; 2];
36            for &v in slice {
37                if v > top2[0] {
38                    top2[1] = top2[0];
39                    top2[0] = v;
40                } else if v > top2[1] {
41                    top2[1] = v;
42                }
43            }
44            group_scores[g] = top2[0] + top2[1];
45        }
46        let mut group_order: Vec<usize> = (0..n_group).collect();
47        group_order.sort_by(|&a, &b| {
48            group_scores[b]
49                .partial_cmp(&group_scores[a])
50                .unwrap_or(std::cmp::Ordering::Equal)
51        });
52        let selected: std::collections::HashSet<usize> =
53            group_order.into_iter().take(topk_group).collect();
54        let mut masked = vec![f32::NEG_INFINITY; num_experts];
55        for g in selected {
56            let base = g * epg;
57            masked[base..base + epg].copy_from_slice(&row[base..base + epg]);
58        }
59        let mut order: Vec<usize> = (0..num_experts).collect();
60        order.sort_by(|&a, &b| {
61            masked[b]
62                .partial_cmp(&masked[a])
63                .unwrap_or(std::cmp::Ordering::Equal)
64        });
65        let mut picked_scores = Vec::with_capacity(top_k);
66        let mut picked_idx = Vec::with_capacity(top_k);
67        for &ei in order.iter().take(top_k) {
68            picked_scores.push(row[ei]);
69            picked_idx.push(ei as u32);
70        }
71        let sum: f32 = picked_scores.iter().sum::<f32>() + 1e-20;
72        let scale = if top_k > 1 { 1.0 / sum } else { 1.0 };
73        for (p, &ei) in picked_scores.iter().zip(&picked_idx) {
74            probs.push(p * scale);
75            indices.push(ei);
76        }
77    }
78    (probs, indices)
79}
80
81#[derive(Clone, Copy)]
82pub struct GateAttrs {
83    pub n_group: u32,
84    pub topk_group: u32,
85    pub top_k: u32,
86    pub routed_scaling: f32,
87    pub num_experts: u32,
88}
89
90impl GateAttrs {
91    pub fn from_bytes(attrs: &[u8]) -> Self {
92        if attrs.len() >= 20 {
93            let n_group = u32::from_le_bytes(attrs[0..4].try_into().unwrap());
94            let topk_group = u32::from_le_bytes(attrs[4..8].try_into().unwrap());
95            let top_k = u32::from_le_bytes(attrs[8..12].try_into().unwrap());
96            let routed_scaling = f32::from_le_bytes(attrs[12..16].try_into().unwrap());
97            let num_experts = u32::from_le_bytes(attrs[16..20].try_into().unwrap());
98            GateAttrs {
99                n_group,
100                topk_group,
101                top_k,
102                routed_scaling,
103                num_experts,
104            }
105        } else {
106            GateAttrs {
107                n_group: 8,
108                topk_group: 4,
109                top_k: 8,
110                routed_scaling: 2.5,
111                num_experts: 256,
112            }
113        }
114    }
115}
116
117/// Run the gate inside a contiguous f32 arena (CUDA/ROCm/WGPU host segments).
118pub fn execute_gate_in_f32_arena(
119    host: &mut [f32],
120    sig_f32_off: usize,
121    route_f32_off: usize,
122    out_f32_off: usize,
123    n_elems: usize,
124    attrs: &[u8],
125) -> Result<(), String> {
126    let a = GateAttrs::from_bytes(attrs);
127    let e = a.num_experts as usize;
128    let k = a.top_k as usize;
129    let rows = n_elems / e.max(1);
130    let out_end = out_f32_off + rows * k * 2;
131    let sig = host[sig_f32_off..sig_f32_off + n_elems].to_vec();
132    let route = host[route_f32_off..route_f32_off + n_elems].to_vec();
133    let out = &mut host[out_f32_off..out_end];
134    execute_gate_f32(&sig, &route, out, attrs)
135}
136
137/// Execute gate: inputs = [sigmoid scores, routing scores]; output = [idx, weights] packed.
138pub fn execute_gate_f32(
139    scores_sigmoid: &[f32],
140    scores_route: &[f32],
141    out: &mut [f32],
142    attrs: &[u8],
143) -> Result<(), String> {
144    let a = GateAttrs::from_bytes(attrs);
145    let rows = scores_sigmoid.len() / a.num_experts as usize;
146    let e = a.num_experts as usize;
147    let k = a.top_k as usize;
148    if scores_route.len() != scores_sigmoid.len() {
149        return Err("gate: sigmoid and routing score lengths differ".into());
150    }
151    if out.len() != rows * k * 2 {
152        return Err(format!("output len {} != rows*k*2", out.len()));
153    }
154    let (_, idx) = group_limited_topk(
155        scores_route,
156        rows,
157        e,
158        a.n_group as usize,
159        a.topk_group as usize,
160        k,
161    );
162    for t in 0..rows {
163        let row_sig = &scores_sigmoid[t * e..(t + 1) * e];
164        let mut picked = Vec::with_capacity(k);
165        for ki in 0..k {
166            let ei = idx[t * k + ki] as usize;
167            picked.push(row_sig[ei]);
168        }
169        let sum: f32 = picked.iter().sum::<f32>() + 1e-20;
170        let norm = if k > 1 { 1.0 / sum } else { 1.0 };
171        for ki in 0..k {
172            out[t * k * 2 + ki] = idx[t * k + ki] as f32;
173            out[t * k * 2 + k + ki] = picked[ki] * norm * a.routed_scaling;
174        }
175    }
176    Ok(())
177}