1use std::collections::HashMap;
28
29use axonml_autograd::Variable;
30use axonml_tensor::Tensor;
31
32use crate::layers::Linear;
33use crate::module::Module;
34use crate::parameter::Parameter;
35
36pub struct Expert {
50 up_proj: Linear,
52 gate_proj: Linear,
54 down_proj: Linear,
56}
57
58impl Expert {
59 pub fn new(d_model: usize, intermediate_size: usize) -> Self {
65 Self {
66 up_proj: Linear::with_bias(d_model, intermediate_size, false),
67 gate_proj: Linear::with_bias(d_model, intermediate_size, false),
68 down_proj: Linear::with_bias(intermediate_size, d_model, false),
69 }
70 }
71}
72
73impl Module for Expert {
74 fn forward(&self, input: &Variable) -> Variable {
75 let gate = self.gate_proj.forward(input).silu();
77 let up = self.up_proj.forward(input);
78 let hidden = gate.mul_var(&up);
79 self.down_proj.forward(&hidden)
80 }
81
82 fn parameters(&self) -> Vec<Parameter> {
83 let mut params = Vec::new();
84 params.extend(self.up_proj.parameters());
85 params.extend(self.gate_proj.parameters());
86 params.extend(self.down_proj.parameters());
87 params
88 }
89
90 fn named_parameters(&self) -> HashMap<String, Parameter> {
91 let mut params = HashMap::new();
92 for (name, param) in self.up_proj.named_parameters() {
93 params.insert(format!("up_proj.{name}"), param);
94 }
95 for (name, param) in self.gate_proj.named_parameters() {
96 params.insert(format!("gate_proj.{name}"), param);
97 }
98 for (name, param) in self.down_proj.named_parameters() {
99 params.insert(format!("down_proj.{name}"), param);
100 }
101 params
102 }
103
104 fn name(&self) -> &'static str {
105 "Expert"
106 }
107}
108
109impl std::fmt::Debug for Expert {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 f.debug_struct("Expert")
112 .field("up_proj", &self.up_proj)
113 .field("gate_proj", &self.gate_proj)
114 .field("down_proj", &self.down_proj)
115 .finish()
116 }
117}
118
119pub struct MoERouter {
136 gate: Linear,
138 num_experts: usize,
140 top_k: usize,
142}
143
144impl MoERouter {
145 pub fn new(d_model: usize, num_experts: usize, top_k: usize) -> Self {
152 assert!(
153 top_k <= num_experts,
154 "top_k ({top_k}) must be <= num_experts ({num_experts})"
155 );
156 Self {
157 gate: Linear::with_bias(d_model, num_experts, false),
158 num_experts,
159 top_k,
160 }
161 }
162
163 pub fn route(&self, x: &Variable) -> (Variable, Vec<Vec<f32>>, Vec<Vec<usize>>) {
173 let gate_logits = self.gate.forward(x);
174 let gate_probs = gate_logits.softmax(-1);
175
176 let probs_data = gate_probs.data();
177 let probs_vec = probs_data.to_vec();
178 let num_tokens = probs_data.shape()[0];
179
180 let mut top_k_weights = Vec::with_capacity(num_tokens);
181 let mut top_k_indices = Vec::with_capacity(num_tokens);
182
183 for t in 0..num_tokens {
184 let offset = t * self.num_experts;
185 let token_probs = &probs_vec[offset..offset + self.num_experts];
186
187 let mut indexed: Vec<(usize, f32)> = token_probs
189 .iter()
190 .enumerate()
191 .map(|(i, &p)| (i, p))
192 .collect();
193 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
194
195 let top_indices: Vec<usize> = indexed[..self.top_k].iter().map(|(i, _)| *i).collect();
196 let top_weights: Vec<f32> = indexed[..self.top_k].iter().map(|(_, w)| *w).collect();
197
198 let weight_sum: f32 = top_weights.iter().sum();
200 let normalized: Vec<f32> = if weight_sum > 0.0 {
201 top_weights.iter().map(|w| w / weight_sum).collect()
202 } else {
203 vec![1.0 / self.top_k as f32; self.top_k]
204 };
205
206 top_k_weights.push(normalized);
207 top_k_indices.push(top_indices);
208 }
209
210 (gate_probs, top_k_weights, top_k_indices)
211 }
212
213 pub fn num_experts(&self) -> usize {
215 self.num_experts
216 }
217
218 pub fn top_k(&self) -> usize {
220 self.top_k
221 }
222}
223
224impl std::fmt::Debug for MoERouter {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 f.debug_struct("MoERouter")
227 .field("num_experts", &self.num_experts)
228 .field("top_k", &self.top_k)
229 .finish()
230 }
231}
232
233pub struct MoELayer {
253 experts: Vec<Expert>,
255 router: MoERouter,
257 d_model: usize,
259 num_experts: usize,
261 top_k: usize,
263 last_gate_probs: std::sync::RwLock<Option<Variable>>,
265 last_expert_counts: std::sync::RwLock<Vec<usize>>,
267}
268
269impl MoELayer {
270 pub fn new(d_model: usize, intermediate_size: usize, num_experts: usize, top_k: usize) -> Self {
278 let experts: Vec<Expert> = (0..num_experts)
279 .map(|_| Expert::new(d_model, intermediate_size))
280 .collect();
281 let router = MoERouter::new(d_model, num_experts, top_k);
282
283 Self {
284 experts,
285 router,
286 d_model,
287 num_experts,
288 top_k,
289 last_gate_probs: std::sync::RwLock::new(None),
290 last_expert_counts: std::sync::RwLock::new(vec![0; num_experts]),
291 }
292 }
293
294 pub fn load_balancing_loss(&self) -> Variable {
306 let gate_probs_opt = self.last_gate_probs.read().unwrap();
307 if gate_probs_opt.is_none() {
308 return Variable::new(
309 Tensor::from_vec(vec![0.0f32], &[1]).expect("tensor creation failed"),
310 false,
311 );
312 }
313
314 let gate_probs = gate_probs_opt.as_ref().unwrap();
315 let probs_data = gate_probs.data();
316 let probs_vec = probs_data.to_vec();
317 let shape = probs_data.shape();
318 let num_tokens = shape[0];
319 let num_experts = shape[1];
320
321 if num_tokens == 0 {
322 return Variable::new(
323 Tensor::from_vec(vec![0.0f32], &[1]).expect("tensor creation failed"),
324 false,
325 );
326 }
327
328 let expert_counts = self.last_expert_counts.read().unwrap();
329
330 let token_fractions: Vec<f32> = expert_counts
332 .iter()
333 .map(|&c| c as f32 / num_tokens as f32)
334 .collect();
335
336 let mut mean_probs = vec![0.0f32; num_experts];
338 for t in 0..num_tokens {
339 for e in 0..num_experts {
340 mean_probs[e] += probs_vec[t * num_experts + e];
341 }
342 }
343 for p in &mut mean_probs {
344 *p /= num_tokens as f32;
345 }
346
347 let mut loss_val = 0.0f32;
349 for e in 0..num_experts {
350 loss_val += token_fractions[e] * mean_probs[e];
351 }
352 loss_val *= num_experts as f32;
353
354 Variable::new(
355 Tensor::from_vec(vec![loss_val], &[1]).expect("tensor creation failed"),
356 false,
357 )
358 }
359
360 pub fn expert_utilization(&self) -> Vec<usize> {
364 self.last_expert_counts.read().unwrap().clone()
365 }
366
367 pub fn num_experts(&self) -> usize {
369 self.num_experts
370 }
371
372 pub fn top_k(&self) -> usize {
374 self.top_k
375 }
376}
377
378impl Module for MoELayer {
379 fn forward(&self, input: &Variable) -> Variable {
380 let shape = input.shape();
381 let batch_size = shape[0];
382 let seq_len = shape[1];
383 let d_model = shape[2];
384 let num_tokens = batch_size * seq_len;
385
386 let flat_input = input.reshape(&[num_tokens, d_model]);
388
389 let (gate_probs, top_k_weights, top_k_indices) = self.router.route(&flat_input);
391
392 let mut expert_counts = vec![0usize; self.num_experts];
394 for indices in &top_k_indices {
395 for &idx in indices {
396 expert_counts[idx] += 1;
397 }
398 }
399 *self.last_expert_counts.write().unwrap() = expert_counts;
400 *self.last_gate_probs.write().unwrap() = Some(gate_probs);
401
402 let mut output_data = vec![0.0f32; num_tokens * d_model];
404
405 for expert_idx in 0..self.num_experts {
407 let mut token_indices = Vec::new();
409 let mut token_weights = Vec::new();
410
411 for (t, (indices, weights)) in
412 top_k_indices.iter().zip(top_k_weights.iter()).enumerate()
413 {
414 for (k, (&idx, &w)) in indices.iter().zip(weights.iter()).enumerate() {
415 if idx == expert_idx {
416 token_indices.push(t);
417 token_weights.push(w);
418 let _ = k;
419 }
420 }
421 }
422
423 if token_indices.is_empty() {
424 continue;
425 }
426
427 let flat_data = flat_input.data();
429 let flat_vec = flat_data.to_vec();
430 let n = token_indices.len();
431 let mut expert_input_data = Vec::with_capacity(n * d_model);
432 for &t in &token_indices {
433 let offset = t * d_model;
434 expert_input_data.extend_from_slice(&flat_vec[offset..offset + d_model]);
435 }
436 let expert_input = Variable::new(
437 Tensor::from_vec(expert_input_data, &[n, d_model]).expect("tensor creation failed"),
438 true,
439 );
440
441 let expert_output = self.experts[expert_idx].forward(&expert_input);
443 let expert_out_vec = expert_output.data().to_vec();
444
445 for (local_idx, &global_idx) in token_indices.iter().enumerate() {
447 let weight = token_weights[local_idx];
448 let src_offset = local_idx * d_model;
449 let dst_offset = global_idx * d_model;
450 for d in 0..d_model {
451 output_data[dst_offset + d] += weight * expert_out_vec[src_offset + d];
452 }
453 }
454 }
455
456 let output_tensor =
457 Tensor::from_vec(output_data, &[num_tokens, d_model]).expect("tensor creation failed");
458 let output = Variable::new(output_tensor, true);
459
460 output.reshape(&[batch_size, seq_len, d_model])
462 }
463
464 fn parameters(&self) -> Vec<Parameter> {
465 let mut params = Vec::new();
466 params.extend(self.router.gate.parameters());
467 for expert in &self.experts {
468 params.extend(expert.parameters());
469 }
470 params
471 }
472
473 fn named_parameters(&self) -> HashMap<String, Parameter> {
474 let mut params = HashMap::new();
475 for (name, param) in self.router.gate.named_parameters() {
476 params.insert(format!("router.gate.{name}"), param);
477 }
478 for (i, expert) in self.experts.iter().enumerate() {
479 for (name, param) in expert.named_parameters() {
480 params.insert(format!("experts.{i}.{name}"), param);
481 }
482 }
483 params
484 }
485
486 fn name(&self) -> &'static str {
487 "MoELayer"
488 }
489}
490
491impl std::fmt::Debug for MoELayer {
492 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
493 f.debug_struct("MoELayer")
494 .field("d_model", &self.d_model)
495 .field("num_experts", &self.num_experts)
496 .field("top_k", &self.top_k)
497 .field("experts", &self.experts.len())
498 .finish()
499 }
500}
501
502#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_expert_creation() {
512 let expert = Expert::new(64, 256);
513 let params = expert.parameters();
514 assert_eq!(params.len(), 3);
516 }
517
518 #[test]
519 fn test_expert_forward() {
520 let expert = Expert::new(64, 256);
521 let input = Variable::new(
522 Tensor::from_vec(vec![0.1; 4 * 64], &[4, 64]).expect("tensor creation failed"),
523 false,
524 );
525 let output = expert.forward(&input);
526 assert_eq!(output.shape(), vec![4, 64]);
527 }
528
529 #[test]
530 fn test_router_creation() {
531 let router = MoERouter::new(64, 8, 2);
532 assert_eq!(router.num_experts(), 8);
533 assert_eq!(router.top_k(), 2);
534 }
535
536 #[test]
537 fn test_router_route() {
538 let router = MoERouter::new(64, 8, 2);
539 let input = Variable::new(
540 Tensor::from_vec(vec![0.1; 4 * 64], &[4, 64]).expect("tensor creation failed"),
541 false,
542 );
543 let (_gate_probs, weights, indices) = router.route(&input);
544
545 assert_eq!(weights.len(), 4); assert_eq!(indices.len(), 4);
547 for w in &weights {
548 assert_eq!(w.len(), 2); let sum: f32 = w.iter().sum();
550 assert!((sum - 1.0).abs() < 1e-5, "Weights should sum to 1");
551 }
552 for idx in &indices {
553 assert_eq!(idx.len(), 2);
554 for &i in idx {
555 assert!(i < 8, "Expert index should be < num_experts");
556 }
557 }
558 }
559
560 #[test]
561 fn test_moe_layer_forward() {
562 let moe = MoELayer::new(64, 256, 8, 2);
563 let input = Variable::new(
564 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
565 false,
566 );
567 let output = moe.forward(&input);
568 assert_eq!(output.shape(), vec![2, 5, 64]);
569 }
570
571 #[test]
572 fn test_moe_layer_parameters() {
573 let moe = MoELayer::new(64, 256, 8, 2);
574 let params = moe.parameters();
575 assert_eq!(params.len(), 25);
579 }
580
581 #[test]
582 fn test_moe_load_balancing_loss() {
583 let moe = MoELayer::new(64, 256, 4, 2);
584 let input = Variable::new(
585 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
586 false,
587 );
588 let _output = moe.forward(&input);
589
590 let lb_loss = moe.load_balancing_loss();
591 let loss_val = lb_loss.data().to_vec()[0];
592 assert!(loss_val > 0.0, "Load balancing loss should be > 0");
594 }
595
596 #[test]
597 fn test_moe_expert_utilization() {
598 let moe = MoELayer::new(64, 256, 4, 2);
599 let input = Variable::new(
600 Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
601 false,
602 );
603 let _output = moe.forward(&input);
604
605 let util = moe.expert_utilization();
606 assert_eq!(util.len(), 4);
607 let total: usize = util.iter().sum();
608 assert_eq!(total, 20);
610 }
611
612 #[test]
613 fn test_moe_named_parameters() {
614 let moe = MoELayer::new(64, 256, 4, 2);
615 let named = moe.named_parameters();
616 assert!(named.contains_key("router.gate.weight"));
617 assert!(named.contains_key("experts.0.up_proj.weight"));
618 assert!(named.contains_key("experts.3.down_proj.weight"));
619 }
620}