1use crate::error::{ClusterError, Result};
12use crate::worker_pool::WorkerId;
13use dashmap::DashMap;
14use parking_lot::RwLock;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::net::IpAddr;
18use std::sync::Arc;
19use std::time::Instant;
20use tracing::warn;
21
22pub struct TopologyManager {
24 worker_locations: Arc<DashMap<WorkerId, Location>>,
26 topology: Arc<RwLock<TopologyTree>>,
28 #[allow(dead_code)]
30 bandwidth_matrix: Arc<RwLock<HashMap<(LocationId, LocationId), f64>>>,
31 stats: Arc<RwLock<TopologyStats>>,
33}
34
35pub type LocationId = String;
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct Location {
41 pub datacenter: String,
43 pub rack: String,
45 pub host: String,
47 pub ip_address: Option<IpAddr>,
49}
50
51#[derive(Debug, Clone, Default)]
53pub struct TopologyTree {
54 pub datacenters: HashMap<String, Datacenter>,
56}
57
58#[derive(Debug, Clone)]
60pub struct Datacenter {
61 pub id: String,
63 pub racks: HashMap<String, Rack>,
65}
66
67#[derive(Debug, Clone)]
69pub struct Rack {
70 pub id: String,
72 pub workers: Vec<WorkerId>,
74}
75
76#[derive(Debug, Clone, Default, Serialize, Deserialize)]
78pub struct TopologyStats {
79 pub total_datacenters: usize,
81 pub total_racks: usize,
83 pub total_workers: usize,
85 pub cross_rack_transfers: u64,
87 pub cross_datacenter_transfers: u64,
89 pub same_rack_transfers: u64,
91}
92
93impl TopologyManager {
94 pub fn new() -> Self {
96 Self {
97 worker_locations: Arc::new(DashMap::new()),
98 topology: Arc::new(RwLock::new(TopologyTree::default())),
99 bandwidth_matrix: Arc::new(RwLock::new(HashMap::new())),
100 stats: Arc::new(RwLock::new(TopologyStats::default())),
101 }
102 }
103
104 pub fn register_worker(&self, worker_id: WorkerId, location: Location) -> Result<()> {
106 self.worker_locations.insert(worker_id, location.clone());
107
108 {
109 let mut topology = self.topology.write();
110 let datacenter = topology
111 .datacenters
112 .entry(location.datacenter.clone())
113 .or_insert_with(|| Datacenter {
114 id: location.datacenter.clone(),
115 racks: HashMap::new(),
116 });
117
118 let rack = datacenter
119 .racks
120 .entry(location.rack.clone())
121 .or_insert_with(|| Rack {
122 id: location.rack.clone(),
123 workers: Vec::new(),
124 });
125
126 if !rack.workers.contains(&worker_id) {
127 rack.workers.push(worker_id);
128 }
129 } self.update_topology_stats();
132
133 Ok(())
134 }
135
136 pub fn calculate_distance(&self, worker1: &WorkerId, worker2: &WorkerId) -> NetworkDistance {
138 let loc1 = self.worker_locations.get(worker1);
139 let loc2 = self.worker_locations.get(worker2);
140
141 match (loc1, loc2) {
142 (Some(l1), Some(l2)) => {
143 if l1.datacenter != l2.datacenter {
144 NetworkDistance::CrossDatacenter
145 } else if l1.rack != l2.rack {
146 NetworkDistance::CrossRack
147 } else if l1.host != l2.host {
148 NetworkDistance::SameRack
149 } else {
150 NetworkDistance::SameHost
151 }
152 }
153 _ => NetworkDistance::Unknown,
154 }
155 }
156
157 pub fn get_same_rack_workers(&self, worker_id: &WorkerId) -> Vec<WorkerId> {
159 let location = match self.worker_locations.get(worker_id) {
160 Some(loc) => loc.clone(),
161 None => return Vec::new(),
162 };
163
164 self.worker_locations
165 .iter()
166 .filter(|entry| {
167 let loc = entry.value();
168 loc.datacenter == location.datacenter && loc.rack == location.rack
169 })
170 .map(|entry| *entry.key())
171 .collect()
172 }
173
174 pub fn get_same_datacenter_workers(&self, worker_id: &WorkerId) -> Vec<WorkerId> {
176 let location = match self.worker_locations.get(worker_id) {
177 Some(loc) => loc.clone(),
178 None => return Vec::new(),
179 };
180
181 self.worker_locations
182 .iter()
183 .filter(|entry| entry.value().datacenter == location.datacenter)
184 .map(|entry| *entry.key())
185 .collect()
186 }
187
188 pub fn record_transfer(&self, from: &WorkerId, to: &WorkerId) {
190 let distance = self.calculate_distance(from, to);
191 let mut stats = self.stats.write();
192
193 match distance {
194 NetworkDistance::SameHost | NetworkDistance::SameRack => {
195 stats.same_rack_transfers += 1;
196 }
197 NetworkDistance::CrossRack => {
198 stats.cross_rack_transfers += 1;
199 }
200 NetworkDistance::CrossDatacenter => {
201 stats.cross_datacenter_transfers += 1;
202 }
203 NetworkDistance::Unknown => {}
204 }
205 }
206
207 fn update_topology_stats(&self) {
208 let topology = self.topology.read();
209 let mut stats = self.stats.write();
210
211 stats.total_datacenters = topology.datacenters.len();
212 stats.total_racks = topology.datacenters.values().map(|dc| dc.racks.len()).sum();
213 stats.total_workers = self.worker_locations.len();
214 }
215
216 pub fn get_stats(&self) -> TopologyStats {
218 self.stats.read().clone()
219 }
220}
221
222impl Default for TopologyManager {
223 fn default() -> Self {
224 Self::new()
225 }
226}
227
228#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
230pub enum NetworkDistance {
231 SameHost = 0,
233 SameRack = 1,
235 CrossRack = 2,
237 CrossDatacenter = 3,
239 Unknown = 4,
241}
242
243pub struct BandwidthTracker {
245 worker_bandwidth: Arc<DashMap<WorkerId, RwLock<BandwidthUsage>>>,
247 link_bandwidth: Arc<DashMap<(WorkerId, WorkerId), RwLock<BandwidthUsage>>>,
249 limits: Arc<RwLock<BandwidthLimits>>,
251 stats: Arc<RwLock<BandwidthStats>>,
253}
254
255#[derive(Debug, Clone)]
257pub struct BandwidthUsage {
258 pub bytes_sent: u64,
260 pub bytes_received: u64,
262 pub start_time: Instant,
264 pub last_update: Instant,
266}
267
268impl Default for BandwidthUsage {
269 fn default() -> Self {
270 let now = Instant::now();
271 Self {
272 bytes_sent: 0,
273 bytes_received: 0,
274 start_time: now,
275 last_update: now,
276 }
277 }
278}
279
280impl BandwidthUsage {
281 pub fn current_bandwidth_mbps(&self) -> f64 {
283 let elapsed = self
284 .last_update
285 .duration_since(self.start_time)
286 .as_secs_f64();
287 if elapsed > 0.0 {
288 let total_bytes = self.bytes_sent + self.bytes_received;
289 (total_bytes as f64 / 1_048_576.0) / elapsed
290 } else {
291 0.0
292 }
293 }
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct BandwidthLimits {
299 pub worker_limit_mbps: f64,
301 pub link_limit_mbps: f64,
303 pub global_limit_mbps: f64,
305}
306
307impl Default for BandwidthLimits {
308 fn default() -> Self {
309 Self {
310 worker_limit_mbps: 1000.0, link_limit_mbps: 1000.0,
312 global_limit_mbps: 10000.0, }
314 }
315}
316
317#[derive(Debug, Clone, Default, Serialize, Deserialize)]
319pub struct BandwidthStats {
320 pub total_bytes_transferred: u64,
322 pub peak_bandwidth_mbps: f64,
324 pub average_bandwidth_mbps: f64,
326 pub bandwidth_limit_violations: u64,
328}
329
330impl BandwidthTracker {
331 pub fn new(limits: BandwidthLimits) -> Self {
333 Self {
334 worker_bandwidth: Arc::new(DashMap::new()),
335 link_bandwidth: Arc::new(DashMap::new()),
336 limits: Arc::new(RwLock::new(limits)),
337 stats: Arc::new(RwLock::new(BandwidthStats::default())),
338 }
339 }
340
341 pub fn record_transfer(&self, from: WorkerId, to: WorkerId, bytes: u64) -> Result<()> {
343 let now = Instant::now();
344
345 self.update_worker_usage(from, bytes, 0, now);
347
348 self.update_worker_usage(to, 0, bytes, now);
350
351 self.update_link_usage(from, to, bytes, now);
353
354 self.update_global_stats(bytes);
356
357 self.check_limits(from, to)?;
359
360 Ok(())
361 }
362
363 fn update_worker_usage(&self, worker: WorkerId, sent: u64, received: u64, now: Instant) {
364 let entry = self.worker_bandwidth.entry(worker).or_insert_with(|| {
365 RwLock::new(BandwidthUsage {
366 start_time: now,
367 last_update: now,
368 ..Default::default()
369 })
370 });
371
372 let mut usage = entry.write();
373 usage.bytes_sent += sent;
374 usage.bytes_received += received;
375 usage.last_update = now;
376 }
377
378 fn update_link_usage(&self, from: WorkerId, to: WorkerId, bytes: u64, now: Instant) {
379 let entry = self.link_bandwidth.entry((from, to)).or_insert_with(|| {
380 RwLock::new(BandwidthUsage {
381 start_time: now,
382 last_update: now,
383 ..Default::default()
384 })
385 });
386
387 let mut usage = entry.write();
388 usage.bytes_sent += bytes;
389 usage.last_update = now;
390 }
391
392 fn update_global_stats(&self, bytes: u64) {
393 let mut stats = self.stats.write();
394 stats.total_bytes_transferred += bytes;
395 }
396
397 fn check_limits(&self, from: WorkerId, _to: WorkerId) -> Result<()> {
398 let limits = self.limits.read();
399
400 if let Some(usage) = self.worker_bandwidth.get(&from) {
402 let mbps = usage.read().current_bandwidth_mbps();
403 if mbps > limits.worker_limit_mbps {
404 let mut stats = self.stats.write();
405 stats.bandwidth_limit_violations += 1;
406 warn!("Worker {} bandwidth limit exceeded: {} MB/s", from, mbps);
407 }
408 }
409
410 Ok(())
411 }
412
413 pub fn get_worker_bandwidth(&self, worker: &WorkerId) -> Option<f64> {
415 self.worker_bandwidth
416 .get(worker)
417 .map(|u| u.read().current_bandwidth_mbps())
418 }
419
420 pub fn get_stats(&self) -> BandwidthStats {
422 self.stats.read().clone()
423 }
424}
425
426pub struct CongestionController {
428 windows: Arc<DashMap<(WorkerId, WorkerId), RwLock<CongestionWindow>>>,
430 config: CongestionConfig,
432 stats: Arc<RwLock<CongestionStats>>,
434}
435
436#[derive(Debug, Clone)]
438pub struct CongestionWindow {
439 pub size: usize,
441 pub ssthresh: usize,
443 pub rtt_ms: f64,
445 pub last_update: Instant,
447}
448
449#[derive(Debug, Clone, Serialize, Deserialize)]
451pub struct CongestionConfig {
452 pub initial_window: usize,
454 pub max_window: usize,
456 pub min_rtt_ms: f64,
458}
459
460impl Default for CongestionConfig {
461 fn default() -> Self {
462 Self {
463 initial_window: 65536, max_window: 16777216, min_rtt_ms: 1.0,
466 }
467 }
468}
469
470#[derive(Debug, Clone, Default, Serialize, Deserialize)]
472pub struct CongestionStats {
473 pub total_congestion_events: u64,
475 pub total_backoffs: u64,
477 pub average_window_size: usize,
479}
480
481impl CongestionController {
482 pub fn new(config: CongestionConfig) -> Self {
484 Self {
485 windows: Arc::new(DashMap::new()),
486 config,
487 stats: Arc::new(RwLock::new(CongestionStats::default())),
488 }
489 }
490
491 pub fn report_success(&self, from: WorkerId, to: WorkerId, rtt_ms: f64) {
493 let now = Instant::now();
494
495 let entry = self.windows.entry((from, to)).or_insert_with(|| {
496 RwLock::new(CongestionWindow {
497 size: self.config.initial_window,
498 ssthresh: self.config.max_window / 2,
499 rtt_ms: self.config.min_rtt_ms,
500 last_update: now,
501 })
502 });
503
504 let mut window = entry.write();
505
506 window.rtt_ms = 0.875 * window.rtt_ms + 0.125 * rtt_ms;
508
509 if window.size < window.ssthresh {
511 window.size = (window.size * 2).min(self.config.max_window);
513 } else {
514 window.size = (window.size + 1024).min(self.config.max_window);
516 }
517
518 window.last_update = now;
519 }
520
521 pub fn report_congestion(&self, from: WorkerId, to: WorkerId) {
523 let entry = self.windows.entry((from, to)).or_insert_with(|| {
524 RwLock::new(CongestionWindow {
525 size: self.config.initial_window,
526 ssthresh: self.config.max_window / 2,
527 rtt_ms: self.config.min_rtt_ms,
528 last_update: Instant::now(),
529 })
530 });
531
532 let mut window = entry.write();
533
534 window.ssthresh = window.size / 2;
536 window.size = window.ssthresh;
537
538 let mut stats = self.stats.write();
539 stats.total_congestion_events += 1;
540 stats.total_backoffs += 1;
541 }
542
543 pub fn get_window_size(&self, from: &WorkerId, to: &WorkerId) -> usize {
545 self.windows
546 .get(&(*from, *to))
547 .map(|w| w.read().size)
548 .unwrap_or(self.config.initial_window)
549 }
550
551 pub fn get_stats(&self) -> CongestionStats {
553 self.stats.read().clone()
554 }
555}
556
557pub struct CompressionManager {
559 stats: Arc<DashMap<CompressionAlgorithm, RwLock<CompressionStats>>>,
561 default_algorithm: Arc<RwLock<CompressionAlgorithm>>,
563}
564
565#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
567pub enum CompressionAlgorithm {
568 None,
570 Zstd,
572 Lz4,
574 Snappy,
576}
577
578#[derive(Debug, Clone, Default, Serialize, Deserialize)]
580pub struct CompressionStats {
581 pub bytes_before: u64,
583 pub bytes_after: u64,
585 pub compression_ratio: f64,
587 pub compression_time_ms: f64,
589}
590
591impl CompressionManager {
592 pub fn new(default_algorithm: CompressionAlgorithm) -> Self {
594 Self {
595 stats: Arc::new(DashMap::new()),
596 default_algorithm: Arc::new(RwLock::new(default_algorithm)),
597 }
598 }
599
600 pub fn compress(
602 &self,
603 data: &[u8],
604 algorithm: Option<CompressionAlgorithm>,
605 ) -> Result<Vec<u8>> {
606 let algo = algorithm.unwrap_or(*self.default_algorithm.read());
607 let start = Instant::now();
608
609 let compressed = match algo {
610 CompressionAlgorithm::None => data.to_vec(),
611 CompressionAlgorithm::Zstd => oxiarc_zstd::compress_with_level(data, 3)
612 .map_err(|e| ClusterError::CompressionError(e.to_string()))?,
613 CompressionAlgorithm::Lz4 | CompressionAlgorithm::Snappy => {
614 data.to_vec()
616 }
617 };
618
619 let elapsed = start.elapsed().as_secs_f64() * 1000.0;
620
621 self.update_stats(algo, data.len(), compressed.len(), elapsed);
622
623 Ok(compressed)
624 }
625
626 fn update_stats(&self, algo: CompressionAlgorithm, before: usize, after: usize, time_ms: f64) {
627 let entry = self
628 .stats
629 .entry(algo)
630 .or_insert_with(|| RwLock::new(CompressionStats::default()));
631
632 let mut stats = entry.write();
633 stats.bytes_before += before as u64;
634 stats.bytes_after += after as u64;
635 stats.compression_ratio = stats.bytes_after as f64 / stats.bytes_before as f64;
636 stats.compression_time_ms += time_ms;
637 }
638
639 pub fn get_stats(&self, algorithm: CompressionAlgorithm) -> Option<CompressionStats> {
641 self.stats.get(&algorithm).map(|s| s.read().clone())
642 }
643}
644
645#[cfg(test)]
646#[allow(clippy::expect_used, clippy::unwrap_used)]
647mod tests {
648 use super::*;
649
650 #[test]
651 fn test_topology_manager() {
652 use std::time::{Duration, Instant};
653
654 let start = Instant::now();
655 let manager = TopologyManager::new();
656
657 let worker1 = WorkerId(uuid::Uuid::new_v4());
658 let worker2 = WorkerId(uuid::Uuid::new_v4());
659
660 let loc1 = Location {
661 datacenter: "dc1".to_string(),
662 rack: "rack1".to_string(),
663 host: "host1".to_string(),
664 ip_address: None,
665 };
666
667 let loc2 = Location {
668 datacenter: "dc1".to_string(),
669 rack: "rack2".to_string(),
670 host: "host2".to_string(),
671 ip_address: None,
672 };
673
674 manager
676 .register_worker(worker1, loc1)
677 .expect("Failed to register worker1");
678 assert!(
679 start.elapsed() < Duration::from_secs(1),
680 "Worker registration took too long: {:?}",
681 start.elapsed()
682 );
683
684 manager
685 .register_worker(worker2, loc2)
686 .expect("Failed to register worker2");
687 assert!(
688 start.elapsed() < Duration::from_secs(1),
689 "Worker registration took too long: {:?}",
690 start.elapsed()
691 );
692
693 let distance = manager.calculate_distance(&worker1, &worker2);
695 assert_eq!(distance, NetworkDistance::CrossRack);
696
697 let stats = manager.get_stats();
699 assert_eq!(stats.total_datacenters, 1, "Should have 1 datacenter");
700 assert_eq!(stats.total_racks, 2, "Should have 2 racks");
701 assert_eq!(stats.total_workers, 2, "Should have 2 workers");
702
703 assert!(
705 start.elapsed() < Duration::from_secs(5),
706 "Test took too long: {:?}",
707 start.elapsed()
708 );
709 }
710
711 #[test]
712 fn test_bandwidth_tracker() {
713 let limits = BandwidthLimits::default();
714 let tracker = BandwidthTracker::new(limits);
715
716 let worker1 = WorkerId(uuid::Uuid::new_v4());
717 let worker2 = WorkerId(uuid::Uuid::new_v4());
718
719 let _ = tracker.record_transfer(worker1, worker2, 1048576); let bandwidth = tracker.get_worker_bandwidth(&worker1);
722 assert!(bandwidth.is_some());
723 }
724
725 #[test]
726 fn test_congestion_controller() {
727 let config = CongestionConfig::default();
728 let controller = CongestionController::new(config);
729
730 let worker1 = WorkerId(uuid::Uuid::new_v4());
731 let worker2 = WorkerId(uuid::Uuid::new_v4());
732
733 let initial_window = controller.get_window_size(&worker1, &worker2);
734
735 controller.report_success(worker1, worker2, 10.0);
736
737 let new_window = controller.get_window_size(&worker1, &worker2);
738 assert!(new_window > initial_window);
739 }
740}