1use std::collections::HashMap;
29
30use axonml_autograd::Variable;
31use axonml_tensor::Tensor;
32
33use crate::layers::Linear;
34use crate::module::Module;
35use crate::parameter::Parameter;
36
37pub struct Expert {
51 up_proj: Linear,
53 gate_proj: Linear,
55 down_proj: Linear,
57}
58
59impl Expert {
60 pub fn new(d_model: usize, intermediate_size: usize) -> Self {
66 Self {
67 up_proj: Linear::with_bias(d_model, intermediate_size, false),
68 gate_proj: Linear::with_bias(d_model, intermediate_size, false),
69 down_proj: Linear::with_bias(intermediate_size, d_model, false),
70 }
71 }
72}
73
74impl Module for Expert {
75 fn forward(&self, input: &Variable) -> Variable {
76 let gate = self.gate_proj.forward(input).silu();
78 let up = self.up_proj.forward(input);
79 let hidden = gate.mul_var(&up);
80 self.down_proj.forward(&hidden)
81 }
82
83 fn parameters(&self) -> Vec<Parameter> {
84 let mut params = Vec::new();
85 params.extend(self.up_proj.parameters());
86 params.extend(self.gate_proj.parameters());
87 params.extend(self.down_proj.parameters());
88 params
89 }
90
91 fn named_parameters(&self) -> HashMap<String, Parameter> {
92 let mut params = HashMap::new();
93 for (name, param) in self.up_proj.named_parameters() {
94 params.insert(format!("up_proj.{name}"), param);
95 }
96 for (name, param) in self.gate_proj.named_parameters() {
97 params.insert(format!("gate_proj.{name}"), param);
98 }
99 for (name, param) in self.down_proj.named_parameters() {
100 params.insert(format!("down_proj.{name}"), param);
101 }
102 params
103 }
104
105 fn name(&self) -> &'static str {
106 "Expert"
107 }
108}
109
110impl std::fmt::Debug for Expert {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 f.debug_struct("Expert")
113 .field("up_proj", &self.up_proj)
114 .field("gate_proj", &self.gate_proj)
115 .field("down_proj", &self.down_proj)
116 .finish()
117 }
118}
119
120pub struct MoERouter {
137 gate: Linear,
139 num_experts: usize,
141 top_k: usize,
143}
144
145impl MoERouter {
146 pub fn new(d_model: usize, num_experts: usize, top_k: usize) -> Self {
153 assert!(
154 top_k <= num_experts,
155 "top_k ({top_k}) must be <= num_experts ({num_experts})"
156 );
157 Self {
158 gate: Linear::with_bias(d_model, num_experts, false),
159 num_experts,
160 top_k,
161 }
162 }
163
164 pub fn route(&self, x: &Variable) -> (Variable, Vec<Vec<f32>>, Vec<Vec<usize>>) {
174 let gate_logits = self.gate.forward(x);
175 let gate_probs = gate_logits.softmax(-1);
176
177 let probs_data = gate_probs.data();
178 let probs_vec = probs_data.to_vec();
179 let num_tokens = probs_data.shape()[0];
180
181 let mut top_k_weights = Vec::with_capacity(num_tokens);
182 let mut top_k_indices = Vec::with_capacity(num_tokens);
183
184 for t in 0..num_tokens {
185 let offset = t * self.num_experts;
186 let token_probs = &probs_vec[offset..offset + self.num_experts];
187
188 let mut indexed: Vec<(usize, f32)> = token_probs
190 .iter()
191 .enumerate()
192 .map(|(i, &p)| (i, p))
193 .collect();
194 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
195
196 let top_indices: Vec<usize> = indexed[..self.top_k].iter().map(|(i, _)| *i).collect();
197 let top_weights: Vec<f32> = indexed[..self.top_k].iter().map(|(_, w)| *w).collect();
198
199 let weight_sum: f32 = top_weights.iter().sum();
201 let normalized: Vec<f32> = if weight_sum > 0.0 {
202 top_weights.iter().map(|w| w / weight_sum).collect()
203 } else {
204 vec![1.0 / self.top_k as f32; self.top_k]
205 };
206
207 top_k_weights.push(normalized);
208 top_k_indices.push(top_indices);
209 }
210
211 (gate_probs, top_k_weights, top_k_indices)
212 }
213
214 pub fn num_experts(&self) -> usize {
216 self.num_experts
217 }
218
219 pub fn top_k(&self) -> usize {
221 self.top_k
222 }
223}
224
225impl std::fmt::Debug for MoERouter {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 f.debug_struct("MoERouter")
228 .field("num_experts", &self.num_experts)
229 .field("top_k", &self.top_k)
230 .finish()
231 }
232}
233
234pub struct MoELayer {
254 experts: Vec<Expert>,
256 router: MoERouter,
258 d_model: usize,
260 num_experts: usize,
262 top_k: usize,
264 last_gate_probs: std::sync::RwLock<Option<Variable>>,
266 last_expert_counts: std::sync::RwLock<Vec<usize>>,
268}
269
270impl MoELayer {
271 pub fn new(d_model: usize, intermediate_size: usize, num_experts: usize, top_k: usize) -> Self {
279 let experts: Vec<Expert> = (0..num_experts)
280 .map(|_| Expert::new(d_model, intermediate_size))
281 .collect();
282 let router = MoERouter::new(d_model, num_experts, top_k);
283
284 Self {
285 experts,
286 router,
287 d_model,
288 num_experts,
289 top_k,
290 last_gate_probs: std::sync::RwLock::new(None),
291 last_expert_counts: std::sync::RwLock::new(vec![0; num_experts]),
292 }
293 }
294
295 pub fn load_balancing_loss(&self) -> Variable {
307 let gate_probs_opt = self.last_gate_probs.read().unwrap();
308 if gate_probs_opt.is_none() {
309 return Variable::new(
310 Tensor::from_vec(vec![0.0f32], &[1]).expect("tensor creation failed"),
311 false,
312 );
313 }
314
315 let gate_probs = gate_probs_opt.as_ref().unwrap();
316 let probs_data = gate_probs.data();
317 let probs_vec = probs_data.to_vec();
318 let shape = probs_data.shape();
319 let num_tokens = shape[0];
320 let num_experts = shape[1];
321
322 if num_tokens == 0 {
323 return Variable::new(
324 Tensor::from_vec(vec![0.0f32], &[1]).expect("tensor creation failed"),
325 false,
326 );
327 }
328
329 let expert_counts = self.last_expert_counts.read().unwrap();
330
331 let token_fractions: Vec<f32> = expert_counts
333 .iter()
334 .map(|&c| c as f32 / num_tokens as f32)
335 .collect();
336
337 let mut mean_probs = vec![0.0f32; num_experts];
339 for t in 0..num_tokens {
340 for e in 0..num_experts {
341 mean_probs[e] += probs_vec[t * num_experts + e];
342 }
343 }
344 for p in &mut mean_probs {
345 *p /= num_tokens as f32;
346 }
347
348 let mut loss_val = 0.0f32;
350 for e in 0..num_experts {
351 loss_val += token_fractions[e] * mean_probs[e];
352 }
353 loss_val *= num_experts as f32;
354
355 Variable::new(
356 Tensor::from_vec(vec![loss_val], &[1]).expect("tensor creation failed"),
357 false,
358 )
359 }
360
361 pub fn expert_utilization(&self) -> Vec<usize> {
365 self.last_expert_counts.read().unwrap().clone()
366 }
367
368 pub fn num_experts(&self) -> usize {
370 self.num_experts
371 }
372
373 pub fn top_k(&self) -> usize {
375 self.top_k
376 }
377}
378
379impl Module for MoELayer {
380 fn forward(&self, input: &Variable) -> Variable {
381 let shape = input.shape();
382 let batch_size = shape[0];
383 let seq_len = shape[1];
384 let d_model = shape[2];
385 let num_tokens = batch_size * seq_len;
386
387 let flat_input = input.reshape(&[num_tokens, d_model]);
389
390 let (gate_probs, top_k_weights, top_k_indices) = self.router.route(&flat_input);
392
393 let mut expert_counts = vec![0usize; self.num_experts];
395 for indices in &top_k_indices {
396 for &idx in indices {
397 expert_counts[idx] += 1;
398 }
399 }
400 *self.last_expert_counts.write().unwrap() = expert_counts;
401 *self.last_gate_probs.write().unwrap() = Some(gate_probs);
402
403 let mut output_data = vec![0.0f32; num_tokens * d_model];
405
406 for expert_idx in 0..self.num_experts {
408 let mut token_indices = Vec::new();
410 let mut token_weights = Vec::new();
411
412 for (t, (indices, weights)) in
413 top_k_indices.iter().zip(top_k_weights.iter()).enumerate()
414 {
415 for (k, (&idx, &w)) in indices.iter().zip(weights.iter()).enumerate() {
416 if idx == expert_idx {
417 token_indices.push(t);
418 token_weights.push(w);
419 let _ = k;
420 }
421 }
422 }
423
424 if token_indices.is_empty() {
425 continue;
426 }
427
428 let flat_data = flat_input.data();
430 let flat_vec = flat_data.to_vec();
431 let n = token_indices.len();
432 let mut expert_input_data = Vec::with_capacity(n * d_model);
433 for &t in &token_indices {
434 let offset = t * d_model;
435 expert_input_data.extend_from_slice(&flat_vec[offset..offset + d_model]);
436 }
437 let expert_input = Variable::new(
438 Tensor::from_vec(expert_input_data, &[n, d_model]).expect("tensor creation failed"),
439 true,
440 );
441
442 let expert_output = self.experts[expert_idx].forward(&expert_input);
444 let expert_out_vec = expert_output.data().to_vec();
445
446 for (local_idx, &global_idx) in token_indices.iter().enumerate() {
448 let weight = token_weights[local_idx];
449 let src_offset = local_idx * d_model;
450 let dst_offset = global_idx * d_model;
451 for d in 0..d_model {
452 output_data[dst_offset + d] += weight * expert_out_vec[src_offset + d];
453 }
454 }
455 }
456
457 let output_tensor =
458 Tensor::from_vec(output_data, &[num_tokens, d_model]).expect("tensor creation failed");
459 let output = Variable::new(output_tensor, true);
460
461 output.reshape(&[batch_size, seq_len, d_model])
463 }
464
465 fn parameters(&self) -> Vec<Parameter> {
466 let mut params = Vec::new();
467 params.extend(self.router.gate.parameters());
468 for expert in &self.experts {
469 params.extend(expert.parameters());
470 }
471 params
472 }
473
474 fn named_parameters(&self) -> HashMap<String, Parameter> {
475 let mut params = HashMap::new();
476 for (name, param) in self.router.gate.named_parameters() {
477 params.insert(format!("router.gate.{name}"), param);
478 }
479 for (i, expert) in self.experts.iter().enumerate() {
480 for (name, param) in expert.named_parameters() {
481 params.insert(format!("experts.{i}.{name}"), param);
482 }
483 }
484 params
485 }
486
487 fn name(&self) -> &'static str {
488 "MoELayer"
489 }
490}
491
492impl std::fmt::Debug for MoELayer {
493 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494 f.debug_struct("MoELayer")
495 .field("d_model", &self.d_model)
496 .field("num_experts", &self.num_experts)
497 .field("top_k", &self.top_k)
498 .field("experts", &self.experts.len())
499 .finish()
500 }
501}
502
503#[cfg(test)]
508mod tests {
509 use super::*;
510
511 #[test]
512 fn test_expert_creation() {
513 let expert = Expert::new(64, 256);
514 let params = expert.parameters();
515 assert_eq!(params.len(), 3);
517 }
518
519 #[test]
520 fn test_expert_forward() {
521 let expert = Expert::new(64, 256);
522 let input = Variable::new(
523 Tensor::from_vec(vec![0.1; 4 * 64], &[4, 64]).expect("tensor creation failed"),
524 false,
525 );
526 let output = expert.forward(&input);
527 assert_eq!(output.shape(), vec![4, 64]);
528 }
529
530 #[test]
531 fn test_router_creation() {
532 let router = MoERouter::new(64, 8, 2);
533 assert_eq!(router.num_experts(), 8);
534 assert_eq!(router.top_k(), 2);
535 }
536
537 #[test]
538 fn test_router_route() {
539 let router = MoERouter::new(64, 8, 2);
540 let input = Variable::new(
541 Tensor::from_vec(vec![0.1; 4 * 64], &[4, 64]).expect("tensor creation failed"),
542 false,
543 );
544 let (_gate_probs, weights, indices) = router.route(&input);
545
546 assert_eq!(weights.len(), 4); assert_eq!(indices.len(), 4);
548 for w in &weights {
549 assert_eq!(w.len(), 2); let sum: f32 = w.iter().sum();
551 assert!((sum - 1.0).abs() < 1e-5, "Weights should sum to 1");
552 }
553 for idx in &indices {
554 assert_eq!(idx.len(), 2);
555 for &i in idx {
556 assert!(i < 8, "Expert index should be < num_experts");
557 }
558 }
559 }
560
561 #[test]
562 fn test_moe_layer_forward() {
563 let moe = MoELayer::new(64, 256, 8, 2);
564 let input = Variable::new(
565 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
566 false,
567 );
568 let output = moe.forward(&input);
569 assert_eq!(output.shape(), vec![2, 5, 64]);
570 }
571
572 #[test]
573 fn test_moe_layer_parameters() {
574 let moe = MoELayer::new(64, 256, 8, 2);
575 let params = moe.parameters();
576 assert_eq!(params.len(), 25);
580 }
581
582 #[test]
583 fn test_moe_load_balancing_loss() {
584 let moe = MoELayer::new(64, 256, 4, 2);
585 let input = Variable::new(
586 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
587 false,
588 );
589 let _output = moe.forward(&input);
590
591 let lb_loss = moe.load_balancing_loss();
592 let loss_val = lb_loss.data().to_vec()[0];
593 assert!(loss_val > 0.0, "Load balancing loss should be > 0");
595 }
596
597 #[test]
598 fn test_moe_expert_utilization() {
599 let moe = MoELayer::new(64, 256, 4, 2);
600 let input = Variable::new(
601 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
602 false,
603 );
604 let _output = moe.forward(&input);
605
606 let util = moe.expert_utilization();
607 assert_eq!(util.len(), 4);
608 let total: usize = util.iter().sum();
609 assert_eq!(total, 20);
611 }
612
613 #[test]
614 fn test_moe_named_parameters() {
615 let moe = MoELayer::new(64, 256, 4, 2);
616 let named = moe.named_parameters();
617 assert!(named.contains_key("router.gate.weight"));
618 assert!(named.contains_key("experts.0.up_proj.weight"));
619 assert!(named.contains_key("experts.3.down_proj.weight"));
620 }
621}