Skip to main content

pmetal_distributed/
lib.rs

1//! Distributed training backend for PMetal.
2//!
3//! Enables "Home Clusters" by synchronizing gradients across multiple devices
4//! (e.g., Mac Studio + MacBook Pro) over standard networks (TCP/IP, Wi-Fi).
5//!
6//! # Features
7//!
8//! - **Zero-Configuration Discovery**: Automatically finds peers using mDNS/Bonjour
9//! - **Ring All-Reduce**: Bandwidth-optimal gradient synchronization
10//! - **Persistent Identity**: Ed25519 keypairs stored at `~/.pmetal/node_keypair`
11//! - **Topology Awareness**: Graph-based cluster management with petgraph
12//! - **Master Election**: Distributed leader election for coordination
13//! - **Health Monitoring**: Heartbeat-based peer health tracking
14//! - **Gradient Compression**: TopK, quantization, and error feedback
15//! - **Network Isolation**: PSK-based namespace isolation
16//! - **Observability**: Comprehensive metrics and tracing
17//!
18//! # Quick Start (Auto-Discovery)
19//!
20//! ```ignore
21//! use pmetal_distributed::{AutoDiscoveryBackend, DistributedContext};
22//! use std::time::Duration;
23//!
24//! // Create backend with automatic peer discovery
25//! let backend = AutoDiscoveryBackend::new().await?;
26//!
27//! // Wait for at least 1 peer to join
28//! backend.wait_for_peers(1, Duration::from_secs(30)).await?;
29//!
30//! // Create context for distributed operations
31//! let ctx = DistributedContext::new(Box::new(backend));
32//!
33//! // Synchronize gradients across cluster
34//! ctx.all_reduce(&mut gradient_buffer).await?;
35//! ```
36//!
37//! # Manual Configuration
38//!
39//! For advanced use cases, you can manually configure peers:
40//!
41//! ```ignore
42//! use pmetal_distributed::{DistributedConfig, RingBackend, DistributedContext};
43//!
44//! let config = DistributedConfig::new(
45//!     vec!["192.168.1.10:52416".parse()?, "192.168.1.11:52416".parse()?],
46//!     0, // This node's rank
47//! );
48//!
49//! let backend = RingBackend::new(config).await?;
50//! let ctx = DistributedContext::new(Box::new(backend));
51//! ```
52//!
53//! # Architecture
54//!
55//! ```text
56//! ┌─────────────────────────────────────────────────────────────────┐
57//! │                     AutoDiscoveryBackend                         │
58//! │                                                                  │
59//! │  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐           │
60//! │  │   Identity   │  │  Discovery   │  │  Topology    │           │
61//! │  │  (Ed25519)   │  │   (mDNS)     │  │  (petgraph)  │           │
62//! │  └──────────────┘  └──────────────┘  └──────────────┘           │
63//! │          │                │                 │                    │
64//! │          └────────────────┼─────────────────┘                    │
65//! │                           ▼                                      │
66//! │  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐           │
67//! │  │  Election    │  │   Health     │  │  Collective  │           │
68//! │  │  (Master)    │  │  (Heartbeat) │  │  (Strategies)│           │
69//! │  └──────────────┘  └──────────────┘  └──────────────┘           │
70//! │          │                │                 │                    │
71//! │          └────────────────┼─────────────────┘                    │
72//! │                           ▼                                      │
73//! │  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐           │
74//! │  │ Compression  │  │   Metrics    │  │  Namespace   │           │
75//! │  │  (TopK/Quant)│  │ (Observ.)    │  │  (PSK)       │           │
76//! │  └──────────────┘  └──────────────┘  └──────────────┘           │
77//! └─────────────────────────────────────────────────────────────────┘
78//! ```
79
80use anyhow::Result;
81use async_trait::async_trait;
82
83/// Reduction operation for `all_reduce`.
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum ReduceOp {
86    /// Sum all contributions across nodes.
87    Sum,
88    /// Average all contributions across nodes (sum divided by `world_size`).
89    Mean,
90}
91
92// Core modules
93pub mod auto;
94pub mod cloud_bridge;
95pub mod config;
96pub mod discovery;
97pub mod error;
98pub mod identity;
99pub mod ring;
100pub mod topology;
101pub mod transport;
102
103// Advanced modules
104pub mod collective;
105pub mod compression;
106pub mod election;
107pub mod health;
108pub mod metrics;
109pub mod namespace;
110
111// Pipeline inference modules
112pub mod activation_codec;
113pub mod activation_transport;
114pub mod layer_assignment;
115pub mod pipeline;
116pub mod solver;
117
118// Re-exports for convenience
119pub use activation_codec::ActivationCodec;
120pub use activation_transport::{ActivationMessage, DtypeTag};
121pub use auto::{AutoDiscoveryBackend, AutoDiscoveryConfig};
122pub use collective::{AllReduceStrategy, BroadcastStrategy, CollectiveConfig, ReduceStrategy};
123pub use compression::{CompressionStrategy, GradientCompressor, QuantizationType};
124pub use config::DistributedConfig;
125pub use election::{ElectionConfig, ElectionEvent, ElectionManager, ElectionState};
126pub use error::{DistributedError, DistributedResult};
127pub use health::{HealthConfig, HealthEvent, HealthMonitor, HealthStatus, HealthSummary};
128pub use identity::NodeIdentity;
129pub use layer_assignment::{assign_layers_bandwidth_aware, assign_layers_proportional};
130pub use metrics::{DistributedMetrics, MetricsSnapshot, SharedMetrics};
131pub use namespace::NetworkNamespace;
132pub use pipeline::{
133    PipelineGenerationLoop, PipelineStageConfig, PipelineStageRuntime, StreamMultiplexer,
134};
135pub use ring::RingBackend;
136pub use topology::{ClusterTopology, ConnectionProfile, NodeProfile, SharedTopology};
137// ReduceOp is already public via `pub enum ReduceOp` at module level
138
139/// Interface for distributed operations.
140#[async_trait]
141pub trait DistributedBackend: Send + Sync {
142    /// Get the rank of this node (0 to world_size - 1).
143    fn rank(&self) -> usize;
144
145    /// Get the total number of nodes.
146    fn world_size(&self) -> usize;
147
148    /// Perform an all-reduce operation on a buffer.
149    ///
150    /// The input buffer contains the local gradients encoded as little-endian
151    /// `f32` values.  On return, all nodes hold the same result:
152    /// - `ReduceOp::Sum`  – element-wise sum across all nodes.
153    /// - `ReduceOp::Mean` – element-wise sum divided by `world_size`.
154    async fn all_reduce(&self, buffer: &mut [u8], op: ReduceOp) -> Result<()>;
155
156    /// Barrier synchronization.
157    async fn barrier(&self) -> Result<()>;
158}
159
160/// A handle to the distributed runtime.
161pub struct DistributedContext {
162    backend: Box<dyn DistributedBackend>,
163    metrics: Option<SharedMetrics>,
164}
165
166impl DistributedContext {
167    /// Create a new distributed context with the given backend.
168    pub fn new(backend: Box<dyn DistributedBackend>) -> Self {
169        Self {
170            backend,
171            metrics: None,
172        }
173    }
174
175    /// Create a new distributed context with metrics enabled.
176    pub fn with_metrics(backend: Box<dyn DistributedBackend>, metrics: SharedMetrics) -> Self {
177        Self {
178            backend,
179            metrics: Some(metrics),
180        }
181    }
182
183    /// Get the rank of this node.
184    pub fn rank(&self) -> usize {
185        self.backend.rank()
186    }
187
188    /// Get the total number of nodes in the cluster.
189    pub fn world_size(&self) -> usize {
190        self.backend.world_size()
191    }
192
193    /// Perform an all-reduce operation on the buffer.
194    ///
195    /// After this call, all nodes will have the same values in their buffers.
196    /// `op` controls whether the result is a sum or mean across nodes.
197    pub async fn all_reduce(&self, buffer: &mut [u8], op: ReduceOp) -> Result<()> {
198        let start = std::time::Instant::now();
199        let result = self.backend.all_reduce(buffer, op).await;
200
201        if let Some(ref metrics) = self.metrics {
202            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
203            metrics.all_reduce.duration_ms.observe(duration_ms);
204            metrics.all_reduce.bytes_processed.add(buffer.len() as u64);
205
206            if result.is_ok() {
207                metrics.all_reduce.completed.inc();
208            } else {
209                metrics.all_reduce.failed.inc();
210            }
211        }
212
213        result
214    }
215
216    /// Synchronize all nodes at a barrier.
217    ///
218    /// All nodes must call this method, and none will proceed until all have.
219    pub async fn barrier(&self) -> Result<()> {
220        let start = std::time::Instant::now();
221        let result = self.backend.barrier().await;
222
223        if let Some(ref metrics) = self.metrics {
224            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
225            metrics.barrier.duration_ms.observe(duration_ms);
226
227            if result.is_ok() {
228                metrics.barrier.completed.inc();
229            } else {
230                metrics.barrier.failed.inc();
231            }
232        }
233
234        result
235    }
236
237    /// Check if this is the master node (rank 0).
238    pub fn is_master(&self) -> bool {
239        self.rank() == 0
240    }
241
242    /// Get metrics snapshot if enabled.
243    pub fn metrics_snapshot(&self) -> Option<MetricsSnapshot> {
244        self.metrics.as_ref().map(|m| m.snapshot())
245    }
246}
247
248/// Prelude for convenient imports.
249pub mod prelude {
250    pub use crate::DistributedBackend;
251    pub use crate::DistributedContext;
252    pub use crate::ReduceOp;
253    pub use crate::auto::{AutoDiscoveryBackend, AutoDiscoveryConfig};
254    pub use crate::collective::{AllReduceStrategy, CollectiveConfig};
255    pub use crate::compression::{CompressionStrategy, GradientCompressor};
256    pub use crate::config::DistributedConfig;
257    pub use crate::election::{ElectionConfig, ElectionManager};
258    pub use crate::error::{DistributedError, DistributedResult};
259    pub use crate::health::{HealthConfig, HealthMonitor, HealthStatus};
260    pub use crate::identity::NodeIdentity;
261    pub use crate::metrics::{DistributedMetrics, SharedMetrics};
262    pub use crate::namespace::NetworkNamespace;
263    pub use crate::ring::RingBackend;
264    pub use crate::topology::{ClusterTopology, NodeProfile};
265}