use crate::error::{AprenderError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RoutingMethod {
TopK,
SwitchTransformer,
ExpertChoice,
}
impl Default for RoutingMethod {
fn default() -> Self {
Self::TopK
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RouterInit {
Random,
Uniform,
Balanced,
}
impl Default for RouterInit {
fn default() -> Self {
Self::Balanced
}
}
#[derive(Debug, Clone)]
pub struct MoeConfig {
pub num_experts: usize,
pub num_experts_per_tok: usize,
pub routing_method: RoutingMethod,
pub gate_hidden_dim: Option<usize>,
}
impl Default for MoeConfig {
fn default() -> Self {
Self {
num_experts: 8,
num_experts_per_tok: 2,
routing_method: RoutingMethod::default(),
gate_hidden_dim: None,
}
}
}
impl MoeConfig {
pub fn validate(&self) -> Result<()> {
if self.num_experts == 0 {
return Err(AprenderError::FormatError {
message: "num_experts must be > 0".to_string(),
});
}
if self.num_experts_per_tok == 0 {
return Err(AprenderError::FormatError {
message: "num_experts_per_tok must be > 0".to_string(),
});
}
if self.num_experts_per_tok > self.num_experts {
return Err(AprenderError::FormatError {
message: format!(
"num_experts_per_tok ({}) must not exceed num_experts ({})",
self.num_experts_per_tok, self.num_experts
),
});
}
if self.gate_hidden_dim == Some(0) {
return Err(AprenderError::FormatError {
message: "gate_hidden_dim must be > 0 when specified".to_string(),
});
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExpertAssignment {
pub expert_index: usize,
pub source_model: usize,
pub source_layer: usize,
}
#[derive(Debug, Clone)]
pub struct MoeConstructionPlan {
pub assignments: Vec<Vec<ExpertAssignment>>,
pub num_layers: usize,
pub router_init: RouterInit,
}
#[derive(Debug, Clone)]
pub struct MoeReport {
pub num_experts: usize,
pub num_layers: usize,
pub load_balance: f64,
pub total_params_estimate: u64,
}
pub fn plan_moe_construction(
num_models: usize,
num_layers: usize,
config: &MoeConfig,
) -> Result<MoeConstructionPlan> {
if num_models == 0 {
return Err(AprenderError::FormatError {
message: "num_models must be > 0".to_string(),
});
}
if num_layers == 0 {
return Err(AprenderError::FormatError {
message: "num_layers must be > 0".to_string(),
});
}
config.validate()?;
let mut assignments = Vec::with_capacity(num_layers);
for layer_idx in 0..num_layers {
let mut layer_assignments = Vec::with_capacity(config.num_experts);
for expert_idx in 0..config.num_experts {
let source_model = expert_idx % num_models;
let source_layer = layer_idx;
layer_assignments.push(ExpertAssignment {
expert_index: expert_idx,
source_model,
source_layer,
});
}
assignments.push(layer_assignments);
}
Ok(MoeConstructionPlan {
assignments,
num_layers,
router_init: RouterInit::default(),
})
}
#[must_use]
pub fn compute_gate_weights(hidden_dim: usize, num_experts: usize, init: RouterInit) -> Vec<f64> {
let total = hidden_dim * num_experts;
if total == 0 {
return vec![];
}
match init {
RouterInit::Random => {
let scale = 1.0 / (hidden_dim as f64).sqrt();
let mut weights = Vec::with_capacity(total);
let mut state: u64 = 0x5DEE_CE66_D1A4_F681;
for _ in 0..total {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
let frac = (state >> 33) as f64 / (u32::MAX as f64);
weights.push((frac * 2.0 - 1.0) * scale);
}
weights
}
RouterInit::Uniform => {
let val = 1.0 / num_experts as f64;
vec![val; total]
}
RouterInit::Balanced => {
let base = 1.0 / num_experts as f64;
let perturbation_scale = 0.01 / (hidden_dim as f64).sqrt();
let mut weights = Vec::with_capacity(total);
let mut state: u64 = 0xCAFE_BABE_DEAD_BEEF;
for _ in 0..total {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
let frac = (state >> 33) as f64 / (u32::MAX as f64);
let noise = (frac * 2.0 - 1.0) * perturbation_scale;
weights.push(base + noise);
}
weights
}
}
}
#[must_use]
pub fn compute_expert_load_balance(assignments: &[Vec<ExpertAssignment>]) -> f64 {
if assignments.is_empty() {
return 0.0;
}
let max_model = assignments
.iter()
.flat_map(|layer| layer.iter())
.map(|a| a.source_model)
.max()
.unwrap_or(0);
let num_models = max_model + 1;
let mut counts = vec![0u64; num_models];
for layer in assignments {
for assignment in layer {
counts[assignment.source_model] += 1;
}
}
let total: u64 = counts.iter().sum();
if total == 0 {
return 0.0;
}
let mean = total as f64 / num_models as f64;
if mean == 0.0 {
return 0.0;
}
let variance = counts
.iter()
.map(|&c| {
let diff = c as f64 - mean;
diff * diff
})
.sum::<f64>()
/ num_models as f64;
variance.sqrt() / mean
}
impl MoeConstructionPlan {
#[must_use]
pub fn report(
&self,
hidden_dim: usize,
intermediate_dim: usize,
num_experts: usize,
) -> MoeReport {
let load_balance = compute_expert_load_balance(&self.assignments);
let expert_params_per_layer =
num_experts as u64 * 3 * hidden_dim as u64 * intermediate_dim as u64;
let router_params_per_layer = hidden_dim as u64 * num_experts as u64;
let total_params_estimate =
(expert_params_per_layer + router_params_per_layer) * self.num_layers as u64;
MoeReport {
num_experts,
num_layers: self.num_layers,
load_balance,
total_params_estimate,
}
}
}
#[cfg(test)]
#[path = "moe_construction_tests.rs"]
mod tests;