Skip to main content

entrenar/moe/
mod.rs

1//! Mixture of Experts (MoE) layer
2//!
3//! Provides a sparse MoE layer where each input token is routed to a subset of
4//! expert networks via a learned gating mechanism. This enables scaling model
5//! capacity without proportionally increasing computation.
6//!
7//! ## Architecture
8//!
9//! - **Router**: Linear gating network with softmax that selects top-k experts per token
10//! - **Experts**: Independent feed-forward networks (weight + bias)
11//! - **MoeLayer**: Combines router and experts into a single forward pass
12//!
13//! ## Load Balancing
14//!
15//! The `balance_loss()` method computes a Switch Transformer-style auxiliary loss
16//! that penalizes uneven expert utilization, encouraging the router to distribute
17//! tokens uniformly across experts.
18//!
19//! ## References
20//!
21//! - Fedus, W., Zoph, B., & Shazeer, N. (2022). Switch Transformers. JMLR.
22//! - Lepikhin, D., et al. (2021). GShard: Scaling Giant Models. ICLR.
23
24pub mod router;
25
26#[cfg(test)]
27mod tests;
28
29use ndarray::{Array1, Array2};
30use serde::{Deserialize, Serialize};
31
32pub use router::{NoisyTopKRouter, RoutingResult, TopKRouter};
33
34/// Configuration for a Mixture of Experts layer.
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct MoeConfig {
37    /// Number of expert networks
38    pub num_experts: usize,
39    /// Number of experts each token is routed to
40    pub top_k: usize,
41    /// Capacity factor controlling max tokens per expert (typically 1.0-1.5)
42    pub capacity_factor: f32,
43    /// Standard deviation of noise for exploration (0.0 = deterministic)
44    pub noise_std: f32,
45    /// Input/output dimension of each expert
46    pub input_dim: usize,
47    /// Hidden dimension within each expert
48    pub hidden_dim: usize,
49}
50
51impl Default for MoeConfig {
52    fn default() -> Self {
53        Self {
54            num_experts: 8,
55            top_k: 2,
56            capacity_factor: 1.25,
57            noise_std: 0.0,
58            input_dim: 64,
59            hidden_dim: 128,
60        }
61    }
62}
63
64/// A single expert network: a two-layer feed-forward with ReLU activation.
65///
66/// Computes: output = ReLU(input @ W1 + b1) @ W2 + b2
67#[derive(Debug, Clone)]
68pub struct Expert {
69    /// First layer weights: [input_dim, hidden_dim]
70    pub w1: Array2<f32>,
71    /// First layer bias: [hidden_dim]
72    pub b1: Array1<f32>,
73    /// Second layer weights: [hidden_dim, input_dim]
74    pub w2: Array2<f32>,
75    /// Second layer bias: [input_dim]
76    pub b2: Array1<f32>,
77}
78
79impl Expert {
80    /// Create a new expert with Xavier-initialized weights and zero biases.
81    pub fn new(input_dim: usize, hidden_dim: usize) -> Self {
82        let scale1 = (2.0 / (input_dim + hidden_dim) as f32).sqrt();
83        let scale2 = (2.0 / (hidden_dim + input_dim) as f32).sqrt();
84
85        Self {
86            w1: Array2::from_shape_fn((input_dim, hidden_dim), |(i, j)| {
87                ((i * hidden_dim + j) as f32 * 0.3141).sin() * scale1
88            }),
89            b1: Array1::zeros(hidden_dim),
90            w2: Array2::from_shape_fn((hidden_dim, input_dim), |(i, j)| {
91                ((i * input_dim + j) as f32 * 0.2718).sin() * scale2
92            }),
93            b2: Array1::zeros(input_dim),
94        }
95    }
96
97    /// Forward pass through this expert for a single token.
98    ///
99    /// # Arguments
100    /// * `input` - Input vector of length input_dim
101    ///
102    /// # Returns
103    /// Output vector of length input_dim
104    pub fn forward(&self, input: &Array1<f32>) -> Array1<f32> {
105        // hidden = ReLU(input @ W1 + b1)
106        let hidden = input.dot(&self.w1) + &self.b1;
107        let hidden = hidden.mapv(|v| v.max(0.0)); // ReLU
108
109        // output = hidden @ W2 + b2
110        hidden.dot(&self.w2) + &self.b2
111    }
112
113    /// Forward pass for a batch of tokens.
114    ///
115    /// # Arguments
116    /// * `input` - Input matrix of shape [batch_size, input_dim]
117    ///
118    /// # Returns
119    /// Output matrix of shape [batch_size, input_dim]
120    pub fn forward_batch(&self, input: &Array2<f32>) -> Array2<f32> {
121        let hidden = input.dot(&self.w1) + &self.b1;
122        let hidden = hidden.mapv(|v| v.max(0.0));
123        hidden.dot(&self.w2) + &self.b2
124    }
125}
126
127/// Router variant: either deterministic or noisy.
128#[derive(Debug, Clone)]
129pub enum Router {
130    /// Deterministic top-k routing
131    Deterministic(TopKRouter),
132    /// Noisy top-k routing (adds Gaussian noise for exploration)
133    Noisy(NoisyTopKRouter),
134}
135
136impl Router {
137    /// Route input tokens to experts.
138    pub fn route(&self, input: &Array2<f32>) -> RoutingResult {
139        match self {
140            Router::Deterministic(r) => r.route(input),
141            Router::Noisy(r) => r.route(input),
142        }
143    }
144}
145
146/// Mixture of Experts layer combining a router with a set of expert networks.
147#[derive(Debug, Clone)]
148pub struct MoeLayer {
149    /// Configuration
150    pub config: MoeConfig,
151    /// Router (gating network)
152    pub router: Router,
153    /// Expert networks
154    pub experts: Vec<Expert>,
155}
156
157impl MoeLayer {
158    /// Create a new MoE layer from configuration.
159    pub fn new(config: MoeConfig) -> Self {
160        let router_config = router::RouterConfig {
161            input_dim: config.input_dim,
162            num_experts: config.num_experts,
163            top_k: config.top_k,
164            capacity_factor: config.capacity_factor,
165        };
166
167        let router = if config.noise_std > 0.0 {
168            Router::Noisy(NoisyTopKRouter::new(&router_config, config.noise_std))
169        } else {
170            Router::Deterministic(TopKRouter::new(&router_config))
171        };
172
173        let experts = (0..config.num_experts)
174            .map(|_| Expert::new(config.input_dim, config.hidden_dim))
175            .collect();
176
177        Self { config, router, experts }
178    }
179
180    /// Forward pass: route each token to top-k experts and combine outputs.
181    ///
182    /// # Arguments
183    /// * `input` - Input tensor of shape [batch_size, input_dim]
184    ///
185    /// # Returns
186    /// Tuple of (output tensor [batch_size, input_dim], routing result for loss computation)
187    pub fn forward(&self, input: &Array2<f32>) -> (Array2<f32>, RoutingResult) {
188        let batch_size = input.nrows();
189        let input_dim = input.ncols();
190        let routing = self.router.route(input);
191
192        let mut output = Array2::zeros((batch_size, input_dim));
193
194        for i in 0..batch_size {
195            let token = input.row(i).to_owned();
196            let mut combined = Array1::zeros(input_dim);
197
198            for (k, &expert_idx) in routing.expert_indices[i].iter().enumerate() {
199                let weight = routing.expert_weights[i][k];
200                if weight > 0.0 {
201                    let expert_output = self.experts[expert_idx].forward(&token);
202                    combined += &(expert_output * weight);
203                }
204            }
205
206            output.row_mut(i).assign(&combined);
207        }
208
209        (output, routing)
210    }
211
212    /// Compute the Switch Transformer-style load balancing auxiliary loss.
213    ///
214    /// The balance loss encourages uniform expert utilization:
215    ///
216    ///   L_balance = num_experts * sum_i(f_i * P_i)
217    ///
218    /// where:
219    /// - f_i = fraction of tokens dispatched to expert i
220    /// - P_i = mean routing probability for expert i
221    ///
222    /// A perfectly balanced router produces L_balance = 1.0.
223    /// Unbalanced routing produces L_balance > 1.0.
224    ///
225    /// # Arguments
226    /// * `routing` - The routing result from a forward pass
227    ///
228    /// # Returns
229    /// Scalar auxiliary loss value
230    pub fn balance_loss(&self, routing: &RoutingResult) -> f32 {
231        let num_experts = self.config.num_experts;
232        let batch_size = routing.routing_probs.nrows();
233
234        if batch_size == 0 {
235            return 0.0;
236        }
237
238        // f_i: fraction of tokens actually dispatched to each expert
239        let mut dispatch_counts = vec![0usize; num_experts];
240        for token_experts in &routing.expert_indices {
241            for &expert_idx in token_experts {
242                dispatch_counts[expert_idx] += 1;
243            }
244        }
245        let total_dispatches: usize = dispatch_counts.iter().sum();
246        let f: Vec<f32> = dispatch_counts
247            .iter()
248            .map(|&c| if total_dispatches > 0 { c as f32 / total_dispatches as f32 } else { 0.0 })
249            .collect();
250
251        // P_i: mean routing probability for each expert across the batch
252        let p = router::expert_load_fractions(&routing.routing_probs);
253
254        // L_balance = N * sum(f_i * P_i)
255        let dot: f32 = f.iter().zip(p.iter()).map(|(fi, pi)| fi * pi).sum();
256        num_experts as f32 * dot
257    }
258}