1use crate::DistributedBackend;
23use crate::discovery::{DiscoveryEvent, DiscoveryService};
24use crate::error::DistributedError;
25use crate::identity::NodeIdentity;
26use crate::topology::{NodeProfile, SharedTopology, new_shared_topology};
27use crate::transport::{TcpTransport, TransportReceiver, TransportSender};
28use anyhow::Result;
29use async_trait::async_trait;
30use libp2p::PeerId;
31use parking_lot::RwLock;
32use std::net::SocketAddr;
33use std::sync::Arc;
34use std::time::Duration;
35use tokio::sync::{Mutex, mpsc};
36use tracing::{debug, error, info, warn};
37use zerocopy::{FromBytes, IntoBytes};
38
39const DEFAULT_GRADIENT_PORT: u16 = 52416;
41
42const DEFAULT_DISCOVERY_PORT: u16 = 52415;
44
45#[derive(Debug, Clone)]
47pub struct AutoDiscoveryConfig {
48 pub gradient_port: u16,
50 pub discovery_port: u16,
52 pub min_peers: usize,
54 pub peer_timeout: Duration,
56 pub profile: NodeProfile,
58}
59
60impl Default for AutoDiscoveryConfig {
61 fn default() -> Self {
62 Self {
63 gradient_port: DEFAULT_GRADIENT_PORT,
64 discovery_port: DEFAULT_DISCOVERY_PORT,
65 min_peers: 1,
66 peer_timeout: Duration::from_secs(60),
67 profile: NodeProfile::default(),
68 }
69 }
70}
71
72pub struct AutoDiscoveryBackend {
77 identity: NodeIdentity,
79 config: AutoDiscoveryConfig,
81 topology: SharedTopology,
83 discovery_state: Arc<RwLock<crate::discovery::DiscoveryState>>,
85 ring_connections: Mutex<Option<(TransportSender, TransportReceiver)>>,
87 event_rx: Mutex<mpsc::Receiver<DiscoveryEvent>>,
89 ring_ready: Arc<std::sync::atomic::AtomicBool>,
91}
92
93impl AutoDiscoveryBackend {
94 pub async fn new() -> Result<Self> {
96 Self::with_config(AutoDiscoveryConfig::default()).await
97 }
98
99 pub async fn with_config(config: AutoDiscoveryConfig) -> Result<Self> {
101 let identity = NodeIdentity::load_or_generate()?;
102 let topology = new_shared_topology(*identity.peer_id(), config.profile.clone());
103
104 let (event_tx, event_rx) = mpsc::channel(256);
106
107 let discovery = DiscoveryService::new(identity.clone(), config.discovery_port, event_tx);
109 let discovery_state = discovery.state();
110
111 tokio::spawn(async move {
113 if let Err(e) = discovery.run().await {
114 error!("Discovery service error: {}", e);
115 }
116 });
117
118 info!(
119 "AutoDiscoveryBackend initialized: peer_id={}, gradient_port={}, discovery_port={}",
120 identity.peer_id(),
121 config.gradient_port,
122 config.discovery_port
123 );
124
125 Ok(Self {
126 identity,
127 config,
128 topology,
129 discovery_state,
130 ring_connections: Mutex::new(None),
131 event_rx: Mutex::new(event_rx),
132 ring_ready: Arc::new(std::sync::atomic::AtomicBool::new(false)),
133 })
134 }
135
136 pub fn peer_id(&self) -> &PeerId {
138 self.identity.peer_id()
139 }
140
141 pub fn peer_id_string(&self) -> String {
143 self.identity.peer_id_string()
144 }
145
146 pub fn topology(&self) -> SharedTopology {
148 Arc::clone(&self.topology)
149 }
150
151 pub fn peer_count(&self) -> usize {
153 self.discovery_state.read().connected_count()
154 }
155
156 pub async fn wait_for_peers(
160 &self,
161 min_peers: usize,
162 timeout_duration: Duration,
163 ) -> Result<usize> {
164 info!(
165 "Waiting for {} peers (timeout: {:?})",
166 min_peers, timeout_duration
167 );
168
169 let start = std::time::Instant::now();
170
171 while start.elapsed() < timeout_duration {
172 {
174 let mut rx = self.event_rx.lock().await;
175 while let Ok(event) = rx.try_recv() {
176 self.handle_discovery_event(event).await;
177 }
178 }
179
180 let count = self.peer_count();
181 if count >= min_peers {
182 info!("Found {} peers, proceeding", count);
183 return Ok(count);
184 }
185
186 tokio::time::sleep(Duration::from_millis(100)).await;
188 }
189
190 let count = self.peer_count();
191 if count >= min_peers {
192 Ok(count)
193 } else {
194 Err(DistributedError::Protocol(format!(
195 "Timeout waiting for peers: found {} of {} required",
196 count, min_peers
197 ))
198 .into())
199 }
200 }
201
202 async fn handle_discovery_event(&self, event: DiscoveryEvent) {
204 match event {
205 DiscoveryEvent::PeerDiscovered { peer_id, addresses } => {
206 debug!("Discovered peer: {} at {:?}", peer_id, addresses);
207 }
208 DiscoveryEvent::PeerConnected { peer_id, address } => {
209 info!("Connected to peer: {} at {}", peer_id, address);
210
211 let mut topology = self.topology.write();
212 topology.add_node(peer_id, Some(address));
213 }
214 DiscoveryEvent::PeerDisconnected { peer_id } => {
215 warn!("Disconnected from peer: {}", peer_id);
216
217 let mut topology = self.topology.write();
218 topology.remove_node(&peer_id);
219
220 self.ring_ready
222 .store(false, std::sync::atomic::Ordering::SeqCst);
223 }
224 DiscoveryEvent::PeerExpired { peer_id } => {
225 debug!("Peer expired: {}", peer_id);
226 }
227 DiscoveryEvent::Message { peer_id, data } => {
228 debug!("Message from {}: {} bytes", peer_id, data.len());
229 }
230 }
231 }
232
233 pub async fn establish_ring(&self) -> Result<()> {
237 let (local_rank, world_size, node_addrs, peer_ids) = {
239 let topology = self.topology.read();
240
241 if !topology.can_form_ring() {
242 return Err(DistributedError::Protocol(
243 "Not enough peers to form ring (need at least 2 nodes)".into(),
244 )
245 .into());
246 }
247
248 let ring_order = topology.ring_order();
249 let local_rank = topology.local_rank();
250 let world_size = ring_order.len();
251
252 let node_addrs: Vec<SocketAddr> = ring_order
254 .iter()
255 .filter_map(|n| n.socket_addr)
256 .map(|a| SocketAddr::new(a.ip(), self.config.gradient_port))
257 .collect();
258
259 let peer_ids: Vec<String> = ring_order.iter().map(|n| n.peer_id.to_base58()).collect();
261
262 (local_rank, world_size, node_addrs, peer_ids)
263 }; info!(
266 "Establishing ring: rank={}/{}, peers={:?}",
267 local_rank, world_size, peer_ids
268 );
269
270 if node_addrs.len() < 2 {
271 return Err(DistributedError::Protocol(
272 "Not enough peers with known addresses to form ring".into(),
273 )
274 .into());
275 }
276
277 let config = crate::config::DistributedConfig {
279 nodes: node_addrs,
280 rank: local_rank,
281 connection_timeout_ms: 30000,
282 max_retries: 50,
283 };
284
285 let (sender, receiver) = TcpTransport::connect(&config).await?;
287
288 *self.ring_connections.lock().await = Some((sender, receiver));
289 self.ring_ready
290 .store(true, std::sync::atomic::Ordering::SeqCst);
291
292 info!("Ring established successfully");
293 Ok(())
294 }
295
296 pub fn is_ring_ready(&self) -> bool {
298 self.ring_ready.load(std::sync::atomic::Ordering::SeqCst)
299 }
300}
301
302#[async_trait]
303impl DistributedBackend for AutoDiscoveryBackend {
304 fn rank(&self) -> usize {
305 self.topology.read().local_rank()
306 }
307
308 fn world_size(&self) -> usize {
309 self.topology.read().node_count()
310 }
311
312 async fn all_reduce(&self, buffer: &mut [u8]) -> Result<()> {
313 if !self.is_ring_ready() {
314 self.establish_ring().await?;
315 }
316
317 if !buffer.len().is_multiple_of(4) {
319 return Err(DistributedError::Protocol(format!(
320 "Buffer length {} is not a multiple of 4 (f32 size)",
321 buffer.len()
322 ))
323 .into());
324 }
325
326 if !(buffer.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()) {
327 return Err(DistributedError::Protocol(
328 "Buffer is not properly aligned for f32 operations".into(),
329 )
330 .into());
331 }
332
333 let floats: &mut [f32] = <[f32]>::mut_from_bytes(buffer)
334 .map_err(|e| DistributedError::Protocol(format!("Buffer cast failed: {e}")))?;
335 let len = floats.len();
336 let world_size = self.world_size();
337 let rank = self.rank();
338
339 if world_size < 2 {
340 return Ok(()); }
342
343 let chunk_size = len / world_size;
344 let remainder = len % world_size;
345
346 let get_chunk_range = |idx: usize| -> (usize, usize) {
347 let start = idx * chunk_size + idx.min(remainder);
348 let end = start + chunk_size + (if idx < remainder { 1 } else { 0 });
349 (start, end)
350 };
351
352 let mut connections = self.ring_connections.lock().await;
353 let (sender, receiver) = connections
354 .as_mut()
355 .ok_or_else(|| DistributedError::Protocol("Ring not established".into()))?;
356
357 for step in 0..(world_size - 1) {
359 let send_idx = (rank + world_size - step) % world_size;
360 let recv_idx = (rank + world_size - step - 1) % world_size;
361
362 let (send_start, send_end) = get_chunk_range(send_idx);
363 let (recv_start, recv_end) = get_chunk_range(recv_idx);
364
365 let recv_bytes_len = (recv_end - recv_start) * 4;
366
367 let send_buf = floats[send_start..send_end].as_bytes().to_vec();
369
370 let mut recv_buf = vec![0u8; recv_bytes_len];
372 tokio::try_join!(sender.send(&send_buf), receiver.recv(&mut recv_buf))?;
373
374 let recv_floats =
376 <[f32]>::ref_from_bytes(&recv_buf).expect("recv buffer aligned for f32");
377 for (i, &val) in recv_floats.iter().enumerate() {
378 floats[recv_start + i] += val;
379 }
380 }
381
382 for step in 0..(world_size - 1) {
384 let send_idx = (rank + world_size - step + 1) % world_size;
385 let recv_idx = (rank + world_size - step) % world_size;
386
387 let (send_start, send_end) = get_chunk_range(send_idx);
388 let (recv_start, recv_end) = get_chunk_range(recv_idx);
389
390 let recv_bytes_len = (recv_end - recv_start) * 4;
391
392 let send_buf: &[u8] = floats[send_start..send_end].as_bytes();
393
394 let mut recv_buf = vec![0u8; recv_bytes_len];
395 tokio::try_join!(sender.send(send_buf), receiver.recv(&mut recv_buf))?;
396
397 let recv_floats =
399 <[f32]>::ref_from_bytes(&recv_buf).expect("recv buffer aligned for f32");
400 floats[recv_start..recv_end].copy_from_slice(recv_floats);
401 }
402
403 Ok(())
404 }
405
406 async fn barrier(&self) -> Result<()> {
407 if !self.is_ring_ready() {
408 self.establish_ring().await?;
409 }
410
411 let world_size = self.world_size();
412 if world_size < 2 {
413 return Ok(());
414 }
415
416 let mut connections = self.ring_connections.lock().await;
417 let (sender, receiver) = connections
418 .as_mut()
419 .ok_or_else(|| DistributedError::Protocol("Ring not established".into()))?;
420
421 let token = [0u8; 4];
423
424 for _ in 0..(world_size - 1) {
425 let mut recv_buf = [0u8; 4];
426 tokio::try_join!(sender.send(&token), receiver.recv(&mut recv_buf))?;
427 }
428
429 Ok(())
430 }
431}
432
433impl std::fmt::Debug for AutoDiscoveryBackend {
434 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435 f.debug_struct("AutoDiscoveryBackend")
436 .field("peer_id", &self.identity.peer_id_string())
437 .field("peer_count", &self.peer_count())
438 .field("ring_ready", &self.is_ring_ready())
439 .finish()
440 }
441}