1pub fn group_limited_topk(
19 scores: &[f32],
20 num_tokens: usize,
21 num_experts: usize,
22 n_group: usize,
23 topk_group: usize,
24 top_k: usize,
25) -> (Vec<f32>, Vec<u32>) {
26 let epg = num_experts / n_group;
27 let mut probs = Vec::with_capacity(num_tokens * top_k);
28 let mut indices = Vec::with_capacity(num_tokens * top_k);
29 for t in 0..num_tokens {
30 let row = &scores[t * num_experts..(t + 1) * num_experts];
31 let mut group_scores = vec![0f32; n_group];
32 for g in 0..n_group {
33 let base = g * epg;
34 let slice = &row[base..base + epg];
35 let mut top2 = [f32::NEG_INFINITY; 2];
36 for &v in slice {
37 if v > top2[0] {
38 top2[1] = top2[0];
39 top2[0] = v;
40 } else if v > top2[1] {
41 top2[1] = v;
42 }
43 }
44 group_scores[g] = top2[0] + top2[1];
45 }
46 let mut group_order: Vec<usize> = (0..n_group).collect();
47 group_order.sort_by(|&a, &b| {
48 group_scores[b]
49 .partial_cmp(&group_scores[a])
50 .unwrap_or(std::cmp::Ordering::Equal)
51 });
52 let selected: std::collections::HashSet<usize> =
53 group_order.into_iter().take(topk_group).collect();
54 let mut masked = vec![f32::NEG_INFINITY; num_experts];
55 for g in selected {
56 let base = g * epg;
57 masked[base..base + epg].copy_from_slice(&row[base..base + epg]);
58 }
59 let mut order: Vec<usize> = (0..num_experts).collect();
60 order.sort_by(|&a, &b| {
61 masked[b]
62 .partial_cmp(&masked[a])
63 .unwrap_or(std::cmp::Ordering::Equal)
64 });
65 let mut picked_scores = Vec::with_capacity(top_k);
66 let mut picked_idx = Vec::with_capacity(top_k);
67 for &ei in order.iter().take(top_k) {
68 picked_scores.push(row[ei]);
69 picked_idx.push(ei as u32);
70 }
71 let sum: f32 = picked_scores.iter().sum::<f32>() + 1e-20;
72 let scale = if top_k > 1 { 1.0 / sum } else { 1.0 };
73 for (p, &ei) in picked_scores.iter().zip(&picked_idx) {
74 probs.push(p * scale);
75 indices.push(ei);
76 }
77 }
78 (probs, indices)
79}
80
81#[derive(Clone, Copy)]
82pub struct GateAttrs {
83 pub n_group: u32,
84 pub topk_group: u32,
85 pub top_k: u32,
86 pub routed_scaling: f32,
87 pub num_experts: u32,
88}
89
90impl GateAttrs {
91 pub fn from_bytes(attrs: &[u8]) -> Self {
92 if attrs.len() >= 20 {
93 let n_group = u32::from_le_bytes(attrs[0..4].try_into().unwrap());
94 let topk_group = u32::from_le_bytes(attrs[4..8].try_into().unwrap());
95 let top_k = u32::from_le_bytes(attrs[8..12].try_into().unwrap());
96 let routed_scaling = f32::from_le_bytes(attrs[12..16].try_into().unwrap());
97 let num_experts = u32::from_le_bytes(attrs[16..20].try_into().unwrap());
98 GateAttrs {
99 n_group,
100 topk_group,
101 top_k,
102 routed_scaling,
103 num_experts,
104 }
105 } else {
106 GateAttrs {
107 n_group: 8,
108 topk_group: 4,
109 top_k: 8,
110 routed_scaling: 2.5,
111 num_experts: 256,
112 }
113 }
114 }
115}
116
117pub fn execute_gate_in_f32_arena(
119 host: &mut [f32],
120 sig_f32_off: usize,
121 route_f32_off: usize,
122 out_f32_off: usize,
123 n_elems: usize,
124 attrs: &[u8],
125) -> Result<(), String> {
126 let a = GateAttrs::from_bytes(attrs);
127 let e = a.num_experts as usize;
128 let k = a.top_k as usize;
129 let rows = n_elems / e.max(1);
130 let out_end = out_f32_off + rows * k * 2;
131 let sig = host[sig_f32_off..sig_f32_off + n_elems].to_vec();
132 let route = host[route_f32_off..route_f32_off + n_elems].to_vec();
133 let out = &mut host[out_f32_off..out_end];
134 execute_gate_f32(&sig, &route, out, attrs)
135}
136
137pub fn execute_gate_f32(
139 scores_sigmoid: &[f32],
140 scores_route: &[f32],
141 out: &mut [f32],
142 attrs: &[u8],
143) -> Result<(), String> {
144 let a = GateAttrs::from_bytes(attrs);
145 let rows = scores_sigmoid.len() / a.num_experts as usize;
146 let e = a.num_experts as usize;
147 let k = a.top_k as usize;
148 if scores_route.len() != scores_sigmoid.len() {
149 return Err("gate: sigmoid and routing score lengths differ".into());
150 }
151 if out.len() != rows * k * 2 {
152 return Err(format!("output len {} != rows*k*2", out.len()));
153 }
154 let (_, idx) = group_limited_topk(
155 scores_route,
156 rows,
157 e,
158 a.n_group as usize,
159 a.topk_group as usize,
160 k,
161 );
162 for t in 0..rows {
163 let row_sig = &scores_sigmoid[t * e..(t + 1) * e];
164 let mut picked = Vec::with_capacity(k);
165 for ki in 0..k {
166 let ei = idx[t * k + ki] as usize;
167 picked.push(row_sig[ei]);
168 }
169 let sum: f32 = picked.iter().sum::<f32>() + 1e-20;
170 let norm = if k > 1 { 1.0 / sum } else { 1.0 };
171 for ki in 0..k {
172 out[t * k * 2 + ki] = idx[t * k + ki] as f32;
173 out[t * k * 2 + k + ki] = picked[ki] * norm * a.routed_scaling;
174 }
175 }
176 Ok(())
177}