use crate::*;
#[test]
fn sharding_strategy_min_devices() {
assert_eq!(ShardingStrategy::None.min_devices(), 1);
assert_eq!(
ShardingStrategy::PipelineParallel { num_stages: 4 }.min_devices(),
4
);
assert_eq!(
ShardingStrategy::TensorParallel { num_devices: 8 }.min_devices(),
8
);
assert_eq!(
ShardingStrategy::DataParallel { num_replicas: 2 }.min_devices(),
2
);
}
#[test]
fn model_shard_num_layers() {
let shard = ModelShard {
shard_id: 0,
layer_range: (0, 31),
device: AcceleratorType::Cpu,
memory_bytes: 1024,
};
assert_eq!(shard.num_layers(), 32);
assert!(shard.is_valid());
}
#[test]
fn model_shard_single_layer() {
let shard = ModelShard {
shard_id: 0,
layer_range: (5, 5),
device: AcceleratorType::Cpu,
memory_bytes: 0,
};
assert_eq!(shard.num_layers(), 1);
assert!(shard.is_valid());
}
#[test]
fn model_shard_invalid_range() {
let shard = ModelShard {
shard_id: 0,
layer_range: (10, 5),
device: AcceleratorType::Cpu,
memory_bytes: 0,
};
assert_eq!(shard.num_layers(), 0);
assert!(!shard.is_valid());
}
#[test]
fn plan_sharding_small_model_single_device() {
let reg = AcceleratorRegistry::new();
let plan = reg.plan_sharding(1_000_000_000, &QuantizationLevel::Int4);
assert_eq!(plan.strategy, ShardingStrategy::None);
assert_eq!(plan.shards.len(), 1);
assert!(plan.estimated_tokens_per_sec.is_some());
}
#[test]
fn plan_sharding_tpu_tensor_parallel() {
let mut profiles = vec![AcceleratorProfile::cpu(16 * 1024 * 1024 * 1024)];
for i in 0..4 {
profiles.push(AcceleratorProfile::tpu(i, 1, TpuVersion::V5p));
}
let reg = AcceleratorRegistry::from_profiles(profiles);
let plan = reg.plan_sharding(70_000_000_000, &QuantizationLevel::BFloat16);
assert!(matches!(
plan.strategy,
ShardingStrategy::TensorParallel { num_devices: 4 }
));
assert_eq!(plan.shards.len(), 4);
}
#[test]
fn plan_sharding_multi_gpu_pipeline() {
let reg = AcceleratorRegistry::from_profiles(vec![
AcceleratorProfile::cpu(16 * 1024 * 1024 * 1024),
AcceleratorProfile::cuda(0, 8 * 1024 * 1024 * 1024),
AcceleratorProfile::cuda(1, 8 * 1024 * 1024 * 1024),
]);
let plan = reg.plan_sharding(7_000_000_000, &QuantizationLevel::Float16);
assert!(matches!(
plan.strategy,
ShardingStrategy::PipelineParallel { .. }
));
assert_eq!(plan.shards.len(), 2);
assert_eq!(plan.shards[0].layer_range.0, 0);
assert!(plan.shards[1].layer_range.0 > 0);
}
#[test]
fn plan_sharding_cpu_fallback() {
let reg = AcceleratorRegistry::new();
let plan = reg.plan_sharding(70_000_000_000, &QuantizationLevel::None);
assert_eq!(plan.shards[0].device, AcceleratorType::Cpu);
assert_eq!(plan.strategy, ShardingStrategy::None);
}
#[test]
fn plan_sharding_fits_single_gpu() {
let reg = AcceleratorRegistry::from_profiles(vec![
AcceleratorProfile::cpu(16 * 1024 * 1024 * 1024),
AcceleratorProfile::cuda(0, 80 * 1024 * 1024 * 1024), ]);
let plan = reg.plan_sharding(7_000_000_000, &QuantizationLevel::Float16);
assert_eq!(plan.strategy, ShardingStrategy::None);
assert_eq!(plan.shards.len(), 1);
assert!(matches!(
plan.shards[0].device,
AcceleratorType::CudaGpu { .. }
));
}
#[test]
fn plan_sharding_gaudi_pipeline() {
let reg = AcceleratorRegistry::from_profiles(vec![
AcceleratorProfile::cpu(16 * 1024 * 1024 * 1024),
AcceleratorProfile::gaudi(0, GaudiGeneration::Gaudi3),
AcceleratorProfile::gaudi(1, GaudiGeneration::Gaudi3),
]);
let plan = reg.plan_sharding(70_000_000_000, &QuantizationLevel::BFloat16);
assert!(matches!(
plan.strategy,
ShardingStrategy::PipelineParallel { num_stages: 2 }
));
}
#[test]
fn plan_sharding_model_exactly_fits_single_device() {
let reg = AcceleratorRegistry::from_profiles(vec![
AcceleratorProfile::cpu(16 * 1024 * 1024 * 1024),
AcceleratorProfile::cuda(0, 600_000_000), ]);
let plan = reg.plan_sharding(1_000_000_000, &QuantizationLevel::Int4);
assert_eq!(plan.strategy, ShardingStrategy::None);
assert!(matches!(
plan.shards[0].device,
AcceleratorType::CudaGpu { .. }
));
}
#[test]
fn plan_fits_in_memory() {
let reg = AcceleratorRegistry::from_profiles(vec![
AcceleratorProfile::cpu(16 * 1024 * 1024 * 1024),
AcceleratorProfile::cuda(0, 24 * 1024 * 1024 * 1024),
]);
let plan = reg.plan_sharding(1_000_000_000, &QuantizationLevel::Float16);
assert!(plan.fits_in_memory(®));
}
#[test]
fn plan_does_not_fit_in_memory() {
let plan = ShardingPlan {
shards: vec![ModelShard {
shard_id: 0,
layer_range: (0, 0),
device: AcceleratorType::Cpu,
memory_bytes: 999 * 1024 * 1024 * 1024,
}],
strategy: ShardingStrategy::None,
total_memory_bytes: 999 * 1024 * 1024 * 1024,
estimated_tokens_per_sec: None,
};
let reg =
AcceleratorRegistry::from_profiles(vec![AcceleratorProfile::cpu(16 * 1024 * 1024 * 1024)]);
assert!(!plan.fits_in_memory(®));
}
#[test]
fn plan_sharding_zero_params() {
let reg = AcceleratorRegistry::new();
let plan = reg.plan_sharding(0, &QuantizationLevel::Float16);
assert_eq!(plan.strategy, ShardingStrategy::None);
if let Some(tps) = plan.estimated_tokens_per_sec {
assert!(tps.is_finite());
}
}
#[test]
fn plan_sharding_three_gpus_uneven_layers() {
let reg = AcceleratorRegistry::from_profiles(vec![
AcceleratorProfile::cpu(16 * 1024 * 1024 * 1024),
AcceleratorProfile::cuda(0, 40 * 1024 * 1024 * 1024),
AcceleratorProfile::cuda(1, 40 * 1024 * 1024 * 1024),
AcceleratorProfile::cuda(2, 40 * 1024 * 1024 * 1024),
]);
let plan = reg.plan_sharding(30_000_000_000, &QuantizationLevel::Float16);
assert!(matches!(
plan.strategy,
ShardingStrategy::PipelineParallel { num_stages: 3 }
));
let last = plan.shards.last().unwrap();
let first = &plan.shards[0];
assert!(last.layer_range.1 >= first.layer_range.1);
assert!(last.is_valid());
}
#[test]
fn plan_sharding_all_devices_unavailable() {
let mut gpu = AcceleratorProfile::cuda(0, 80 * 1024 * 1024 * 1024);
gpu.available = false;
let reg = AcceleratorRegistry::from_profiles(vec![
AcceleratorProfile::cpu(16 * 1024 * 1024 * 1024),
gpu,
]);
let plan = reg.plan_sharding(70_000_000_000, &QuantizationLevel::Float16);
assert_eq!(plan.shards[0].device, AcceleratorType::Cpu);
}