aprender-core 0.29.3

Next-generation machine learning library in pure Rust
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
//! Mixture of Experts (MoE) Construction from Dense Models (GH-445)
//!
//! Constructs MoE architectures by combining multiple dense models into
//! a single sparse model with learned routing. Each dense model contributes
//! expert FFN weights, and a gating network learns to route tokens to the
//! most relevant experts.
//!
//! # Key Design Decisions
//!
//! - **Round-robin assignment**: Experts are assigned to layers in a balanced
//!   round-robin pattern across source models to ensure diversity
//! - **Load balancing**: Measured via coefficient of variation to detect
//!   routing collapse (all tokens routed to same expert)
//! - **Router initialization**: Supports random, uniform, and balanced
//!   strategies to prevent early routing bias
//!
//! # References
//!
//! - Shazeer et al. 2017: "Outrageously Large Neural Networks: The
//!   Sparsely-Gated Mixture-of-Experts Layer"
//! - Fedus et al. 2022: "Switch Transformers: Scaling to Trillion
//!   Parameter Models with Simple and Efficient Sparsity"
//! - Zhou et al. 2022: "Mixture-of-Experts with Expert Choice Routing"
//!
//! # Toyota Way Principles
//!
//! - **Heijunka**: Load-balanced expert assignment prevents hotspots
//! - **Jidoka**: Validation stops construction on invalid configurations
//! - **Muda Elimination**: Only activate top-k experts per token

use crate::error::{AprenderError, Result};

/// Routing method for dispatching tokens to experts.
///
/// Each method trades off between load balance and quality:
/// - `TopK`: Best quality, potential load imbalance
/// - `SwitchTransformer`: Good balance with auxiliary loss
/// - `ExpertChoice`: Perfect balance, experts choose their tokens
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RoutingMethod {
    /// Standard top-k gating (Shazeer et al. 2017).
    /// Each token selects its top-k experts by gate score.
    TopK,
    /// Switch Transformer routing (Fedus et al. 2022).
    /// Each token routed to exactly one expert with capacity factor.
    SwitchTransformer,
    /// Expert Choice routing (Zhou et al. 2022).
    /// Each expert selects its top-k tokens, guaranteeing perfect balance.
    ExpertChoice,
}

impl Default for RoutingMethod {
    fn default() -> Self {
        Self::TopK
    }
}

/// Router weight initialization strategy.
///
/// Controls how the gating network weights are initialized before
/// training. Proper initialization prevents early routing collapse.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RouterInit {
    /// Random initialization from uniform distribution.
    /// Each weight sampled from U(-scale, scale) where scale = 1/sqrt(hidden_dim).
    Random,
    /// Uniform initialization (all weights equal).
    /// Ensures equal probability for all experts at start.
    Uniform,
    /// Balanced initialization with small perturbation.
    /// Base uniform value with small noise to break symmetry.
    Balanced,
}

impl Default for RouterInit {
    fn default() -> Self {
        Self::Balanced
    }
}

/// Configuration for MoE construction from dense models.
#[derive(Debug, Clone)]
pub struct MoeConfig {
    /// Total number of experts in the MoE layer.
    pub num_experts: usize,
    /// Number of experts activated per token (default: 2).
    pub num_experts_per_tok: usize,
    /// Routing method for token-to-expert dispatch.
    pub routing_method: RoutingMethod,
    /// Hidden dimension for the gating network.
    /// If `None`, uses the model's hidden dimension directly.
    pub gate_hidden_dim: Option<usize>,
}

impl Default for MoeConfig {
    fn default() -> Self {
        Self {
            num_experts: 8,
            num_experts_per_tok: 2,
            routing_method: RoutingMethod::default(),
            gate_hidden_dim: None,
        }
    }
}

impl MoeConfig {
    /// Validate configuration constraints.
    ///
    /// # Errors
    ///
    /// Returns `AprenderError::FormatError` if:
    /// - `num_experts` is zero
    /// - `num_experts_per_tok` is zero or exceeds `num_experts`
    /// - `gate_hidden_dim` is `Some(0)`
    pub fn validate(&self) -> Result<()> {
        if self.num_experts == 0 {
            return Err(AprenderError::FormatError {
                message: "num_experts must be > 0".to_string(),
            });
        }
        if self.num_experts_per_tok == 0 {
            return Err(AprenderError::FormatError {
                message: "num_experts_per_tok must be > 0".to_string(),
            });
        }
        if self.num_experts_per_tok > self.num_experts {
            return Err(AprenderError::FormatError {
                message: format!(
                    "num_experts_per_tok ({}) must not exceed num_experts ({})",
                    self.num_experts_per_tok, self.num_experts
                ),
            });
        }
        if self.gate_hidden_dim == Some(0) {
            return Err(AprenderError::FormatError {
                message: "gate_hidden_dim must be > 0 when specified".to_string(),
            });
        }
        Ok(())
    }
}

