pmetal_distributed/lib.rs
1//! Distributed training backend for PMetal.
2//!
3//! Enables "Home Clusters" by synchronizing gradients across multiple devices
4//! (e.g., Mac Studio + MacBook Pro) over standard networks (TCP/IP, Wi-Fi).
5//!
6//! # Features
7//!
8//! - **Zero-Configuration Discovery**: Automatically finds peers using mDNS/Bonjour
9//! - **Ring All-Reduce**: Bandwidth-optimal gradient synchronization
10//! - **Persistent Identity**: Ed25519 keypairs stored at `~/.pmetal/node_keypair`
11//! - **Topology Awareness**: Graph-based cluster management with petgraph
12//! - **Master Election**: Distributed leader election for coordination
13//! - **Health Monitoring**: Heartbeat-based peer health tracking
14//! - **Gradient Compression**: TopK, quantization, and error feedback
15//! - **Network Isolation**: PSK-based namespace isolation
16//! - **Observability**: Comprehensive metrics and tracing
17//!
18//! # Quick Start (Auto-Discovery)
19//!
20//! ```ignore
21//! use pmetal_distributed::{AutoDiscoveryBackend, DistributedContext};
22//! use std::time::Duration;
23//!
24//! // Create backend with automatic peer discovery
25//! let backend = AutoDiscoveryBackend::new().await?;
26//!
27//! // Wait for at least 1 peer to join
28//! backend.wait_for_peers(1, Duration::from_secs(30)).await?;
29//!
30//! // Create context for distributed operations
31//! let ctx = DistributedContext::new(Box::new(backend));
32//!
33//! // Synchronize gradients across cluster
34//! ctx.all_reduce(&mut gradient_buffer).await?;
35//! ```
36//!
37//! # Manual Configuration
38//!
39//! For advanced use cases, you can manually configure peers:
40//!
41//! ```ignore
42//! use pmetal_distributed::{DistributedConfig, RingBackend, DistributedContext};
43//!
44//! let config = DistributedConfig::new(
45//! vec!["192.168.1.10:52416".parse()?, "192.168.1.11:52416".parse()?],
46//! 0, // This node's rank
47//! );
48//!
49//! let backend = RingBackend::new(config).await?;
50//! let ctx = DistributedContext::new(Box::new(backend));
51//! ```
52//!
53//! # Architecture
54//!
55//! ```text
56//! ┌─────────────────────────────────────────────────────────────────┐
57//! │ AutoDiscoveryBackend │
58//! │ │
59//! │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
60//! │ │ Identity │ │ Discovery │ │ Topology │ │
61//! │ │ (Ed25519) │ │ (mDNS) │ │ (petgraph) │ │
62//! │ └──────────────┘ └──────────────┘ └──────────────┘ │
63//! │ │ │ │ │
64//! │ └────────────────┼─────────────────┘ │
65//! │ ▼ │
66//! │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
67//! │ │ Election │ │ Health │ │ Collective │ │
68//! │ │ (Master) │ │ (Heartbeat) │ │ (Strategies)│ │
69//! │ └──────────────┘ └──────────────┘ └──────────────┘ │
70//! │ │ │ │ │
71//! │ └────────────────┼─────────────────┘ │
72//! │ ▼ │
73//! │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
74//! │ │ Compression │ │ Metrics │ │ Namespace │ │
75//! │ │ (TopK/Quant)│ │ (Observ.) │ │ (PSK) │ │
76//! │ └──────────────┘ └──────────────┘ └──────────────┘ │
77//! └─────────────────────────────────────────────────────────────────┘
78//! ```
79
80use anyhow::Result;
81use async_trait::async_trait;
82
83/// Reduction operation for `all_reduce`.
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum ReduceOp {
86 /// Sum all contributions across nodes.
87 Sum,
88 /// Average all contributions across nodes (sum divided by `world_size`).
89 Mean,
90}
91
92// Core modules
93pub mod auto;
94pub mod cloud_bridge;
95pub mod config;
96pub mod discovery;
97pub mod error;
98pub mod identity;
99pub mod ring;
100pub mod topology;
101pub mod transport;
102
103// Advanced modules
104pub mod collective;
105pub mod compression;
106pub mod election;
107pub mod health;
108pub mod metrics;
109pub mod namespace;
110
111// Pipeline inference modules
112pub mod activation_codec;
113pub mod activation_transport;
114pub mod layer_assignment;
115pub mod pipeline;
116pub mod solver;
117
118// Re-exports for convenience
119pub use activation_codec::ActivationCodec;
120pub use activation_transport::{ActivationMessage, DtypeTag};
121pub use auto::{AutoDiscoveryBackend, AutoDiscoveryConfig};
122pub use collective::{AllReduceStrategy, BroadcastStrategy, CollectiveConfig, ReduceStrategy};
123pub use compression::{CompressionStrategy, GradientCompressor, QuantizationType};
124pub use config::DistributedConfig;
125pub use election::{ElectionConfig, ElectionEvent, ElectionManager, ElectionState};
126pub use error::{DistributedError, DistributedResult};
127pub use health::{HealthConfig, HealthEvent, HealthMonitor, HealthStatus, HealthSummary};
128pub use identity::NodeIdentity;
129pub use layer_assignment::{assign_layers_bandwidth_aware, assign_layers_proportional};
130pub use metrics::{DistributedMetrics, MetricsSnapshot, SharedMetrics};
131pub use namespace::NetworkNamespace;
132pub use pipeline::{
133 PipelineGenerationLoop, PipelineStageConfig, PipelineStageRuntime, StreamMultiplexer,
134};
135pub use ring::RingBackend;
136pub use topology::{ClusterTopology, ConnectionProfile, NodeProfile, SharedTopology};
137// ReduceOp is already public via `pub enum ReduceOp` at module level
138
139/// Interface for distributed operations.
140#[async_trait]
141pub trait DistributedBackend: Send + Sync {
142 /// Get the rank of this node (0 to world_size - 1).
143 fn rank(&self) -> usize;
144
145 /// Get the total number of nodes.
146 fn world_size(&self) -> usize;
147
148 /// Perform an all-reduce operation on a buffer.
149 ///
150 /// The input buffer contains the local gradients encoded as little-endian
151 /// `f32` values. On return, all nodes hold the same result:
152 /// - `ReduceOp::Sum` – element-wise sum across all nodes.
153 /// - `ReduceOp::Mean` – element-wise sum divided by `world_size`.
154 async fn all_reduce(&self, buffer: &mut [u8], op: ReduceOp) -> Result<()>;
155
156 /// Barrier synchronization.
157 async fn barrier(&self) -> Result<()>;
158}
159
160/// A handle to the distributed runtime.
161pub struct DistributedContext {
162 backend: Box<dyn DistributedBackend>,
163 metrics: Option<SharedMetrics>,
164}
165
166impl DistributedContext {
167 /// Create a new distributed context with the given backend.
168 pub fn new(backend: Box<dyn DistributedBackend>) -> Self {
169 Self {
170 backend,
171 metrics: None,
172 }
173 }
174
175 /// Create a new distributed context with metrics enabled.
176 pub fn with_metrics(backend: Box<dyn DistributedBackend>, metrics: SharedMetrics) -> Self {
177 Self {
178 backend,
179 metrics: Some(metrics),
180 }
181 }
182
183 /// Get the rank of this node.
184 pub fn rank(&self) -> usize {
185 self.backend.rank()
186 }
187
188 /// Get the total number of nodes in the cluster.
189 pub fn world_size(&self) -> usize {
190 self.backend.world_size()
191 }
192
193 /// Perform an all-reduce operation on the buffer.
194 ///
195 /// After this call, all nodes will have the same values in their buffers.
196 /// `op` controls whether the result is a sum or mean across nodes.
197 pub async fn all_reduce(&self, buffer: &mut [u8], op: ReduceOp) -> Result<()> {
198 let start = std::time::Instant::now();
199 let result = self.backend.all_reduce(buffer, op).await;
200
201 if let Some(ref metrics) = self.metrics {
202 let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
203 metrics.all_reduce.duration_ms.observe(duration_ms);
204 metrics.all_reduce.bytes_processed.add(buffer.len() as u64);
205
206 if result.is_ok() {
207 metrics.all_reduce.completed.inc();
208 } else {
209 metrics.all_reduce.failed.inc();
210 }
211 }
212
213 result
214 }
215
216 /// Synchronize all nodes at a barrier.
217 ///
218 /// All nodes must call this method, and none will proceed until all have.
219 pub async fn barrier(&self) -> Result<()> {
220 let start = std::time::Instant::now();
221 let result = self.backend.barrier().await;
222
223 if let Some(ref metrics) = self.metrics {
224 let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
225 metrics.barrier.duration_ms.observe(duration_ms);
226
227 if result.is_ok() {
228 metrics.barrier.completed.inc();
229 } else {
230 metrics.barrier.failed.inc();
231 }
232 }
233
234 result
235 }
236
237 /// Check if this is the master node (rank 0).
238 pub fn is_master(&self) -> bool {
239 self.rank() == 0
240 }
241
242 /// Get metrics snapshot if enabled.
243 pub fn metrics_snapshot(&self) -> Option<MetricsSnapshot> {
244 self.metrics.as_ref().map(|m| m.snapshot())
245 }
246}
247
248/// Prelude for convenient imports.
249pub mod prelude {
250 pub use crate::DistributedBackend;
251 pub use crate::DistributedContext;
252 pub use crate::ReduceOp;
253 pub use crate::auto::{AutoDiscoveryBackend, AutoDiscoveryConfig};
254 pub use crate::collective::{AllReduceStrategy, CollectiveConfig};
255 pub use crate::compression::{CompressionStrategy, GradientCompressor};
256 pub use crate::config::DistributedConfig;
257 pub use crate::election::{ElectionConfig, ElectionManager};
258 pub use crate::error::{DistributedError, DistributedResult};
259 pub use crate::health::{HealthConfig, HealthMonitor, HealthStatus};
260 pub use crate::identity::NodeIdentity;
261 pub use crate::metrics::{DistributedMetrics, SharedMetrics};
262 pub use crate::namespace::NetworkNamespace;
263 pub use crate::ring::RingBackend;
264 pub use crate::topology::{ClusterTopology, NodeProfile};
265}