pub mod config;
pub mod coordinator;
pub mod gradient_sync;
pub mod memory_management;
pub mod model_shards;
pub mod performance;
pub mod process_group;
pub use config::{
CommunicationStrategy, MemoryOptimizationStrategy, MemoryRequirements, PipelineSchedule,
ProcessGroupIds, RankMapping, ThreeDParallelismConfig,
};
pub use coordinator::ThreeDParallelismCoordinator;
pub use process_group::{CommunicationStats, ProcessGroupManager};
pub use memory_management::{MemoryManager, MemoryOptimizationResult, MemoryUsageStats};
pub use gradient_sync::{
GradientBucketingConfig, GradientCompressionConfig, GradientSynchronizer, SyncStatistics,
};
pub use performance::{
BottleneckSeverity, CommunicationType, Memory3DStats, Performance3DMonitor, Performance3DStats,
PerformanceAnalysis, PerformanceBottleneck,
};
pub use model_shards::{
CommunicationPattern, LayerShard, LayerTensorParallelPlan, LayerType, ModelShard, ModelShards,
ShardInfo, ShardStrategy, TensorParallelShardingPlan,
};
#[cfg(test)]
mod tests {
use super::*;
use crate::{init_process_group, BackendType};
use std::sync::Arc;
#[test]
fn test_3d_config_validation() {
let config = ThreeDParallelismConfig {
dp_size: 2,
tp_size: 2,
pp_size: 2,
num_layers: 24,
..Default::default()
};
assert!(config.validate(8).is_ok());
assert!(config.validate(16).is_err());
let invalid_config = ThreeDParallelismConfig {
dp_size: 2,
tp_size: 2,
pp_size: 3, num_layers: 25,
..Default::default()
};
assert!(invalid_config.validate(12).is_err());
}
#[test]
fn test_rank_mapping() {
let config = ThreeDParallelismConfig {
dp_size: 2,
tp_size: 2,
pp_size: 2,
..Default::default()
};
let mapping_0 = RankMapping::new(&config, 0);
assert_eq!(
(mapping_0.dp_rank, mapping_0.tp_rank, mapping_0.pp_rank),
(0, 0, 0)
);
let mapping_3 = RankMapping::new(&config, 3);
assert_eq!(
(mapping_3.dp_rank, mapping_3.tp_rank, mapping_3.pp_rank),
(0, 1, 1)
);
let mapping_7 = RankMapping::new(&config, 7);
assert_eq!(
(mapping_7.dp_rank, mapping_7.tp_rank, mapping_7.pp_rank),
(1, 1, 1)
);
let reconstructed_rank = RankMapping::from_3d_coords(&config, 1, 1, 1);
assert_eq!(reconstructed_rank, 7);
}
#[test]
fn test_memory_requirements() {
let config = ThreeDParallelismConfig {
dp_size: 1,
tp_size: 2,
pp_size: 4,
num_layers: 24,
memory_strategy: MemoryOptimizationStrategy::Standard,
..Default::default()
};
let requirements = config.memory_requirements();
assert!(requirements.model_memory_mb > 0.0);
assert!(requirements.activation_memory_mb > 0.0);
assert!(requirements.optimizer_memory_mb > 0.0);
assert!(requirements.total_memory_mb > requirements.model_memory_mb);
let basic_config = ThreeDParallelismConfig {
memory_strategy: MemoryOptimizationStrategy::Basic,
..config
};
let basic_requirements = basic_config.memory_requirements();
assert!(requirements.activation_memory_mb < basic_requirements.activation_memory_mb);
}
#[test]
fn test_model_shards_creation() {
let config = ThreeDParallelismConfig {
dp_size: 1,
tp_size: 2,
pp_size: 4,
num_layers: 24,
..Default::default()
};
let model_shards = ModelShards::new(&config).expect("Model Shards should succeed");
assert_eq!(model_shards.pipeline_stages.len(), 4); assert_eq!(model_shards.pipeline_stages[0].len(), 6);
assert!(model_shards.total_parameters > 0);
assert_eq!(model_shards.parameters_per_stage.len(), 4);
assert!(!model_shards.shards.is_empty());
let layer_0 = model_shards.get_layer_shard(0);
assert!(layer_0.is_some());
let layer_shard = layer_0.expect("operation should succeed");
assert_eq!(layer_shard.layer_id, 0);
assert!(layer_shard.parameter_count() > 0);
}
#[test]
fn test_layer_shard_parameters() {
let layer = LayerShard::new(0, 4).expect("Layer Shard should succeed");
assert_eq!(layer.layer_id, 0);
assert_eq!(layer.weight.shape().dims()[1], 128); assert!(layer.parameter_count() > 0);
assert!(layer.memory_usage_bytes() > 0);
let mut layer_with_grads = layer;
layer_with_grads
.init_gradients()
.expect("gradient initialization should succeed");
assert!(layer_with_grads.grad_weight.is_some());
}
#[test]
fn test_layer_types() {
for layer_id in 0..8 {
let layer = LayerShard::new(layer_id, 2).expect("Layer Shard should succeed");
let expected_type = match layer_id % 4 {
0 => LayerType::Embedding,
1 => LayerType::Attention,
2 => LayerType::MLP,
_ => LayerType::Linear,
};
assert_eq!(layer.layer_type, expected_type);
if matches!(layer.layer_type, LayerType::MLP) {
assert!(layer.down_projection_weight.is_some());
}
}
}
#[test]
fn test_memory_optimization_strategies() {
let config = ThreeDParallelismConfig {
dp_size: 1,
tp_size: 1,
pp_size: 1,
num_layers: 4,
..Default::default()
};
let rank_mapping = RankMapping::new(&config, 0);
let strategies = [
MemoryOptimizationStrategy::Basic,
MemoryOptimizationStrategy::Standard,
MemoryOptimizationStrategy::Aggressive,
MemoryOptimizationStrategy::Extreme,
];
for strategy in strategies {
let mut test_config = config.clone();
test_config.memory_strategy = strategy;
let memory_manager = MemoryManager::new(&test_config, &rank_mapping);
assert!(memory_manager.is_ok());
}
}
#[test]
fn test_communication_strategies() {
let strategies = [
CommunicationStrategy::AllReduce,
CommunicationStrategy::HierarchicalAllReduce,
CommunicationStrategy::RingAllReduce,
CommunicationStrategy::TreeAllReduce,
CommunicationStrategy::Adaptive,
];
for strategy in strategies {
let config = ThreeDParallelismConfig {
dp_size: 2,
tp_size: 2,
pp_size: 2,
comm_strategy: strategy,
..Default::default()
};
assert!(config.validate(8).is_ok());
}
}
#[test]
fn test_performance_monitoring() {
let config = ThreeDParallelismConfig::default();
let rank_mapping = RankMapping::new(&config, 0);
let monitor = Performance3DMonitor::new(&rank_mapping);
let stats = monitor.get_stats();
assert_eq!(stats.forward_passes, 0);
assert_eq!(stats.backward_passes, 0);
assert_eq!(stats.tokens_per_second, 0.0);
let analysis = monitor.get_performance_analysis();
assert_eq!(analysis.overall_throughput, 0.0);
let report = monitor.generate_report();
assert!(report.contains("Performance Report"));
assert!(report.contains("Overall Performance"));
}
#[test]
fn test_gradient_synchronization() {
let config = ThreeDParallelismConfig {
dp_size: 4,
tp_size: 2,
pp_size: 1,
..Default::default()
};
let rank_mapping = RankMapping::new(&config, 0);
let gradient_sync = GradientSynchronizer::new(&config, &rank_mapping)
.expect("Gradient Synchronizer should succeed");
let stats = gradient_sync.get_sync_stats();
assert_eq!(stats.total_sync_operations, 0);
let compression_config = GradientCompressionConfig {
enable_compression: true,
compression_ratio: 0.1,
error_feedback: true,
quantization_bits: 8,
};
assert!(compression_config.compression_ratio > 0.0);
assert!(compression_config.compression_ratio < 1.0);
assert!(compression_config.quantization_bits > 0);
}
#[test]
fn test_pipeline_scheduling() {
let schedules = [
PipelineSchedule::RoundRobin,
PipelineSchedule::Interleaved,
PipelineSchedule::GPipe,
PipelineSchedule::OneForwardOneBackward,
];
for schedule in schedules {
let config = ThreeDParallelismConfig {
dp_size: 1,
tp_size: 1,
pp_size: 4,
num_layers: 16,
pipeline_schedule: schedule,
..Default::default()
};
assert!(config.validate(4).is_ok());
assert_eq!(config.layers_per_stage(), 4);
}
}
#[test]
fn test_tensor_parallel_sharding() {
let config = ThreeDParallelismConfig {
dp_size: 1,
tp_size: 4,
pp_size: 1,
num_layers: 8,
..Default::default()
};
let model_shards = ModelShards::new(&config).expect("Model Shards should succeed");
let sharding_plan = model_shards.create_tp_sharding_plan(config.tp_size);
let layer_plan = sharding_plan.get_layer_plan(0, 0);
assert!(layer_plan.is_some());
let plan = layer_plan.expect("operation should succeed");
assert!(!plan.shard_strategies.is_empty());
assert!(!plan.weight_shape.is_empty());
}
#[test]
fn test_shard_strategies() {
let strategies = [
ShardStrategy::ColumnParallel,
ShardStrategy::RowParallel,
ShardStrategy::VocabParallel,
ShardStrategy::Replicated,
];
for (i, &strategy1) in strategies.iter().enumerate() {
for &strategy2 in &strategies[i + 1..] {
assert_ne!(strategy1, strategy2);
}
}
}
#[test]
fn test_communication_patterns() {
let patterns = [
CommunicationPattern::AllReduce,
CommunicationPattern::AllGatherThenReduceScatter,
CommunicationPattern::ReduceScatterThenAllGather,
CommunicationPattern::None,
];
for (i, &pattern1) in patterns.iter().enumerate() {
for &pattern2 in &patterns[i + 1..] {
assert_ne!(pattern1, pattern2);
}
}
}
#[test]
fn test_bottleneck_severity() {
let severities = [
BottleneckSeverity::Low,
BottleneckSeverity::Medium,
BottleneckSeverity::High,
BottleneckSeverity::Critical,
];
for severity in severities {
let severity_str = severity.as_str();
assert!(!severity_str.is_empty());
}
assert_ne!(BottleneckSeverity::Low, BottleneckSeverity::Critical);
}
#[test]
fn test_memory_statistics() {
let mut stats = Memory3DStats::new();
assert_eq!(stats.model_memory, 0);
assert_eq!(stats.total_memory, 0);
assert_eq!(stats.memory_efficiency, 0.0);
stats.model_memory = 1000;
stats.activation_memory = 500;
stats.total_memory = stats.model_memory + stats.activation_memory;
assert_eq!(stats.total_memory, 1500);
}
#[test]
fn test_configuration_defaults() {
let config = ThreeDParallelismConfig::default();
assert_eq!(config.dp_size, 1);
assert_eq!(config.tp_size, 1);
assert_eq!(config.pp_size, 1);
assert_eq!(config.num_layers, 24);
assert_eq!(config.micro_batch_size, 1);
assert!(!config.enable_gradient_checkpointing);
assert!(!config.enable_mixed_precision);
assert!(config.max_memory_per_device > 0.0);
assert!(config.communication_timeout_ms > 0);
}
#[test]
fn test_rank_groups() {
let config = ThreeDParallelismConfig {
dp_size: 2,
tp_size: 2,
pp_size: 2,
..Default::default()
};
let rank_mapping = RankMapping::new(&config, 5);
let dp_group = rank_mapping.dp_group_ranks();
assert_eq!(dp_group.len(), 2); assert!(dp_group.contains(&5));
let tp_group = rank_mapping.tp_group_ranks();
assert_eq!(tp_group.len(), 2); assert!(tp_group.contains(&5));
let pp_group = rank_mapping.pp_group_ranks();
assert_eq!(pp_group.len(), 2); assert!(pp_group.contains(&5));
let next_rank = rank_mapping.next_pp_rank();
let prev_rank = rank_mapping.prev_pp_rank();
assert!(next_rank.is_none()); assert!(prev_rank.is_some());
}
#[tokio::test]
async fn test_3d_parallelism_integration() {
let pg = init_process_group(BackendType::Gloo, 0, 8, "127.0.0.1", 29500)
.await
.expect("operation should succeed");
let config = ThreeDParallelismConfig {
dp_size: 2,
tp_size: 2,
pp_size: 2,
num_layers: 8,
micro_batch_size: 2,
enable_gradient_checkpointing: true,
memory_strategy: MemoryOptimizationStrategy::Standard,
comm_strategy: CommunicationStrategy::Adaptive,
..Default::default()
};
let coordinator = ThreeDParallelismCoordinator::new(config, Arc::new(pg));
assert!(coordinator.is_ok());
let coordinator = coordinator.expect("operation should succeed");
let retrieved_config = coordinator.get_config();
assert_eq!(retrieved_config.dp_size, 2);
assert_eq!(retrieved_config.tp_size, 2);
assert_eq!(retrieved_config.pp_size, 2);
let rank_mapping = coordinator.get_rank_mapping();
assert_eq!(rank_mapping.world_size, 8);
let model_shards = coordinator.get_model_shards();
assert_eq!(model_shards.pipeline_stages.len(), 2); assert!(model_shards.total_parameters > 0);
}
#[test]
fn test_performance_characteristics() {
let config = ThreeDParallelismConfig {
dp_size: 4,
tp_size: 4,
pp_size: 2,
num_layers: 48,
..Default::default()
};
let requirements = config.memory_requirements();
assert!(requirements.total_memory_mb < 50000.0); assert!(requirements.model_memory_mb > 0.0);
let model_shards = ModelShards::new(&config).expect("Model Shards should succeed");
let memory_usage = model_shards.memory_usage_bytes();
assert!(memory_usage > 0);
assert!(memory_usage < 1_000_000_000);
assert!(model_shards
.parameters_per_stage
.iter()
.all(|&count| count > 0));
let total_from_stages: usize = model_shards.parameters_per_stage.iter().sum();
assert_eq!(total_from_stages, model_shards.total_parameters);
}
}