Skip to main content

axonml_nn/layers/
moe.rs

1//! Mixture of Experts - Sparse Expert Routing
2//!
3//! Implements a Mixture of Experts (MoE) layer where each token is routed to
4//! a subset of expert MLPs. This allows massive model capacity (many experts)
5//! while keeping per-token compute constant (only top-k experts activate).
6//!
7//! Key components:
8//! - `Expert` — Standard SiLU-gated MLP (up_proj + gate_proj + down_proj)
9//! - `MoERouter` — Learned routing: Linear -> softmax -> top-k selection
10//! - `MoELayer` — Full MoE: routes tokens to top-k experts, combines outputs
11//! - `load_balancing_loss()` — Auxiliary loss preventing expert collapse
12//!
13//! # File
14//! `crates/axonml-nn/src/layers/moe.rs`
15//!
16//! # Author
17//! Andrew Jewell Sr - AutomataNexus
18//!
19//! # Updated
20//! March 19, 2026
21//!
22//! # Disclaimer
23//! Use at own risk. This software is provided "as is", without warranty of any
24//! kind, express or implied. The author and AutomataNexus shall not be held
25//! liable for any damages arising from the use of this software.
26
27use std::collections::HashMap;
28
29use axonml_autograd::Variable;
30use axonml_tensor::Tensor;
31
32use crate::layers::Linear;
33use crate::module::Module;
34use crate::parameter::Parameter;
35
36// =============================================================================
37// Expert MLP
38// =============================================================================
39
40/// A single expert MLP with SiLU-gated architecture (SwiGLU variant).
41///
42/// Computes: down_proj(SiLU(gate_proj(x)) * up_proj(x))
43///
44/// This is the same MLP architecture used in LLaMA/Mistral/Phi models.
45///
46/// # Shape
47/// - Input: (*, d_model)
48/// - Output: (*, d_model)
49pub struct Expert {
50    /// Up projection: d_model -> intermediate_size
51    up_proj: Linear,
52    /// Gate projection: d_model -> intermediate_size (for SiLU gating)
53    gate_proj: Linear,
54    /// Down projection: intermediate_size -> d_model
55    down_proj: Linear,
56}
57
58impl Expert {
59    /// Creates a new Expert MLP.
60    ///
61    /// # Arguments
62    /// * `d_model` - Input/output dimension
63    /// * `intermediate_size` - Hidden dimension (typically 4 * d_model or 8/3 * d_model)
64    pub fn new(d_model: usize, intermediate_size: usize) -> Self {
65        Self {
66            up_proj: Linear::with_bias(d_model, intermediate_size, false),
67            gate_proj: Linear::with_bias(d_model, intermediate_size, false),
68            down_proj: Linear::with_bias(intermediate_size, d_model, false),
69        }
70    }
71}
72
73impl Module for Expert {
74    fn forward(&self, input: &Variable) -> Variable {
75        // SwiGLU: down(SiLU(gate(x)) * up(x))
76        let gate = self.gate_proj.forward(input).silu();
77        let up = self.up_proj.forward(input);
78        let hidden = gate.mul_var(&up);
79        self.down_proj.forward(&hidden)
80    }
81
82    fn parameters(&self) -> Vec<Parameter> {
83        let mut params = Vec::new();
84        params.extend(self.up_proj.parameters());
85        params.extend(self.gate_proj.parameters());
86        params.extend(self.down_proj.parameters());
87        params
88    }
89
90    fn named_parameters(&self) -> HashMap<String, Parameter> {
91        let mut params = HashMap::new();
92        for (name, param) in self.up_proj.named_parameters() {
93            params.insert(format!("up_proj.{name}"), param);
94        }
95        for (name, param) in self.gate_proj.named_parameters() {
96            params.insert(format!("gate_proj.{name}"), param);
97        }
98        for (name, param) in self.down_proj.named_parameters() {
99            params.insert(format!("down_proj.{name}"), param);
100        }
101        params
102    }
103
104    fn name(&self) -> &'static str {
105        "Expert"
106    }
107}
108
109impl std::fmt::Debug for Expert {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        f.debug_struct("Expert")
112            .field("up_proj", &self.up_proj)
113            .field("gate_proj", &self.gate_proj)
114            .field("down_proj", &self.down_proj)
115            .finish()
116    }
117}
118
119// =============================================================================
120// MoE Router
121// =============================================================================
122
123/// Router for Mixture of Experts.
124///
125/// Maps each token's hidden state to expert selection probabilities via a
126/// learned linear projection followed by softmax, then selects the top-k
127/// experts per token.
128///
129/// # Routing
130/// ```text
131/// gate_logits = Linear(x)           [num_tokens, num_experts]
132/// gate_probs  = softmax(gate_logits) [num_tokens, num_experts]
133/// top_k_probs, top_k_indices = topk(gate_probs, k)
134/// ```
135pub struct MoERouter {
136    /// Gate projection: d_model -> num_experts
137    gate: Linear,
138    /// Number of experts
139    num_experts: usize,
140    /// Number of experts to select per token
141    top_k: usize,
142}
143
144impl MoERouter {
145    /// Creates a new MoE router.
146    ///
147    /// # Arguments
148    /// * `d_model` - Input hidden dimension
149    /// * `num_experts` - Total number of experts
150    /// * `top_k` - Number of experts activated per token
151    pub fn new(d_model: usize, num_experts: usize, top_k: usize) -> Self {
152        assert!(
153            top_k <= num_experts,
154            "top_k ({top_k}) must be <= num_experts ({num_experts})"
155        );
156        Self {
157            gate: Linear::with_bias(d_model, num_experts, false),
158            num_experts,
159            top_k,
160        }
161    }
162
163    /// Routes tokens to experts.
164    ///
165    /// # Arguments
166    /// * `x` - Hidden states [num_tokens, d_model]
167    ///
168    /// # Returns
169    /// * `gate_probs` - Full probability distribution [num_tokens, num_experts] (for load balancing)
170    /// * `top_k_weights` - Normalized weights for selected experts [num_tokens, top_k]
171    /// * `top_k_indices` - Indices of selected experts [num_tokens, top_k]
172    pub fn route(&self, x: &Variable) -> (Variable, Vec<Vec<f32>>, Vec<Vec<usize>>) {
173        let gate_logits = self.gate.forward(x);
174        let gate_probs = gate_logits.softmax(-1);
175
176        let probs_data = gate_probs.data();
177        let probs_vec = probs_data.to_vec();
178        let num_tokens = probs_data.shape()[0];
179
180        let mut top_k_weights = Vec::with_capacity(num_tokens);
181        let mut top_k_indices = Vec::with_capacity(num_tokens);
182
183        for t in 0..num_tokens {
184            let offset = t * self.num_experts;
185            let token_probs = &probs_vec[offset..offset + self.num_experts];
186
187            // Find top-k experts by probability
188            let mut indexed: Vec<(usize, f32)> = token_probs
189                .iter()
190                .enumerate()
191                .map(|(i, &p)| (i, p))
192                .collect();
193            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
194
195            let top_indices: Vec<usize> = indexed[..self.top_k].iter().map(|(i, _)| *i).collect();
196            let top_weights: Vec<f32> = indexed[..self.top_k].iter().map(|(_, w)| *w).collect();
197
198            // Normalize top-k weights to sum to 1
199            let weight_sum: f32 = top_weights.iter().sum();
200            let normalized: Vec<f32> = if weight_sum > 0.0 {
201                top_weights.iter().map(|w| w / weight_sum).collect()
202            } else {
203                vec![1.0 / self.top_k as f32; self.top_k]
204            };
205
206            top_k_weights.push(normalized);
207            top_k_indices.push(top_indices);
208        }
209
210        (gate_probs, top_k_weights, top_k_indices)
211    }
212
213    /// Returns the number of experts.
214    pub fn num_experts(&self) -> usize {
215        self.num_experts
216    }
217
218    /// Returns top-k value.
219    pub fn top_k(&self) -> usize {
220        self.top_k
221    }
222}
223
224impl std::fmt::Debug for MoERouter {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        f.debug_struct("MoERouter")
227            .field("num_experts", &self.num_experts)
228            .field("top_k", &self.top_k)
229            .finish()
230    }
231}
232
233// =============================================================================
234// MoE Layer
235// =============================================================================
236
237/// Mixture of Experts layer.
238///
239/// Routes each token to the top-k experts (out of N total), applies the
240/// selected expert MLPs, and combines outputs using the routing weights.
241///
242/// This provides N times the model capacity while only using top-k/N of
243/// the compute per token.
244///
245/// # Shape
246/// - Input: (batch, seq_len, d_model)
247/// - Output: (batch, seq_len, d_model)
248///
249/// # Load Balancing
250/// Call `load_balancing_loss()` after forward to get the auxiliary loss
251/// that prevents expert collapse (all tokens routing to same expert).
252pub struct MoELayer {
253    /// Expert MLPs
254    experts: Vec<Expert>,
255    /// Router for expert selection
256    router: MoERouter,
257    /// Model dimension
258    d_model: usize,
259    /// Number of experts
260    num_experts: usize,
261    /// Top-k routing
262    top_k: usize,
263    /// Cached gate probabilities from last forward pass (for load balancing loss)
264    last_gate_probs: std::sync::RwLock<Option<Variable>>,
265    /// Cached expert assignments from last forward (for utilization stats)
266    last_expert_counts: std::sync::RwLock<Vec<usize>>,
267}
268
269impl MoELayer {
270    /// Creates a new MoE layer.
271    ///
272    /// # Arguments
273    /// * `d_model` - Hidden dimension
274    /// * `intermediate_size` - Expert MLP hidden dimension
275    /// * `num_experts` - Total number of expert MLPs
276    /// * `top_k` - Number of experts activated per token
277    pub fn new(d_model: usize, intermediate_size: usize, num_experts: usize, top_k: usize) -> Self {
278        let experts: Vec<Expert> = (0..num_experts)
279            .map(|_| Expert::new(d_model, intermediate_size))
280            .collect();
281        let router = MoERouter::new(d_model, num_experts, top_k);
282
283        Self {
284            experts,
285            router,
286            d_model,
287            num_experts,
288            top_k,
289            last_gate_probs: std::sync::RwLock::new(None),
290            last_expert_counts: std::sync::RwLock::new(vec![0; num_experts]),
291        }
292    }
293
294    /// Computes load balancing auxiliary loss.
295    ///
296    /// This loss encourages uniform routing across experts to prevent
297    /// expert collapse. Uses the formulation from Switch Transformer:
298    ///
299    ///   L_bal = num_experts * sum_i(f_i * P_i)
300    ///
301    /// where f_i = fraction of tokens assigned to expert i,
302    ///       P_i = mean routing probability for expert i.
303    ///
304    /// Returns zero if no forward pass has been done yet.
305    pub fn load_balancing_loss(&self) -> Variable {
306        let gate_probs_opt = self.last_gate_probs.read().unwrap();
307        if gate_probs_opt.is_none() {
308            return Variable::new(Tensor::from_vec(vec![0.0f32], &[1]).unwrap(), false);
309        }
310
311        let gate_probs = gate_probs_opt.as_ref().unwrap();
312        let probs_data = gate_probs.data();
313        let probs_vec = probs_data.to_vec();
314        let shape = probs_data.shape();
315        let num_tokens = shape[0];
316        let num_experts = shape[1];
317
318        if num_tokens == 0 {
319            return Variable::new(Tensor::from_vec(vec![0.0f32], &[1]).unwrap(), false);
320        }
321
322        let expert_counts = self.last_expert_counts.read().unwrap();
323
324        // f_i: fraction of tokens routed to expert i
325        let token_fractions: Vec<f32> = expert_counts
326            .iter()
327            .map(|&c| c as f32 / num_tokens as f32)
328            .collect();
329
330        // P_i: mean routing probability for expert i
331        let mut mean_probs = vec![0.0f32; num_experts];
332        for t in 0..num_tokens {
333            for e in 0..num_experts {
334                mean_probs[e] += probs_vec[t * num_experts + e];
335            }
336        }
337        for p in &mut mean_probs {
338            *p /= num_tokens as f32;
339        }
340
341        // L_bal = num_experts * sum(f_i * P_i)
342        let mut loss_val = 0.0f32;
343        for e in 0..num_experts {
344            loss_val += token_fractions[e] * mean_probs[e];
345        }
346        loss_val *= num_experts as f32;
347
348        Variable::new(Tensor::from_vec(vec![loss_val], &[1]).unwrap(), false)
349    }
350
351    /// Returns expert utilization counts from the last forward pass.
352    ///
353    /// Each element is the number of tokens routed to that expert.
354    pub fn expert_utilization(&self) -> Vec<usize> {
355        self.last_expert_counts.read().unwrap().clone()
356    }
357
358    /// Returns the number of experts.
359    pub fn num_experts(&self) -> usize {
360        self.num_experts
361    }
362
363    /// Returns the top-k value.
364    pub fn top_k(&self) -> usize {
365        self.top_k
366    }
367}
368
369impl Module for MoELayer {
370    fn forward(&self, input: &Variable) -> Variable {
371        let shape = input.shape();
372        let batch_size = shape[0];
373        let seq_len = shape[1];
374        let d_model = shape[2];
375        let num_tokens = batch_size * seq_len;
376
377        // Flatten to [num_tokens, d_model]
378        let flat_input = input.reshape(&[num_tokens, d_model]);
379
380        // Route tokens to experts
381        let (gate_probs, top_k_weights, top_k_indices) = self.router.route(&flat_input);
382
383        // Track expert utilization
384        let mut expert_counts = vec![0usize; self.num_experts];
385        for indices in &top_k_indices {
386            for &idx in indices {
387                expert_counts[idx] += 1;
388            }
389        }
390        *self.last_expert_counts.write().unwrap() = expert_counts;
391        *self.last_gate_probs.write().unwrap() = Some(gate_probs);
392
393        // Initialize output as zeros
394        let mut output_data = vec![0.0f32; num_tokens * d_model];
395
396        // Process each expert: gather tokens, forward, scatter back
397        for expert_idx in 0..self.num_experts {
398            // Find which tokens go to this expert and their weights
399            let mut token_indices = Vec::new();
400            let mut token_weights = Vec::new();
401
402            for (t, (indices, weights)) in
403                top_k_indices.iter().zip(top_k_weights.iter()).enumerate()
404            {
405                for (k, (&idx, &w)) in indices.iter().zip(weights.iter()).enumerate() {
406                    if idx == expert_idx {
407                        token_indices.push(t);
408                        token_weights.push(w);
409                        let _ = k;
410                    }
411                }
412            }
413
414            if token_indices.is_empty() {
415                continue;
416            }
417
418            // Gather tokens for this expert
419            let flat_data = flat_input.data();
420            let flat_vec = flat_data.to_vec();
421            let n = token_indices.len();
422            let mut expert_input_data = Vec::with_capacity(n * d_model);
423            for &t in &token_indices {
424                let offset = t * d_model;
425                expert_input_data.extend_from_slice(&flat_vec[offset..offset + d_model]);
426            }
427            let expert_input = Variable::new(
428                Tensor::from_vec(expert_input_data, &[n, d_model]).unwrap(),
429                true,
430            );
431
432            // Forward through expert
433            let expert_output = self.experts[expert_idx].forward(&expert_input);
434            let expert_out_vec = expert_output.data().to_vec();
435
436            // Scatter weighted outputs back
437            for (local_idx, &global_idx) in token_indices.iter().enumerate() {
438                let weight = token_weights[local_idx];
439                let src_offset = local_idx * d_model;
440                let dst_offset = global_idx * d_model;
441                for d in 0..d_model {
442                    output_data[dst_offset + d] += weight * expert_out_vec[src_offset + d];
443                }
444            }
445        }
446
447        let output_tensor = Tensor::from_vec(output_data, &[num_tokens, d_model]).unwrap();
448        let output = Variable::new(output_tensor, true);
449
450        // Reshape back to [batch, seq_len, d_model]
451        output.reshape(&[batch_size, seq_len, d_model])
452    }
453
454    fn parameters(&self) -> Vec<Parameter> {
455        let mut params = Vec::new();
456        params.extend(self.router.gate.parameters());
457        for expert in &self.experts {
458            params.extend(expert.parameters());
459        }
460        params
461    }
462
463    fn named_parameters(&self) -> HashMap<String, Parameter> {
464        let mut params = HashMap::new();
465        for (name, param) in self.router.gate.named_parameters() {
466            params.insert(format!("router.gate.{name}"), param);
467        }
468        for (i, expert) in self.experts.iter().enumerate() {
469            for (name, param) in expert.named_parameters() {
470                params.insert(format!("experts.{i}.{name}"), param);
471            }
472        }
473        params
474    }
475
476    fn name(&self) -> &'static str {
477        "MoELayer"
478    }
479}
480
481impl std::fmt::Debug for MoELayer {
482    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483        f.debug_struct("MoELayer")
484            .field("d_model", &self.d_model)
485            .field("num_experts", &self.num_experts)
486            .field("top_k", &self.top_k)
487            .field("experts", &self.experts.len())
488            .finish()
489    }
490}
491
492// =============================================================================
493// Tests
494// =============================================================================
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    #[test]
501    fn test_expert_creation() {
502        let expert = Expert::new(64, 256);
503        let params = expert.parameters();
504        // up_proj(w) + gate_proj(w) + down_proj(w) = 3 weights (no bias)
505        assert_eq!(params.len(), 3);
506    }
507
508    #[test]
509    fn test_expert_forward() {
510        let expert = Expert::new(64, 256);
511        let input = Variable::new(
512            Tensor::from_vec(vec![0.1; 4 * 64], &[4, 64]).unwrap(),
513            false,
514        );
515        let output = expert.forward(&input);
516        assert_eq!(output.shape(), vec![4, 64]);
517    }
518
519    #[test]
520    fn test_router_creation() {
521        let router = MoERouter::new(64, 8, 2);
522        assert_eq!(router.num_experts(), 8);
523        assert_eq!(router.top_k(), 2);
524    }
525
526    #[test]
527    fn test_router_route() {
528        let router = MoERouter::new(64, 8, 2);
529        let input = Variable::new(
530            Tensor::from_vec(vec![0.1; 4 * 64], &[4, 64]).unwrap(),
531            false,
532        );
533        let (_gate_probs, weights, indices) = router.route(&input);
534
535        assert_eq!(weights.len(), 4); // 4 tokens
536        assert_eq!(indices.len(), 4);
537        for w in &weights {
538            assert_eq!(w.len(), 2); // top-2
539            let sum: f32 = w.iter().sum();
540            assert!((sum - 1.0).abs() < 1e-5, "Weights should sum to 1");
541        }
542        for idx in &indices {
543            assert_eq!(idx.len(), 2);
544            for &i in idx {
545                assert!(i < 8, "Expert index should be < num_experts");
546            }
547        }
548    }
549
550    #[test]
551    fn test_moe_layer_forward() {
552        let moe = MoELayer::new(64, 256, 8, 2);
553        let input = Variable::new(
554            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
555            false,
556        );
557        let output = moe.forward(&input);
558        assert_eq!(output.shape(), vec![2, 5, 64]);
559    }
560
561    #[test]
562    fn test_moe_layer_parameters() {
563        let moe = MoELayer::new(64, 256, 8, 2);
564        let params = moe.parameters();
565        // Router: 1 weight (no bias)
566        // 8 experts * 3 weights each = 24
567        // Total = 25
568        assert_eq!(params.len(), 25);
569    }
570
571    #[test]
572    fn test_moe_load_balancing_loss() {
573        let moe = MoELayer::new(64, 256, 4, 2);
574        let input = Variable::new(
575            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
576            false,
577        );
578        let _output = moe.forward(&input);
579
580        let lb_loss = moe.load_balancing_loss();
581        let loss_val = lb_loss.data().to_vec()[0];
582        // Load balancing loss should be positive
583        assert!(loss_val > 0.0, "Load balancing loss should be > 0");
584    }
585
586    #[test]
587    fn test_moe_expert_utilization() {
588        let moe = MoELayer::new(64, 256, 4, 2);
589        let input = Variable::new(
590            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).unwrap(),
591            false,
592        );
593        let _output = moe.forward(&input);
594
595        let util = moe.expert_utilization();
596        assert_eq!(util.len(), 4);
597        let total: usize = util.iter().sum();
598        // Each of 10 tokens selects top-2 experts = 20 assignments total
599        assert_eq!(total, 20);
600    }
601
602    #[test]
603    fn test_moe_named_parameters() {
604        let moe = MoELayer::new(64, 256, 4, 2);
605        let named = moe.named_parameters();
606        assert!(named.contains_key("router.gate.weight"));
607        assert!(named.contains_key("experts.0.up_proj.weight"));
608        assert!(named.contains_key("experts.3.down_proj.weight"));
609    }
610}