1use std::collections::HashMap;
28
29use axonml_autograd::Variable;
30use axonml_tensor::Tensor;
31
32use crate::layers::Linear;
33use crate::module::Module;
34use crate::parameter::Parameter;
35
36pub struct Expert {
50 up_proj: Linear,
52 gate_proj: Linear,
54 down_proj: Linear,
56}
57
58impl Expert {
59 pub fn new(d_model: usize, intermediate_size: usize) -> Self {
65 Self {
66 up_proj: Linear::with_bias(d_model, intermediate_size, false),
67 gate_proj: Linear::with_bias(d_model, intermediate_size, false),
68 down_proj: Linear::with_bias(intermediate_size, d_model, false),
69 }
70 }
71}
72
73impl Module for Expert {
74 fn forward(&self, input: &Variable) -> Variable {
75 let gate = self.gate_proj.forward(input).silu();
77 let up = self.up_proj.forward(input);
78 let hidden = gate.mul_var(&up);
79 self.down_proj.forward(&hidden)
80 }
81
82 fn parameters(&self) -> Vec<Parameter> {
83 let mut params = Vec::new();
84 params.extend(self.up_proj.parameters());
85 params.extend(self.gate_proj.parameters());
86 params.extend(self.down_proj.parameters());
87 params
88 }
89
90 fn named_parameters(&self) -> HashMap<String, Parameter> {
91 let mut params = HashMap::new();
92 for (name, param) in self.up_proj.named_parameters() {
93 params.insert(format!("up_proj.{name}"), param);
94 }
95 for (name, param) in self.gate_proj.named_parameters() {
96 params.insert(format!("gate_proj.{name}"), param);
97 }
98 for (name, param) in self.down_proj.named_parameters() {
99 params.insert(format!("down_proj.{name}"), param);
100 }
101 params
102 }
103
104 fn name(&self) -> &'static str {
105 "Expert"
106 }
107}
108
109impl std::fmt::Debug for Expert {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 f.debug_struct("Expert")
112 .field("up_proj", &self.up_proj)
113 .field("gate_proj", &self.gate_proj)
114 .field("down_proj", &self.down_proj)
115 .finish()
116 }
117}
118
119pub struct MoERouter {
136 gate: Linear,
138 num_experts: usize,
140 top_k: usize,
142}
143
144impl MoERouter {
145 pub fn new(d_model: usize, num_experts: usize, top_k: usize) -> Self {
152 assert!(
153 top_k <= num_experts,
154 "top_k ({top_k}) must be <= num_experts ({num_experts})"
155 );
156 Self {
157 gate: Linear::with_bias(d_model, num_experts, false),
158 num_experts,
159 top_k,
160 }
161 }
162
163 pub fn route(&self, x: &Variable) -> (Variable, Vec<Vec<f32>>, Vec<Vec<usize>>) {
173 let gate_logits = self.gate.forward(x);
174 let gate_probs = gate_logits.softmax(-1);
175
176 let probs_data = gate_probs.data();
177 let probs_vec = probs_data.to_vec();
178 let num_tokens = probs_data.shape()[0];
179
180 let mut top_k_weights = Vec::with_capacity(num_tokens);
181 let mut top_k_indices = Vec::with_capacity(num_tokens);
182
183 for t in 0..num_tokens {
184 let offset = t * self.num_experts;
185 let token_probs = &probs_vec[offset..offset + self.num_experts];
186
187 let mut indexed: Vec<(usize, f32)> = token_probs
189 .iter()
190 .enumerate()
191 .map(|(i, &p)| (i, p))
192 .collect();
193 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
194
195 let top_indices: Vec<usize> = indexed[..self.top_k].iter().map(|(i, _)| *i).collect();
196 let top_weights: Vec<f32> = indexed[..self.top_k].iter().map(|(_, w)| *w).collect();
197
198 let weight_sum: f32 = top_weights.iter().sum();
200 let normalized: Vec<f32> = if weight_sum > 0.0 {
201 top_weights.iter().map(|w| w / weight_sum).collect()
202 } else {
203 vec![1.0 / self.top_k as f32; self.top_k]
204 };
205
206 top_k_weights.push(normalized);
207 top_k_indices.push(top_indices);
208 }
209
210 (gate_probs, top_k_weights, top_k_indices)
211 }
212
213 pub fn num_experts(&self) -> usize {
215 self.num_experts
216 }
217
218 pub fn top_k(&self) -> usize {
220 self.top_k
221 }
222}
223
224impl std::fmt::Debug for MoERouter {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 f.debug_struct("MoERouter")
227 .field("num_experts", &self.num_experts)
228 .field("top_k", &self.top_k)
229 .finish()
230 }
231}
232
233pub struct MoELayer {
253 experts: Vec<Expert>,
255 router: MoERouter,
257 d_model: usize,
259 num_experts: usize,
261 top_k: usize,
263 last_gate_probs: std::sync::RwLock<Option<Variable>>,
265 last_expert_counts: std::sync::RwLock<Vec<usize>>,
267}
268
269impl MoELayer {
270 pub fn new(d_model: usize, intermediate_size: usize, num_experts: usize, top_k: usize) -> Self {
278 let experts: Vec<Expert> = (0..num_experts)
279 .map(|_| Expert::new(d_model, intermediate_size))
280 .collect();
281 let router = MoERouter::new(d_model, num_experts, top_k);
282
283 Self {
284 experts,
285 router,
286 d_model,
287 num_experts,
288 top_k,
289 last_gate_probs: std::sync::RwLock::new(None),
290 last_expert_counts: std::sync::RwLock::new(vec![0; num_experts]),
291 }
292 }
293
294 pub fn load_balancing_loss(&self) -> Variable {
306 let gate_probs_opt = self.last_gate_probs.read().unwrap();
307 if gate_probs_opt.is_none() {
308 return Variable::new(Tensor::from_vec(vec![0.0f32], &[1]).unwrap(), false);
309 }
310
311 let gate_probs = gate_probs_opt.as_ref().unwrap();
312 let probs_data = gate_probs.data();
313 let probs_vec = probs_data.to_vec();
314 let shape = probs_data.shape();
315 let num_tokens = shape[0];
316 let num_experts = shape[1];
317
318 if num_tokens == 0 {
319 return Variable::new(Tensor::from_vec(vec![0.0f32], &[1]).unwrap(), false);
320 }
321
322 let expert_counts = self.last_expert_counts.read().unwrap();
323
324 let token_fractions: Vec<f32> = expert_counts
326 .iter()
327 .map(|&c| c as f32 / num_tokens as f32)
328 .collect();
329
330 let mut mean_probs = vec![0.0f32; num_experts];
332 for t in 0..num_tokens {
333 for e in 0..num_experts {
334 mean_probs[e] += probs_vec[t * num_experts + e];
335 }
336 }
337 for p in &mut mean_probs {
338 *p /= num_tokens as f32;
339 }
340
341 let mut loss_val = 0.0f32;
343 for e in 0..num_experts {
344 loss_val += token_fractions[e] * mean_probs[e];
345 }
346 loss_val *= num_experts as f32;
347
348 Variable::new(Tensor::from_vec(vec![loss_val], &[1]).unwrap(), false)
349 }
350
351 pub fn expert_utilization(&self) -> Vec<usize> {
355 self.last_expert_counts.read().unwrap().clone()
356 }
357
358 pub fn num_experts(&self) -> usize {
360 self.num_experts
361 }
362
363 pub fn top_k(&self) -> usize {
365 self.top_k
366 }
367}
368
369impl Module for MoELayer {
370 fn forward(&self, input: &Variable) -> Variable {
371 let shape = input.shape();
372 let batch_size = shape[0];
373 let seq_len = shape[1];
374 let d_model = shape[2];
375 let num_tokens = batch_size * seq_len;
376
377 let flat_input = input.reshape(&[num_tokens, d_model]);
379
380 let (gate_probs, top_k_weights, top_k_indices) = self.router.route(&flat_input);
382
383 let mut expert_counts = vec![0usize; self.num_experts];
385 for indices in &top_k_indices {
386 for &idx in indices {
387 expert_counts[idx] += 1;
388 }
389 }
390 *self.last_expert_counts.write().unwrap() = expert_counts;
391 *self.last_gate_probs.write().unwrap() = Some(gate_probs);
392
393 let mut output_data = vec![0.0f32; num_tokens * d_model];
395
396 for expert_idx in 0..self.num_experts {
398 let mut token_indices = Vec::new();
400 let mut token_weights = Vec::new();
401
402 for (t, (indices, weights)) in
403 top_k_indices.iter().zip(top_k_weights.iter()).enumerate()
404 {
405 for (k, (&idx, &w)) in indices.iter().zip(weights.iter()).enumerate() {
406 if idx == expert_idx {
407 token_indices.push(t);
408 token_weights.push(w);
409 let _ = k;
410 }
411 }
412 }
413
414 if token_indices.is_empty() {
415 continue;
416 }
417
418 let flat_data = flat_input.data();
420 let flat_vec = flat_data.to_vec();
421 let n = token_indices.len();
422 let mut expert_input_data = Vec::with_capacity(n * d_model);
423 for &t in &token_indices {
424 let offset = t * d_model;
425 expert_input_data.extend_from_slice(&flat_vec[offset..offset + d_model]);
426 }
427 let expert_input = Variable::new(
428 Tensor::from_vec(expert_input_data, &[n, d_model]).unwrap(),
429 true,
430 );
431
432 let expert_output = self.experts[expert_idx].forward(&expert_input);
434 let expert_out_vec = expert_output.data().to_vec();
435
436 for (local_idx, &global_idx) in token_indices.iter().enumerate() {
438 let weight = token_weights[local_idx];
439 let src_offset = local_idx * d_model;
440 let dst_offset = global_idx * d_model;
441 for d in 0..d_model {
442 output_data[dst_offset + d] += weight * expert_out_vec[src_offset + d];
443 }
444 }
445 }
446
447 let output_tensor = Tensor::from_vec(output_data, &[num_tokens, d_model]).unwrap();
448 let output = Variable::new(output_tensor, true);
449
450 output.reshape(&[batch_size, seq_len, d_model])
452 }
453
454 fn parameters(&self) -> Vec<Parameter> {
455 let mut params = Vec::new();
456 params.extend(self.router.gate.parameters());
457 for expert in &self.experts {
458 params.extend(expert.parameters());
459 }
460 params
461 }
462
463 fn named_parameters(&self) -> HashMap<String, Parameter> {
464 let mut params = HashMap::new();
465 for (name, param) in self.router.gate.named_parameters() {
466 params.insert(format!("router.gate.{name}"), param);
467 }
468 for (i, expert) in self.experts.iter().enumerate() {
469 for (name, param) in expert.named_parameters() {
470 params.insert(format!("experts.{i}.{name}"), param);
471 }
472 }
473 params
474 }
475
476 fn name(&self) -> &'static str {
477 "MoELayer"
478 }
479}
480
481impl std::fmt::Debug for MoELayer {
482 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483 f.debug_struct("MoELayer")
484 .field("d_model", &self.d_model)
485 .field("num_experts", &self.num_experts)
486 .field("top_k", &self.top_k)
487 .field("experts", &self.experts.len())
488 .finish()
489 }
490}
491
492#[cfg(test)]
497mod tests {
498 use super::*;
499
500 #[test]
501 fn test_expert_creation() {
502 let expert = Expert::new(64, 256);
503 let params = expert.parameters();
504 assert_eq!(params.len(), 3);
506 }
507
508 #[test]
509 fn test_expert_forward() {
510 let expert = Expert::new(64, 256);
511 let input = Variable::new(
512 Tensor::from_vec(vec![0.1; 4 * 64], &[4, 64]).unwrap(),
513 false,
514 );
515 let output = expert.forward(&input);
516 assert_eq!(output.shape(), vec![4, 64]);
517 }
518
519 #[test]
520 fn test_router_creation() {
521 let router = MoERouter::new(64, 8, 2);
522 assert_eq!(router.num_experts(), 8);
523 assert_eq!(router.top_k(), 2);
524 }
525
526 #[test]
527 fn test_router_route() {
528 let router = MoERouter::new(64, 8, 2);
529 let input = Variable::new(
530 Tensor::from_vec(vec![0.1; 4 * 64], &[4, 64]).unwrap(),
531 false,
532 );
533 let (_gate_probs, weights, indices) = router.route(&input);
534
535 assert_eq!(weights.len(), 4); assert_eq!(indices.len(), 4);
537 for w in &weights {
538 assert_eq!(w.len(), 2); let sum: f32 = w.iter().sum();
540 assert!((sum - 1.0).abs() < 1e-5, "Weights should sum to 1");
541 }
542 for idx in &indices {
543 assert_eq!(idx.len(), 2);
544 for &i in idx {
545 assert!(i < 8, "Expert index should be < num_experts");
546 }
547 }
548 }
549
550 #[test]
551 fn test_moe_layer_forward() {
552 let moe = MoELayer::new(64, 256, 8, 2);
553 let input = Variable::new(
554 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
555 false,
556 );
557 let output = moe.forward(&input);
558 assert_eq!(output.shape(), vec![2, 5, 64]);
559 }
560
561 #[test]
562 fn test_moe_layer_parameters() {
563 let moe = MoELayer::new(64, 256, 8, 2);
564 let params = moe.parameters();
565 assert_eq!(params.len(), 25);
569 }
570
571 #[test]
572 fn test_moe_load_balancing_loss() {
573 let moe = MoELayer::new(64, 256, 4, 2);
574 let input = Variable::new(
575 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
576 false,
577 );
578 let _output = moe.forward(&input);
579
580 let lb_loss = moe.load_balancing_loss();
581 let loss_val = lb_loss.data().to_vec()[0];
582 assert!(loss_val > 0.0, "Load balancing loss should be > 0");
584 }
585
586 #[test]
587 fn test_moe_expert_utilization() {
588 let moe = MoELayer::new(64, 256, 4, 2);
589 let input = Variable::new(
590 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
591 false,
592 );
593 let _output = moe.forward(&input);
594
595 let util = moe.expert_utilization();
596 assert_eq!(util.len(), 4);
597 let total: usize = util.iter().sum();
598 assert_eq!(total, 20);
600 }
601
602 #[test]
603 fn test_moe_named_parameters() {
604 let moe = MoELayer::new(64, 256, 4, 2);
605 let named = moe.named_parameters();
606 assert!(named.contains_key("router.gate.weight"));
607 assert!(named.contains_key("experts.0.up_proj.weight"));
608 assert!(named.contains_key("experts.3.down_proj.weight"));
609 }
610}