Skip to main content

pmetal_distributed/
auto.rs

1//! Auto-discovery distributed backend.
2//!
3//! This module provides a zero-configuration distributed training backend
4//! that automatically discovers peers on the local network using mDNS/Bonjour.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! use pmetal_distributed::{AutoDiscoveryBackend, DistributedContext};
10//!
11//! // Create backend with automatic peer discovery
12//! let backend = AutoDiscoveryBackend::new().await?;
13//!
14//! // Wait for peers to join
15//! backend.wait_for_peers(2, Duration::from_secs(30)).await?;
16//!
17//! // Use for distributed training
18//! let ctx = DistributedContext::new(Box::new(backend));
19//! ctx.all_reduce(&mut gradients).await?;
20//! ```
21
22use crate::discovery::{DiscoveryEvent, DiscoveryService};
23use crate::error::DistributedError;
24use crate::identity::NodeIdentity;
25use crate::topology::{NodeProfile, SharedTopology, new_shared_topology};
26use crate::transport::{TcpTransport, TransportReceiver, TransportSender};
27use crate::{DistributedBackend, ReduceOp};
28use anyhow::Result;
29use async_trait::async_trait;
30use libp2p::PeerId;
31use parking_lot::RwLock;
32use std::net::SocketAddr;
33use std::sync::Arc;
34use std::time::Duration;
35use tokio::sync::{Mutex, OnceCell, mpsc};
36use tracing::{debug, error, info, warn};
37use zerocopy::{FromBytes, IntoBytes};
38
39/// Default port for gradient exchange.
40const DEFAULT_GRADIENT_PORT: u16 = 52416;
41
42/// Default port for discovery/libp2p.
43const DEFAULT_DISCOVERY_PORT: u16 = 52415;
44
45/// Configuration for auto-discovery backend.
46#[derive(Debug, Clone)]
47pub struct AutoDiscoveryConfig {
48    /// Port for gradient exchange (default: 52416).
49    pub gradient_port: u16,
50    /// Port for libp2p discovery (default: 52415).
51    pub discovery_port: u16,
52    /// Minimum peers required before training can start.
53    pub min_peers: usize,
54    /// Maximum time to wait for peers.
55    pub peer_timeout: Duration,
56    /// Local node profile (for topology awareness).
57    pub profile: NodeProfile,
58}
59
60impl Default for AutoDiscoveryConfig {
61    fn default() -> Self {
62        Self {
63            gradient_port: DEFAULT_GRADIENT_PORT,
64            discovery_port: DEFAULT_DISCOVERY_PORT,
65            min_peers: 1,
66            peer_timeout: Duration::from_secs(60),
67            profile: NodeProfile::default(),
68        }
69    }
70}
71
72/// Auto-discovery distributed backend.
73///
74/// Automatically discovers peers on the local network using mDNS/Bonjour
75/// and establishes connections for gradient synchronization.
76pub struct AutoDiscoveryBackend {
77    /// Our node identity.
78    identity: NodeIdentity,
79    /// Configuration.
80    config: AutoDiscoveryConfig,
81    /// Cluster topology.
82    topology: SharedTopology,
83    /// Discovery state.
84    discovery_state: Arc<RwLock<crate::discovery::DiscoveryState>>,
85    /// Ring connections (sender to next, receiver from prev).
86    ring_connections: Mutex<Option<(TransportSender, TransportReceiver)>>,
87    /// Event receiver from discovery service.
88    event_rx: Mutex<mpsc::Receiver<DiscoveryEvent>>,
89    /// Ensures `establish_ring_inner` runs exactly once, even under concurrent
90    /// calls.  Replaces the former `AtomicBool` which had a TOCTOU race:
91    /// two callers could both observe `false` and both attempt to connect.
92    ring_init: OnceCell<()>,
93}
94
95impl AutoDiscoveryBackend {
96    /// Create a new auto-discovery backend with default configuration.
97    pub async fn new() -> Result<Self> {
98        Self::with_config(AutoDiscoveryConfig::default()).await
99    }
100
101    /// Create a new auto-discovery backend with custom configuration.
102    pub async fn with_config(config: AutoDiscoveryConfig) -> Result<Self> {
103        let identity = NodeIdentity::load_or_generate()?;
104        let topology = new_shared_topology(*identity.peer_id(), config.profile.clone());
105
106        // Create event channel
107        let (event_tx, event_rx) = mpsc::channel(256);
108
109        // Create and spawn discovery service
110        let discovery = DiscoveryService::new(identity.clone(), config.discovery_port, event_tx);
111        let discovery_state = discovery.state();
112
113        // Spawn discovery in background
114        tokio::spawn(async move {
115            if let Err(e) = discovery.run().await {
116                error!("Discovery service error: {}", e);
117            }
118        });
119
120        info!(
121            "AutoDiscoveryBackend initialized: peer_id={}, gradient_port={}, discovery_port={}",
122            identity.peer_id(),
123            config.gradient_port,
124            config.discovery_port
125        );
126
127        Ok(Self {
128            identity,
129            config,
130            topology,
131            discovery_state,
132            ring_connections: Mutex::new(None),
133            event_rx: Mutex::new(event_rx),
134            ring_init: OnceCell::new(),
135        })
136    }
137
138    /// Get the local node's peer ID.
139    pub fn peer_id(&self) -> &PeerId {
140        self.identity.peer_id()
141    }
142
143    /// Get the local node's peer ID as a string.
144    pub fn peer_id_string(&self) -> String {
145        self.identity.peer_id_string()
146    }
147
148    /// Get the current cluster topology.
149    pub fn topology(&self) -> SharedTopology {
150        Arc::clone(&self.topology)
151    }
152
153    /// Get the number of discovered peers.
154    pub fn peer_count(&self) -> usize {
155        self.discovery_state.read().connected_count()
156    }
157
158    /// Wait for a minimum number of peers to be discovered.
159    ///
160    /// Returns the number of peers found, or an error if timeout occurs.
161    pub async fn wait_for_peers(
162        &self,
163        min_peers: usize,
164        timeout_duration: Duration,
165    ) -> Result<usize> {
166        info!(
167            "Waiting for {} peers (timeout: {:?})",
168            min_peers, timeout_duration
169        );
170
171        let start = std::time::Instant::now();
172
173        while start.elapsed() < timeout_duration {
174            // Process discovery events
175            {
176                let mut rx = self.event_rx.lock().await;
177                while let Ok(event) = rx.try_recv() {
178                    self.handle_discovery_event(event).await;
179                }
180            }
181
182            let count = self.peer_count();
183            if count >= min_peers {
184                info!("Found {} peers, proceeding", count);
185                return Ok(count);
186            }
187
188            // Brief sleep before checking again
189            tokio::time::sleep(Duration::from_millis(100)).await;
190        }
191
192        let count = self.peer_count();
193        if count >= min_peers {
194            Ok(count)
195        } else {
196            Err(DistributedError::Protocol(format!(
197                "Timeout waiting for peers: found {} of {} required",
198                count, min_peers
199            ))
200            .into())
201        }
202    }
203
204    /// Handle a discovery event.
205    async fn handle_discovery_event(&self, event: DiscoveryEvent) {
206        match event {
207            DiscoveryEvent::PeerDiscovered { peer_id, addresses } => {
208                debug!("Discovered peer: {} at {:?}", peer_id, addresses);
209            }
210            DiscoveryEvent::PeerConnected { peer_id, address } => {
211                info!("Connected to peer: {} at {}", peer_id, address);
212
213                let mut topology = self.topology.write();
214                topology.add_node(peer_id, Some(address));
215            }
216            DiscoveryEvent::PeerDisconnected { peer_id } => {
217                warn!("Disconnected from peer: {}", peer_id);
218
219                let mut topology = self.topology.write();
220                topology.remove_node(&peer_id);
221
222                // Note: ring_init is a OnceCell and cannot be reset.  A peer
223                // disconnect means the ring is broken; callers must create a
224                // new AutoDiscoveryBackend to reform the ring.
225            }
226            DiscoveryEvent::PeerExpired { peer_id } => {
227                debug!("Peer expired: {}", peer_id);
228            }
229            DiscoveryEvent::Message { peer_id, data } => {
230                debug!("Message from {}: {} bytes", peer_id, data.len());
231            }
232        }
233    }
234
235    /// Internal ring setup — performs the actual TCP connection work.
236    ///
237    /// Called at most once via `ring_init.get_or_init(...)`.
238    async fn establish_ring_inner(&self) -> Result<()> {
239        // Collect all needed data from topology while holding the lock
240        let (local_rank, world_size, node_addrs, peer_ids) = {
241            let topology = self.topology.read();
242
243            if !topology.can_form_ring() {
244                return Err(DistributedError::Protocol(
245                    "Not enough peers to form ring (need at least 2 nodes)".into(),
246                )
247                .into());
248            }
249
250            let ring_order = topology.ring_order();
251            let local_rank = topology.local_rank();
252            let world_size = ring_order.len();
253
254            // Collect socket addresses in ring order
255            let node_addrs: Vec<SocketAddr> = ring_order
256                .iter()
257                .filter_map(|n| n.socket_addr)
258                .map(|a| SocketAddr::new(a.ip(), self.config.gradient_port))
259                .collect();
260
261            // Collect peer IDs for logging
262            let peer_ids: Vec<String> = ring_order.iter().map(|n| n.peer_id.to_base58()).collect();
263
264            (local_rank, world_size, node_addrs, peer_ids)
265        }; // topology lock released here
266
267        info!(
268            "Establishing ring: rank={}/{}, peers={:?}",
269            local_rank, world_size, peer_ids
270        );
271
272        if node_addrs.len() < 2 {
273            return Err(DistributedError::Protocol(
274                "Not enough peers with known addresses to form ring".into(),
275            )
276            .into());
277        }
278
279        // Create configuration for TCP transport
280        let config = crate::config::DistributedConfig {
281            nodes: node_addrs,
282            rank: local_rank,
283            connection_timeout_ms: 30000,
284            max_retries: 50,
285        };
286
287        // Establish ring connections
288        let (sender, receiver) = TcpTransport::connect(&config).await?;
289
290        *self.ring_connections.lock().await = Some((sender, receiver));
291
292        info!("Ring established successfully");
293        Ok(())
294    }
295
296    /// Ensure the ring is established, initialising it exactly once.
297    ///
298    /// Uses `tokio::sync::OnceCell::get_or_try_init` so that exactly one
299    /// concurrent caller performs the TCP connection and all others wait.
300    /// If the connection attempt fails the cell remains unset, allowing the
301    /// caller to retry on the next all-reduce.
302    pub async fn establish_ring(&self) -> Result<()> {
303        self.ring_init
304            .get_or_try_init(|| async { self.establish_ring_inner().await })
305            .await?;
306        Ok(())
307    }
308
309    /// Check if the ring has been successfully established.
310    ///
311    /// Returns `true` iff `establish_ring` has completed successfully at least
312    /// once.  This is a cheap non-blocking check suitable for logging.
313    pub fn is_ring_ready(&self) -> bool {
314        self.ring_init.initialized()
315    }
316}
317
318#[async_trait]
319impl DistributedBackend for AutoDiscoveryBackend {
320    fn rank(&self) -> usize {
321        self.topology.read().local_rank()
322    }
323
324    fn world_size(&self) -> usize {
325        self.topology.read().node_count()
326    }
327
328    async fn all_reduce(&self, buffer: &mut [u8], op: ReduceOp) -> Result<()> {
329        // establish_ring is idempotent — a no-op after the first successful call.
330        self.establish_ring().await?;
331
332        // Validate buffer
333        if !buffer.len().is_multiple_of(4) {
334            return Err(DistributedError::Protocol(format!(
335                "Buffer length {} is not a multiple of 4 (f32 size)",
336                buffer.len()
337            ))
338            .into());
339        }
340
341        if !(buffer.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()) {
342            return Err(DistributedError::Protocol(
343                "Buffer is not properly aligned for f32 operations".into(),
344            )
345            .into());
346        }
347
348        let floats: &mut [f32] = <[f32]>::mut_from_bytes(buffer)
349            .map_err(|e| DistributedError::Protocol(format!("Buffer cast failed: {e}")))?;
350        let len = floats.len();
351        let world_size = self.world_size();
352        let rank = self.rank();
353
354        if world_size < 2 {
355            return Ok(()); // Nothing to reduce
356        }
357
358        let chunk_size = len / world_size;
359        let remainder = len % world_size;
360
361        let get_chunk_range = |idx: usize| -> (usize, usize) {
362            let start = idx * chunk_size + idx.min(remainder);
363            let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
364            (start, end)
365        };
366
367        let mut connections = self.ring_connections.lock().await;
368        let (sender, receiver) = connections
369            .as_mut()
370            .ok_or_else(|| DistributedError::Protocol("Ring not established".into()))?;
371
372        // === SCATTER-REDUCE PHASE ===
373        for step in 0..(world_size - 1) {
374            let send_idx = (rank + world_size - step) % world_size;
375            let recv_idx = (rank + world_size - step - 1) % world_size;
376
377            let (send_start, send_end) = get_chunk_range(send_idx);
378            let (recv_start, recv_end) = get_chunk_range(recv_idx);
379
380            let recv_bytes_len = (recv_end - recv_start) * 4;
381
382            // Copy data to send buffer
383            let send_buf = floats[send_start..send_end].as_bytes().to_vec();
384
385            // Send and receive concurrently
386            let mut recv_buf = vec![0u8; recv_bytes_len];
387            tokio::try_join!(sender.send(&send_buf), receiver.recv(&mut recv_buf))?;
388
389            // Reduce received data into local buffer
390            let recv_floats =
391                <[f32]>::ref_from_bytes(&recv_buf).expect("recv buffer aligned for f32");
392            for (i, &val) in recv_floats.iter().enumerate() {
393                floats[recv_start + i] += val;
394            }
395        }
396
397        // === ALL-GATHER PHASE ===
398        for step in 0..(world_size - 1) {
399            let send_idx = (rank + world_size - step) % world_size;
400            let recv_idx = (rank + world_size - step - 1) % world_size;
401
402            let (send_start, send_end) = get_chunk_range(send_idx);
403            let (recv_start, recv_end) = get_chunk_range(recv_idx);
404
405            let recv_bytes_len = (recv_end - recv_start) * 4;
406
407            let send_buf: &[u8] = floats[send_start..send_end].as_bytes();
408
409            let mut recv_buf = vec![0u8; recv_bytes_len];
410            tokio::try_join!(sender.send(send_buf), receiver.recv(&mut recv_buf))?;
411
412            // Copy received data to local buffer
413            let recv_floats =
414                <[f32]>::ref_from_bytes(&recv_buf).expect("recv buffer aligned for f32");
415            floats[recv_start..recv_end].copy_from_slice(recv_floats);
416        }
417
418        // Apply mean reduction after the ring has summed all contributions.
419        if op == ReduceOp::Mean {
420            let divisor = world_size as f32;
421            for f in floats.iter_mut() {
422                *f /= divisor;
423            }
424        }
425
426        Ok(())
427    }
428
429    async fn barrier(&self) -> Result<()> {
430        self.establish_ring().await?;
431
432        let world_size = self.world_size();
433        if world_size < 2 {
434            return Ok(());
435        }
436
437        let mut connections = self.ring_connections.lock().await;
438        let (sender, receiver) = connections
439            .as_mut()
440            .ok_or_else(|| DistributedError::Protocol("Ring not established".into()))?;
441
442        // Simple barrier: send a token around the ring
443        let token = [0u8; 4];
444
445        for _ in 0..(world_size - 1) {
446            let mut recv_buf = [0u8; 4];
447            tokio::try_join!(sender.send(&token), receiver.recv(&mut recv_buf))?;
448        }
449
450        Ok(())
451    }
452}
453
454impl std::fmt::Debug for AutoDiscoveryBackend {
455    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456        f.debug_struct("AutoDiscoveryBackend")
457            .field("peer_id", &self.identity.peer_id_string())
458            .field("peer_count", &self.peer_count())
459            .field("ring_ready", &self.is_ring_ready())
460            .finish()
461    }
462}