use std::fmt;
use serde::{Deserialize, Serialize};
use crate::hardware::AcceleratorType;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[non_exhaustive]
pub enum ShardingStrategy {
#[default]
None,
PipelineParallel { num_stages: u32 },
TensorParallel { num_devices: u32 },
DataParallel { num_replicas: u32 },
}
impl ShardingStrategy {
#[must_use]
#[inline]
pub fn min_devices(&self) -> u32 {
match self {
Self::None => 1,
Self::PipelineParallel { num_stages } => *num_stages,
Self::TensorParallel { num_devices } => *num_devices,
Self::DataParallel { num_replicas } => *num_replicas,
}
}
}
impl fmt::Display for ShardingStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => write!(f, "None"),
Self::PipelineParallel { num_stages } => {
write!(f, "Pipeline Parallel ({} stages)", num_stages)
}
Self::TensorParallel { num_devices } => {
write!(f, "Tensor Parallel ({} devices)", num_devices)
}
Self::DataParallel { num_replicas } => {
write!(f, "Data Parallel ({} replicas)", num_replicas)
}
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ModelShard {
pub shard_id: u32,
pub layer_range: (u32, u32),
pub device: AcceleratorType,
pub memory_bytes: u64,
}
impl ModelShard {
#[must_use]
#[inline]
pub fn num_layers(&self) -> u32 {
if self.layer_range.1 >= self.layer_range.0 {
self.layer_range.1 - self.layer_range.0 + 1
} else {
0
}
}
#[must_use]
#[inline]
pub fn is_valid(&self) -> bool {
self.layer_range.0 <= self.layer_range.1
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ShardingPlan {
pub(crate) shards: Vec<ModelShard>,
pub strategy: ShardingStrategy,
pub total_memory_bytes: u64,
pub estimated_tokens_per_sec: Option<f64>,
}
impl ShardingPlan {
#[inline]
pub fn shards(&self) -> &[ModelShard] {
&self.shards
}
}
impl fmt::Display for ShardingPlan {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Strategy: {}", self.strategy)?;
writeln!(
f,
"Total memory: {:.1} GB",
self.total_memory_bytes as f64 / (1024.0 * 1024.0 * 1024.0)
)?;
if let Some(tps) = self.estimated_tokens_per_sec {
writeln!(f, "Est. throughput: {:.0} tok/s", tps)?;
}
if self.shards.len() > 1 {
writeln!(f, "Shards:")?;
for shard in &self.shards {
writeln!(
f,
" [{}] {} — layers {}-{} ({:.1} GB)",
shard.shard_id,
shard.device,
shard.layer_range.0,
shard.layer_range.1,
shard.memory_bytes as f64 / (1024.0 * 1024.0 * 1024.0)
)?;
}
} else if let Some(shard) = self.shards.first() {
writeln!(f, "Device: {}", shard.device)?;
}
Ok(())
}
}