Skip to main content

axonml_nn/layers/
moe.rs

1//! Mixture of Experts - Sparse Expert Routing
2//!
3//! Implements a Mixture of Experts (MoE) layer where each token is routed to
4//! a subset of expert MLPs. This allows massive model capacity (many experts)
5//! while keeping per-token compute constant (only top-k experts activate).
6//!
7//! Key components:
8//! - `Expert` — Standard SiLU-gated MLP (up_proj + gate_proj + down_proj)
9//! - `MoERouter` — Learned routing: Linear -> softmax -> top-k selection
10//! - `MoELayer` — Full MoE: routes tokens to top-k experts, combines outputs
11//! - `load_balancing_loss()` — Auxiliary loss preventing expert collapse
12//!
13//! # File
14//! `crates/axonml-nn/src/layers/moe.rs`
15//!
16//! # Author
17//! Andrew Jewell Sr - AutomataNexus
18//!
19//! # Updated
20//! March 19, 2026
21//!
22//! # Disclaimer
23//! Use at own risk. This software is provided "as is", without warranty of any
24//! kind, express or implied. The author and AutomataNexus shall not be held
25//! liable for any damages arising from the use of this software.
26
27use std::collections::HashMap;
28
29use axonml_autograd::Variable;
30use axonml_tensor::Tensor;
31
32use crate::layers::Linear;
33use crate::module::Module;
34use crate::parameter::Parameter;
35
36// =============================================================================
37// Expert MLP
38// =============================================================================
39
40/// A single expert MLP with SiLU-gated architecture (SwiGLU variant).
41///
42/// Computes: down_proj(SiLU(gate_proj(x)) * up_proj(x))
43///
44/// This is the same MLP architecture used in LLaMA/Mistral/Phi models.
45///
46/// # Shape
47/// - Input: (*, d_model)
48/// - Output: (*, d_model)
49pub struct Expert {
50    /// Up projection: d_model -> intermediate_size
51    up_proj: Linear,
52    /// Gate projection: d_model -> intermediate_size (for SiLU gating)
53    gate_proj: Linear,
54    /// Down projection: intermediate_size -> d_model
55    down_proj: Linear,
56}
57
58impl Expert {
59    /// Creates a new Expert MLP.
60    ///
61    /// # Arguments
62    /// * `d_model` - Input/output dimension
63    /// * `intermediate_size` - Hidden dimension (typically 4 * d_model or 8/3 * d_model)
64    pub fn new(d_model: usize, intermediate_size: usize) -> Self {
65        Self {
66            up_proj: Linear::with_bias(d_model, intermediate_size, false),
67            gate_proj: Linear::with_bias(d_model, intermediate_size, false),
68            down_proj: Linear::with_bias(intermediate_size, d_model, false),
69        }
70    }
71}
72
73impl Module for Expert {
74    fn forward(&self, input: &Variable) -> Variable {
75        // SwiGLU: down(SiLU(gate(x)) * up(x))
76        let gate = self.gate_proj.forward(input).silu();
77        let up = self.up_proj.forward(input);
78        let hidden = gate.mul_var(&up);
79        self.down_proj.forward(&hidden)
80    }
81
82    fn parameters(&self) -> Vec<Parameter> {
83        let mut params = Vec::new();
84        params.extend(self.up_proj.parameters());
85        params.extend(self.gate_proj.parameters());
86        params.extend(self.down_proj.parameters());
87        params
88    }
89
90    fn named_parameters(&self) -> HashMap<String, Parameter> {
91        let mut params = HashMap::new();
92        for (name, param) in self.up_proj.named_parameters() {
93            params.insert(format!("up_proj.{name}"), param);
94        }
95        for (name, param) in self.gate_proj.named_parameters() {
96            params.insert(format!("gate_proj.{name}"), param);
97        }
98        for (name, param) in self.down_proj.named_parameters() {
99            params.insert(format!("down_proj.{name}"), param);
100        }
101        params
102    }
103
104    fn name(&self) -> &'static str {
105        "Expert"
106    }
107}
108
109impl std::fmt::Debug for Expert {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        f.debug_struct("Expert")
112            .field("up_proj", &self.up_proj)
113            .field("gate_proj", &self.gate_proj)
114            .field("down_proj", &self.down_proj)
115            .finish()
116    }
117}
118
119// =============================================================================
120// MoE Router
121// =============================================================================
122
123/// Router for Mixture of Experts.
124///
125/// Maps each token's hidden state to expert selection probabilities via a
126/// learned linear projection followed by softmax, then selects the top-k
127/// experts per token.
128///
129/// # Routing
130/// ```text
131/// gate_logits = Linear(x)           [num_tokens, num_experts]
132/// gate_probs  = softmax(gate_logits) [num_tokens, num_experts]
133/// top_k_probs, top_k_indices = topk(gate_probs, k)
134/// ```
135pub struct MoERouter {
136    /// Gate projection: d_model -> num_experts
137    gate: Linear,
138    /// Number of experts
139    num_experts: usize,
140    /// Number of experts to select per token
141    top_k: usize,
142}
143
144impl MoERouter {
145    /// Creates a new MoE router.
146    ///
147    /// # Arguments
148    /// * `d_model` - Input hidden dimension
149    /// * `num_experts` - Total number of experts
150    /// * `top_k` - Number of experts activated per token
151    pub fn new(d_model: usize, num_experts: usize, top_k: usize) -> Self {
152        assert!(
153            top_k <= num_experts,
154            "top_k ({top_k}) must be <= num_experts ({num_experts})"
155        );
156        Self {
157            gate: Linear::with_bias(d_model, num_experts, false),
158            num_experts,
159            top_k,
160        }
161    }
162
163    /// Routes tokens to experts.
164    ///
165    /// # Arguments
166    /// * `x` - Hidden states [num_tokens, d_model]
167    ///
168    /// # Returns
169    /// * `gate_probs` - Full probability distribution [num_tokens, num_experts] (for load balancing)
170    /// * `top_k_weights` - Normalized weights for selected experts [num_tokens, top_k]
171    /// * `top_k_indices` - Indices of selected experts [num_tokens, top_k]
172    pub fn route(&self, x: &Variable) -> (Variable, Vec<Vec<f32>>, Vec<Vec<usize>>) {
173        let gate_logits = self.gate.forward(x);
174        let gate_probs = gate_logits.softmax(-1);
175
176        let probs_data = gate_probs.data();
177        let probs_vec = probs_data.to_vec();
178        let num_tokens = probs_data.shape()[0];
179
180        let mut top_k_weights = Vec::with_capacity(num_tokens);
181        let mut top_k_indices = Vec::with_capacity(num_tokens);
182
183        for t in 0..num_tokens {
184            let offset = t * self.num_experts;
185            let token_probs = &probs_vec[offset..offset + self.num_experts];
186
187            // Find top-k experts by probability
188            let mut indexed: Vec<(usize, f32)> = token_probs
189                .iter()
190                .enumerate()
191                .map(|(i, &p)| (i, p))
192                .collect();
193            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
194
195            let top_indices: Vec<usize> = indexed[..self.top_k].iter().map(|(i, _)| *i).collect();
196            let top_weights: Vec<f32> = indexed[..self.top_k].iter().map(|(_, w)| *w).collect();
197
198            // Normalize top-k weights to sum to 1
199            let weight_sum: f32 = top_weights.iter().sum();
200            let normalized: Vec<f32> = if weight_sum > 0.0 {
201                top_weights.iter().map(|w| w / weight_sum).collect()
202            } else {
203                vec![1.0 / self.top_k as f32; self.top_k]
204            };
205
206            top_k_weights.push(normalized);
207            top_k_indices.push(top_indices);
208        }
209
210        (gate_probs, top_k_weights, top_k_indices)
211    }
212
213    /// Returns the number of experts.
214    pub fn num_experts(&self) -> usize {
215        self.num_experts
216    }
217
218    /// Returns top-k value.
219    pub fn top_k(&self) -> usize {
220        self.top_k
221    }
222}
223
224impl std::fmt::Debug for MoERouter {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        f.debug_struct("MoERouter")
227            .field("num_experts", &self.num_experts)
228            .field("top_k", &self.top_k)
229            .finish()
230    }
231}
232
233// =============================================================================
234// MoE Layer
235// =============================================================================
236
237/// Mixture of Experts layer.
238///
239/// Routes each token to the top-k experts (out of N total), applies the
240/// selected expert MLPs, and combines outputs using the routing weights.
241///
242/// This provides N times the model capacity while only using top-k/N of
243/// the compute per token.
244///
245/// # Shape
246/// - Input: (batch, seq_len, d_model)
247/// - Output: (batch, seq_len, d_model)
248///
249/// # Load Balancing
250/// Call `load_balancing_loss()` after forward to get the auxiliary loss
251/// that prevents expert collapse (all tokens routing to same expert).
252pub struct MoELayer {
253    /// Expert MLPs
254    experts: Vec<Expert>,
255    /// Router for expert selection
256    router: MoERouter,
257    /// Model dimension
258    d_model: usize,
259    /// Number of experts
260    num_experts: usize,
261    /// Top-k routing
262    top_k: usize,
263    /// Cached gate probabilities from last forward pass (for load balancing loss)
264    last_gate_probs: std::sync::RwLock<Option<Variable>>,
265    /// Cached expert assignments from last forward (for utilization stats)
266    last_expert_counts: std::sync::RwLock<Vec<usize>>,
267}
268
269impl MoELayer {
270    /// Creates a new MoE layer.
271    ///
272    /// # Arguments
273    /// * `d_model` - Hidden dimension
274    /// * `intermediate_size` - Expert MLP hidden dimension
275    /// * `num_experts` - Total number of expert MLPs
276    /// * `top_k` - Number of experts activated per token
277    pub fn new(d_model: usize, intermediate_size: usize, num_experts: usize, top_k: usize) -> Self {
278        let experts: Vec<Expert> = (0..num_experts)
279            .map(|_| Expert::new(d_model, intermediate_size))
280            .collect();
281        let router = MoERouter::new(d_model, num_experts, top_k);
282
283        Self {
284            experts,
285            router,
286            d_model,
287            num_experts,
288            top_k,
289            last_gate_probs: std::sync::RwLock::new(None),
290            last_expert_counts: std::sync::RwLock::new(vec![0; num_experts]),
291        }
292    }
293
294    /// Computes load balancing auxiliary loss.
295    ///
296    /// This loss encourages uniform routing across experts to prevent
297    /// expert collapse. Uses the formulation from Switch Transformer:
298    ///
299    ///   L_bal = num_experts * sum_i(f_i * P_i)
300    ///
301    /// where f_i = fraction of tokens assigned to expert i,
302    ///       P_i = mean routing probability for expert i.
303    ///
304    /// Returns zero if no forward pass has been done yet.
305    pub fn load_balancing_loss(&self) -> Variable {
306        let gate_probs_opt = self.last_gate_probs.read().unwrap();
307        if gate_probs_opt.is_none() {
308            return Variable::new(
309                Tensor::from_vec(vec![0.0f32], &[1]).expect("tensor creation failed"),
310                false,
311            );
312        }
313
314        let gate_probs = gate_probs_opt.as_ref().unwrap();
315        let probs_data = gate_probs.data();
316        let probs_vec = probs_data.to_vec();
317        let shape = probs_data.shape();
318        let num_tokens = shape[0];
319        let num_experts = shape[1];
320
321        if num_tokens == 0 {
322            return Variable::new(
323                Tensor::from_vec(vec![0.0f32], &[1]).expect("tensor creation failed"),
324                false,
325            );
326        }
327
328        let expert_counts = self.last_expert_counts.read().unwrap();
329
330        // f_i: fraction of tokens routed to expert i
331        let token_fractions: Vec<f32> = expert_counts
332            .iter()
333            .map(|&c| c as f32 / num_tokens as f32)
334            .collect();
335
336        // P_i: mean routing probability for expert i
337        let mut mean_probs = vec![0.0f32; num_experts];
338        for t in 0..num_tokens {
339            for e in 0..num_experts {
340                mean_probs[e] += probs_vec[t * num_experts + e];
341            }
342        }
343        for p in &mut mean_probs {
344            *p /= num_tokens as f32;
345        }
346
347        // L_bal = num_experts * sum(f_i * P_i)
348        let mut loss_val = 0.0f32;
349        for e in 0..num_experts {
350            loss_val += token_fractions[e] * mean_probs[e];
351        }
352        loss_val *= num_experts as f32;
353
354        Variable::new(
355            Tensor::from_vec(vec![loss_val], &[1]).expect("tensor creation failed"),
356            false,
357        )
358    }
359
360    /// Returns expert utilization counts from the last forward pass.
361    ///
362    /// Each element is the number of tokens routed to that expert.
363    pub fn expert_utilization(&self) -> Vec<usize> {
364        self.last_expert_counts.read().unwrap().clone()
365    }
366
367    /// Returns the number of experts.
368    pub fn num_experts(&self) -> usize {
369        self.num_experts
370    }
371
372    /// Returns the top-k value.
373    pub fn top_k(&self) -> usize {
374        self.top_k
375    }
376}
377
378impl Module for MoELayer {
379    fn forward(&self, input: &Variable) -> Variable {
380        let shape = input.shape();
381        let batch_size = shape[0];
382        let seq_len = shape[1];
383        let d_model = shape[2];
384        let num_tokens = batch_size * seq_len;
385
386        // Flatten to [num_tokens, d_model]
387        let flat_input = input.reshape(&[num_tokens, d_model]);
388
389        // Route tokens to experts
390        let (gate_probs, top_k_weights, top_k_indices) = self.router.route(&flat_input);
391
392        // Track expert utilization
393        let mut expert_counts = vec![0usize; self.num_experts];
394        for indices in &top_k_indices {
395            for &idx in indices {
396                expert_counts[idx] += 1;
397            }
398        }
399        *self.last_expert_counts.write().unwrap() = expert_counts;
400        *self.last_gate_probs.write().unwrap() = Some(gate_probs);
401
402        // Initialize output as zeros
403        let mut output_data = vec![0.0f32; num_tokens * d_model];
404
405        // Process each expert: gather tokens, forward, scatter back
406        for expert_idx in 0..self.num_experts {
407            // Find which tokens go to this expert and their weights
408            let mut token_indices = Vec::new();
409            let mut token_weights = Vec::new();
410
411            for (t, (indices, weights)) in
412                top_k_indices.iter().zip(top_k_weights.iter()).enumerate()
413            {
414                for (k, (&idx, &w)) in indices.iter().zip(weights.iter()).enumerate() {
415                    if idx == expert_idx {
416                        token_indices.push(t);
417                        token_weights.push(w);
418                        let _ = k;
419                    }
420                }
421            }
422
423            if token_indices.is_empty() {
424                continue;
425            }
426
427            // Gather tokens for this expert
428            let flat_data = flat_input.data();
429            let flat_vec = flat_data.to_vec();
430            let n = token_indices.len();
431            let mut expert_input_data = Vec::with_capacity(n * d_model);
432            for &t in &token_indices {
433                let offset = t * d_model;
434                expert_input_data.extend_from_slice(&flat_vec[offset..offset + d_model]);
435            }
436            let expert_input = Variable::new(
437                Tensor::from_vec(expert_input_data, &[n, d_model]).expect("tensor creation failed"),
438                true,
439            );
440
441            // Forward through expert
442            let expert_output = self.experts[expert_idx].forward(&expert_input);
443            let expert_out_vec = expert_output.data().to_vec();
444
445            // Scatter weighted outputs back
446            for (local_idx, &global_idx) in token_indices.iter().enumerate() {
447                let weight = token_weights[local_idx];
448                let src_offset = local_idx * d_model;
449                let dst_offset = global_idx * d_model;
450                for d in 0..d_model {
451                    output_data[dst_offset + d] += weight * expert_out_vec[src_offset + d];
452                }
453            }
454        }
455
456        let output_tensor =
457            Tensor::from_vec(output_data, &[num_tokens, d_model]).expect("tensor creation failed");
458        let output = Variable::new(output_tensor, true);
459
460        // Reshape back to [batch, seq_len, d_model]
461        output.reshape(&[batch_size, seq_len, d_model])
462    }
463
464    fn parameters(&self) -> Vec<Parameter> {
465        let mut params = Vec::new();
466        params.extend(self.router.gate.parameters());
467        for expert in &self.experts {
468            params.extend(expert.parameters());
469        }
470        params
471    }
472
473    fn named_parameters(&self) -> HashMap<String, Parameter> {
474        let mut params = HashMap::new();
475        for (name, param) in self.router.gate.named_parameters() {
476            params.insert(format!("router.gate.{name}"), param);
477        }
478        for (i, expert) in self.experts.iter().enumerate() {
479            for (name, param) in expert.named_parameters() {
480                params.insert(format!("experts.{i}.{name}"), param);
481            }
482        }
483        params
484    }
485
486    fn name(&self) -> &'static str {
487        "MoELayer"
488    }
489}
490
491impl std::fmt::Debug for MoELayer {
492    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
493        f.debug_struct("MoELayer")
494            .field("d_model", &self.d_model)
495            .field("num_experts", &self.num_experts)
496            .field("top_k", &self.top_k)
497            .field("experts", &self.experts.len())
498            .finish()
499    }
500}
501
502// =============================================================================
503// Tests
504// =============================================================================
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_expert_creation() {
512        let expert = Expert::new(64, 256);
513        let params = expert.parameters();
514        // up_proj(w) + gate_proj(w) + down_proj(w) = 3 weights (no bias)
515        assert_eq!(params.len(), 3);
516    }
517
518    #[test]
519    fn test_expert_forward() {
520        let expert = Expert::new(64, 256);
521        let input = Variable::new(
522            Tensor::from_vec(vec![0.1; 4 * 64], &[4, 64]).expect("tensor creation failed"),
523            false,
524        );
525        let output = expert.forward(&input);
526        assert_eq!(output.shape(), vec![4, 64]);
527    }
528
529    #[test]
530    fn test_router_creation() {
531        let router = MoERouter::new(64, 8, 2);
532        assert_eq!(router.num_experts(), 8);
533        assert_eq!(router.top_k(), 2);
534    }
535
536    #[test]
537    fn test_router_route() {
538        let router = MoERouter::new(64, 8, 2);
539        let input = Variable::new(
540            Tensor::from_vec(vec![0.1; 4 * 64], &[4, 64]).expect("tensor creation failed"),
541            false,
542        );
543        let (_gate_probs, weights, indices) = router.route(&input);
544
545        assert_eq!(weights.len(), 4); // 4 tokens
546        assert_eq!(indices.len(), 4);
547        for w in &weights {
548            assert_eq!(w.len(), 2); // top-2
549            let sum: f32 = w.iter().sum();
550            assert!((sum - 1.0).abs() < 1e-5, "Weights should sum to 1");
551        }
552        for idx in &indices {
553            assert_eq!(idx.len(), 2);
554            for &i in idx {
555                assert!(i < 8, "Expert index should be < num_experts");
556            }
557        }
558    }
559
560    #[test]
561    fn test_moe_layer_forward() {
562        let moe = MoELayer::new(64, 256, 8, 2);
563        let input = Variable::new(
564            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
565            false,
566        );
567        let output = moe.forward(&input);
568        assert_eq!(output.shape(), vec![2, 5, 64]);
569    }
570
571    #[test]
572    fn test_moe_layer_parameters() {
573        let moe = MoELayer::new(64, 256, 8, 2);
574        let params = moe.parameters();
575        // Router: 1 weight (no bias)
576        // 8 experts * 3 weights each = 24
577        // Total = 25
578        assert_eq!(params.len(), 25);
579    }
580
581    #[test]
582    fn test_moe_load_balancing_loss() {
583        let moe = MoELayer::new(64, 256, 4, 2);
584        let input = Variable::new(
585            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
586            false,
587        );
588        let _output = moe.forward(&input);
589
590        let lb_loss = moe.load_balancing_loss();
591        let loss_val = lb_loss.data().to_vec()[0];
592        // Load balancing loss should be positive
593        assert!(loss_val > 0.0, "Load balancing loss should be > 0");
594    }
595
596    #[test]
597    fn test_moe_expert_utilization() {
598        let moe = MoELayer::new(64, 256, 4, 2);
599        let input = Variable::new(
600            Tensor::from_vec(vec![0.1; 2 * 5 * 64], &[2, 5, 64]).expect("tensor creation failed"),
601            false,
602        );
603        let _output = moe.forward(&input);
604
605        let util = moe.expert_utilization();
606        assert_eq!(util.len(), 4);
607        let total: usize = util.iter().sum();
608        // Each of 10 tokens selects top-2 experts = 20 assignments total
609        assert_eq!(total, 20);
610    }
611
612    #[test]
613    fn test_moe_named_parameters() {
614        let moe = MoELayer::new(64, 256, 4, 2);
615        let named = moe.named_parameters();
616        assert!(named.contains_key("router.gate.weight"));
617        assert!(named.contains_key("experts.0.up_proj.weight"));
618        assert!(named.contains_key("experts.3.down_proj.weight"));
619    }
620}