1use ghostflow_core::Tensor;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum RoutingStrategy {
17 TopK,
19 Switch,
21 ExpertChoice,
23}
24
25#[derive(Debug, Clone)]
27pub struct MoEConfig {
28 pub num_experts: usize,
30 pub top_k: usize,
32 pub capacity_factor: f32,
34 pub routing_strategy: RoutingStrategy,
36 pub load_balance_loss_weight: f32,
38 pub expert_dropout: f32,
40 pub expert_parallel: bool,
42}
43
44impl Default for MoEConfig {
45 fn default() -> Self {
46 MoEConfig {
47 num_experts: 8,
48 top_k: 2,
49 capacity_factor: 1.25,
50 routing_strategy: RoutingStrategy::TopK,
51 load_balance_loss_weight: 0.01,
52 expert_dropout: 0.0,
53 expert_parallel: false,
54 }
55 }
56}
57
58impl MoEConfig {
59 pub fn switch_transformer(num_experts: usize) -> Self {
61 MoEConfig {
62 num_experts,
63 top_k: 1,
64 routing_strategy: RoutingStrategy::Switch,
65 capacity_factor: 1.0,
66 ..Default::default()
67 }
68 }
69
70 pub fn gshard(num_experts: usize) -> Self {
72 MoEConfig {
73 num_experts,
74 top_k: 2,
75 routing_strategy: RoutingStrategy::TopK,
76 capacity_factor: 1.25,
77 ..Default::default()
78 }
79 }
80
81 pub fn expert_choice(num_experts: usize, capacity_factor: f32) -> Self {
83 MoEConfig {
84 num_experts,
85 top_k: 1,
86 routing_strategy: RoutingStrategy::ExpertChoice,
87 capacity_factor,
88 ..Default::default()
89 }
90 }
91}
92
93pub struct Expert {
95 id: usize,
97 d_model: usize,
99 d_ff: usize,
101 w1: Tensor,
103 w2: Tensor,
104}
105
106impl Expert {
107 pub fn new(id: usize, d_model: usize, d_ff: usize) -> Result<Self, String> {
109 let w1 = Tensor::randn(&[d_model, d_ff]);
110 let w2 = Tensor::randn(&[d_ff, d_model]);
111
112 Ok(Expert {
113 id,
114 d_model,
115 d_ff,
116 w1,
117 w2,
118 })
119 }
120
121 pub fn forward(&self, input: &Tensor) -> Result<Tensor, String> {
123 let hidden = input.matmul(&self.w1)
125 .map_err(|e| format!("Failed to compute W1: {:?}", e))?;
126 let activated = hidden.gelu();
127 activated.matmul(&self.w2)
128 .map_err(|e| format!("Failed to compute W2: {:?}", e))
129 }
130}
131
132pub struct Router {
134 weights: Tensor,
136 num_experts: usize,
138}
139
140impl Router {
141 pub fn new(d_model: usize, num_experts: usize) -> Result<Self, String> {
143 let weights = Tensor::randn(&[d_model, num_experts]);
144
145 Ok(Router {
146 weights,
147 num_experts,
148 })
149 }
150
151 pub fn route(&self, input: &Tensor) -> Result<Tensor, String> {
153 let logits = input.matmul(&self.weights)
155 .map_err(|e| format!("Failed to compute routing logits: {:?}", e))?;
156
157 Ok(logits.softmax(-1))
159 }
160
161 pub fn select_top_k(&self, probs: &Tensor, k: usize) -> Result<(Vec<usize>, Vec<f32>), String> {
163 let data = probs.data_f32();
164
165 let mut indexed: Vec<(usize, f32)> = data.iter()
167 .enumerate()
168 .map(|(i, &v)| (i, v))
169 .collect();
170
171 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
172
173 let top_k_indices: Vec<usize> = indexed.iter().take(k).map(|(i, _)| *i).collect();
174 let top_k_values: Vec<f32> = indexed.iter().take(k).map(|(_, v)| *v).collect();
175
176 Ok((top_k_indices, top_k_values))
177 }
178}
179
180pub struct MixtureOfExperts {
182 config: MoEConfig,
183 experts: Vec<Expert>,
184 router: Router,
185 expert_usage: Vec<usize>,
187 aux_loss: f32,
189}
190
191impl MixtureOfExperts {
192 pub fn new(config: MoEConfig, d_model: usize, d_ff: usize) -> Result<Self, String> {
194 let mut experts = Vec::new();
195 for i in 0..config.num_experts {
196 experts.push(Expert::new(i, d_model, d_ff)?);
197 }
198
199 let router = Router::new(d_model, config.num_experts)?;
200 let expert_usage = vec![0; config.num_experts];
201
202 Ok(MixtureOfExperts {
203 config,
204 experts,
205 router,
206 expert_usage,
207 aux_loss: 0.0,
208 })
209 }
210
211 pub fn forward(&mut self, input: &Tensor) -> Result<Tensor, String> {
213 let dims = input.dims();
214
215 if dims.len() != 2 {
216 return Err("Expected 2D tensor [seq_len, d_model]".to_string());
217 }
218
219 let seq_len = dims[0];
220 let d_model = dims[1];
221
222 let mut outputs = Vec::new();
224 let mut routing_probs = Vec::new();
225
226 for i in 0..seq_len {
227 let token = self.extract_token(input, i)?;
228
229 let probs = self.router.route(&token)?;
231 routing_probs.push(probs.clone());
232
233 let (expert_ids, expert_weights) = self.router.select_top_k(&probs, self.config.top_k)?;
235
236 let mut token_output = vec![0.0f32; d_model];
238 let mut weight_sum = 0.0;
239
240 for (expert_id, weight) in expert_ids.iter().zip(expert_weights.iter()) {
241 if *expert_id < self.experts.len() {
242 let expert = &self.experts[*expert_id];
243 let expert_output = expert.forward(&token)?;
244
245 let expert_data = expert_output.data_f32();
247 for j in 0..d_model {
248 token_output[j] += weight * expert_data[j];
249 }
250
251 weight_sum += weight;
252 self.expert_usage[*expert_id] += 1;
253 }
254 }
255
256 if weight_sum > 0.0 {
258 for val in &mut token_output {
259 *val /= weight_sum;
260 }
261 }
262
263 outputs.push(token_output);
264 }
265
266 self.aux_loss = self.compute_load_balance_loss(&routing_probs)?;
268
269 let flattened: Vec<f32> = outputs.into_iter().flatten().collect();
271
272 Tensor::from_slice(&flattened, &[seq_len, d_model])
273 .map_err(|e| format!("Failed to create output tensor: {:?}", e))
274 }
275
276 fn extract_token(&self, input: &Tensor, token_idx: usize) -> Result<Tensor, String> {
278 let data = input.data_f32();
279 let d_model = input.dims()[1];
280
281 let start = token_idx * d_model;
282 let end = start + d_model;
283
284 Tensor::from_slice(&data[start..end], &[1, d_model])
285 .map_err(|e| format!("Failed to extract token: {:?}", e))
286 }
287
288 fn compute_load_balance_loss(&self, routing_probs: &[Tensor]) -> Result<f32, String> {
290 if routing_probs.is_empty() {
291 return Ok(0.0);
292 }
293
294 let num_tokens = routing_probs.len() as f32;
295 let num_experts = self.config.num_experts as f32;
296
297 let mut expert_fractions = vec![0.0f32; self.config.num_experts];
299
300 for probs in routing_probs {
301 let data = probs.data_f32();
302 for (i, &prob) in data.iter().enumerate() {
303 if i < expert_fractions.len() {
304 expert_fractions[i] += prob;
305 }
306 }
307 }
308
309 for frac in &mut expert_fractions {
310 *frac /= num_tokens;
311 }
312
313 let mean = 1.0 / num_experts;
316 let variance: f32 = expert_fractions.iter()
317 .map(|&f| (f - mean).powi(2))
318 .sum::<f32>() / num_experts;
319
320 let cv = variance.sqrt() / mean;
321
322 Ok(cv * self.config.load_balance_loss_weight)
323 }
324
325 pub fn get_aux_loss(&self) -> f32 {
327 self.aux_loss
328 }
329
330 pub fn get_expert_usage(&self) -> &[usize] {
332 &self.expert_usage
333 }
334
335 pub fn reset_usage_stats(&mut self) {
337 self.expert_usage.fill(0);
338 }
339
340 pub fn load_balance_factor(&self) -> f32 {
342 if self.expert_usage.is_empty() {
343 return 1.0;
344 }
345
346 let total: usize = self.expert_usage.iter().sum();
347 if total == 0 {
348 return 1.0;
349 }
350
351 let mean = total as f32 / self.expert_usage.len() as f32;
352 let variance: f32 = self.expert_usage.iter()
353 .map(|&u| (u as f32 - mean).powi(2))
354 .sum::<f32>() / self.expert_usage.len() as f32;
355
356 let std_dev = variance.sqrt();
357 let cv = std_dev / mean;
358
359 1.0 / (1.0 + cv)
361 }
362
363 pub fn get_stats(&self) -> MoEStats {
365 MoEStats {
366 num_experts: self.config.num_experts,
367 top_k: self.config.top_k,
368 routing_strategy: self.config.routing_strategy,
369 aux_loss: self.aux_loss,
370 load_balance_factor: self.load_balance_factor(),
371 expert_usage: self.expert_usage.clone(),
372 }
373 }
374}
375
376#[derive(Debug, Clone)]
378pub struct MoEStats {
379 pub num_experts: usize,
380 pub top_k: usize,
381 pub routing_strategy: RoutingStrategy,
382 pub aux_loss: f32,
383 pub load_balance_factor: f32,
384 pub expert_usage: Vec<usize>,
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_moe_config() {
393 let config = MoEConfig::default();
394 assert_eq!(config.num_experts, 8);
395 assert_eq!(config.top_k, 2);
396
397 let switch = MoEConfig::switch_transformer(16);
398 assert_eq!(switch.num_experts, 16);
399 assert_eq!(switch.top_k, 1);
400 assert_eq!(switch.routing_strategy, RoutingStrategy::Switch);
401 }
402
403 #[test]
404 fn test_expert_creation() {
405 let expert = Expert::new(0, 512, 2048).unwrap();
406 assert_eq!(expert.id, 0);
407 assert_eq!(expert.d_model, 512);
408 assert_eq!(expert.d_ff, 2048);
409 }
410
411 #[test]
412 fn test_expert_forward() {
413 let expert = Expert::new(0, 64, 256).unwrap();
414 let input = Tensor::randn(&[1, 64]);
415
416 let output = expert.forward(&input).unwrap();
417 assert_eq!(output.dims(), &[1, 64]);
418 }
419
420 #[test]
421 fn test_router_creation() {
422 let router = Router::new(512, 8).unwrap();
423 assert_eq!(router.num_experts, 8);
424 }
425
426 #[test]
427 fn test_router_route() {
428 let router = Router::new(64, 8).unwrap();
429 let input = Tensor::randn(&[1, 64]);
430
431 let probs = router.route(&input).unwrap();
432 assert_eq!(probs.dims()[1], 8);
433
434 let data = probs.data_f32();
436 let sum: f32 = data.iter().sum();
437 assert!((sum - 1.0).abs() < 1e-5);
438 }
439
440 #[test]
441 fn test_router_top_k() {
442 let router = Router::new(64, 8).unwrap();
443 let probs = Tensor::from_slice(&[0.1f32, 0.3, 0.05, 0.25, 0.15, 0.05, 0.05, 0.05], &[1, 8]).unwrap();
444
445 let (indices, values) = router.select_top_k(&probs, 2).unwrap();
446
447 assert_eq!(indices.len(), 2);
448 assert_eq!(values.len(), 2);
449 assert!(values[0] >= values[1]); }
451
452 #[test]
453 fn test_moe_creation() {
454 let config = MoEConfig::default();
455 let moe = MixtureOfExperts::new(config, 128, 512).unwrap();
456
457 assert_eq!(moe.experts.len(), 8);
458 }
459
460 #[test]
461 fn test_moe_forward() {
462 let config = MoEConfig {
463 num_experts: 4,
464 top_k: 2,
465 ..Default::default()
466 };
467 let mut moe = MixtureOfExperts::new(config, 64, 256).unwrap();
468
469 let input = Tensor::randn(&[8, 64]);
470 let output = moe.forward(&input).unwrap();
471
472 assert_eq!(output.dims(), &[8, 64]);
473 }
474
475 #[test]
476 fn test_load_balance_factor() {
477 let config = MoEConfig::default();
478 let mut moe = MixtureOfExperts::new(config, 64, 256).unwrap();
479
480 let input = Tensor::randn(&[16, 64]);
481 moe.forward(&input).unwrap();
482
483 let balance = moe.load_balance_factor();
484 assert!(balance > 0.0);
485 assert!(balance <= 1.0);
486 }
487
488 #[test]
489 fn test_aux_loss() {
490 let config = MoEConfig::default();
491 let mut moe = MixtureOfExperts::new(config, 64, 256).unwrap();
492
493 let input = Tensor::randn(&[8, 64]);
494 moe.forward(&input).unwrap();
495
496 let aux_loss = moe.get_aux_loss();
497 assert!(aux_loss >= 0.0);
498 }
499
500 #[test]
501 fn test_expert_usage_stats() {
502 let config = MoEConfig::default();
503 let mut moe = MixtureOfExperts::new(config, 64, 256).unwrap();
504
505 let input = Tensor::randn(&[16, 64]);
506 moe.forward(&input).unwrap();
507
508 let usage = moe.get_expert_usage();
509 let total: usize = usage.iter().sum();
510 assert!(total > 0);
511
512 moe.reset_usage_stats();
513 let usage_after = moe.get_expert_usage();
514 assert_eq!(usage_after.iter().sum::<usize>(), 0);
515 }
516
517 #[test]
518 fn test_gshard_config() {
519 let config = MoEConfig::gshard(16);
520 assert_eq!(config.num_experts, 16);
521 assert_eq!(config.top_k, 2);
522 assert_eq!(config.routing_strategy, RoutingStrategy::TopK);
523 }
524}