Skip to main content

pmetal_distributed/
lib.rs

1#![allow(unsafe_code)] // Raw pointer ops needed for zero-copy allreduce
2//! Distributed training backend for PMetal.
3//!
4//! Enables "Home Clusters" by synchronizing gradients across multiple devices
5//! (e.g., Mac Studio + MacBook Pro) over standard networks (TCP/IP, Wi-Fi).
6//!
7//! # Features
8//!
9//! - **Zero-Configuration Discovery**: Automatically finds peers using mDNS/Bonjour
10//! - **Ring All-Reduce**: Bandwidth-optimal gradient synchronization
11//! - **Persistent Identity**: Ed25519 keypairs stored at `~/.pmetal/node_keypair`
12//! - **Topology Awareness**: Graph-based cluster management with petgraph
13//! - **Master Election**: Distributed leader election for coordination
14//! - **Health Monitoring**: Heartbeat-based peer health tracking
15//! - **Gradient Compression**: TopK, quantization, and error feedback
16//! - **Network Isolation**: PSK-based namespace isolation
17//! - **Observability**: Comprehensive metrics and tracing
18//!
19//! # Quick Start (Auto-Discovery)
20//!
21//! ```ignore
22//! use pmetal_distributed::{AutoDiscoveryBackend, DistributedContext};
23//! use std::time::Duration;
24//!
25//! // Create backend with automatic peer discovery
26//! let backend = AutoDiscoveryBackend::new().await?;
27//!
28//! // Wait for at least 1 peer to join
29//! backend.wait_for_peers(1, Duration::from_secs(30)).await?;
30//!
31//! // Create context for distributed operations
32//! let ctx = DistributedContext::new(Box::new(backend));
33//!
34//! // Synchronize gradients across cluster
35//! ctx.all_reduce(&mut gradient_buffer).await?;
36//! ```
37//!
38//! # Manual Configuration
39//!
40//! For advanced use cases, you can manually configure peers:
41//!
42//! ```ignore
43//! use pmetal_distributed::{DistributedConfig, RingBackend, DistributedContext};
44//!
45//! let config = DistributedConfig::new(
46//!     vec!["192.168.1.10:52416".parse()?, "192.168.1.11:52416".parse()?],
47//!     0, // This node's rank
48//! );
49//!
50//! let backend = RingBackend::new(config).await?;
51//! let ctx = DistributedContext::new(Box::new(backend));
52//! ```
53//!
54//! # Architecture
55//!
56//! ```text
57//! ┌─────────────────────────────────────────────────────────────────┐
58//! │                     AutoDiscoveryBackend                         │
59//! │                                                                  │
60//! │  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐           │
61//! │  │   Identity   │  │  Discovery   │  │  Topology    │           │
62//! │  │  (Ed25519)   │  │   (mDNS)     │  │  (petgraph)  │           │
63//! │  └──────────────┘  └──────────────┘  └──────────────┘           │
64//! │          │                │                 │                    │
65//! │          └────────────────┼─────────────────┘                    │
66//! │                           ▼                                      │
67//! │  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐           │
68//! │  │  Election    │  │   Health     │  │  Collective  │           │
69//! │  │  (Master)    │  │  (Heartbeat) │  │  (Strategies)│           │
70//! │  └──────────────┘  └──────────────┘  └──────────────┘           │
71//! │          │                │                 │                    │
72//! │          └────────────────┼─────────────────┘                    │
73//! │                           ▼                                      │
74//! │  ┌──────────────┐  ┌──────────────┐  ┌──────────────┐           │
75//! │  │ Compression  │  │   Metrics    │  │  Namespace   │           │
76//! │  │  (TopK/Quant)│  │ (Observ.)    │  │  (PSK)       │           │
77//! │  └──────────────┘  └──────────────┘  └──────────────┘           │
78//! └─────────────────────────────────────────────────────────────────┘
79//! ```
80
81use anyhow::Result;
82use async_trait::async_trait;
83
84// Core modules
85pub mod auto;
86pub mod cloud_bridge;
87pub mod config;
88pub mod discovery;
89pub mod error;
90pub mod identity;
91pub mod ring;
92pub mod topology;
93pub mod transport;
94
95// Advanced modules
96pub mod collective;
97pub mod compression;
98pub mod election;
99pub mod health;
100pub mod metrics;
101pub mod namespace;
102
103// Re-exports for convenience
104pub use auto::{AutoDiscoveryBackend, AutoDiscoveryConfig};
105pub use collective::{AllReduceStrategy, BroadcastStrategy, CollectiveConfig, ReduceStrategy};
106pub use compression::{CompressionStrategy, GradientCompressor, QuantizationType};
107pub use config::DistributedConfig;
108pub use election::{ElectionConfig, ElectionEvent, ElectionManager, ElectionState};
109pub use error::{DistributedError, DistributedResult};
110pub use health::{HealthConfig, HealthEvent, HealthMonitor, HealthStatus, HealthSummary};
111pub use identity::NodeIdentity;
112pub use metrics::{DistributedMetrics, MetricsSnapshot, SharedMetrics};
113pub use namespace::NetworkNamespace;
114pub use ring::RingBackend;
115pub use topology::{ClusterTopology, ConnectionProfile, NodeProfile, SharedTopology};
116
117/// Interface for distributed operations.
118#[async_trait]
119pub trait DistributedBackend: Send + Sync {
120    /// Get the rank of this node (0 to world_size - 1).
121    fn rank(&self) -> usize;
122
123    /// Get the total number of nodes.
124    fn world_size(&self) -> usize;
125
126    /// Perform an all-reduce operation on a buffer (sum).
127    ///
128    /// The input buffer contains the local gradients.
129    /// On return, it should contain the sum of gradients from all nodes.
130    async fn all_reduce(&self, buffer: &mut [u8]) -> Result<()>;
131
132    /// Barrier synchronization.
133    async fn barrier(&self) -> Result<()>;
134}
135
136/// A handle to the distributed runtime.
137pub struct DistributedContext {
138    backend: Box<dyn DistributedBackend>,
139    metrics: Option<SharedMetrics>,
140}
141
142impl DistributedContext {
143    /// Create a new distributed context with the given backend.
144    pub fn new(backend: Box<dyn DistributedBackend>) -> Self {
145        Self {
146            backend,
147            metrics: None,
148        }
149    }
150
151    /// Create a new distributed context with metrics enabled.
152    pub fn with_metrics(backend: Box<dyn DistributedBackend>, metrics: SharedMetrics) -> Self {
153        Self {
154            backend,
155            metrics: Some(metrics),
156        }
157    }
158
159    /// Get the rank of this node.
160    pub fn rank(&self) -> usize {
161        self.backend.rank()
162    }
163
164    /// Get the total number of nodes in the cluster.
165    pub fn world_size(&self) -> usize {
166        self.backend.world_size()
167    }
168
169    /// Perform an all-reduce operation (sum) on the buffer.
170    ///
171    /// After this call, all nodes will have the same values in their buffers,
172    /// which is the element-wise sum of all input buffers.
173    pub async fn all_reduce(&self, buffer: &mut [u8]) -> Result<()> {
174        let start = std::time::Instant::now();
175        let result = self.backend.all_reduce(buffer).await;
176
177        if let Some(ref metrics) = self.metrics {
178            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
179            metrics.all_reduce.duration_ms.observe(duration_ms);
180            metrics.all_reduce.bytes_processed.add(buffer.len() as u64);
181
182            if result.is_ok() {
183                metrics.all_reduce.completed.inc();
184            } else {
185                metrics.all_reduce.failed.inc();
186            }
187        }
188
189        result
190    }
191
192    /// Synchronize all nodes at a barrier.
193    ///
194    /// All nodes must call this method, and none will proceed until all have.
195    pub async fn barrier(&self) -> Result<()> {
196        let start = std::time::Instant::now();
197        let result = self.backend.barrier().await;
198
199        if let Some(ref metrics) = self.metrics {
200            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
201            metrics.barrier.duration_ms.observe(duration_ms);
202
203            if result.is_ok() {
204                metrics.barrier.completed.inc();
205            } else {
206                metrics.barrier.failed.inc();
207            }
208        }
209
210        result
211    }
212
213    /// Check if this is the master node (rank 0).
214    pub fn is_master(&self) -> bool {
215        self.rank() == 0
216    }
217
218    /// Get metrics snapshot if enabled.
219    pub fn metrics_snapshot(&self) -> Option<MetricsSnapshot> {
220        self.metrics.as_ref().map(|m| m.snapshot())
221    }
222}
223
224/// Prelude for convenient imports.
225pub mod prelude {
226    pub use crate::DistributedBackend;
227    pub use crate::DistributedContext;
228    pub use crate::auto::{AutoDiscoveryBackend, AutoDiscoveryConfig};
229    pub use crate::collective::{AllReduceStrategy, CollectiveConfig};
230    pub use crate::compression::{CompressionStrategy, GradientCompressor};
231    pub use crate::config::DistributedConfig;
232    pub use crate::election::{ElectionConfig, ElectionManager};
233    pub use crate::error::{DistributedError, DistributedResult};
234    pub use crate::health::{HealthConfig, HealthMonitor, HealthStatus};
235    pub use crate::identity::NodeIdentity;
236    pub use crate::metrics::{DistributedMetrics, SharedMetrics};
237    pub use crate::namespace::NetworkNamespace;
238    pub use crate::ring::RingBackend;
239    pub use crate::topology::{ClusterTopology, NodeProfile};
240}