/// Assignment of a single expert within a layer.
///
/// Maps an expert slot to its source dense model and layer,
/// enabling reconstruction of the MoE from original weights.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExpertAssignment {
    /// Index of this expert within the MoE layer (0..num_experts).
    pub expert_index: usize,
    /// Index of the source dense model providing this expert's weights.
    pub source_model: usize,
    /// Layer index within the source model to extract weights from.
    pub source_layer: usize,
}

/// Complete construction plan for building an MoE from dense models.
///
/// Contains per-layer expert assignments and router initialization
/// strategy. Used as a blueprint before actual weight extraction.
#[derive(Debug, Clone)]
pub struct MoeConstructionPlan {
    /// Per-layer expert assignments. `assignments[layer][expert]`.
    pub assignments: Vec<Vec<ExpertAssignment>>,
    /// Number of transformer layers in the MoE model.
    pub num_layers: usize,
    /// Router weight initialization strategy.
    pub router_init: RouterInit,
}

/// Summary report of an MoE construction plan.
#[derive(Debug, Clone)]
pub struct MoeReport {
    /// Total number of experts per layer.
    pub num_experts: usize,
    /// Number of transformer layers.
    pub num_layers: usize,
    /// Load balance score (0.0 = perfectly balanced, higher = worse).
    pub load_balance: f64,
    /// Estimated total parameter count across all experts and routers.
    pub total_params_estimate: u64,
}

/// Plan MoE construction by assigning experts from source models.
///
/// Creates a round-robin assignment where experts are distributed
/// evenly across source models. For each layer, expert `i` is assigned
/// to model `i % num_models`, using the corresponding layer from that
/// source model.
///
/// # Arguments
///
/// * `num_models` - Number of source dense models
/// * `num_layers` - Number of transformer layers in the output MoE
/// * `config` - MoE configuration (validated internally)
///
/// # Errors
///
/// Returns error if `num_models` is zero, `num_layers` is zero, or
/// `config` fails validation.
///
/// # Example
///
/// ```rust,ignore
/// use aprender::online::moe_construction::{MoeConfig, plan_moe_construction};
///
/// let config = MoeConfig { num_experts: 8, ..Default::default() };
/// let plan = plan_moe_construction(4, 32, &config)?;
/// assert_eq!(plan.assignments.len(), 32);
/// assert_eq!(plan.assignments[0].len(), 8);
/// ```
pub fn plan_moe_construction(
    num_models: usize,
    num_layers: usize,
    config: &MoeConfig,
) -> Result<MoeConstructionPlan> {
    if num_models == 0 {
        return Err(AprenderError::FormatError {
            message: "num_models must be > 0".to_string(),
        });
    }
    if num_layers == 0 {
        return Err(AprenderError::FormatError {
            message: "num_layers must be > 0".to_string(),
        });
    }
    config.validate()?;

    let mut assignments = Vec::with_capacity(num_layers);

    for layer_idx in 0..num_layers {
        let mut layer_assignments = Vec::with_capacity(config.num_experts);

        for expert_idx in 0..config.num_experts {
            // Round-robin: distribute experts across source models
            let source_model = expert_idx % num_models;
            let source_layer = layer_idx;

            layer_assignments.push(ExpertAssignment {
                expert_index: expert_idx,
                source_model,
                source_layer,
            });
        }

        assignments.push(layer_assignments);
    }

    Ok(MoeConstructionPlan {
        assignments,
        num_layers,
        router_init: RouterInit::default(),
    })
}

