use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpertParallelismConfig {
pub num_experts: usize,
pub num_experts_per_token: usize,
pub capacity_factor: f32,
pub load_balance_loss_coeff: f32,
pub router_z_loss_coeff: f32,
pub expert_dropout: f32,
pub enable_load_balancing: bool,
pub sharding_strategy: ExpertShardingStrategy,
pub max_expert_batch_size: Option<usize>,
pub enable_gradient_accumulation: bool,
pub gradient_accumulation_steps: usize,
pub initialization_strategy: ExpertInitStrategy,
pub enable_expert_sync: bool,
pub sync_frequency: usize,
pub gate_network: Option<GateNetworkConfig>,
pub load_balancing: Option<LoadBalancingConfig>,
pub migration: Option<ExpertMigrationConfig>,
pub enable_expert_migration: bool,
pub migration_threshold: f32,
pub memory_per_expert_mb: usize,
pub communication_overlap: bool,
pub gradient_compression: bool,
}
impl Default for ExpertParallelismConfig {
fn default() -> Self {
Self {
num_experts: 8,
num_experts_per_token: 2,
capacity_factor: 1.25,
load_balance_loss_coeff: 0.01,
router_z_loss_coeff: 0.001,
expert_dropout: 0.0,
enable_load_balancing: true,
sharding_strategy: ExpertShardingStrategy::ModelParallel,
max_expert_batch_size: None,
enable_gradient_accumulation: false,
gradient_accumulation_steps: 1,
initialization_strategy: ExpertInitStrategy::Xavier,
enable_expert_sync: false,
sync_frequency: 100,
gate_network: None,
load_balancing: None,
migration: None,
enable_expert_migration: false,
migration_threshold: 0.3,
memory_per_expert_mb: 512,
communication_overlap: true,
gradient_compression: false,
}
}
}
impl ExpertParallelismConfig {
pub fn new() -> Self {
Self::default()
}
pub fn small_scale() -> Self {
Self {
num_experts: 8,
num_experts_per_token: 2,
capacity_factor: 1.25,
load_balance_loss_coeff: 0.01,
sharding_strategy: ExpertShardingStrategy::DataParallel,
..Default::default()
}
}
pub fn large_scale() -> Self {
Self {
num_experts: 128,
num_experts_per_token: 2,
capacity_factor: 1.5,
load_balance_loss_coeff: 0.001,
sharding_strategy: ExpertShardingStrategy::ModelParallel,
enable_gradient_accumulation: true,
gradient_accumulation_steps: 4,
enable_expert_sync: true,
sync_frequency: 50,
..Default::default()
}
}
pub fn inference() -> Self {
Self {
expert_dropout: 0.0,
enable_load_balancing: false,
enable_gradient_accumulation: false,
enable_expert_sync: false,
..Default::default()
}
}
pub fn validate(&self) -> Result<(), String> {
if self.num_experts == 0 {
return Err("Number of experts must be greater than 0".to_string());
}
if self.num_experts_per_token == 0 || self.num_experts_per_token > self.num_experts {
return Err(
"Number of experts per token must be between 1 and num_experts".to_string(),
);
}
if self.capacity_factor <= 0.0 {
return Err("Capacity factor must be positive".to_string());
}
if self.load_balance_loss_coeff < 0.0 {
return Err("Load balance loss coefficient must be non-negative".to_string());
}
if self.router_z_loss_coeff < 0.0 {
return Err("Router z-loss coefficient must be non-negative".to_string());
}
if self.expert_dropout < 0.0 || self.expert_dropout > 1.0 {
return Err("Expert dropout must be between 0.0 and 1.0".to_string());
}
if self.gradient_accumulation_steps == 0 {
return Err("Gradient accumulation steps must be greater than 0".to_string());
}
if self.sync_frequency == 0 {
return Err("Sync frequency must be greater than 0".to_string());
}
Ok(())
}
pub fn calculate_expert_capacity(&self, total_tokens: usize) -> usize {
let tokens_per_expert = (total_tokens * self.num_experts_per_token) / self.num_experts;
(tokens_per_expert as f32 * self.capacity_factor).ceil() as usize
}
pub fn recommended_num_devices(&self) -> usize {
match self.sharding_strategy {
ExpertShardingStrategy::DataParallel => 1,
ExpertShardingStrategy::ModelParallel => self.num_experts.min(64),
ExpertShardingStrategy::Hybrid => (self.num_experts / 4).clamp(2, 16),
ExpertShardingStrategy::Dynamic => (self.num_experts / 2).clamp(4, 32),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExpertShardingStrategy {
DataParallel,
ModelParallel,
Hybrid,
Dynamic,
}
impl ExpertShardingStrategy {
pub fn description(&self) -> &'static str {
match self {
Self::DataParallel => "All experts replicated on each device",
Self::ModelParallel => "Experts partitioned across devices",
Self::Hybrid => "Mix of replicated and partitioned experts",
Self::Dynamic => "Dynamic expert placement based on load",
}
}
pub fn requires_load_balancing(&self) -> bool {
matches!(self, Self::ModelParallel | Self::Hybrid | Self::Dynamic)
}
pub fn supports_migration(&self) -> bool {
matches!(self, Self::Hybrid | Self::Dynamic)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpertParameters {
pub input_dim: usize,
pub hidden_dim: usize,
pub output_dim: usize,
pub activation: String,
pub num_layers: usize,
pub dropout: f32,
pub use_bias: bool,
pub layer_norm_eps: f32,
pub init_scale: f32,
}
impl Default for ExpertParameters {
fn default() -> Self {
Self {
input_dim: 512,
hidden_dim: 2048,
output_dim: 512,
activation: "relu".to_string(),
num_layers: 2,
dropout: 0.1,
use_bias: true,
layer_norm_eps: 1e-5,
init_scale: 0.02,
}
}
}
impl ExpertParameters {
pub fn new(input_dim: usize, hidden_dim: usize, output_dim: usize) -> Self {
Self {
input_dim,
hidden_dim,
output_dim,
..Default::default()
}
}
pub fn transformer_ffn(model_dim: usize) -> Self {
Self {
input_dim: model_dim,
hidden_dim: model_dim * 4,
output_dim: model_dim,
activation: "gelu".to_string(),
..Default::default()
}
}
pub fn lightweight(model_dim: usize) -> Self {
Self {
input_dim: model_dim,
hidden_dim: model_dim * 2,
output_dim: model_dim,
num_layers: 1,
dropout: 0.05,
..Default::default()
}
}
pub fn validate(&self) -> Result<(), String> {
if self.input_dim == 0 {
return Err("Input dimension must be greater than 0".to_string());
}
if self.hidden_dim == 0 {
return Err("Hidden dimension must be greater than 0".to_string());
}
if self.output_dim == 0 {
return Err("Output dimension must be greater than 0".to_string());
}
if self.num_layers == 0 {
return Err("Number of layers must be greater than 0".to_string());
}
if self.dropout < 0.0 || self.dropout > 1.0 {
return Err("Dropout must be between 0.0 and 1.0".to_string());
}
if self.layer_norm_eps <= 0.0 {
return Err("Layer norm epsilon must be positive".to_string());
}
if self.init_scale <= 0.0 {
return Err("Initialization scale must be positive".to_string());
}
let valid_activations = ["relu", "gelu", "swish", "tanh", "leaky_relu", "elu"];
if !valid_activations.contains(&self.activation.as_str()) {
return Err(format!(
"Unsupported activation function: {}. Supported: {:?}",
self.activation, valid_activations
));
}
Ok(())
}
pub fn parameter_count(&self) -> usize {
if self.num_layers == 1 {
let layer1_params =
self.input_dim * self.hidden_dim + if self.use_bias { self.hidden_dim } else { 0 };
let layer2_params =
self.hidden_dim * self.output_dim + if self.use_bias { self.output_dim } else { 0 };
layer1_params + layer2_params
} else {
let input_layer =
self.input_dim * self.hidden_dim + if self.use_bias { self.hidden_dim } else { 0 };
let hidden_layers = (self.num_layers - 2)
* (self.hidden_dim * self.hidden_dim
+ if self.use_bias { self.hidden_dim } else { 0 });
let output_layer =
self.hidden_dim * self.output_dim + if self.use_bias { self.output_dim } else { 0 };
input_layer + hidden_layers + output_layer
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExpertInitStrategy {
Xavier,
Kaiming,
Normal,
Uniform,
TruncatedNormal,
}
impl ExpertInitStrategy {
pub fn description(&self) -> &'static str {
match self {
Self::Xavier => "Xavier/Glorot initialization for balanced gradients",
Self::Kaiming => "Kaiming/He initialization for ReLU networks",
Self::Normal => "Standard normal distribution",
Self::Uniform => "Uniform distribution",
Self::TruncatedNormal => "Truncated normal distribution",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GateNetworkConfig {
pub hierarchical: Option<HierarchicalGateConfig>,
pub enable_learned_gates: bool,
pub gate_dropout: f32,
pub num_gate_layers: usize,
}
impl Default for GateNetworkConfig {
fn default() -> Self {
Self {
hierarchical: None,
enable_learned_gates: true,
gate_dropout: 0.1,
num_gate_layers: 2,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HierarchicalGateConfig {
pub levels: usize,
pub experts_per_group: usize,
pub gate_hidden_dim: usize,
pub use_learned_grouping: bool,
pub grouping_strategy: GroupingStrategy,
}
impl Default for HierarchicalGateConfig {
fn default() -> Self {
Self {
levels: 2,
experts_per_group: 8,
gate_hidden_dim: 512,
use_learned_grouping: true,
grouping_strategy: GroupingStrategy::LoadBased,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum GroupingStrategy {
LoadBased,
SimilarityBased,
Static,
Dynamic,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoadBalancingConfig {
pub enable_auto_balancing: bool,
pub imbalance_threshold: f32,
pub check_frequency: usize,
pub max_concurrent_migrations: usize,
pub load_smoothing_factor: f32,
}
impl Default for LoadBalancingConfig {
fn default() -> Self {
Self {
enable_auto_balancing: true,
imbalance_threshold: 0.3,
check_frequency: 50,
max_concurrent_migrations: 2,
load_smoothing_factor: 0.9,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpertMigrationConfig {
pub enable_migration: bool,
pub triggers: Vec<MigrationTrigger>,
pub preferred_strategies: Vec<MigrationStrategy>,
pub cooldown_period: usize,
pub max_migration_distance: usize,
}
impl Default for ExpertMigrationConfig {
fn default() -> Self {
Self {
enable_migration: false,
triggers: vec![MigrationTrigger::LoadImbalance],
preferred_strategies: vec![MigrationStrategy::GradualMigration],
cooldown_period: 100,
max_migration_distance: 1,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MigrationTrigger {
LoadImbalance,
MemoryPressure,
PerformanceDegradation,
Periodic,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MigrationStrategy {
GradualMigration,
CompleteMigration,
LoadRedistribution,
Hybrid,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expert_parallelism_config_default() {
let config = ExpertParallelismConfig::default();
assert_eq!(config.num_experts, 8);
assert_eq!(config.num_experts_per_token, 2);
assert_eq!(config.capacity_factor, 1.25);
assert!(config.validate().is_ok());
}
#[test]
fn test_expert_parallelism_config_validation() {
let config1 = ExpertParallelismConfig {
num_experts: 0,
..Default::default()
};
assert!(config1.validate().is_err());
let config2 = ExpertParallelismConfig {
num_experts: 8,
num_experts_per_token: 10,
..Default::default()
};
assert!(config2.validate().is_err());
let config3 = ExpertParallelismConfig {
num_experts: 8,
num_experts_per_token: 2,
capacity_factor: -1.0,
..Default::default()
};
assert!(config3.validate().is_err());
}
#[test]
fn test_expert_capacity_calculation() {
let config = ExpertParallelismConfig::default();
let capacity = config.calculate_expert_capacity(1000);
assert_eq!(capacity, 313);
}
#[test]
fn test_sharding_strategy_properties() {
assert!(ExpertShardingStrategy::ModelParallel.requires_load_balancing());
assert!(!ExpertShardingStrategy::DataParallel.requires_load_balancing());
assert!(ExpertShardingStrategy::Dynamic.supports_migration());
assert!(!ExpertShardingStrategy::DataParallel.supports_migration());
}
#[test]
fn test_expert_parameters_default() {
let params = ExpertParameters::default();
assert_eq!(params.input_dim, 512);
assert_eq!(params.hidden_dim, 2048);
assert_eq!(params.output_dim, 512);
assert!(params.validate().is_ok());
}
#[test]
fn test_expert_parameters_transformer_ffn() {
let params = ExpertParameters::transformer_ffn(768);
assert_eq!(params.input_dim, 768);
assert_eq!(params.hidden_dim, 768 * 4);
assert_eq!(params.output_dim, 768);
assert_eq!(params.activation, "gelu");
}
#[test]
fn test_expert_parameters_validation() {
let params1 = ExpertParameters {
input_dim: 0,
..Default::default()
};
assert!(params1.validate().is_err());
let params2 = ExpertParameters {
input_dim: 512,
dropout: 1.5,
..Default::default()
};
assert!(params2.validate().is_err());
let params3 = ExpertParameters {
input_dim: 512,
dropout: 0.1,
activation: "invalid".to_string(),
..Default::default()
};
assert!(params3.validate().is_err());
}
#[test]
fn test_expert_parameters_parameter_count() {
let params = ExpertParameters::new(100, 200, 100);
let count = params.parameter_count();
assert_eq!(count, 40300);
}
#[test]
fn test_preset_configs() {
let small = ExpertParallelismConfig::small_scale();
assert_eq!(small.num_experts, 8);
assert_eq!(
small.sharding_strategy,
ExpertShardingStrategy::DataParallel
);
let large = ExpertParallelismConfig::large_scale();
assert_eq!(large.num_experts, 128);
assert!(large.enable_gradient_accumulation);
let inference = ExpertParallelismConfig::inference();
assert_eq!(inference.expert_dropout, 0.0);
assert!(!inference.enable_load_balancing);
}
#[test]
fn test_recommended_num_devices() {
let config = ExpertParallelismConfig {
num_experts: 32,
sharding_strategy: ExpertShardingStrategy::ModelParallel,
..Default::default()
};
let num_devices = config.recommended_num_devices();
assert_eq!(num_devices, 32);
}
}