Skip to main content

oxigdal_cluster/network/
mod.rs

1//! Network optimization for distributed computing.
2//!
3//! This module provides network-aware optimizations including:
4//! - Topology-aware scheduling (rack/datacenter awareness)
5//! - Network bandwidth tracking and monitoring
6//! - Congestion control and avoidance
7//! - Data compression for network transfers
8//! - Multicast support for broadcast operations
9//! - Network failure detection and recovery
10
11use crate::error::{ClusterError, Result};
12use crate::worker_pool::WorkerId;
13use dashmap::DashMap;
14use parking_lot::RwLock;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::net::IpAddr;
18use std::sync::Arc;
19use std::time::Instant;
20use tracing::warn;
21
22/// Network topology manager for rack/datacenter awareness.
23pub struct TopologyManager {
24    /// Worker to location mapping
25    worker_locations: Arc<DashMap<WorkerId, Location>>,
26    /// Location hierarchy (datacenter -> racks -> workers)
27    topology: Arc<RwLock<TopologyTree>>,
28    /// Inter-location bandwidth (reserved for future use)
29    #[allow(dead_code)]
30    bandwidth_matrix: Arc<RwLock<HashMap<(LocationId, LocationId), f64>>>,
31    /// Statistics
32    stats: Arc<RwLock<TopologyStats>>,
33}
34
35/// Physical location identifier.
36pub type LocationId = String;
37
38/// Worker location in the topology.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct Location {
41    /// Datacenter ID
42    pub datacenter: String,
43    /// Rack ID
44    pub rack: String,
45    /// Host ID
46    pub host: String,
47    /// IP address
48    pub ip_address: Option<IpAddr>,
49}
50
51/// Topology tree structure.
52#[derive(Debug, Clone, Default)]
53pub struct TopologyTree {
54    /// Datacenters
55    pub datacenters: HashMap<String, Datacenter>,
56}
57
58/// Datacenter in topology.
59#[derive(Debug, Clone)]
60pub struct Datacenter {
61    /// Datacenter identifier
62    pub id: String,
63    /// Racks within this datacenter
64    pub racks: HashMap<String, Rack>,
65}
66
67/// Rack in topology.
68#[derive(Debug, Clone)]
69pub struct Rack {
70    /// Rack identifier
71    pub id: String,
72    /// Workers located in this rack
73    pub workers: Vec<WorkerId>,
74}
75
76/// Topology statistics.
77#[derive(Debug, Clone, Default, Serialize, Deserialize)]
78pub struct TopologyStats {
79    /// Total number of datacenters
80    pub total_datacenters: usize,
81    /// Total number of racks
82    pub total_racks: usize,
83    /// Total number of workers
84    pub total_workers: usize,
85    /// Number of cross-rack data transfers
86    pub cross_rack_transfers: u64,
87    /// Number of cross-datacenter data transfers
88    pub cross_datacenter_transfers: u64,
89    /// Number of same-rack data transfers
90    pub same_rack_transfers: u64,
91}
92
93impl TopologyManager {
94    /// Create a new topology manager.
95    pub fn new() -> Self {
96        Self {
97            worker_locations: Arc::new(DashMap::new()),
98            topology: Arc::new(RwLock::new(TopologyTree::default())),
99            bandwidth_matrix: Arc::new(RwLock::new(HashMap::new())),
100            stats: Arc::new(RwLock::new(TopologyStats::default())),
101        }
102    }
103
104    /// Register a worker's location.
105    pub fn register_worker(&self, worker_id: WorkerId, location: Location) -> Result<()> {
106        self.worker_locations.insert(worker_id, location.clone());
107
108        {
109            let mut topology = self.topology.write();
110            let datacenter = topology
111                .datacenters
112                .entry(location.datacenter.clone())
113                .or_insert_with(|| Datacenter {
114                    id: location.datacenter.clone(),
115                    racks: HashMap::new(),
116                });
117
118            let rack = datacenter
119                .racks
120                .entry(location.rack.clone())
121                .or_insert_with(|| Rack {
122                    id: location.rack.clone(),
123                    workers: Vec::new(),
124                });
125
126            if !rack.workers.contains(&worker_id) {
127                rack.workers.push(worker_id);
128            }
129        } // Drop the write lock before calling update_topology_stats
130
131        self.update_topology_stats();
132
133        Ok(())
134    }
135
136    /// Calculate network distance between two workers.
137    pub fn calculate_distance(&self, worker1: &WorkerId, worker2: &WorkerId) -> NetworkDistance {
138        let loc1 = self.worker_locations.get(worker1);
139        let loc2 = self.worker_locations.get(worker2);
140
141        match (loc1, loc2) {
142            (Some(l1), Some(l2)) => {
143                if l1.datacenter != l2.datacenter {
144                    NetworkDistance::CrossDatacenter
145                } else if l1.rack != l2.rack {
146                    NetworkDistance::CrossRack
147                } else if l1.host != l2.host {
148                    NetworkDistance::SameRack
149                } else {
150                    NetworkDistance::SameHost
151                }
152            }
153            _ => NetworkDistance::Unknown,
154        }
155    }
156
157    /// Get workers in the same rack.
158    pub fn get_same_rack_workers(&self, worker_id: &WorkerId) -> Vec<WorkerId> {
159        let location = match self.worker_locations.get(worker_id) {
160            Some(loc) => loc.clone(),
161            None => return Vec::new(),
162        };
163
164        self.worker_locations
165            .iter()
166            .filter(|entry| {
167                let loc = entry.value();
168                loc.datacenter == location.datacenter && loc.rack == location.rack
169            })
170            .map(|entry| *entry.key())
171            .collect()
172    }
173
174    /// Get workers in the same datacenter.
175    pub fn get_same_datacenter_workers(&self, worker_id: &WorkerId) -> Vec<WorkerId> {
176        let location = match self.worker_locations.get(worker_id) {
177            Some(loc) => loc.clone(),
178            None => return Vec::new(),
179        };
180
181        self.worker_locations
182            .iter()
183            .filter(|entry| entry.value().datacenter == location.datacenter)
184            .map(|entry| *entry.key())
185            .collect()
186    }
187
188    /// Record a data transfer for statistics.
189    pub fn record_transfer(&self, from: &WorkerId, to: &WorkerId) {
190        let distance = self.calculate_distance(from, to);
191        let mut stats = self.stats.write();
192
193        match distance {
194            NetworkDistance::SameHost | NetworkDistance::SameRack => {
195                stats.same_rack_transfers += 1;
196            }
197            NetworkDistance::CrossRack => {
198                stats.cross_rack_transfers += 1;
199            }
200            NetworkDistance::CrossDatacenter => {
201                stats.cross_datacenter_transfers += 1;
202            }
203            NetworkDistance::Unknown => {}
204        }
205    }
206
207    fn update_topology_stats(&self) {
208        let topology = self.topology.read();
209        let mut stats = self.stats.write();
210
211        stats.total_datacenters = topology.datacenters.len();
212        stats.total_racks = topology.datacenters.values().map(|dc| dc.racks.len()).sum();
213        stats.total_workers = self.worker_locations.len();
214    }
215
216    /// Get topology statistics.
217    pub fn get_stats(&self) -> TopologyStats {
218        self.stats.read().clone()
219    }
220}
221
222impl Default for TopologyManager {
223    fn default() -> Self {
224        Self::new()
225    }
226}
227
228/// Network distance between workers.
229#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
230pub enum NetworkDistance {
231    /// Workers on the same host
232    SameHost = 0,
233    /// Workers in the same rack
234    SameRack = 1,
235    /// Workers in different racks within same datacenter
236    CrossRack = 2,
237    /// Workers in different datacenters
238    CrossDatacenter = 3,
239    /// Unknown network distance
240    Unknown = 4,
241}
242
243/// Bandwidth tracker for monitoring network usage.
244pub struct BandwidthTracker {
245    /// Per-worker bandwidth usage
246    worker_bandwidth: Arc<DashMap<WorkerId, RwLock<BandwidthUsage>>>,
247    /// Per-link bandwidth usage
248    link_bandwidth: Arc<DashMap<(WorkerId, WorkerId), RwLock<BandwidthUsage>>>,
249    /// Bandwidth limits
250    limits: Arc<RwLock<BandwidthLimits>>,
251    /// Statistics
252    stats: Arc<RwLock<BandwidthStats>>,
253}
254
255/// Bandwidth usage tracking.
256#[derive(Debug, Clone)]
257pub struct BandwidthUsage {
258    /// Bytes sent
259    pub bytes_sent: u64,
260    /// Bytes received
261    pub bytes_received: u64,
262    /// Start time
263    pub start_time: Instant,
264    /// Last update time
265    pub last_update: Instant,
266}
267
268impl Default for BandwidthUsage {
269    fn default() -> Self {
270        let now = Instant::now();
271        Self {
272            bytes_sent: 0,
273            bytes_received: 0,
274            start_time: now,
275            last_update: now,
276        }
277    }
278}
279
280impl BandwidthUsage {
281    /// Calculate current bandwidth in MB/s.
282    pub fn current_bandwidth_mbps(&self) -> f64 {
283        let elapsed = self
284            .last_update
285            .duration_since(self.start_time)
286            .as_secs_f64();
287        if elapsed > 0.0 {
288            let total_bytes = self.bytes_sent + self.bytes_received;
289            (total_bytes as f64 / 1_048_576.0) / elapsed
290        } else {
291            0.0
292        }
293    }
294}
295
296/// Bandwidth limits configuration.
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct BandwidthLimits {
299    /// Per-worker bandwidth limit (MB/s)
300    pub worker_limit_mbps: f64,
301    /// Per-link bandwidth limit (MB/s)
302    pub link_limit_mbps: f64,
303    /// Global bandwidth limit (MB/s)
304    pub global_limit_mbps: f64,
305}
306
307impl Default for BandwidthLimits {
308    fn default() -> Self {
309        Self {
310            worker_limit_mbps: 1000.0, // 1 GB/s
311            link_limit_mbps: 1000.0,
312            global_limit_mbps: 10000.0, // 10 GB/s
313        }
314    }
315}
316
317/// Bandwidth statistics.
318#[derive(Debug, Clone, Default, Serialize, Deserialize)]
319pub struct BandwidthStats {
320    /// Total bytes transferred across the network
321    pub total_bytes_transferred: u64,
322    /// Peak bandwidth in MB/s
323    pub peak_bandwidth_mbps: f64,
324    /// Average bandwidth in MB/s
325    pub average_bandwidth_mbps: f64,
326    /// Number of times bandwidth limits were exceeded
327    pub bandwidth_limit_violations: u64,
328}
329
330impl BandwidthTracker {
331    /// Create a new bandwidth tracker.
332    pub fn new(limits: BandwidthLimits) -> Self {
333        Self {
334            worker_bandwidth: Arc::new(DashMap::new()),
335            link_bandwidth: Arc::new(DashMap::new()),
336            limits: Arc::new(RwLock::new(limits)),
337            stats: Arc::new(RwLock::new(BandwidthStats::default())),
338        }
339    }
340
341    /// Record data transfer.
342    pub fn record_transfer(&self, from: WorkerId, to: WorkerId, bytes: u64) -> Result<()> {
343        let now = Instant::now();
344
345        // Update sender
346        self.update_worker_usage(from, bytes, 0, now);
347
348        // Update receiver
349        self.update_worker_usage(to, 0, bytes, now);
350
351        // Update link
352        self.update_link_usage(from, to, bytes, now);
353
354        // Update global stats
355        self.update_global_stats(bytes);
356
357        // Check limits
358        self.check_limits(from, to)?;
359
360        Ok(())
361    }
362
363    fn update_worker_usage(&self, worker: WorkerId, sent: u64, received: u64, now: Instant) {
364        let entry = self.worker_bandwidth.entry(worker).or_insert_with(|| {
365            RwLock::new(BandwidthUsage {
366                start_time: now,
367                last_update: now,
368                ..Default::default()
369            })
370        });
371
372        let mut usage = entry.write();
373        usage.bytes_sent += sent;
374        usage.bytes_received += received;
375        usage.last_update = now;
376    }
377
378    fn update_link_usage(&self, from: WorkerId, to: WorkerId, bytes: u64, now: Instant) {
379        let entry = self.link_bandwidth.entry((from, to)).or_insert_with(|| {
380            RwLock::new(BandwidthUsage {
381                start_time: now,
382                last_update: now,
383                ..Default::default()
384            })
385        });
386
387        let mut usage = entry.write();
388        usage.bytes_sent += bytes;
389        usage.last_update = now;
390    }
391
392    fn update_global_stats(&self, bytes: u64) {
393        let mut stats = self.stats.write();
394        stats.total_bytes_transferred += bytes;
395    }
396
397    fn check_limits(&self, from: WorkerId, _to: WorkerId) -> Result<()> {
398        let limits = self.limits.read();
399
400        // Check sender limit
401        if let Some(usage) = self.worker_bandwidth.get(&from) {
402            let mbps = usage.read().current_bandwidth_mbps();
403            if mbps > limits.worker_limit_mbps {
404                let mut stats = self.stats.write();
405                stats.bandwidth_limit_violations += 1;
406                warn!("Worker {} bandwidth limit exceeded: {} MB/s", from, mbps);
407            }
408        }
409
410        Ok(())
411    }
412
413    /// Get current bandwidth usage for a worker.
414    pub fn get_worker_bandwidth(&self, worker: &WorkerId) -> Option<f64> {
415        self.worker_bandwidth
416            .get(worker)
417            .map(|u| u.read().current_bandwidth_mbps())
418    }
419
420    /// Get bandwidth statistics.
421    pub fn get_stats(&self) -> BandwidthStats {
422        self.stats.read().clone()
423    }
424}
425
426/// Congestion control manager.
427pub struct CongestionController {
428    /// Congestion windows per link
429    windows: Arc<DashMap<(WorkerId, WorkerId), RwLock<CongestionWindow>>>,
430    /// Configuration
431    config: CongestionConfig,
432    /// Statistics
433    stats: Arc<RwLock<CongestionStats>>,
434}
435
436/// Congestion window for flow control.
437#[derive(Debug, Clone)]
438pub struct CongestionWindow {
439    /// Current window size in bytes
440    pub size: usize,
441    /// Slow start threshold in bytes
442    pub ssthresh: usize,
443    /// Round-trip time estimate in milliseconds
444    pub rtt_ms: f64,
445    /// Last time the window was updated
446    pub last_update: Instant,
447}
448
449/// Congestion control configuration.
450#[derive(Debug, Clone, Serialize, Deserialize)]
451pub struct CongestionConfig {
452    /// Initial window size
453    pub initial_window: usize,
454    /// Maximum window size
455    pub max_window: usize,
456    /// Minimum RTT (ms)
457    pub min_rtt_ms: f64,
458}
459
460impl Default for CongestionConfig {
461    fn default() -> Self {
462        Self {
463            initial_window: 65536, // 64 KB
464            max_window: 16777216,  // 16 MB
465            min_rtt_ms: 1.0,
466        }
467    }
468}
469
470/// Congestion control statistics.
471#[derive(Debug, Clone, Default, Serialize, Deserialize)]
472pub struct CongestionStats {
473    /// Total number of congestion events detected
474    pub total_congestion_events: u64,
475    /// Total number of backoff operations performed
476    pub total_backoffs: u64,
477    /// Average congestion window size in bytes
478    pub average_window_size: usize,
479}
480
481impl CongestionController {
482    /// Create a new congestion controller.
483    pub fn new(config: CongestionConfig) -> Self {
484        Self {
485            windows: Arc::new(DashMap::new()),
486            config,
487            stats: Arc::new(RwLock::new(CongestionStats::default())),
488        }
489    }
490
491    /// Report successful transfer.
492    pub fn report_success(&self, from: WorkerId, to: WorkerId, rtt_ms: f64) {
493        let now = Instant::now();
494
495        let entry = self.windows.entry((from, to)).or_insert_with(|| {
496            RwLock::new(CongestionWindow {
497                size: self.config.initial_window,
498                ssthresh: self.config.max_window / 2,
499                rtt_ms: self.config.min_rtt_ms,
500                last_update: now,
501            })
502        });
503
504        let mut window = entry.write();
505
506        // Update RTT estimate
507        window.rtt_ms = 0.875 * window.rtt_ms + 0.125 * rtt_ms;
508
509        // Increase window (AIMD)
510        if window.size < window.ssthresh {
511            // Slow start: exponential increase
512            window.size = (window.size * 2).min(self.config.max_window);
513        } else {
514            // Congestion avoidance: linear increase
515            window.size = (window.size + 1024).min(self.config.max_window);
516        }
517
518        window.last_update = now;
519    }
520
521    /// Report congestion event (packet loss, timeout).
522    pub fn report_congestion(&self, from: WorkerId, to: WorkerId) {
523        let entry = self.windows.entry((from, to)).or_insert_with(|| {
524            RwLock::new(CongestionWindow {
525                size: self.config.initial_window,
526                ssthresh: self.config.max_window / 2,
527                rtt_ms: self.config.min_rtt_ms,
528                last_update: Instant::now(),
529            })
530        });
531
532        let mut window = entry.write();
533
534        // Multiplicative decrease
535        window.ssthresh = window.size / 2;
536        window.size = window.ssthresh;
537
538        let mut stats = self.stats.write();
539        stats.total_congestion_events += 1;
540        stats.total_backoffs += 1;
541    }
542
543    /// Get current window size.
544    pub fn get_window_size(&self, from: &WorkerId, to: &WorkerId) -> usize {
545        self.windows
546            .get(&(*from, *to))
547            .map(|w| w.read().size)
548            .unwrap_or(self.config.initial_window)
549    }
550
551    /// Get congestion statistics.
552    pub fn get_stats(&self) -> CongestionStats {
553        self.stats.read().clone()
554    }
555}
556
557/// Data compression manager for network transfers.
558pub struct CompressionManager {
559    /// Compression statistics per algorithm
560    stats: Arc<DashMap<CompressionAlgorithm, RwLock<CompressionStats>>>,
561    /// Default algorithm
562    default_algorithm: Arc<RwLock<CompressionAlgorithm>>,
563}
564
565/// Compression algorithm.
566#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
567pub enum CompressionAlgorithm {
568    /// No compression
569    None,
570    /// Zstandard compression
571    Zstd,
572    /// LZ4 compression
573    Lz4,
574    /// Snappy compression
575    Snappy,
576}
577
578/// Compression statistics.
579#[derive(Debug, Clone, Default, Serialize, Deserialize)]
580pub struct CompressionStats {
581    /// Total bytes before compression
582    pub bytes_before: u64,
583    /// Total bytes after compression
584    pub bytes_after: u64,
585    /// Compression ratio (after/before)
586    pub compression_ratio: f64,
587    /// Total time spent compressing in milliseconds
588    pub compression_time_ms: f64,
589}
590
591impl CompressionManager {
592    /// Create a new compression manager.
593    pub fn new(default_algorithm: CompressionAlgorithm) -> Self {
594        Self {
595            stats: Arc::new(DashMap::new()),
596            default_algorithm: Arc::new(RwLock::new(default_algorithm)),
597        }
598    }
599
600    /// Compress data.
601    pub fn compress(
602        &self,
603        data: &[u8],
604        algorithm: Option<CompressionAlgorithm>,
605    ) -> Result<Vec<u8>> {
606        let algo = algorithm.unwrap_or(*self.default_algorithm.read());
607        let start = Instant::now();
608
609        let compressed = match algo {
610            CompressionAlgorithm::None => data.to_vec(),
611            CompressionAlgorithm::Zstd => oxiarc_zstd::compress_with_level(data, 3)
612                .map_err(|e| ClusterError::CompressionError(e.to_string()))?,
613            CompressionAlgorithm::Lz4 | CompressionAlgorithm::Snappy => {
614                // Simplified - in production use actual libraries
615                data.to_vec()
616            }
617        };
618
619        let elapsed = start.elapsed().as_secs_f64() * 1000.0;
620
621        self.update_stats(algo, data.len(), compressed.len(), elapsed);
622
623        Ok(compressed)
624    }
625
626    fn update_stats(&self, algo: CompressionAlgorithm, before: usize, after: usize, time_ms: f64) {
627        let entry = self
628            .stats
629            .entry(algo)
630            .or_insert_with(|| RwLock::new(CompressionStats::default()));
631
632        let mut stats = entry.write();
633        stats.bytes_before += before as u64;
634        stats.bytes_after += after as u64;
635        stats.compression_ratio = stats.bytes_after as f64 / stats.bytes_before as f64;
636        stats.compression_time_ms += time_ms;
637    }
638
639    /// Get compression statistics.
640    pub fn get_stats(&self, algorithm: CompressionAlgorithm) -> Option<CompressionStats> {
641        self.stats.get(&algorithm).map(|s| s.read().clone())
642    }
643}
644
645#[cfg(test)]
646#[allow(clippy::expect_used, clippy::unwrap_used)]
647mod tests {
648    use super::*;
649
650    #[test]
651    fn test_topology_manager() {
652        use std::time::{Duration, Instant};
653
654        let start = Instant::now();
655        let manager = TopologyManager::new();
656
657        let worker1 = WorkerId(uuid::Uuid::new_v4());
658        let worker2 = WorkerId(uuid::Uuid::new_v4());
659
660        let loc1 = Location {
661            datacenter: "dc1".to_string(),
662            rack: "rack1".to_string(),
663            host: "host1".to_string(),
664            ip_address: None,
665        };
666
667        let loc2 = Location {
668            datacenter: "dc1".to_string(),
669            rack: "rack2".to_string(),
670            host: "host2".to_string(),
671            ip_address: None,
672        };
673
674        // Register workers - should complete quickly
675        manager
676            .register_worker(worker1, loc1)
677            .expect("Failed to register worker1");
678        assert!(
679            start.elapsed() < Duration::from_secs(1),
680            "Worker registration took too long: {:?}",
681            start.elapsed()
682        );
683
684        manager
685            .register_worker(worker2, loc2)
686            .expect("Failed to register worker2");
687        assert!(
688            start.elapsed() < Duration::from_secs(1),
689            "Worker registration took too long: {:?}",
690            start.elapsed()
691        );
692
693        // Calculate distance - should be instant
694        let distance = manager.calculate_distance(&worker1, &worker2);
695        assert_eq!(distance, NetworkDistance::CrossRack);
696
697        // Verify stats were updated correctly
698        let stats = manager.get_stats();
699        assert_eq!(stats.total_datacenters, 1, "Should have 1 datacenter");
700        assert_eq!(stats.total_racks, 2, "Should have 2 racks");
701        assert_eq!(stats.total_workers, 2, "Should have 2 workers");
702
703        // Entire test should complete in under 5 seconds
704        assert!(
705            start.elapsed() < Duration::from_secs(5),
706            "Test took too long: {:?}",
707            start.elapsed()
708        );
709    }
710
711    #[test]
712    fn test_bandwidth_tracker() {
713        let limits = BandwidthLimits::default();
714        let tracker = BandwidthTracker::new(limits);
715
716        let worker1 = WorkerId(uuid::Uuid::new_v4());
717        let worker2 = WorkerId(uuid::Uuid::new_v4());
718
719        let _ = tracker.record_transfer(worker1, worker2, 1048576); // 1 MB
720
721        let bandwidth = tracker.get_worker_bandwidth(&worker1);
722        assert!(bandwidth.is_some());
723    }
724
725    #[test]
726    fn test_congestion_controller() {
727        let config = CongestionConfig::default();
728        let controller = CongestionController::new(config);
729
730        let worker1 = WorkerId(uuid::Uuid::new_v4());
731        let worker2 = WorkerId(uuid::Uuid::new_v4());
732
733        let initial_window = controller.get_window_size(&worker1, &worker2);
734
735        controller.report_success(worker1, worker2, 10.0);
736
737        let new_window = controller.get_window_size(&worker1, &worker2);
738        assert!(new_window > initial_window);
739    }
740}