/// Compute initial gate weights for the router network.
///
/// The gate projects from `hidden_dim` to `num_experts`, producing
/// logits that are softmaxed to get routing probabilities.
///
/// # Arguments
///
/// * `hidden_dim` - Input hidden dimension of the transformer
/// * `num_experts` - Number of experts to route to
/// * `init` - Initialization strategy
///
/// # Returns
///
/// Flattened weight matrix of shape `[hidden_dim, num_experts]` in
/// row-major order, consistent with LAYOUT-002.
#[must_use]
pub fn compute_gate_weights(hidden_dim: usize, num_experts: usize, init: RouterInit) -> Vec<f64> {
    let total = hidden_dim * num_experts;
    if total == 0 {
        return vec![];
    }

    match init {
        RouterInit::Random => {
            // Xavier/Glorot-style scale: 1/sqrt(hidden_dim)
            let scale = 1.0 / (hidden_dim as f64).sqrt();
            // Deterministic pseudo-random using a simple LCG seeded from indices.
            // This avoids pulling in rand as a dependency.
            let mut weights = Vec::with_capacity(total);
            let mut state: u64 = 0x5DEE_CE66_D1A4_F681;
            for _ in 0..total {
                // LCG step (Knuth MMIX parameters)
                state = state
                    .wrapping_mul(6_364_136_223_846_793_005)
                    .wrapping_add(1);
                // Map to [-scale, scale]
                let frac = (state >> 33) as f64 / (u32::MAX as f64);
                weights.push((frac * 2.0 - 1.0) * scale);
            }
            weights
        }
        RouterInit::Uniform => {
            // Equal weight for all experts: 1/num_experts per output
            let val = 1.0 / num_experts as f64;
            vec![val; total]
        }
        RouterInit::Balanced => {
            // Uniform base with small symmetry-breaking perturbation.
            // Perturbation magnitude: 0.01 / sqrt(hidden_dim)
            let base = 1.0 / num_experts as f64;
            let perturbation_scale = 0.01 / (hidden_dim as f64).sqrt();
            let mut weights = Vec::with_capacity(total);
            let mut state: u64 = 0xCAFE_BABE_DEAD_BEEF;
            for _ in 0..total {
                state = state
                    .wrapping_mul(6_364_136_223_846_793_005)
                    .wrapping_add(1);
                let frac = (state >> 33) as f64 / (u32::MAX as f64);
                let noise = (frac * 2.0 - 1.0) * perturbation_scale;
                weights.push(base + noise);
            }
            weights
        }
    }
}

/// Compute load balance across expert assignments.
///
/// Measures how evenly source models are utilized across all layers.
/// Uses coefficient of variation (std_dev / mean) of per-model
/// assignment counts. Returns 0.0 for perfectly balanced plans.
///
/// # Arguments
///
/// * `assignments` - Per-layer expert assignments
///
/// # Returns
///
/// Load balance score where 0.0 = perfectly balanced and higher
/// values indicate worse imbalance.
#[must_use]
pub fn compute_expert_load_balance(assignments: &[Vec<ExpertAssignment>]) -> f64 {
    if assignments.is_empty() {
        return 0.0;
    }

    // Count how many times each source model is used across all layers
    let max_model = assignments
        .iter()
        .flat_map(|layer| layer.iter())
        .map(|a| a.source_model)
        .max()
        .unwrap_or(0);

    let num_models = max_model + 1;
    let mut counts = vec![0u64; num_models];

    for layer in assignments {
        for assignment in layer {
            counts[assignment.source_model] += 1;
        }
    }

    // Coefficient of variation: std_dev / mean
    let total: u64 = counts.iter().sum();
    if total == 0 {
        return 0.0;
    }

    let mean = total as f64 / num_models as f64;
    if mean == 0.0 {
        return 0.0;
    }

    let variance = counts
        .iter()
        .map(|&c| {
            let diff = c as f64 - mean;
            diff * diff
        })
        .sum::<f64>()
        / num_models as f64;

    variance.sqrt() / mean
}

impl MoeConstructionPlan {
    /// Generate a summary report of this construction plan.
    ///
    /// Estimates total parameters assuming each expert has
    /// `3 * hidden_dim * intermediate_dim` parameters (gate + up + down
    /// projections) plus router parameters.
    ///
    /// # Arguments
    ///
    /// * `hidden_dim` - Hidden dimension of the transformer
    /// * `intermediate_dim` - Intermediate (FFN) dimension
    /// * `num_experts` - Number of experts per layer
    #[must_use]
    pub fn report(
        &self,
        hidden_dim: usize,
        intermediate_dim: usize,
        num_experts: usize,
    ) -> MoeReport {
        let load_balance = compute_expert_load_balance(&self.assignments);

        // Expert FFN params: gate_proj + up_proj + down_proj per expert per layer
        // gate_proj: hidden_dim * intermediate_dim
        // up_proj:   hidden_dim * intermediate_dim
        // down_proj: intermediate_dim * hidden_dim
        let expert_params_per_layer =
            num_experts as u64 * 3 * hidden_dim as u64 * intermediate_dim as u64;

        // Router params per layer: hidden_dim * num_experts
        let router_params_per_layer = hidden_dim as u64 * num_experts as u64;

        let total_params_estimate =
            (expert_params_per_layer + router_params_per_layer) * self.num_layers as u64;

        MoeReport {
            num_experts,
            num_layers: self.num_layers,
            load_balance,
            total_params_estimate,
        }
    }
}

#[cfg(test)]
#[path = "moe_construction_tests.rs"]
mod tests;