1use crate::discovery::{DiscoveryEvent, DiscoveryService};
23use crate::error::DistributedError;
24use crate::identity::NodeIdentity;
25use crate::topology::{NodeProfile, SharedTopology, new_shared_topology};
26use crate::transport::{TcpTransport, TransportReceiver, TransportSender};
27use crate::{DistributedBackend, ReduceOp};
28use anyhow::Result;
29use async_trait::async_trait;
30use libp2p::PeerId;
31use parking_lot::RwLock;
32use std::net::SocketAddr;
33use std::sync::Arc;
34use std::time::Duration;
35use tokio::sync::{Mutex, OnceCell, mpsc};
36use tracing::{debug, error, info, warn};
37use zerocopy::{FromBytes, IntoBytes};
38
39const DEFAULT_GRADIENT_PORT: u16 = 52416;
41
42const DEFAULT_DISCOVERY_PORT: u16 = 52415;
44
45#[derive(Debug, Clone)]
47pub struct AutoDiscoveryConfig {
48 pub gradient_port: u16,
50 pub discovery_port: u16,
52 pub min_peers: usize,
54 pub peer_timeout: Duration,
56 pub profile: NodeProfile,
58}
59
60impl Default for AutoDiscoveryConfig {
61 fn default() -> Self {
62 Self {
63 gradient_port: DEFAULT_GRADIENT_PORT,
64 discovery_port: DEFAULT_DISCOVERY_PORT,
65 min_peers: 1,
66 peer_timeout: Duration::from_secs(60),
67 profile: NodeProfile::default(),
68 }
69 }
70}
71
72pub struct AutoDiscoveryBackend {
77 identity: NodeIdentity,
79 config: AutoDiscoveryConfig,
81 topology: SharedTopology,
83 discovery_state: Arc<RwLock<crate::discovery::DiscoveryState>>,
85 ring_connections: Mutex<Option<(TransportSender, TransportReceiver)>>,
87 event_rx: Mutex<mpsc::Receiver<DiscoveryEvent>>,
89 ring_init: OnceCell<()>,
93}
94
95impl AutoDiscoveryBackend {
96 pub async fn new() -> Result<Self> {
98 Self::with_config(AutoDiscoveryConfig::default()).await
99 }
100
101 pub async fn with_config(config: AutoDiscoveryConfig) -> Result<Self> {
103 let identity = NodeIdentity::load_or_generate()?;
104 let topology = new_shared_topology(*identity.peer_id(), config.profile.clone());
105
106 let (event_tx, event_rx) = mpsc::channel(256);
108
109 let discovery = DiscoveryService::new(identity.clone(), config.discovery_port, event_tx);
111 let discovery_state = discovery.state();
112
113 tokio::spawn(async move {
115 if let Err(e) = discovery.run().await {
116 error!("Discovery service error: {}", e);
117 }
118 });
119
120 info!(
121 "AutoDiscoveryBackend initialized: peer_id={}, gradient_port={}, discovery_port={}",
122 identity.peer_id(),
123 config.gradient_port,
124 config.discovery_port
125 );
126
127 Ok(Self {
128 identity,
129 config,
130 topology,
131 discovery_state,
132 ring_connections: Mutex::new(None),
133 event_rx: Mutex::new(event_rx),
134 ring_init: OnceCell::new(),
135 })
136 }
137
138 pub fn peer_id(&self) -> &PeerId {
140 self.identity.peer_id()
141 }
142
143 pub fn peer_id_string(&self) -> String {
145 self.identity.peer_id_string()
146 }
147
148 pub fn topology(&self) -> SharedTopology {
150 Arc::clone(&self.topology)
151 }
152
153 pub fn peer_count(&self) -> usize {
155 self.discovery_state.read().connected_count()
156 }
157
158 pub async fn wait_for_peers(
162 &self,
163 min_peers: usize,
164 timeout_duration: Duration,
165 ) -> Result<usize> {
166 info!(
167 "Waiting for {} peers (timeout: {:?})",
168 min_peers, timeout_duration
169 );
170
171 let start = std::time::Instant::now();
172
173 while start.elapsed() < timeout_duration {
174 {
176 let mut rx = self.event_rx.lock().await;
177 while let Ok(event) = rx.try_recv() {
178 self.handle_discovery_event(event).await;
179 }
180 }
181
182 let count = self.peer_count();
183 if count >= min_peers {
184 info!("Found {} peers, proceeding", count);
185 return Ok(count);
186 }
187
188 tokio::time::sleep(Duration::from_millis(100)).await;
190 }
191
192 let count = self.peer_count();
193 if count >= min_peers {
194 Ok(count)
195 } else {
196 Err(DistributedError::Protocol(format!(
197 "Timeout waiting for peers: found {} of {} required",
198 count, min_peers
199 ))
200 .into())
201 }
202 }
203
204 async fn handle_discovery_event(&self, event: DiscoveryEvent) {
206 match event {
207 DiscoveryEvent::PeerDiscovered { peer_id, addresses } => {
208 debug!("Discovered peer: {} at {:?}", peer_id, addresses);
209 }
210 DiscoveryEvent::PeerConnected { peer_id, address } => {
211 info!("Connected to peer: {} at {}", peer_id, address);
212
213 let mut topology = self.topology.write();
214 topology.add_node(peer_id, Some(address));
215 }
216 DiscoveryEvent::PeerDisconnected { peer_id } => {
217 warn!("Disconnected from peer: {}", peer_id);
218
219 let mut topology = self.topology.write();
220 topology.remove_node(&peer_id);
221
222 }
226 DiscoveryEvent::PeerExpired { peer_id } => {
227 debug!("Peer expired: {}", peer_id);
228 }
229 DiscoveryEvent::Message { peer_id, data } => {
230 debug!("Message from {}: {} bytes", peer_id, data.len());
231 }
232 }
233 }
234
235 async fn establish_ring_inner(&self) -> Result<()> {
239 let (local_rank, world_size, node_addrs, peer_ids) = {
241 let topology = self.topology.read();
242
243 if !topology.can_form_ring() {
244 return Err(DistributedError::Protocol(
245 "Not enough peers to form ring (need at least 2 nodes)".into(),
246 )
247 .into());
248 }
249
250 let ring_order = topology.ring_order();
251 let local_rank = topology.local_rank();
252 let world_size = ring_order.len();
253
254 let node_addrs: Vec<SocketAddr> = ring_order
256 .iter()
257 .filter_map(|n| n.socket_addr)
258 .map(|a| SocketAddr::new(a.ip(), self.config.gradient_port))
259 .collect();
260
261 let peer_ids: Vec<String> = ring_order.iter().map(|n| n.peer_id.to_base58()).collect();
263
264 (local_rank, world_size, node_addrs, peer_ids)
265 }; info!(
268 "Establishing ring: rank={}/{}, peers={:?}",
269 local_rank, world_size, peer_ids
270 );
271
272 if node_addrs.len() < 2 {
273 return Err(DistributedError::Protocol(
274 "Not enough peers with known addresses to form ring".into(),
275 )
276 .into());
277 }
278
279 let config = crate::config::DistributedConfig {
281 nodes: node_addrs,
282 rank: local_rank,
283 connection_timeout_ms: 30000,
284 max_retries: 50,
285 };
286
287 let (sender, receiver) = TcpTransport::connect(&config).await?;
289
290 *self.ring_connections.lock().await = Some((sender, receiver));
291
292 info!("Ring established successfully");
293 Ok(())
294 }
295
296 pub async fn establish_ring(&self) -> Result<()> {
303 self.ring_init
304 .get_or_try_init(|| async { self.establish_ring_inner().await })
305 .await?;
306 Ok(())
307 }
308
309 pub fn is_ring_ready(&self) -> bool {
314 self.ring_init.initialized()
315 }
316}
317
318#[async_trait]
319impl DistributedBackend for AutoDiscoveryBackend {
320 fn rank(&self) -> usize {
321 self.topology.read().local_rank()
322 }
323
324 fn world_size(&self) -> usize {
325 self.topology.read().node_count()
326 }
327
328 async fn all_reduce(&self, buffer: &mut [u8], op: ReduceOp) -> Result<()> {
329 self.establish_ring().await?;
331
332 if !buffer.len().is_multiple_of(4) {
334 return Err(DistributedError::Protocol(format!(
335 "Buffer length {} is not a multiple of 4 (f32 size)",
336 buffer.len()
337 ))
338 .into());
339 }
340
341 if !(buffer.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()) {
342 return Err(DistributedError::Protocol(
343 "Buffer is not properly aligned for f32 operations".into(),
344 )
345 .into());
346 }
347
348 let floats: &mut [f32] = <[f32]>::mut_from_bytes(buffer)
349 .map_err(|e| DistributedError::Protocol(format!("Buffer cast failed: {e}")))?;
350 let len = floats.len();
351 let world_size = self.world_size();
352 let rank = self.rank();
353
354 if world_size < 2 {
355 return Ok(()); }
357
358 let chunk_size = len / world_size;
359 let remainder = len % world_size;
360
361 let get_chunk_range = |idx: usize| -> (usize, usize) {
362 let start = idx * chunk_size + idx.min(remainder);
363 let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
364 (start, end)
365 };
366
367 let mut connections = self.ring_connections.lock().await;
368 let (sender, receiver) = connections
369 .as_mut()
370 .ok_or_else(|| DistributedError::Protocol("Ring not established".into()))?;
371
372 for step in 0..(world_size - 1) {
374 let send_idx = (rank + world_size - step) % world_size;
375 let recv_idx = (rank + world_size - step - 1) % world_size;
376
377 let (send_start, send_end) = get_chunk_range(send_idx);
378 let (recv_start, recv_end) = get_chunk_range(recv_idx);
379
380 let recv_bytes_len = (recv_end - recv_start) * 4;
381
382 let send_buf = floats[send_start..send_end].as_bytes().to_vec();
384
385 let mut recv_buf = vec![0u8; recv_bytes_len];
387 tokio::try_join!(sender.send(&send_buf), receiver.recv(&mut recv_buf))?;
388
389 let recv_floats =
391 <[f32]>::ref_from_bytes(&recv_buf).expect("recv buffer aligned for f32");
392 for (i, &val) in recv_floats.iter().enumerate() {
393 floats[recv_start + i] += val;
394 }
395 }
396
397 for step in 0..(world_size - 1) {
399 let send_idx = (rank + world_size - step) % world_size;
400 let recv_idx = (rank + world_size - step - 1) % world_size;
401
402 let (send_start, send_end) = get_chunk_range(send_idx);
403 let (recv_start, recv_end) = get_chunk_range(recv_idx);
404
405 let recv_bytes_len = (recv_end - recv_start) * 4;
406
407 let send_buf: &[u8] = floats[send_start..send_end].as_bytes();
408
409 let mut recv_buf = vec![0u8; recv_bytes_len];
410 tokio::try_join!(sender.send(send_buf), receiver.recv(&mut recv_buf))?;
411
412 let recv_floats =
414 <[f32]>::ref_from_bytes(&recv_buf).expect("recv buffer aligned for f32");
415 floats[recv_start..recv_end].copy_from_slice(recv_floats);
416 }
417
418 if op == ReduceOp::Mean {
420 let divisor = world_size as f32;
421 for f in floats.iter_mut() {
422 *f /= divisor;
423 }
424 }
425
426 Ok(())
427 }
428
429 async fn barrier(&self) -> Result<()> {
430 self.establish_ring().await?;
431
432 let world_size = self.world_size();
433 if world_size < 2 {
434 return Ok(());
435 }
436
437 let mut connections = self.ring_connections.lock().await;
438 let (sender, receiver) = connections
439 .as_mut()
440 .ok_or_else(|| DistributedError::Protocol("Ring not established".into()))?;
441
442 let token = [0u8; 4];
444
445 for _ in 0..(world_size - 1) {
446 let mut recv_buf = [0u8; 4];
447 tokio::try_join!(sender.send(&token), receiver.recv(&mut recv_buf))?;
448 }
449
450 Ok(())
451 }
452}
453
454impl std::fmt::Debug for AutoDiscoveryBackend {
455 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
456 f.debug_struct("AutoDiscoveryBackend")
457 .field("peer_id", &self.identity.peer_id_string())
458 .field("peer_count", &self.peer_count())
459 .field("ring_ready", &self.is_ring_ready())
460 .finish()
461 }
462}