Skip to main content

pmetal_distributed/
config.rs

1use crate::error::DistributedError;
2use anyhow::Result;
3use serde::{Deserialize, Serialize};
4use std::collections::HashSet;
5use std::net::SocketAddr;
6
7/// Configuration for distributed training.
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct DistributedConfig {
10    /// List of all nodes in the cluster (IP:Port).
11    /// The order must be consistent across all nodes.
12    pub nodes: Vec<SocketAddr>,
13
14    /// Rank of this node (index into nodes list).
15    pub rank: usize,
16
17    /// Connection timeout in milliseconds (default: 30000).
18    #[serde(default = "default_connection_timeout_ms")]
19    pub connection_timeout_ms: u64,
20
21    /// Maximum connection retry attempts (default: 50).
22    #[serde(default = "default_max_retries")]
23    pub max_retries: u32,
24}
25
26fn default_connection_timeout_ms() -> u64 {
27    30000
28}
29
30fn default_max_retries() -> u32 {
31    50
32}
33
34impl DistributedConfig {
35    /// Create a new configuration.
36    pub fn new(nodes: Vec<SocketAddr>, rank: usize) -> Self {
37        Self {
38            nodes,
39            rank,
40            connection_timeout_ms: default_connection_timeout_ms(),
41            max_retries: default_max_retries(),
42        }
43    }
44
45    /// Validate the configuration.
46    pub fn validate(&self) -> Result<()> {
47        if self.nodes.is_empty() {
48            return Err(DistributedError::Config("nodes list cannot be empty".to_string()).into());
49        }
50
51        if self.rank >= self.nodes.len() {
52            return Err(DistributedError::Config(format!(
53                "rank {} is out of bounds for {} nodes",
54                self.rank,
55                self.nodes.len()
56            ))
57            .into());
58        }
59
60        // Check for duplicate addresses
61        let unique: HashSet<_> = self.nodes.iter().collect();
62        if unique.len() != self.nodes.len() {
63            return Err(DistributedError::Config(
64                "nodes list contains duplicate addresses".to_string(),
65            )
66            .into());
67        }
68
69        Ok(())
70    }
71
72    /// Get the world size (number of nodes).
73    pub fn world_size(&self) -> usize {
74        self.nodes.len()
75    }
76}