use std::ops::Range;
use std::path::Path;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use super::DistributedError;
use super::tensor_parallel_distributed::ParallelismMode;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterConfig {
pub model_path: String,
pub shards: Vec<ShardSpec>,
#[serde(default = "default_connect_timeout")]
pub connect_timeout_secs: u64,
#[serde(default = "default_request_timeout")]
pub request_timeout_secs: u64,
#[serde(default = "default_true")]
pub use_gpu: bool,
#[serde(default)]
pub max_seq_len: usize,
#[serde(default)]
pub parallelism: ParallelismMode,
#[serde(default)]
pub auto_shard: bool,
#[serde(default)]
pub fault_tolerance: Option<super::fault::FaultConfig>,
}
fn default_connect_timeout() -> u64 {
10
}
fn default_request_timeout() -> u64 {
30
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ShardSpec {
pub name: String,
pub address: String,
#[serde(default)]
pub layer_start: Option<usize>,
#[serde(default)]
pub layer_end: Option<usize>,
}
impl ShardSpec {
pub fn layer_range(&self) -> Option<Range<usize>> {
match (self.layer_start, self.layer_end) {
(Some(start), Some(end)) => Some(start..end),
_ => None,
}
}
}
impl Default for ClusterConfig {
fn default() -> Self {
Self {
model_path: String::new(),
shards: Vec::new(),
connect_timeout_secs: default_connect_timeout(),
request_timeout_secs: default_request_timeout(),
use_gpu: true,
max_seq_len: 0,
parallelism: ParallelismMode::default(),
auto_shard: false,
fault_tolerance: None,
}
}
}
impl ClusterConfig {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, DistributedError> {
let contents = std::fs::read_to_string(path.as_ref()).map_err(DistributedError::Io)?;
let config: Self = toml::from_str(&contents).map_err(|e| {
DistributedError::Config(format!("failed to parse cluster config: {}", e))
})?;
config.validate()?;
Ok(config)
}
pub fn connect_timeout(&self) -> Duration {
Duration::from_secs(self.connect_timeout_secs)
}
pub fn request_timeout(&self) -> Duration {
Duration::from_secs(self.request_timeout_secs)
}
pub fn validate(&self) -> Result<(), DistributedError> {
if self.shards.is_empty() {
return Err(DistributedError::Config(
"cluster must have at least one shard".into(),
));
}
for shard in &self.shards {
if shard.address.is_empty() {
return Err(DistributedError::Config(format!(
"shard '{}' has empty address",
shard.name
)));
}
if shard.layer_start.is_some() != shard.layer_end.is_some() {
return Err(DistributedError::Config(format!(
"shard '{}': layer_start and layer_end must both be set or both omitted",
shard.name
)));
}
if let (Some(start), Some(end)) = (shard.layer_start, shard.layer_end) {
if start >= end {
return Err(DistributedError::Config(format!(
"shard '{}': layer_start ({}) must be less than layer_end ({})",
shard.name, start, end
)));
}
}
}
Ok(())
}
pub fn compute_layer_assignments(
&self,
num_layers: usize,
) -> Result<Vec<Range<usize>>, DistributedError> {
let all_manual = self.shards.iter().all(|s| s.layer_range().is_some());
let all_auto = self.shards.iter().all(|s| s.layer_range().is_none());
if !all_manual && !all_auto {
return Err(DistributedError::Config(
"either all shards must have manual layer assignments or none".into(),
));
}
if all_manual {
let ranges: Vec<Range<usize>> = self
.shards
.iter()
.map(|s| s.layer_range().unwrap())
.collect();
let total_assigned: usize = ranges.iter().map(|r| r.len()).sum();
if total_assigned != num_layers {
return Err(DistributedError::LayerMismatch {
model_layers: num_layers,
assigned_layers: total_assigned,
});
}
return Ok(ranges);
}
let n_shards = self.shards.len();
let base_layers = num_layers / n_shards;
let remainder = num_layers % n_shards;
let mut assignments = Vec::with_capacity(n_shards);
let mut offset = 0;
for i in 0..n_shards {
let count = base_layers + if i < remainder { 1 } else { 0 };
assignments.push(offset..offset + count);
offset += count;
}
Ok(assignments)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_shard(name: &str, addr: &str) -> ShardSpec {
ShardSpec {
name: name.into(),
address: addr.into(),
layer_start: None,
layer_end: None,
}
}
#[test]
fn test_auto_partition_even() {
let config = ClusterConfig {
model_path: "model.gguf".into(),
shards: vec![
test_shard("a", "host1:50051"),
test_shard("b", "host2:50051"),
],
..Default::default()
};
let assignments = config.compute_layer_assignments(32).unwrap();
assert_eq!(assignments, vec![0..16, 16..32]);
}
#[test]
fn test_auto_partition_uneven() {
let config = ClusterConfig {
model_path: "model.gguf".into(),
shards: vec![
test_shard("a", "h1:50051"),
test_shard("b", "h2:50051"),
test_shard("c", "h3:50051"),
],
..Default::default()
};
let assignments = config.compute_layer_assignments(10).unwrap();
assert_eq!(assignments, vec![0..4, 4..7, 7..10]);
}
#[test]
fn test_manual_partition() {
let config = ClusterConfig {
model_path: "model.gguf".into(),
shards: vec![
ShardSpec {
name: "a".into(),
address: "h1:50051".into(),
layer_start: Some(0),
layer_end: Some(10),
},
ShardSpec {
name: "b".into(),
address: "h2:50051".into(),
layer_start: Some(10),
layer_end: Some(32),
},
],
..Default::default()
};
let assignments = config.compute_layer_assignments(32).unwrap();
assert_eq!(assignments, vec![0..10, 10..32]);
}
#[test]
fn test_manual_partition_mismatch() {
let config = ClusterConfig {
model_path: "model.gguf".into(),
shards: vec![
ShardSpec {
name: "a".into(),
address: "h1:50051".into(),
layer_start: Some(0),
layer_end: Some(10),
},
],
..Default::default()
};
let result = config.compute_layer_assignments(32);
assert!(result.is_err());
}
#[test]
fn test_validate_empty_shards() {
let config = ClusterConfig {
model_path: "model.gguf".into(),
..Default::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_toml_parse() {
let toml_str = r#"
model_path = "model.gguf"
connect_timeout_secs = 15
[[shards]]
name = "gpu1"
address = "192.168.1.10:50051"
[[shards]]
name = "gpu2"
address = "192.168.1.11:50051"
"#;
let config: ClusterConfig = toml::from_str(toml_str).unwrap();
assert_eq!(config.shards.len(), 2);
assert_eq!(config.connect_timeout_secs, 15);
assert_eq!(config.request_timeout_secs, 30); }
}