1use crate::error::{ClusterError, Result};
7use crate::metrics::WorkerMetrics;
8use crate::task_graph::ResourceRequirements;
9use chrono::Utc;
10use dashmap::DashMap;
11use parking_lot::RwLock;
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use uuid::Uuid;
17
18#[derive(Clone)]
20pub struct WorkerPool {
21 inner: Arc<WorkerPoolInner>,
22}
23
24struct WorkerPoolInner {
25 workers: DashMap<WorkerId, Arc<RwLock<Worker>>>,
27
28 cpu_workers: RwLock<HashSet<WorkerId>>,
30 gpu_workers: RwLock<HashSet<WorkerId>>,
31 storage_workers: RwLock<HashSet<WorkerId>>,
32
33 config: WorkerPoolConfig,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct WorkerPoolConfig {
40 pub heartbeat_timeout: Duration,
42
43 pub health_check_interval: Duration,
45
46 pub max_unhealthy_duration: Duration,
48
49 pub min_workers: usize,
51
52 pub max_workers: usize,
54}
55
56impl Default for WorkerPoolConfig {
57 fn default() -> Self {
58 Self {
59 heartbeat_timeout: Duration::from_secs(30),
60 health_check_interval: Duration::from_secs(10),
61 max_unhealthy_duration: Duration::from_secs(120),
62 min_workers: 1,
63 max_workers: 1000,
64 }
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
70pub struct WorkerId(pub Uuid);
71
72impl WorkerId {
73 pub fn new() -> Self {
75 Self(Uuid::new_v4())
76 }
77
78 pub fn from_uuid(uuid: Uuid) -> Self {
80 Self(uuid)
81 }
82}
83
84impl Default for WorkerId {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl std::fmt::Display for WorkerId {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 write!(f, "{}", self.0)
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct Worker {
99 pub id: WorkerId,
101
102 pub name: String,
104
105 pub address: String,
107
108 pub capabilities: WorkerCapabilities,
110
111 pub capacity: WorkerCapacity,
113
114 pub usage: WorkerUsage,
116
117 pub status: WorkerStatus,
119
120 pub last_heartbeat: Instant,
122
123 pub registered_at: Instant,
125
126 pub last_health_check: Option<Instant>,
128
129 pub health_check_failures: u32,
131
132 pub tasks_completed: u64,
134
135 pub tasks_failed: u64,
137
138 pub version: String,
140
141 pub metadata: HashMap<String, String>,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct WorkerCapabilities {
148 pub cpu: bool,
150
151 pub gpu: bool,
153
154 pub storage: bool,
156
157 pub task_types: Vec<String>,
159
160 pub data_formats: Vec<String>,
162}
163
164impl Default for WorkerCapabilities {
165 fn default() -> Self {
166 Self {
167 cpu: true,
168 gpu: false,
169 storage: false,
170 task_types: vec![],
171 data_formats: vec![],
172 }
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct WorkerCapacity {
179 pub cpu_cores: f64,
181
182 pub memory_bytes: u64,
184
185 pub storage_bytes: u64,
187
188 pub gpu_count: u32,
190
191 pub network_bandwidth: u64,
193}
194
195impl Default for WorkerCapacity {
196 fn default() -> Self {
197 Self {
198 cpu_cores: 1.0,
199 memory_bytes: 1024 * 1024 * 1024, storage_bytes: 0,
201 gpu_count: 0,
202 network_bandwidth: 0,
203 }
204 }
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct WorkerUsage {
210 pub cpu_cores: f64,
212
213 pub memory_bytes: u64,
215
216 pub storage_bytes: u64,
218
219 pub active_tasks: u32,
221
222 pub network_sent: u64,
224
225 pub network_received: u64,
227}
228
229impl Default for WorkerUsage {
230 fn default() -> Self {
231 Self {
232 cpu_cores: 0.0,
233 memory_bytes: 0,
234 storage_bytes: 0,
235 active_tasks: 0,
236 network_sent: 0,
237 network_received: 0,
238 }
239 }
240}
241
242#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
244pub enum WorkerStatus {
245 Active,
247
248 Idle,
250
251 Busy,
253
254 Unhealthy,
256
257 Draining,
259
260 Offline,
262}
263
264#[derive(Debug, Clone, Copy, PartialEq, Eq)]
266pub enum SelectionStrategy {
267 LeastLoaded,
269
270 MostAvailable,
272
273 RoundRobin,
275
276 Random,
278}
279
280impl WorkerPool {
281 pub fn new(config: WorkerPoolConfig) -> Self {
283 Self {
284 inner: Arc::new(WorkerPoolInner {
285 workers: DashMap::new(),
286 cpu_workers: RwLock::new(HashSet::new()),
287 gpu_workers: RwLock::new(HashSet::new()),
288 storage_workers: RwLock::new(HashSet::new()),
289 config,
290 }),
291 }
292 }
293
294 pub fn with_defaults() -> Self {
296 Self::new(WorkerPoolConfig::default())
297 }
298
299 pub fn register_worker(&self, worker: Worker) -> Result<WorkerId> {
301 let worker_id = worker.id;
302
303 if self.inner.workers.len() >= self.inner.config.max_workers {
305 return Err(ClusterError::CapacityExceeded(
306 "Worker pool at maximum capacity".to_string(),
307 ));
308 }
309
310 if worker.capabilities.cpu {
312 self.inner.cpu_workers.write().insert(worker_id);
313 }
314 if worker.capabilities.gpu {
315 self.inner.gpu_workers.write().insert(worker_id);
316 }
317 if worker.capabilities.storage {
318 self.inner.storage_workers.write().insert(worker_id);
319 }
320
321 self.inner
323 .workers
324 .insert(worker_id, Arc::new(RwLock::new(worker)));
325
326 Ok(worker_id)
327 }
328
329 pub fn unregister_worker(&self, worker_id: WorkerId) -> Result<()> {
331 self.inner.cpu_workers.write().remove(&worker_id);
333 self.inner.gpu_workers.write().remove(&worker_id);
334 self.inner.storage_workers.write().remove(&worker_id);
335
336 self.inner.workers.remove(&worker_id);
338
339 Ok(())
340 }
341
342 pub fn get_worker(&self, worker_id: WorkerId) -> Result<Arc<RwLock<Worker>>> {
344 self.inner
345 .workers
346 .get(&worker_id)
347 .map(|entry| Arc::clone(entry.value()))
348 .ok_or_else(|| ClusterError::WorkerNotFound(worker_id.to_string()))
349 }
350
351 pub fn get_all_workers(&self) -> Vec<Arc<RwLock<Worker>>> {
353 self.inner
354 .workers
355 .iter()
356 .map(|entry| Arc::clone(entry.value()))
357 .collect()
358 }
359
360 pub fn get_workers_by_status(&self, status: WorkerStatus) -> Vec<Arc<RwLock<Worker>>> {
362 self.inner
363 .workers
364 .iter()
365 .filter(|entry| entry.value().read().status == status)
366 .map(|entry| Arc::clone(entry.value()))
367 .collect()
368 }
369
370 pub fn heartbeat(&self, worker_id: WorkerId) -> Result<()> {
372 let worker = self.get_worker(worker_id)?;
373 let mut worker = worker.write();
374
375 worker.last_heartbeat = Instant::now();
376
377 if worker.status == WorkerStatus::Unhealthy {
379 worker.status = WorkerStatus::Active;
380 worker.health_check_failures = 0;
381 }
382
383 Ok(())
384 }
385
386 pub fn update_worker_usage(&self, worker_id: WorkerId, usage: WorkerUsage) -> Result<()> {
388 let worker = self.get_worker(worker_id)?;
389 let mut worker = worker.write();
390
391 let cpu_utilization = usage.cpu_cores / worker.capacity.cpu_cores;
393 let memory_utilization = usage.memory_bytes as f64 / worker.capacity.memory_bytes as f64;
394 let active_tasks = usage.active_tasks;
395
396 worker.usage = usage;
397
398 if active_tasks == 0 {
400 worker.status = WorkerStatus::Idle;
401 } else {
402 if cpu_utilization >= 0.9 || memory_utilization >= 0.9 {
403 worker.status = WorkerStatus::Busy;
404 } else {
405 worker.status = WorkerStatus::Active;
406 }
407 }
408
409 Ok(())
410 }
411
412 pub fn check_worker_health(&self, worker_id: WorkerId) -> Result<bool> {
414 let worker = self.get_worker(worker_id)?;
415 let mut worker = worker.write();
416
417 let now = Instant::now();
418 worker.last_health_check = Some(now);
419
420 let heartbeat_age = now.duration_since(worker.last_heartbeat);
422 if heartbeat_age > self.inner.config.heartbeat_timeout {
423 worker.health_check_failures += 1;
424 worker.status = WorkerStatus::Unhealthy;
425
426 if heartbeat_age > self.inner.config.max_unhealthy_duration {
428 worker.status = WorkerStatus::Offline;
429 return Ok(false);
430 }
431 } else {
432 worker.health_check_failures = 0;
433 if worker.status == WorkerStatus::Unhealthy {
434 worker.status = WorkerStatus::Active;
435 }
436 }
437
438 Ok(worker.status != WorkerStatus::Offline)
439 }
440
441 pub fn check_all_workers(&self) -> Result<Vec<WorkerId>> {
443 let mut failed_workers = Vec::new();
444
445 for entry in self.inner.workers.iter() {
446 let worker_id = *entry.key();
447 let is_healthy = self.check_worker_health(worker_id)?;
448
449 if !is_healthy {
450 failed_workers.push(worker_id);
451 }
452 }
453
454 for worker_id in &failed_workers {
456 self.unregister_worker(*worker_id)?;
457 }
458
459 Ok(failed_workers)
460 }
461
462 pub fn select_worker(
464 &self,
465 requirements: &ResourceRequirements,
466 strategy: SelectionStrategy,
467 ) -> Result<WorkerId> {
468 let candidates = self.get_candidate_workers(requirements)?;
470
471 if candidates.is_empty() {
472 return Err(ClusterError::WorkerPoolError(
473 "No available workers matching requirements".to_string(),
474 ));
475 }
476
477 let selected = match strategy {
479 SelectionStrategy::LeastLoaded => self.select_least_loaded(&candidates)?,
480 SelectionStrategy::MostAvailable => self.select_most_available(&candidates)?,
481 SelectionStrategy::RoundRobin => self.select_round_robin(&candidates)?,
482 SelectionStrategy::Random => self.select_random(&candidates)?,
483 };
484
485 Ok(selected)
486 }
487
488 fn get_candidate_workers(&self, requirements: &ResourceRequirements) -> Result<Vec<WorkerId>> {
490 let mut candidates = Vec::new();
491
492 let capability_workers = if requirements.gpu {
494 self.inner.gpu_workers.read().clone()
495 } else {
496 self.inner.cpu_workers.read().clone()
497 };
498
499 for worker_id in capability_workers {
500 if let Ok(worker) = self.get_worker(worker_id) {
501 let worker = worker.read();
502
503 if !matches!(worker.status, WorkerStatus::Active | WorkerStatus::Idle) {
505 continue;
506 }
507
508 let available_cpu = worker.capacity.cpu_cores - worker.usage.cpu_cores;
510 let available_memory = worker.capacity.memory_bytes - worker.usage.memory_bytes;
511
512 if available_cpu >= requirements.cpu_cores
513 && available_memory >= requirements.memory_bytes
514 {
515 candidates.push(worker_id);
516 }
517 }
518 }
519
520 Ok(candidates)
521 }
522
523 fn select_least_loaded(&self, candidates: &[WorkerId]) -> Result<WorkerId> {
525 candidates
526 .iter()
527 .min_by_key(|worker_id| {
528 self.get_worker(**worker_id)
529 .map(|w| w.read().usage.active_tasks)
530 .unwrap_or(u32::MAX)
531 })
532 .copied()
533 .ok_or_else(|| ClusterError::WorkerPoolError("No workers available".to_string()))
534 }
535
536 fn select_most_available(&self, candidates: &[WorkerId]) -> Result<WorkerId> {
538 candidates
539 .iter()
540 .max_by_key(|worker_id| {
541 self.get_worker(**worker_id)
542 .map(|w| {
543 let worker = w.read();
544 let available_cpu = worker.capacity.cpu_cores - worker.usage.cpu_cores;
545 let available_memory =
546 worker.capacity.memory_bytes - worker.usage.memory_bytes;
547 (available_cpu * 1000.0) as u64 + available_memory / 1_000_000
548 })
549 .unwrap_or(0)
550 })
551 .copied()
552 .ok_or_else(|| ClusterError::WorkerPoolError("No workers available".to_string()))
553 }
554
555 fn select_round_robin(&self, candidates: &[WorkerId]) -> Result<WorkerId> {
557 candidates
560 .first()
561 .copied()
562 .ok_or_else(|| ClusterError::WorkerPoolError("No workers available".to_string()))
563 }
564
565 fn select_random(&self, candidates: &[WorkerId]) -> Result<WorkerId> {
567 use std::collections::hash_map::RandomState;
568 use std::hash::BuildHasher;
569
570 let state = RandomState::new();
571 let index = (state.hash_one(Instant::now()) as usize) % candidates.len();
572
573 candidates
574 .get(index)
575 .copied()
576 .ok_or_else(|| ClusterError::WorkerPoolError("No workers available".to_string()))
577 }
578
579 pub fn get_worker_metrics(&self, worker_id: WorkerId) -> Result<WorkerMetrics> {
581 let worker = self.get_worker(worker_id)?;
582 let worker = worker.read();
583
584 let cpu_utilization = worker.usage.cpu_cores / worker.capacity.cpu_cores;
585 let memory_utilization =
586 worker.usage.memory_bytes as f64 / worker.capacity.memory_bytes as f64;
587
588 let uptime = worker.registered_at.elapsed();
589
590 Ok(WorkerMetrics {
591 worker_id: worker_id.to_string(),
592 tasks_completed: worker.tasks_completed,
593 tasks_failed: worker.tasks_failed,
594 cpu_utilization,
595 memory_utilization,
596 network_sent: worker.usage.network_sent,
597 network_received: worker.usage.network_received,
598 last_heartbeat: Utc::now(),
599 uptime,
600 })
601 }
602
603 pub fn get_statistics(&self) -> WorkerPoolStatistics {
605 let total_workers = self.inner.workers.len();
606 let mut status_counts = HashMap::new();
607
608 let mut total_capacity = WorkerCapacity::default();
609 let mut total_usage = WorkerUsage::default();
610
611 for entry in self.inner.workers.iter() {
612 let worker = entry.value().read();
613
614 *status_counts.entry(worker.status).or_insert(0) += 1;
615
616 total_capacity.cpu_cores += worker.capacity.cpu_cores;
617 total_capacity.memory_bytes += worker.capacity.memory_bytes;
618 total_capacity.storage_bytes += worker.capacity.storage_bytes;
619 total_capacity.gpu_count += worker.capacity.gpu_count;
620
621 total_usage.cpu_cores += worker.usage.cpu_cores;
622 total_usage.memory_bytes += worker.usage.memory_bytes;
623 total_usage.storage_bytes += worker.usage.storage_bytes;
624 total_usage.active_tasks += worker.usage.active_tasks;
625 }
626
627 WorkerPoolStatistics {
628 total_workers,
629 status_counts,
630 total_capacity,
631 total_usage,
632 cpu_workers: self.inner.cpu_workers.read().len(),
633 gpu_workers: self.inner.gpu_workers.read().len(),
634 storage_workers: self.inner.storage_workers.read().len(),
635 }
636 }
637
638 pub fn drain_worker(&self, worker_id: WorkerId) -> Result<()> {
640 let worker = self.get_worker(worker_id)?;
641 let mut worker = worker.write();
642
643 worker.status = WorkerStatus::Draining;
644
645 Ok(())
646 }
647
648 pub fn resume_worker(&self, worker_id: WorkerId) -> Result<()> {
650 let worker = self.get_worker(worker_id)?;
651 let mut worker = worker.write();
652
653 if worker.status == WorkerStatus::Draining {
654 worker.status = WorkerStatus::Active;
655 }
656
657 Ok(())
658 }
659
660 pub fn get_worker_count(&self) -> usize {
662 self.inner.workers.len()
663 }
664}
665
666#[derive(Debug, Clone, Serialize, Deserialize)]
668pub struct WorkerPoolStatistics {
669 pub total_workers: usize,
671
672 pub status_counts: HashMap<WorkerStatus, usize>,
674
675 pub total_capacity: WorkerCapacity,
677
678 pub total_usage: WorkerUsage,
680
681 pub cpu_workers: usize,
683
684 pub gpu_workers: usize,
686
687 pub storage_workers: usize,
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694
695 fn create_test_worker(name: &str) -> Worker {
696 Worker {
697 id: WorkerId::new(),
698 name: name.to_string(),
699 address: "localhost:8080".to_string(),
700 capabilities: WorkerCapabilities::default(),
701 capacity: WorkerCapacity::default(),
702 usage: WorkerUsage::default(),
703 status: WorkerStatus::Active,
704 last_heartbeat: Instant::now(),
705 registered_at: Instant::now(),
706 last_health_check: None,
707 health_check_failures: 0,
708 tasks_completed: 0,
709 tasks_failed: 0,
710 version: "1.0.0".to_string(),
711 metadata: HashMap::new(),
712 }
713 }
714
715 #[test]
716 fn test_worker_pool_creation() {
717 let pool = WorkerPool::with_defaults();
718 let stats = pool.get_statistics();
719 assert_eq!(stats.total_workers, 0);
720 }
721
722 #[test]
723 fn test_register_worker() {
724 let pool = WorkerPool::with_defaults();
725 let worker = create_test_worker("worker1");
726
727 let result = pool.register_worker(worker);
728 assert!(result.is_ok());
729
730 let stats = pool.get_statistics();
731 assert_eq!(stats.total_workers, 1);
732 }
733
734 #[test]
735 fn test_heartbeat() {
736 let pool = WorkerPool::with_defaults();
737 let worker = create_test_worker("worker1");
738 let worker_id = pool.register_worker(worker).ok().unwrap_or_default();
739
740 let result = pool.heartbeat(worker_id);
741 assert!(result.is_ok());
742 }
743
744 #[test]
745 fn test_worker_selection() {
746 let pool = WorkerPool::with_defaults();
747
748 let mut worker = create_test_worker("worker1");
749 worker.capacity.cpu_cores = 8.0;
750 worker.capacity.memory_bytes = 16_000_000_000;
751
752 pool.register_worker(worker).ok();
753
754 let requirements = ResourceRequirements {
755 cpu_cores: 2.0,
756 memory_bytes: 4_000_000_000,
757 gpu: false,
758 storage_bytes: 0,
759 };
760
761 let result = pool.select_worker(&requirements, SelectionStrategy::LeastLoaded);
762 assert!(result.is_ok());
763 }
764}