Skip to main content

pmetal_distributed/
lib.rs

1#![allow(unsafe_code)] // Raw pointer ops needed for zero-copy allreduce
2//! Distributed training backend for PMetal.
3//!
4//! Enables "Home Clusters" by synchronizing gradients across multiple devices
5//! (e.g., Mac Studio + MacBook Pro) over standard networks (TCP/IP, Wi-Fi).
6//!
7//! # Features
8//!
9//! - **Zero-Configuration Discovery**: Automatically finds peers using mDNS/Bonjour
10//! - **Ring All-Reduce**: Bandwidth-optimal gradient synchronization
11//! - **Persistent Identity**: Ed25519 keypairs stored at `~/.pmetal/node_keypair`
12//! - **Topology Awareness**: Graph-based cluster management with petgraph
13//! - **Master Election**: Distributed leader election for coordination
14//! - **Health Monitoring**: Heartbeat-based peer health tracking
15//! - **Gradient Compression**: TopK, quantization, and error feedback
16//! - **Network Isolation**: PSK-based namespace isolation
17//! - **Observability**: Comprehensive metrics and tracing
18//!
19//! # Quick Start (Auto-Discovery)
20//!
21//! ```ignore
22//! use pmetal_distributed::{AutoDiscoveryBackend, DistributedContext};
23//! use std::time::Duration;
24//!
25//! // Create backend with automatic peer discovery
26//! let backend = AutoDiscoveryBackend::new().await?;
27//!
28//! // Wait for at least 1 peer to join
29//! backend.wait_for_peers(1, Duration::from_secs(30)).await?;
30//!
31//! // Create context for distributed operations
32//! let ctx = DistributedContext::new(Box::new(backend));
33//!
34//! // Synchronize gradients across cluster
35//! ctx.all_reduce(&mut gradient_buffer).await?;
36//! ```
37//!
38//! # Manual Configuration
39//!
40//! For advanced use cases, you can manually configure peers:
41//!
42//! ```ignore
43//! use pmetal_distributed::{DistributedConfig, RingBackend, DistributedContext};
44//!
45//! let config = DistributedConfig::new(
46//!     vec!["192.168.1.10:52416".parse()?, "192.168.1.11:52416".parse()?],
47//!     0, // This node's rank
48//! );
49//!
50//! let backend = RingBackend::new(config).await?;
51//! let ctx = DistributedContext::new(Box::new(backend));
52//! ```
53//!
54//! # Architecture
55//!
56//! ```text
57//! ┌─────────────────────────────────────────────────────────────────┐
58//! │                     AutoDiscoveryBackend                         │
59//! │                                                                  │
60//! │  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐           │
61//! │  │   Identity   │  │  Discovery   │  │  Topology    │           │
62//! │  │  (Ed25519)   │  │   (mDNS)     │  │  (petgraph)  │           │
63//! │  └──────────────┘  └──────────────┘  └──────────────┘           │
64//! │          │                │                 │                    │
65//! │          └────────────────┼─────────────────┘                    │
66//! │                           ▼                                      │
67//! │  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐           │
68//! │  │  Election    │  │   Health     │  │  Collective  │           │
69//! │  │  (Master)    │  │  (Heartbeat) │  │  (Strategies)│           │
70//! │  └──────────────┘  └──────────────┘  └──────────────┘           │
71//! │          │                │                 │                    │
72//! │          └────────────────┼─────────────────┘                    │
73//! │                           ▼                                      │
74//! │  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐           │
75//! │  │ Compression  │  │   Metrics    │  │  Namespace   │           │
76//! │  │  (TopK/Quant)│  │ (Observ.)    │  │  (PSK)       │           │
77//! │  └──────────────┘  └──────────────┘  └──────────────┘           │
78//! └─────────────────────────────────────────────────────────────────┘
79//! ```
80
81use anyhow::Result;
82use async_trait::async_trait;
83
84// Core modules
85pub mod auto;
86pub mod config;
87pub mod discovery;
88pub mod error;
89pub mod identity;
90pub mod ring;
91pub mod topology;
92pub mod transport;
93
94// Advanced modules
95pub mod collective;
96pub mod compression;
97pub mod election;
98pub mod health;
99pub mod metrics;
100pub mod namespace;
101
102// Re-exports for convenience
103pub use auto::{AutoDiscoveryBackend, AutoDiscoveryConfig};
104pub use collective::{AllReduceStrategy, BroadcastStrategy, CollectiveConfig, ReduceStrategy};
105pub use compression::{CompressionStrategy, GradientCompressor, QuantizationType};
106pub use config::DistributedConfig;
107pub use election::{ElectionConfig, ElectionEvent, ElectionManager, ElectionState};
108pub use error::{DistributedError, DistributedResult};
109pub use health::{HealthConfig, HealthEvent, HealthMonitor, HealthStatus, HealthSummary};
110pub use identity::NodeIdentity;
111pub use metrics::{DistributedMetrics, MetricsSnapshot, SharedMetrics};
112pub use namespace::NetworkNamespace;
113pub use ring::RingBackend;
114pub use topology::{ClusterTopology, ConnectionProfile, NodeProfile, SharedTopology};
115
116/// Interface for distributed operations.
117#[async_trait]
118pub trait DistributedBackend: Send + Sync {
119    /// Get the rank of this node (0 to world_size - 1).
120    fn rank(&self) -> usize;
121
122    /// Get the total number of nodes.
123    fn world_size(&self) -> usize;
124
125    /// Perform an all-reduce operation on a buffer (sum).
126    ///
127    /// The input buffer contains the local gradients.
128    /// On return, it should contain the sum of gradients from all nodes.
129    async fn all_reduce(&self, buffer: &mut [u8]) -> Result<()>;
130
131    /// Barrier synchronization.
132    async fn barrier(&self) -> Result<()>;
133}
134
135/// A handle to the distributed runtime.
136pub struct DistributedContext {
137    backend: Box<dyn DistributedBackend>,
138    metrics: Option<SharedMetrics>,
139}
140
141impl DistributedContext {
142    /// Create a new distributed context with the given backend.
143    pub fn new(backend: Box<dyn DistributedBackend>) -> Self {
144        Self {
145            backend,
146            metrics: None,
147        }
148    }
149
150    /// Create a new distributed context with metrics enabled.
151    pub fn with_metrics(backend: Box<dyn DistributedBackend>, metrics: SharedMetrics) -> Self {
152        Self {
153            backend,
154            metrics: Some(metrics),
155        }
156    }
157
158    /// Get the rank of this node.
159    pub fn rank(&self) -> usize {
160        self.backend.rank()
161    }
162
163    /// Get the total number of nodes in the cluster.
164    pub fn world_size(&self) -> usize {
165        self.backend.world_size()
166    }
167
168    /// Perform an all-reduce operation (sum) on the buffer.
169    ///
170    /// After this call, all nodes will have the same values in their buffers,
171    /// which is the element-wise sum of all input buffers.
172    pub async fn all_reduce(&self, buffer: &mut [u8]) -> Result<()> {
173        let start = std::time::Instant::now();
174        let result = self.backend.all_reduce(buffer).await;
175
176        if let Some(ref metrics) = self.metrics {
177            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
178            metrics.all_reduce.duration_ms.observe(duration_ms);
179            metrics.all_reduce.bytes_processed.add(buffer.len() as u64);
180
181            if result.is_ok() {
182                metrics.all_reduce.completed.inc();
183            } else {
184                metrics.all_reduce.failed.inc();
185            }
186        }
187
188        result
189    }
190
191    /// Synchronize all nodes at a barrier.
192    ///
193    /// All nodes must call this method, and none will proceed until all have.
194    pub async fn barrier(&self) -> Result<()> {
195        let start = std::time::Instant::now();
196        let result = self.backend.barrier().await;
197
198        if let Some(ref metrics) = self.metrics {
199            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
200            metrics.barrier.duration_ms.observe(duration_ms);
201
202            if result.is_ok() {
203                metrics.barrier.completed.inc();
204            } else {
205                metrics.barrier.failed.inc();
206            }
207        }
208
209        result
210    }
211
212    /// Check if this is the master node (rank 0).
213    pub fn is_master(&self) -> bool {
214        self.rank() == 0
215    }
216
217    /// Get metrics snapshot if enabled.
218    pub fn metrics_snapshot(&self) -> Option<MetricsSnapshot> {
219        self.metrics.as_ref().map(|m| m.snapshot())
220    }
221}
222
223/// Prelude for convenient imports.
224pub mod prelude {
225    pub use crate::DistributedBackend;
226    pub use crate::DistributedContext;
227    pub use crate::auto::{AutoDiscoveryBackend, AutoDiscoveryConfig};
228    pub use crate::collective::{AllReduceStrategy, CollectiveConfig};
229    pub use crate::compression::{CompressionStrategy, GradientCompressor};
230    pub use crate::config::DistributedConfig;
231    pub use crate::election::{ElectionConfig, ElectionManager};
232    pub use crate::error::{DistributedError, DistributedResult};
233    pub use crate::health::{HealthConfig, HealthMonitor, HealthStatus};
234    pub use crate::identity::NodeIdentity;
235    pub use crate::metrics::{DistributedMetrics, SharedMetrics};
236    pub use crate::namespace::NetworkNamespace;
237    pub use crate::ring::RingBackend;
238    pub use crate::topology::{ClusterTopology, NodeProfile};
239}