Skip to main content

pmetal_distributed/
auto.rs

1//! Auto-discovery distributed backend.
2//!
3//! This module provides a zero-configuration distributed training backend
4//! that automatically discovers peers on the local network using mDNS/Bonjour.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! use pmetal_distributed::{AutoDiscoveryBackend, DistributedContext};
10//!
11//! // Create backend with automatic peer discovery
12//! let backend = AutoDiscoveryBackend::new().await?;
13//!
14//! // Wait for peers to join
15//! backend.wait_for_peers(2, Duration::from_secs(30)).await?;
16//!
17//! // Use for distributed training
18//! let ctx = DistributedContext::new(Box::new(backend));
19//! ctx.all_reduce(&mut gradients).await?;
20//! ```
21
22use crate::DistributedBackend;
23use crate::discovery::{DiscoveryEvent, DiscoveryService};
24use crate::error::DistributedError;
25use crate::identity::NodeIdentity;
26use crate::topology::{NodeProfile, SharedTopology, new_shared_topology};
27use crate::transport::{TcpTransport, TransportReceiver, TransportSender};
28use anyhow::Result;
29use async_trait::async_trait;
30use libp2p::PeerId;
31use parking_lot::RwLock;
32use std::net::SocketAddr;
33use std::sync::Arc;
34use std::time::Duration;
35use tokio::sync::{Mutex, mpsc};
36use tracing::{debug, error, info, warn};
37use zerocopy::{FromBytes, IntoBytes};
38
39/// Default port for gradient exchange.
40const DEFAULT_GRADIENT_PORT: u16 = 52416;
41
42/// Default port for discovery/libp2p.
43const DEFAULT_DISCOVERY_PORT: u16 = 52415;
44
45/// Configuration for auto-discovery backend.
46#[derive(Debug, Clone)]
47pub struct AutoDiscoveryConfig {
48    /// Port for gradient exchange (default: 52416).
49    pub gradient_port: u16,
50    /// Port for libp2p discovery (default: 52415).
51    pub discovery_port: u16,
52    /// Minimum peers required before training can start.
53    pub min_peers: usize,
54    /// Maximum time to wait for peers.
55    pub peer_timeout: Duration,
56    /// Local node profile (for topology awareness).
57    pub profile: NodeProfile,
58}
59
60impl Default for AutoDiscoveryConfig {
61    fn default() -> Self {
62        Self {
63            gradient_port: DEFAULT_GRADIENT_PORT,
64            discovery_port: DEFAULT_DISCOVERY_PORT,
65            min_peers: 1,
66            peer_timeout: Duration::from_secs(60),
67            profile: NodeProfile::default(),
68        }
69    }
70}
71
72/// Auto-discovery distributed backend.
73///
74/// Automatically discovers peers on the local network using mDNS/Bonjour
75/// and establishes connections for gradient synchronization.
76pub struct AutoDiscoveryBackend {
77    /// Our node identity.
78    identity: NodeIdentity,
79    /// Configuration.
80    config: AutoDiscoveryConfig,
81    /// Cluster topology.
82    topology: SharedTopology,
83    /// Discovery state.
84    discovery_state: Arc<RwLock<crate::discovery::DiscoveryState>>,
85    /// Ring connections (sender to next, receiver from prev).
86    ring_connections: Mutex<Option<(TransportSender, TransportReceiver)>>,
87    /// Event receiver from discovery service.
88    event_rx: Mutex<mpsc::Receiver<DiscoveryEvent>>,
89    /// Whether the ring is established.
90    ring_ready: Arc<std::sync::atomic::AtomicBool>,
91}
92
93impl AutoDiscoveryBackend {
94    /// Create a new auto-discovery backend with default configuration.
95    pub async fn new() -> Result<Self> {
96        Self::with_config(AutoDiscoveryConfig::default()).await
97    }
98
99    /// Create a new auto-discovery backend with custom configuration.
100    pub async fn with_config(config: AutoDiscoveryConfig) -> Result<Self> {
101        let identity = NodeIdentity::load_or_generate()?;
102        let topology = new_shared_topology(*identity.peer_id(), config.profile.clone());
103
104        // Create event channel
105        let (event_tx, event_rx) = mpsc::channel(256);
106
107        // Create and spawn discovery service
108        let discovery = DiscoveryService::new(identity.clone(), config.discovery_port, event_tx);
109        let discovery_state = discovery.state();
110
111        // Spawn discovery in background
112        tokio::spawn(async move {
113            if let Err(e) = discovery.run().await {
114                error!("Discovery service error: {}", e);
115            }
116        });
117
118        info!(
119            "AutoDiscoveryBackend initialized: peer_id={}, gradient_port={}, discovery_port={}",
120            identity.peer_id(),
121            config.gradient_port,
122            config.discovery_port
123        );
124
125        Ok(Self {
126            identity,
127            config,
128            topology,
129            discovery_state,
130            ring_connections: Mutex::new(None),
131            event_rx: Mutex::new(event_rx),
132            ring_ready: Arc::new(std::sync::atomic::AtomicBool::new(false)),
133        })
134    }
135
136    /// Get the local node's peer ID.
137    pub fn peer_id(&self) -> &PeerId {
138        self.identity.peer_id()
139    }
140
141    /// Get the local node's peer ID as a string.
142    pub fn peer_id_string(&self) -> String {
143        self.identity.peer_id_string()
144    }
145
146    /// Get the current cluster topology.
147    pub fn topology(&self) -> SharedTopology {
148        Arc::clone(&self.topology)
149    }
150
151    /// Get the number of discovered peers.
152    pub fn peer_count(&self) -> usize {
153        self.discovery_state.read().connected_count()
154    }
155
156    /// Wait for a minimum number of peers to be discovered.
157    ///
158    /// Returns the number of peers found, or an error if timeout occurs.
159    pub async fn wait_for_peers(
160        &self,
161        min_peers: usize,
162        timeout_duration: Duration,
163    ) -> Result<usize> {
164        info!(
165            "Waiting for {} peers (timeout: {:?})",
166            min_peers, timeout_duration
167        );
168
169        let start = std::time::Instant::now();
170
171        while start.elapsed() < timeout_duration {
172            // Process discovery events
173            {
174                let mut rx = self.event_rx.lock().await;
175                while let Ok(event) = rx.try_recv() {
176                    self.handle_discovery_event(event).await;
177                }
178            }
179
180            let count = self.peer_count();
181            if count >= min_peers {
182                info!("Found {} peers, proceeding", count);
183                return Ok(count);
184            }
185
186            // Brief sleep before checking again
187            tokio::time::sleep(Duration::from_millis(100)).await;
188        }
189
190        let count = self.peer_count();
191        if count >= min_peers {
192            Ok(count)
193        } else {
194            Err(DistributedError::Protocol(format!(
195                "Timeout waiting for peers: found {} of {} required",
196                count, min_peers
197            ))
198            .into())
199        }
200    }
201
202    /// Handle a discovery event.
203    async fn handle_discovery_event(&self, event: DiscoveryEvent) {
204        match event {
205            DiscoveryEvent::PeerDiscovered { peer_id, addresses } => {
206                debug!("Discovered peer: {} at {:?}", peer_id, addresses);
207            }
208            DiscoveryEvent::PeerConnected { peer_id, address } => {
209                info!("Connected to peer: {} at {}", peer_id, address);
210
211                let mut topology = self.topology.write();
212                topology.add_node(peer_id, Some(address));
213            }
214            DiscoveryEvent::PeerDisconnected { peer_id } => {
215                warn!("Disconnected from peer: {}", peer_id);
216
217                let mut topology = self.topology.write();
218                topology.remove_node(&peer_id);
219
220                // Mark ring as not ready
221                self.ring_ready
222                    .store(false, std::sync::atomic::Ordering::SeqCst);
223            }
224            DiscoveryEvent::PeerExpired { peer_id } => {
225                debug!("Peer expired: {}", peer_id);
226            }
227            DiscoveryEvent::Message { peer_id, data } => {
228                debug!("Message from {}: {} bytes", peer_id, data.len());
229            }
230        }
231    }
232
233    /// Establish the ring topology for all-reduce operations.
234    ///
235    /// This must be called before performing all-reduce operations.
236    pub async fn establish_ring(&self) -> Result<()> {
237        // Collect all needed data from topology while holding the lock
238        let (local_rank, world_size, node_addrs, peer_ids) = {
239            let topology = self.topology.read();
240
241            if !topology.can_form_ring() {
242                return Err(DistributedError::Protocol(
243                    "Not enough peers to form ring (need at least 2 nodes)".into(),
244                )
245                .into());
246            }
247
248            let ring_order = topology.ring_order();
249            let local_rank = topology.local_rank();
250            let world_size = ring_order.len();
251
252            // Collect socket addresses in ring order
253            let node_addrs: Vec<SocketAddr> = ring_order
254                .iter()
255                .filter_map(|n| n.socket_addr)
256                .map(|a| SocketAddr::new(a.ip(), self.config.gradient_port))
257                .collect();
258
259            // Collect peer IDs for logging
260            let peer_ids: Vec<String> = ring_order.iter().map(|n| n.peer_id.to_base58()).collect();
261
262            (local_rank, world_size, node_addrs, peer_ids)
263        }; // topology lock released here
264
265        info!(
266            "Establishing ring: rank={}/{}, peers={:?}",
267            local_rank, world_size, peer_ids
268        );
269
270        if node_addrs.len() < 2 {
271            return Err(DistributedError::Protocol(
272                "Not enough peers with known addresses to form ring".into(),
273            )
274            .into());
275        }
276
277        // Create configuration for TCP transport
278        let config = crate::config::DistributedConfig {
279            nodes: node_addrs,
280            rank: local_rank,
281            connection_timeout_ms: 30000,
282            max_retries: 50,
283        };
284
285        // Establish ring connections
286        let (sender, receiver) = TcpTransport::connect(&config).await?;
287
288        *self.ring_connections.lock().await = Some((sender, receiver));
289        self.ring_ready
290            .store(true, std::sync::atomic::Ordering::SeqCst);
291
292        info!("Ring established successfully");
293        Ok(())
294    }
295
296    /// Check if the ring is ready for all-reduce operations.
297    pub fn is_ring_ready(&self) -> bool {
298        self.ring_ready.load(std::sync::atomic::Ordering::SeqCst)
299    }
300}
301
302#[async_trait]
303impl DistributedBackend for AutoDiscoveryBackend {
304    fn rank(&self) -> usize {
305        self.topology.read().local_rank()
306    }
307
308    fn world_size(&self) -> usize {
309        self.topology.read().node_count()
310    }
311
312    async fn all_reduce(&self, buffer: &mut [u8]) -> Result<()> {
313        if !self.is_ring_ready() {
314            self.establish_ring().await?;
315        }
316
317        // Validate buffer
318        if !buffer.len().is_multiple_of(4) {
319            return Err(DistributedError::Protocol(format!(
320                "Buffer length {} is not a multiple of 4 (f32 size)",
321                buffer.len()
322            ))
323            .into());
324        }
325
326        if !(buffer.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()) {
327            return Err(DistributedError::Protocol(
328                "Buffer is not properly aligned for f32 operations".into(),
329            )
330            .into());
331        }
332
333        let floats: &mut [f32] = <[f32]>::mut_from_bytes(buffer)
334            .map_err(|e| DistributedError::Protocol(format!("Buffer cast failed: {e}")))?;
335        let len = floats.len();
336        let world_size = self.world_size();
337        let rank = self.rank();
338
339        if world_size < 2 {
340            return Ok(()); // Nothing to reduce
341        }
342
343        let chunk_size = len / world_size;
344        let remainder = len % world_size;
345
346        let get_chunk_range = |idx: usize| -> (usize, usize) {
347            let start = idx * chunk_size + idx.min(remainder);
348            let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
349            (start, end)
350        };
351
352        let mut connections = self.ring_connections.lock().await;
353        let (sender, receiver) = connections
354            .as_mut()
355            .ok_or_else(|| DistributedError::Protocol("Ring not established".into()))?;
356
357        // === SCATTER-REDUCE PHASE ===
358        for step in 0..(world_size - 1) {
359            let send_idx = (rank + world_size - step) % world_size;
360            let recv_idx = (rank + world_size - step - 1) % world_size;
361
362            let (send_start, send_end) = get_chunk_range(send_idx);
363            let (recv_start, recv_end) = get_chunk_range(recv_idx);
364
365            let recv_bytes_len = (recv_end - recv_start) * 4;
366
367            // Copy data to send buffer
368            let send_buf = floats[send_start..send_end].as_bytes().to_vec();
369
370            // Send and receive concurrently
371            let mut recv_buf = vec![0u8; recv_bytes_len];
372            tokio::try_join!(sender.send(&send_buf), receiver.recv(&mut recv_buf))?;
373
374            // Reduce received data into local buffer
375            let recv_floats =
376                <[f32]>::ref_from_bytes(&recv_buf).expect("recv buffer aligned for f32");
377            for (i, &val) in recv_floats.iter().enumerate() {
378                floats[recv_start + i] += val;
379            }
380        }
381
382        // === ALL-GATHER PHASE ===
383        for step in 0..(world_size - 1) {
384            let send_idx = (rank + world_size - step + 1) % world_size;
385            let recv_idx = (rank + world_size - step) % world_size;
386
387            let (send_start, send_end) = get_chunk_range(send_idx);
388            let (recv_start, recv_end) = get_chunk_range(recv_idx);
389
390            let recv_bytes_len = (recv_end - recv_start) * 4;
391
392            let send_buf: &[u8] = floats[send_start..send_end].as_bytes();
393
394            let mut recv_buf = vec![0u8; recv_bytes_len];
395            tokio::try_join!(sender.send(send_buf), receiver.recv(&mut recv_buf))?;
396
397            // Copy received data to local buffer
398            let recv_floats =
399                <[f32]>::ref_from_bytes(&recv_buf).expect("recv buffer aligned for f32");
400            floats[recv_start..recv_end].copy_from_slice(recv_floats);
401        }
402
403        Ok(())
404    }
405
406    async fn barrier(&self) -> Result<()> {
407        if !self.is_ring_ready() {
408            self.establish_ring().await?;
409        }
410
411        let world_size = self.world_size();
412        if world_size < 2 {
413            return Ok(());
414        }
415
416        let mut connections = self.ring_connections.lock().await;
417        let (sender, receiver) = connections
418            .as_mut()
419            .ok_or_else(|| DistributedError::Protocol("Ring not established".into()))?;
420
421        // Simple barrier: send a token around the ring
422        let token = [0u8; 4];
423
424        for _ in 0..(world_size - 1) {
425            let mut recv_buf = [0u8; 4];
426            tokio::try_join!(sender.send(&token), receiver.recv(&mut recv_buf))?;
427        }
428
429        Ok(())
430    }
431}
432
433impl std::fmt::Debug for AutoDiscoveryBackend {
434    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435        f.debug_struct("AutoDiscoveryBackend")
436            .field("peer_id", &self.identity.peer_id_string())
437            .field("peer_count", &self.peer_count())
438            .field("ring_ready", &self.is_ring_ready())
439            .finish()
440    }
441}