Skip to main content

axonml_nn/layers/
moe.rs

1//! Mixture of Experts - Sparse Expert Routing
2//!
3//! Implements a Mixture of Experts (MoE) layer where each token is routed to
4//! a subset of expert MLPs. This allows massive model capacity (many experts)
5//! while keeping per-token compute constant (only top-k experts activate).
6//!
7//! Key components:
8//! - `Expert` — Standard SiLU-gated MLP (up_proj + gate_proj + down_proj)
9//! - `MoERouter` — Learned routing: Linear -> softmax -> top-k selection
10//! - `MoELayer` — Full MoE: routes tokens to top-k experts, combines outputs
11//! - `load_balancing_loss()` — Auxiliary loss preventing expert collapse
12//!
13//! # File
14//! `crates/axonml-nn/src/layers/moe.rs`
15//!
16//! # Author
17//! Andrew Jewell Sr. — AutomataNexus LLC
18//! ORCID: 0009-0005-2158-7060
19//!
20//! # Updated
21//! April 14, 2026 11:15 PM EST
22//!
23//! # Disclaimer
24//! Use at own risk. This software is provided "as is", without warranty of any
25//! kind, express or implied. The author and AutomataNexus shall not be held
26//! liable for any damages arising from the use of this software.
27
28use std::collections::HashMap;
29
30use axonml_autograd::Variable;
31use axonml_tensor::Tensor;
32
33use crate::layers::Linear;
34use crate::module::Module;
35use crate::parameter::Parameter;
36
37// =============================================================================
38// Expert MLP
39// =============================================================================
40
41/// A single expert MLP with SiLU-gated architecture (SwiGLU variant).
42///
43/// Computes: down_proj(SiLU(gate_proj(x)) * up_proj(x))
44///
45/// This is the same MLP architecture used in LLaMA/Mistral/Phi models.
46///
47/// # Shape
48/// - Input: (*, d_model)
49/// - Output: (*, d_model)
50pub struct Expert {
51    /// Up projection: d_model -> intermediate_size
52    up_proj: Linear,
53    /// Gate projection: d_model -> intermediate_size (for SiLU gating)
54    gate_proj: Linear,
55    /// Down projection: intermediate_size -> d_model
56    down_proj: Linear,
57}
58
59impl Expert {
60    /// Creates a new Expert MLP.
61    ///
62    /// # Arguments
63    /// * `d_model` - Input/output dimension
64    /// * `intermediate_size` - Hidden dimension (typically 4 * d_model or 8/3 * d_model)
65    pub fn new(d_model: usize, intermediate_size: usize) -> Self {
66        Self {
67            up_proj: Linear::with_bias(d_model, intermediate_size, false),
68            gate_proj: Linear::with_bias(d_model, intermediate_size, false),
69            down_proj: Linear::with_bias(intermediate_size, d_model, false),
70        }
71    }
72}
73
74impl Module for Expert {
75    fn forward(&self, input: &Variable) -> Variable {
76        // SwiGLU: down(SiLU(gate(x)) * up(x))
77        let gate = self.gate_proj.forward(input).silu();
78        let up = self.up_proj.forward(input);
79        let hidden = gate.mul_var(&up);
80        self.down_proj.forward(&hidden)
81    }
82
83    fn parameters(&self) -> Vec<Parameter> {
84        let mut params = Vec::new();
85        params.extend(self.up_proj.parameters());
86        params.extend(self.gate_proj.parameters());
87        params.extend(self.down_proj.parameters());
88        params
89    }
90
91    fn named_parameters(&self) -> HashMap<String, Parameter> {
92        let mut params = HashMap::new();
93        for (name, param) in self.up_proj.named_parameters() {
94            params.insert(format!("up_proj.{name}"), param);
95        }
96        for (name, param) in self.gate_proj.named_parameters() {
97            params.insert(format!("gate_proj.{name}"), param);
98        }
99        for (name, param) in self.down_proj.named_parameters() {
100            params.insert(format!("down_proj.{name}"), param);
101        }
102        params
103    }
104
105    fn name(&self) -> &'static str {
106        "Expert"
107    }
108}
109
110impl std::fmt::Debug for Expert {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        f.debug_struct("Expert")
113            .field("up_proj", &self.up_proj)
114            .field("gate_proj", &self.gate_proj)
115            .field("down_proj", &self.down_proj)
116            .finish()
117    }
118}
119
120// =============================================================================
121// MoE Router
122// =============================================================================
123
124/// Router for Mixture of Experts.
125///
126/// Maps each token's hidden state to expert selection probabilities via a
127/// learned linear projection followed by softmax, then selects the top-k
128/// experts per token.
129///
130/// # Routing
131/// ```text
132/// gate_logits = Linear(x)           [num_tokens, num_experts]
133/// gate_probs  = softmax(gate_logits) [num_tokens, num_experts]
134/// top_k_probs, top_k_indices = topk(gate_probs, k)
135/// ```
136pub struct MoERouter {
137    /// Gate projection: d_model -> num_experts
138    gate: Linear,
139    /// Number of experts
140    num_experts: usize,
141    /// Number of experts to select per token
142    top_k: usize,
143}
144
145impl MoERouter {
146    /// Creates a new MoE router.
147    ///
148    /// # Arguments
149    /// * `d_model` - Input hidden dimension
150    /// * `num_experts` - Total number of experts
151    /// * `top_k` - Number of experts activated per token
152    pub fn new(d_model: usize, num_experts: usize, top_k: usize) -> Self {
153        assert!(
154            top_k <= num_experts,
155            "top_k ({top_k}) must be <= num_experts ({num_experts})"
156        );
157        Self {
158            gate: Linear::with_bias(d_model, num_experts, false),
159            num_experts,
160            top_k,
161        }
162    }
163
164    /// Routes tokens to experts.
165    ///
166    /// # Arguments
167    /// * `x` - Hidden states [num_tokens, d_model]
168    ///
169    /// # Returns
170    /// * `gate_probs` - Full probability distribution [num_tokens, num_experts] (for load balancing)
171    /// * `top_k_weights` - Normalized weights for selected experts [num_tokens, top_k]
172    /// * `top_k_indices` - Indices of selected experts [num_tokens, top_k]
173    pub fn route(&self, x: &Variable) -> (Variable, Vec<Vec<f32>>, Vec<Vec<usize>>) {
174        let gate_logits = self.gate.forward(x);
175        let gate_probs = gate_logits.softmax(-1);
176
177        let probs_data = gate_probs.data();
178        let probs_vec = probs_data.to_vec();
179        let num_tokens = probs_data.shape()[0];
180
181        let mut top_k_weights = Vec::with_capacity(num_tokens);
182        let mut top_k_indices = Vec::with_capacity(num_tokens);
183
184        for t in 0..num_tokens {
185            let offset = t * self.num_experts;
186            let token_probs = &probs_vec[offset..offset + self.num_experts];
187
188            // Find top-k experts by probability
189            let mut indexed: Vec<(usize, f32)> = token_probs
190                .iter()
191                .enumerate()
192                .map(|(i, &p)| (i, p))
193                .collect();
194            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
195
196            let top_indices: Vec<usize> = indexed[..self.top_k].iter().map(|(i, _)| *i).collect();
197            let top_weights: Vec<f32> = indexed[..self.top_k].iter().map(|(_, w)| *w).collect();
198
199            // Normalize top-k weights to sum to 1
200            let weight_sum: f32 = top_weights.iter().sum();
201            let normalized: Vec<f32> = if weight_sum > 0.0 {
202                top_weights.iter().map(|w| w / weight_sum).collect()
203            } else {
204                vec![1.0 / self.top_k as f32; self.top_k]
205            };
206
207            top_k_weights.push(normalized);
208            top_k_indices.push(top_indices);
209        }
210
211        (gate_probs, top_k_weights, top_k_indices)
212    }
213
214    /// Returns the number of experts.
215    pub fn num_experts(&self) -> usize {
216        self.num_experts
217    }
218
219    /// Returns top-k value.
220    pub fn top_k(&self) -> usize {
221        self.top_k
222    }
223}
224
225impl std::fmt::Debug for MoERouter {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        f.debug_struct("MoERouter")
228            .field("num_experts", &self.num_experts)
229            .field("top_k", &self.top_k)
230            .finish()
231    }
232}
233
234// =============================================================================
235// MoE Layer
236// =============================================================================
237
238/// Mixture of Experts layer.
239///
240/// Routes each token to the top-k experts (out of N total), applies the
241/// selected expert MLPs, and combines outputs using the routing weights.
242///
243/// This provides N times the model capacity while only using top-k/N of
244/// the compute per token.
245///
246/// # Shape
247/// - Input: (batch, seq_len, d_model)
248/// - Output: (batch, seq_len, d_model)
249///
250/// # Load Balancing
251/// Call `load_balancing_loss()` after forward to get the auxiliary loss
252/// that prevents expert collapse (all tokens routing to same expert).
253pub struct MoELayer {
254    /// Expert MLPs
255    experts: Vec<Expert>,
256    /// Router for expert selection
257    router: MoERouter,
258    /// Model dimension
259    d_model: usize,
260    /// Number of experts
261    num_experts: usize,
262    /// Top-k routing
263    top_k: usize,
264    /// Cached gate probabilities from last forward pass (for load balancing loss)
265    last_gate_probs: std::sync::RwLock<Option<Variable>>,
266    /// Cached expert assignments from last forward (for utilization stats)
267    last_expert_counts: std::sync::RwLock<Vec<usize>>,
268}
269
270impl MoELayer {
271    /// Creates a new MoE layer.
272    ///
273    /// # Arguments
274    /// * `d_model` - Hidden dimension
275    /// * `intermediate_size` - Expert MLP hidden dimension
276    /// * `num_experts` - Total number of expert MLPs
277    /// * `top_k` - Number of experts activated per token
278    pub fn new(d_model: usize, intermediate_size: usize, num_experts: usize, top_k: usize) -> Self {
279        let experts: Vec<Expert> = (0..num_experts)
280            .map(|_| Expert::new(d_model, intermediate_size))
281            .collect();
282        let router = MoERouter::new(d_model, num_experts, top_k);
283
284        Self {
285            experts,
286            router,
287            d_model,
288            num_experts,
289            top_k,
290            last_gate_probs: std::sync::RwLock::new(None),
291            last_expert_counts: std::sync::RwLock::new(vec![0; num_experts]),
292        }
293    }
294
295    /// Computes load balancing auxiliary loss.
296    ///
297    /// This loss encourages uniform routing across experts to prevent
298    /// expert collapse. Uses the formulation from Switch Transformer:
299    ///
300    ///   L_bal = num_experts * sum_i(f_i * P_i)
301    ///
302    /// where f_i = fraction of tokens assigned to expert i,
303    ///       P_i = mean routing probability for expert i.
304    ///
305    /// Returns zero if no forward pass has been done yet.
306    pub fn load_balancing_loss(&self) -> Variable {
307        let gate_probs_opt = self.last_gate_probs.read().unwrap();
308        if gate_probs_opt.is_none() {
309            return Variable::new(
310                Tensor::from_vec(vec![0.0f32], &[1]).expect("tensor creation failed"),
311                false,
312            );
313        }
314
315        let gate_probs = gate_probs_opt.as_ref().unwrap();
316        let probs_data = gate_probs.data();
317        let probs_vec = probs_data.to_vec();
318        let shape = probs_data.shape();
319        let num_tokens = shape[0];
320        let num_experts = shape[1];
321
322        if num_tokens == 0 {
323            return Variable::new(
324                Tensor::from_vec(vec![0.0f32], &[1]).expect("tensor creation failed"),
325                false,
326            );
327        }
328
329        let expert_counts = self.last_expert_counts.read().unwrap();
330
331        // f_i: fraction of tokens routed to expert i
332        let token_fractions: Vec<f32> = expert_counts
333            .iter()
334            .map(|&c| c as f32 / num_tokens as f32)
335            .collect();
336
337        // P_i: mean routing probability for expert i
338        let mut mean_probs = vec![0.0f32; num_experts];
339        for t in 0..num_tokens {
340            for e in 0..num_experts {
341                mean_probs[e] += probs_vec[t * num_experts + e];
342            }
343        }
344        for p in &mut mean_probs {
345            *p /= num_tokens as f32;
346        }
347
348        // L_bal = num_experts * sum(f_i * P_i)
349        let mut loss_val = 0.0f32;
350        for e in 0..num_experts {
351            loss_val += token_fractions[e] * mean_probs[e];
352        }
353        loss_val *= num_experts as f32;
354
355        Variable::new(
356            Tensor::from_vec(vec![loss_val], &[1]).expect("tensor creation failed"),
357            false,
358        )
359    }
360
361    /// Returns expert utilization counts from the last forward pass.
362    ///
363    /// Each element is the number of tokens routed to that expert.
364    pub fn expert_utilization(&self) -> Vec<usize> {
365        self.last_expert_counts.read().unwrap().clone()
366    }
367
368    /// Returns the number of experts.
369    pub fn num_experts(&self) -> usize {
370        self.num_experts
371    }
372
373    /// Returns the top-k value.
374    pub fn top_k(&self) -> usize {
375        self.top_k
376    }
377}
378
379impl Module for MoELayer {
380    fn forward(&self, input: &Variable) -> Variable {
381        let shape = input.shape();
382        let batch_size = shape[0];
383        let seq_len = shape[1];
384        let d_model = shape[2];
385        let num_tokens = batch_size * seq_len;
386
387        // Flatten to [num_tokens, d_model]
388        let flat_input = input.reshape(&[num_tokens, d_model]);
389
390        // Route tokens to experts
391        let (gate_probs, top_k_weights, top_k_indices) = self.router.route(&flat_input);
392
393        // Track expert utilization
394        let mut expert_counts = vec![0usize; self.num_experts];
395        for indices in &top_k_indices {
396            for &idx in indices {
397                expert_counts[idx] += 1;
398            }
399        }
400        *self.last_expert_counts.write().unwrap() = expert_counts;
401        *self.last_gate_probs.write().unwrap() = Some(gate_probs);
402
403        // Initialize output as zeros
404        let mut output_data = vec![0.0f32; num_tokens * d_model];
405
406        // Process each expert: gather tokens, forward, scatter back
407        for expert_idx in 0..self.num_experts {
408            // Find which tokens go to this expert and their weights
409            let mut token_indices = Vec::new();
410            let mut token_weights = Vec::new();
411
412            for (t, (indices, weights)) in
413                top_k_indices.iter().zip(top_k_weights.iter()).enumerate()
414            {
415                for (k, (&idx, &w)) in indices.iter().zip(weights.iter()).enumerate() {
416                    if idx == expert_idx {
417                        token_indices.push(t);
418                        token_weights.push(w);
419                        let _ = k;
420                    }
421                }
422            }
423
424            if token_indices.is_empty() {
425                continue;
426            }
427
428            // Gather tokens for this expert
429            let flat_data = flat_input.data();
430            let flat_vec = flat_data.to_vec();
431            let n = token_indices.len();
432            let mut expert_input_data = Vec::with_capacity(n * d_model);
433            for &t in &token_indices {
434                let offset = t * d_model;
435                expert_input_data.extend_from_slice(&flat_vec[offset..offset + d_model]);
436            }
437            let expert_input = Variable::new(
438                Tensor::from_vec(expert_input_data, &[n, d_model]).expect("tensor creation failed"),
439                true,
440            );
441
442            // Forward through expert
443            let expert_output = self.experts[expert_idx].forward(&expert_input);
444            let expert_out_vec = expert_output.data().to_vec();
445
446            // Scatter weighted outputs back
447            for (local_idx, &global_idx) in token_indices.iter().enumerate() {
448                let weight = token_weights[local_idx];
449                let src_offset = local_idx * d_model;
450                let dst_offset = global_idx * d_model;
451                for d in 0..d_model {
452                    output_data[dst_offset + d] += weight * expert_out_vec[src_offset + d];
453                }
454            }
455        }
456
457        let output_tensor =
458            Tensor::from_vec(output_data, &[num_tokens, d_model]).expect("tensor creation failed");
459        let output = Variable::new(output_tensor, true);
460
461        // Reshape back to [batch, seq_len, d_model]
462        output.reshape(&[batch_size, seq_len, d_model])
463    }
464
465    fn parameters(&self) -> Vec<Parameter> {
466        let mut params = Vec::new();
467        params.extend(self.router.gate.parameters());
468        for expert in &self.experts {
469            params.extend(expert.parameters());
470        }
471        params
472    }
473
474    fn named_parameters(&self) -> HashMap<String, Parameter> {
475        let mut params = HashMap::new();
476        for (name, param) in self.router.gate.named_parameters() {
477            params.insert(format!("router.gate.{name}"), param);
478        }
479        for (i, expert) in self.experts.iter().enumerate() {
480            for (name, param) in expert.named_parameters() {
481                params.insert(format!("experts.{i}.{name}"), param);
482            }
483        }
484        params
485    }
486
487    fn name(&self) -> &'static str {
488        "MoELayer"
489    }
490}
491
492impl std::fmt::Debug for MoELayer {
493    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494        f.debug_struct("MoELayer")
495            .field("d_model", &self.d_model)
496            .field("num_experts", &self.num_experts)
497            .field("top_k", &self.top_k)
498            .field("experts", &self.experts.len())
499            .finish()
500    }
501}
502
503// =============================================================================
504// Tests
505// =============================================================================
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510
511    #[test]
512    fn test_expert_creation() {
513        let expert = Expert::new(64, 256);
514        let params = expert.parameters();
515        // up_proj(w) + gate_proj(w) + down_proj(w) = 3 weights (no bias)
516        assert_eq!(params.len(), 3);
517    }
518
519    #[test]
520    fn test_expert_forward() {
521        let expert = Expert::new(64, 256);
522        let input = Variable::new(
523            Tensor::from_vec(vec![0.1; 4 * 64], &[4, 64]).expect("tensor creation failed"),
524            false,
525        );
526        let output = expert.forward(&input);
527        assert_eq!(output.shape(), vec![4, 64]);
528    }
529
530    #[test]
531    fn test_router_creation() {
532        let router = MoERouter::new(64, 8, 2);
533        assert_eq!(router.num_experts(), 8);
534        assert_eq!(router.top_k(), 2);
535    }
536
537    #[test]
538    fn test_router_route() {
539        let router = MoERouter::new(64, 8, 2);
540        let input = Variable::new(
541            Tensor::from_vec(vec![0.1; 4 * 64], &[4, 64]).expect("tensor creation failed"),
542            false,
543        );
544        let (_gate_probs, weights, indices) = router.route(&input);
545
546        assert_eq!(weights.len(), 4); // 4 tokens
547        assert_eq!(indices.len(), 4);
548        for w in &weights {
549            assert_eq!(w.len(), 2); // top-2
550            let sum: f32 = w.iter().sum();
551            assert!((sum - 1.0).abs() < 1e-5, "Weights should sum to 1");
552        }
553        for idx in &indices {
554            assert_eq!(idx.len(), 2);
555            for &i in idx {
556                assert!(i < 8, "Expert index should be < num_experts");
557            }
558        }
559    }
560
561    #[test]
562    fn test_moe_layer_forward() {
563        let moe = MoELayer::new(64, 256, 8, 2);
564        let input = Variable::new(
565            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
566            false,
567        );
568        let output = moe.forward(&input);
569        assert_eq!(output.shape(), vec![2, 5, 64]);
570    }
571
572    #[test]
573    fn test_moe_layer_parameters() {
574        let moe = MoELayer::new(64, 256, 8, 2);
575        let params = moe.parameters();
576        // Router: 1 weight (no bias)
577        // 8 experts * 3 weights each = 24
578        // Total = 25
579        assert_eq!(params.len(), 25);
580    }
581
582    #[test]
583    fn test_moe_load_balancing_loss() {
584        let moe = MoELayer::new(64, 256, 4, 2);
585        let input = Variable::new(
586            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
587            false,
588        );
589        let _output = moe.forward(&input);
590
591        let lb_loss = moe.load_balancing_loss();
592        let loss_val = lb_loss.data().to_vec()[0];
593        // Load balancing loss should be positive
594        assert!(loss_val > 0.0, "Load balancing loss should be > 0");
595    }
596
597    #[test]
598    fn test_moe_expert_utilization() {
599        let moe = MoELayer::new(64, 256, 4, 2);
600        let input = Variable::new(
601            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
602            false,
603        );
604        let _output = moe.forward(&input);
605
606        let util = moe.expert_utilization();
607        assert_eq!(util.len(), 4);
608        let total: usize = util.iter().sum();
609        // Each of 10 tokens selects top-2 experts = 20 assignments total
610        assert_eq!(total, 20);
611    }
612
613    #[test]
614    fn test_moe_named_parameters() {
615        let moe = MoELayer::new(64, 256, 4, 2);
616        let named = moe.named_parameters();
617        assert!(named.contains_key("router.gate.weight"));
618        assert!(named.contains_key("experts.0.up_proj.weight"));
619        assert!(named.contains_key("experts.3.down_proj.weight"));
620    }
621}