entrenar/moe/mod.rs
1//! Mixture of Experts (MoE) layer
2//!
3//! Provides a sparse MoE layer where each input token is routed to a subset of
4//! expert networks via a learned gating mechanism. This enables scaling model
5//! capacity without proportionally increasing computation.
6//!
7//! ## Architecture
8//!
9//! - **Router**: Linear gating network with softmax that selects top-k experts per token
10//! - **Experts**: Independent feed-forward networks (weight + bias)
11//! - **MoeLayer**: Combines router and experts into a single forward pass
12//!
13//! ## Load Balancing
14//!
15//! The `balance_loss()` method computes a Switch Transformer-style auxiliary loss
16//! that penalizes uneven expert utilization, encouraging the router to distribute
17//! tokens uniformly across experts.
18//!
19//! ## References
20//!
21//! - Fedus, W., Zoph, B., & Shazeer, N. (2022). Switch Transformers. JMLR.
22//! - Lepikhin, D., et al. (2021). GShard: Scaling Giant Models. ICLR.
23
24pub mod router;
25
26#[cfg(test)]
27mod tests;
28
29use ndarray::{Array1, Array2};
30use serde::{Deserialize, Serialize};
31
32pub use router::{NoisyTopKRouter, RoutingResult, TopKRouter};
33
34/// Configuration for a Mixture of Experts layer.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct MoeConfig {
37 /// Number of expert networks
38 pub num_experts: usize,
39 /// Number of experts each token is routed to
40 pub top_k: usize,
41 /// Capacity factor controlling max tokens per expert (typically 1.0-1.5)
42 pub capacity_factor: f32,
43 /// Standard deviation of noise for exploration (0.0 = deterministic)
44 pub noise_std: f32,
45 /// Input/output dimension of each expert
46 pub input_dim: usize,
47 /// Hidden dimension within each expert
48 pub hidden_dim: usize,
49}
50
51impl Default for MoeConfig {
52 fn default() -> Self {
53 Self {
54 num_experts: 8,
55 top_k: 2,
56 capacity_factor: 1.25,
57 noise_std: 0.0,
58 input_dim: 64,
59 hidden_dim: 128,
60 }
61 }
62}
63
64/// A single expert network: a two-layer feed-forward with ReLU activation.
65///
66/// Computes: output = ReLU(input @ W1 + b1) @ W2 + b2
67#[derive(Debug, Clone)]
68pub struct Expert {
69 /// First layer weights: [input_dim, hidden_dim]
70 pub w1: Array2<f32>,
71 /// First layer bias: [hidden_dim]
72 pub b1: Array1<f32>,
73 /// Second layer weights: [hidden_dim, input_dim]
74 pub w2: Array2<f32>,
75 /// Second layer bias: [input_dim]
76 pub b2: Array1<f32>,
77}
78
79impl Expert {
80 /// Create a new expert with Xavier-initialized weights and zero biases.
81 pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
82 let scale1 = (2.0 / (input_dim + hidden_dim) as f32).sqrt();
83 let scale2 = (2.0 / (hidden_dim + input_dim) as f32).sqrt();
84
85 Self {
86 w1: Array2::from_shape_fn((input_dim, hidden_dim), |(i, j)| {
87 ((i * hidden_dim + j) as f32 * 0.3141).sin() * scale1
88 }),
89 b1: Array1::zeros(hidden_dim),
90 w2: Array2::from_shape_fn((hidden_dim, input_dim), |(i, j)| {
91 ((i * input_dim + j) as f32 * 0.2718).sin() * scale2
92 }),
93 b2: Array1::zeros(input_dim),
94 }
95 }
96
97 /// Forward pass through this expert for a single token.
98 ///
99 /// # Arguments
100 /// * `input` - Input vector of length input_dim
101 ///
102 /// # Returns
103 /// Output vector of length input_dim
104 pub fn forward(&self, input: &Array1<f32>) -> Array1<f32> {
105 // hidden = ReLU(input @ W1 + b1)
106 let hidden = input.dot(&self.w1) + &self.b1;
107 let hidden = hidden.mapv(|v| v.max(0.0)); // ReLU
108
109 // output = hidden @ W2 + b2
110 hidden.dot(&self.w2) + &self.b2
111 }
112
113 /// Forward pass for a batch of tokens.
114 ///
115 /// # Arguments
116 /// * `input` - Input matrix of shape [batch_size, input_dim]
117 ///
118 /// # Returns
119 /// Output matrix of shape [batch_size, input_dim]
120 pub fn forward_batch(&self, input: &Array2<f32>) -> Array2<f32> {
121 let hidden = input.dot(&self.w1) + &self.b1;
122 let hidden = hidden.mapv(|v| v.max(0.0));
123 hidden.dot(&self.w2) + &self.b2
124 }
125}
126
127/// Router variant: either deterministic or noisy.
128#[derive(Debug, Clone)]
129pub enum Router {
130 /// Deterministic top-k routing
131 Deterministic(TopKRouter),
132 /// Noisy top-k routing (adds Gaussian noise for exploration)
133 Noisy(NoisyTopKRouter),
134}
135
136impl Router {
137 /// Route input tokens to experts.
138 pub fn route(&self, input: &Array2<f32>) -> RoutingResult {
139 match self {
140 Router::Deterministic(r) => r.route(input),
141 Router::Noisy(r) => r.route(input),
142 }
143 }
144}
145
146/// Mixture of Experts layer combining a router with a set of expert networks.
147#[derive(Debug, Clone)]
148pub struct MoeLayer {
149 /// Configuration
150 pub config: MoeConfig,
151 /// Router (gating network)
152 pub router: Router,
153 /// Expert networks
154 pub experts: Vec<Expert>,
155}
156
157impl MoeLayer {
158 /// Create a new MoE layer from configuration.
159 pub fn new(config: MoeConfig) -> Self {
160 let router_config = router::RouterConfig {
161 input_dim: config.input_dim,
162 num_experts: config.num_experts,
163 top_k: config.top_k,
164 capacity_factor: config.capacity_factor,
165 };
166
167 let router = if config.noise_std > 0.0 {
168 Router::Noisy(NoisyTopKRouter::new(&router_config, config.noise_std))
169 } else {
170 Router::Deterministic(TopKRouter::new(&router_config))
171 };
172
173 let experts = (0..config.num_experts)
174 .map(|_| Expert::new(config.input_dim, config.hidden_dim))
175 .collect();
176
177 Self { config, router, experts }
178 }
179
180 /// Forward pass: route each token to top-k experts and combine outputs.
181 ///
182 /// # Arguments
183 /// * `input` - Input tensor of shape [batch_size, input_dim]
184 ///
185 /// # Returns
186 /// Tuple of (output tensor [batch_size, input_dim], routing result for loss computation)
187 pub fn forward(&self, input: &Array2<f32>) -> (Array2<f32>, RoutingResult) {
188 let batch_size = input.nrows();
189 let input_dim = input.ncols();
190 let routing = self.router.route(input);
191
192 let mut output = Array2::zeros((batch_size, input_dim));
193
194 for i in 0..batch_size {
195 let token = input.row(i).to_owned();
196 let mut combined = Array1::zeros(input_dim);
197
198 for (k, &expert_idx) in routing.expert_indices[i].iter().enumerate() {
199 let weight = routing.expert_weights[i][k];
200 if weight > 0.0 {
201 let expert_output = self.experts[expert_idx].forward(&token);
202 combined += &(expert_output * weight);
203 }
204 }
205
206 output.row_mut(i).assign(&combined);
207 }
208
209 (output, routing)
210 }
211
212 /// Compute the Switch Transformer-style load balancing auxiliary loss.
213 ///
214 /// The balance loss encourages uniform expert utilization:
215 ///
216 /// L_balance = num_experts * sum_i(f_i * P_i)
217 ///
218 /// where:
219 /// - f_i = fraction of tokens dispatched to expert i
220 /// - P_i = mean routing probability for expert i
221 ///
222 /// A perfectly balanced router produces L_balance = 1.0.
223 /// Unbalanced routing produces L_balance > 1.0.
224 ///
225 /// # Arguments
226 /// * `routing` - The routing result from a forward pass
227 ///
228 /// # Returns
229 /// Scalar auxiliary loss value
230 pub fn balance_loss(&self, routing: &RoutingResult) -> f32 {
231 let num_experts = self.config.num_experts;
232 let batch_size = routing.routing_probs.nrows();
233
234 if batch_size == 0 {
235 return 0.0;
236 }
237
238 // f_i: fraction of tokens actually dispatched to each expert
239 let mut dispatch_counts = vec![0usize; num_experts];
240 for token_experts in &routing.expert_indices {
241 for &expert_idx in token_experts {
242 dispatch_counts[expert_idx] += 1;
243 }
244 }
245 let total_dispatches: usize = dispatch_counts.iter().sum();
246 let f: Vec<f32> = dispatch_counts
247 .iter()
248 .map(|&c| if total_dispatches > 0 { c as f32 / total_dispatches as f32 } else { 0.0 })
249 .collect();
250
251 // P_i: mean routing probability for each expert across the batch
252 let p = router::expert_load_fractions(&routing.routing_probs);
253
254 // L_balance = N * sum(f_i * P_i)
255 let dot: f32 = f.iter().zip(p.iter()).map(|(fi, pi)| fi * pi).sum();
256 num_experts as f32 * dot
257 }
258}