Skip to main content

oxigdal_cluster/
worker_pool.rs

1//! Worker pool management for the cluster.
2//!
3//! This module manages worker nodes including registration, heartbeat monitoring,
4//! capacity tracking, health checks, automatic failover, and worker pools by capability.
5
6use crate::error::{ClusterError, Result};
7use crate::metrics::WorkerMetrics;
8use crate::task_graph::ResourceRequirements;
9use chrono::Utc;
10use dashmap::DashMap;
11use parking_lot::RwLock;
12use serde::{Deserialize, Serialize};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16use uuid::Uuid;
17
18/// Worker pool manager.
19#[derive(Clone)]
20pub struct WorkerPool {
21    inner: Arc<WorkerPoolInner>,
22}
23
24struct WorkerPoolInner {
25    /// All registered workers
26    workers: DashMap<WorkerId, Arc<RwLock<Worker>>>,
27
28    /// Worker capabilities index
29    cpu_workers: RwLock<HashSet<WorkerId>>,
30    gpu_workers: RwLock<HashSet<WorkerId>>,
31    storage_workers: RwLock<HashSet<WorkerId>>,
32
33    /// Configuration
34    config: WorkerPoolConfig,
35}
36
37/// Worker pool configuration.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct WorkerPoolConfig {
40    /// Heartbeat timeout duration
41    pub heartbeat_timeout: Duration,
42
43    /// Health check interval
44    pub health_check_interval: Duration,
45
46    /// Maximum unhealthy duration before removal
47    pub max_unhealthy_duration: Duration,
48
49    /// Minimum workers required
50    pub min_workers: usize,
51
52    /// Maximum workers allowed
53    pub max_workers: usize,
54}
55
56impl Default for WorkerPoolConfig {
57    fn default() -> Self {
58        Self {
59            heartbeat_timeout: Duration::from_secs(30),
60            health_check_interval: Duration::from_secs(10),
61            max_unhealthy_duration: Duration::from_secs(120),
62            min_workers: 1,
63            max_workers: 1000,
64        }
65    }
66}
67
68/// Worker identifier.
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
70pub struct WorkerId(pub Uuid);
71
72impl WorkerId {
73    /// Create a new random worker ID.
74    pub fn new() -> Self {
75        Self(Uuid::new_v4())
76    }
77
78    /// Create from UUID.
79    pub fn from_uuid(uuid: Uuid) -> Self {
80        Self(uuid)
81    }
82}
83
84impl Default for WorkerId {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl std::fmt::Display for WorkerId {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        write!(f, "{}", self.0)
93    }
94}
95
96/// Worker node information.
97#[derive(Debug, Clone)]
98pub struct Worker {
99    /// Worker ID
100    pub id: WorkerId,
101
102    /// Worker name/hostname
103    pub name: String,
104
105    /// Network address
106    pub address: String,
107
108    /// Worker capabilities
109    pub capabilities: WorkerCapabilities,
110
111    /// Worker capacity
112    pub capacity: WorkerCapacity,
113
114    /// Current resource usage
115    pub usage: WorkerUsage,
116
117    /// Worker status
118    pub status: WorkerStatus,
119
120    /// Last heartbeat time
121    pub last_heartbeat: Instant,
122
123    /// Registration time
124    pub registered_at: Instant,
125
126    /// Last health check
127    pub last_health_check: Option<Instant>,
128
129    /// Health check failures
130    pub health_check_failures: u32,
131
132    /// Total tasks completed
133    pub tasks_completed: u64,
134
135    /// Total tasks failed
136    pub tasks_failed: u64,
137
138    /// Worker version
139    pub version: String,
140
141    /// Custom metadata
142    pub metadata: HashMap<String, String>,
143}
144
145/// Worker capabilities.
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct WorkerCapabilities {
148    /// Has CPU processing
149    pub cpu: bool,
150
151    /// Has GPU processing
152    pub gpu: bool,
153
154    /// Has large storage
155    pub storage: bool,
156
157    /// Supported task types
158    pub task_types: Vec<String>,
159
160    /// Supported data formats
161    pub data_formats: Vec<String>,
162}
163
164impl Default for WorkerCapabilities {
165    fn default() -> Self {
166        Self {
167            cpu: true,
168            gpu: false,
169            storage: false,
170            task_types: vec![],
171            data_formats: vec![],
172        }
173    }
174}
175
176/// Worker capacity (total resources).
177#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct WorkerCapacity {
179    /// Total CPU cores
180    pub cpu_cores: f64,
181
182    /// Total memory (bytes)
183    pub memory_bytes: u64,
184
185    /// Total storage (bytes)
186    pub storage_bytes: u64,
187
188    /// Number of GPUs
189    pub gpu_count: u32,
190
191    /// Network bandwidth (bytes/sec)
192    pub network_bandwidth: u64,
193}
194
195impl Default for WorkerCapacity {
196    fn default() -> Self {
197        Self {
198            cpu_cores: 1.0,
199            memory_bytes: 1024 * 1024 * 1024, // 1 GB
200            storage_bytes: 0,
201            gpu_count: 0,
202            network_bandwidth: 0,
203        }
204    }
205}
206
207/// Worker resource usage (current).
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct WorkerUsage {
210    /// Used CPU cores
211    pub cpu_cores: f64,
212
213    /// Used memory (bytes)
214    pub memory_bytes: u64,
215
216    /// Used storage (bytes)
217    pub storage_bytes: u64,
218
219    /// Active tasks
220    pub active_tasks: u32,
221
222    /// Network sent (bytes)
223    pub network_sent: u64,
224
225    /// Network received (bytes)
226    pub network_received: u64,
227}
228
229impl Default for WorkerUsage {
230    fn default() -> Self {
231        Self {
232            cpu_cores: 0.0,
233            memory_bytes: 0,
234            storage_bytes: 0,
235            active_tasks: 0,
236            network_sent: 0,
237            network_received: 0,
238        }
239    }
240}
241
242/// Worker status.
243#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
244pub enum WorkerStatus {
245    /// Worker is active and healthy
246    Active,
247
248    /// Worker is idle (no tasks)
249    Idle,
250
251    /// Worker is busy (at capacity)
252    Busy,
253
254    /// Worker is unhealthy
255    Unhealthy,
256
257    /// Worker is draining (no new tasks)
258    Draining,
259
260    /// Worker is offline
261    Offline,
262}
263
264/// Worker selection strategy.
265#[derive(Debug, Clone, Copy, PartialEq, Eq)]
266pub enum SelectionStrategy {
267    /// Select least loaded worker
268    LeastLoaded,
269
270    /// Select worker with most available resources
271    MostAvailable,
272
273    /// Round-robin selection
274    RoundRobin,
275
276    /// Random selection
277    Random,
278}
279
280impl WorkerPool {
281    /// Create a new worker pool.
282    pub fn new(config: WorkerPoolConfig) -> Self {
283        Self {
284            inner: Arc::new(WorkerPoolInner {
285                workers: DashMap::new(),
286                cpu_workers: RwLock::new(HashSet::new()),
287                gpu_workers: RwLock::new(HashSet::new()),
288                storage_workers: RwLock::new(HashSet::new()),
289                config,
290            }),
291        }
292    }
293
294    /// Create with default configuration.
295    pub fn with_defaults() -> Self {
296        Self::new(WorkerPoolConfig::default())
297    }
298
299    /// Register a new worker.
300    pub fn register_worker(&self, worker: Worker) -> Result<WorkerId> {
301        let worker_id = worker.id;
302
303        // Check if we're at capacity
304        if self.inner.workers.len() >= self.inner.config.max_workers {
305            return Err(ClusterError::CapacityExceeded(
306                "Worker pool at maximum capacity".to_string(),
307            ));
308        }
309
310        // Update capability indices
311        if worker.capabilities.cpu {
312            self.inner.cpu_workers.write().insert(worker_id);
313        }
314        if worker.capabilities.gpu {
315            self.inner.gpu_workers.write().insert(worker_id);
316        }
317        if worker.capabilities.storage {
318            self.inner.storage_workers.write().insert(worker_id);
319        }
320
321        // Store worker
322        self.inner
323            .workers
324            .insert(worker_id, Arc::new(RwLock::new(worker)));
325
326        Ok(worker_id)
327    }
328
329    /// Unregister a worker.
330    pub fn unregister_worker(&self, worker_id: WorkerId) -> Result<()> {
331        // Remove from capability indices
332        self.inner.cpu_workers.write().remove(&worker_id);
333        self.inner.gpu_workers.write().remove(&worker_id);
334        self.inner.storage_workers.write().remove(&worker_id);
335
336        // Remove worker
337        self.inner.workers.remove(&worker_id);
338
339        Ok(())
340    }
341
342    /// Get a worker by ID.
343    pub fn get_worker(&self, worker_id: WorkerId) -> Result<Arc<RwLock<Worker>>> {
344        self.inner
345            .workers
346            .get(&worker_id)
347            .map(|entry| Arc::clone(entry.value()))
348            .ok_or_else(|| ClusterError::WorkerNotFound(worker_id.to_string()))
349    }
350
351    /// Get all workers.
352    pub fn get_all_workers(&self) -> Vec<Arc<RwLock<Worker>>> {
353        self.inner
354            .workers
355            .iter()
356            .map(|entry| Arc::clone(entry.value()))
357            .collect()
358    }
359
360    /// Get workers by status.
361    pub fn get_workers_by_status(&self, status: WorkerStatus) -> Vec<Arc<RwLock<Worker>>> {
362        self.inner
363            .workers
364            .iter()
365            .filter(|entry| entry.value().read().status == status)
366            .map(|entry| Arc::clone(entry.value()))
367            .collect()
368    }
369
370    /// Update worker heartbeat.
371    pub fn heartbeat(&self, worker_id: WorkerId) -> Result<()> {
372        let worker = self.get_worker(worker_id)?;
373        let mut worker = worker.write();
374
375        worker.last_heartbeat = Instant::now();
376
377        // If worker was unhealthy, mark as active
378        if worker.status == WorkerStatus::Unhealthy {
379            worker.status = WorkerStatus::Active;
380            worker.health_check_failures = 0;
381        }
382
383        Ok(())
384    }
385
386    /// Update worker resource usage.
387    pub fn update_worker_usage(&self, worker_id: WorkerId, usage: WorkerUsage) -> Result<()> {
388        let worker = self.get_worker(worker_id)?;
389        let mut worker = worker.write();
390
391        // Calculate utilizationsbefore moving usage
392        let cpu_utilization = usage.cpu_cores / worker.capacity.cpu_cores;
393        let memory_utilization = usage.memory_bytes as f64 / worker.capacity.memory_bytes as f64;
394        let active_tasks = usage.active_tasks;
395
396        worker.usage = usage;
397
398        // Update status based on usage
399        if active_tasks == 0 {
400            worker.status = WorkerStatus::Idle;
401        } else {
402            if cpu_utilization >= 0.9 || memory_utilization >= 0.9 {
403                worker.status = WorkerStatus::Busy;
404            } else {
405                worker.status = WorkerStatus::Active;
406            }
407        }
408
409        Ok(())
410    }
411
412    /// Check worker health.
413    pub fn check_worker_health(&self, worker_id: WorkerId) -> Result<bool> {
414        let worker = self.get_worker(worker_id)?;
415        let mut worker = worker.write();
416
417        let now = Instant::now();
418        worker.last_health_check = Some(now);
419
420        // Check heartbeat timeout
421        let heartbeat_age = now.duration_since(worker.last_heartbeat);
422        if heartbeat_age > self.inner.config.heartbeat_timeout {
423            worker.health_check_failures += 1;
424            worker.status = WorkerStatus::Unhealthy;
425
426            // Check if worker should be removed
427            if heartbeat_age > self.inner.config.max_unhealthy_duration {
428                worker.status = WorkerStatus::Offline;
429                return Ok(false);
430            }
431        } else {
432            worker.health_check_failures = 0;
433            if worker.status == WorkerStatus::Unhealthy {
434                worker.status = WorkerStatus::Active;
435            }
436        }
437
438        Ok(worker.status != WorkerStatus::Offline)
439    }
440
441    /// Run health checks on all workers.
442    pub fn check_all_workers(&self) -> Result<Vec<WorkerId>> {
443        let mut failed_workers = Vec::new();
444
445        for entry in self.inner.workers.iter() {
446            let worker_id = *entry.key();
447            let is_healthy = self.check_worker_health(worker_id)?;
448
449            if !is_healthy {
450                failed_workers.push(worker_id);
451            }
452        }
453
454        // Remove offline workers
455        for worker_id in &failed_workers {
456            self.unregister_worker(*worker_id)?;
457        }
458
459        Ok(failed_workers)
460    }
461
462    /// Select a worker for task execution.
463    pub fn select_worker(
464        &self,
465        requirements: &ResourceRequirements,
466        strategy: SelectionStrategy,
467    ) -> Result<WorkerId> {
468        // Get candidate workers based on requirements
469        let candidates = self.get_candidate_workers(requirements)?;
470
471        if candidates.is_empty() {
472            return Err(ClusterError::WorkerPoolError(
473                "No available workers matching requirements".to_string(),
474            ));
475        }
476
477        // Select worker based on strategy
478        let selected = match strategy {
479            SelectionStrategy::LeastLoaded => self.select_least_loaded(&candidates)?,
480            SelectionStrategy::MostAvailable => self.select_most_available(&candidates)?,
481            SelectionStrategy::RoundRobin => self.select_round_robin(&candidates)?,
482            SelectionStrategy::Random => self.select_random(&candidates)?,
483        };
484
485        Ok(selected)
486    }
487
488    /// Get candidate workers matching requirements.
489    fn get_candidate_workers(&self, requirements: &ResourceRequirements) -> Result<Vec<WorkerId>> {
490        let mut candidates = Vec::new();
491
492        // Filter by capability
493        let capability_workers = if requirements.gpu {
494            self.inner.gpu_workers.read().clone()
495        } else {
496            self.inner.cpu_workers.read().clone()
497        };
498
499        for worker_id in capability_workers {
500            if let Ok(worker) = self.get_worker(worker_id) {
501                let worker = worker.read();
502
503                // Check status
504                if !matches!(worker.status, WorkerStatus::Active | WorkerStatus::Idle) {
505                    continue;
506                }
507
508                // Check resource availability
509                let available_cpu = worker.capacity.cpu_cores - worker.usage.cpu_cores;
510                let available_memory = worker.capacity.memory_bytes - worker.usage.memory_bytes;
511
512                if available_cpu >= requirements.cpu_cores
513                    && available_memory >= requirements.memory_bytes
514                {
515                    candidates.push(worker_id);
516                }
517            }
518        }
519
520        Ok(candidates)
521    }
522
523    /// Select least loaded worker.
524    fn select_least_loaded(&self, candidates: &[WorkerId]) -> Result<WorkerId> {
525        candidates
526            .iter()
527            .min_by_key(|worker_id| {
528                self.get_worker(**worker_id)
529                    .map(|w| w.read().usage.active_tasks)
530                    .unwrap_or(u32::MAX)
531            })
532            .copied()
533            .ok_or_else(|| ClusterError::WorkerPoolError("No workers available".to_string()))
534    }
535
536    /// Select worker with most available resources.
537    fn select_most_available(&self, candidates: &[WorkerId]) -> Result<WorkerId> {
538        candidates
539            .iter()
540            .max_by_key(|worker_id| {
541                self.get_worker(**worker_id)
542                    .map(|w| {
543                        let worker = w.read();
544                        let available_cpu = worker.capacity.cpu_cores - worker.usage.cpu_cores;
545                        let available_memory =
546                            worker.capacity.memory_bytes - worker.usage.memory_bytes;
547                        (available_cpu * 1000.0) as u64 + available_memory / 1_000_000
548                    })
549                    .unwrap_or(0)
550            })
551            .copied()
552            .ok_or_else(|| ClusterError::WorkerPoolError("No workers available".to_string()))
553    }
554
555    /// Select worker using round-robin.
556    fn select_round_robin(&self, candidates: &[WorkerId]) -> Result<WorkerId> {
557        // Simple round-robin: select first candidate
558        // In production, maintain a counter for true round-robin
559        candidates
560            .first()
561            .copied()
562            .ok_or_else(|| ClusterError::WorkerPoolError("No workers available".to_string()))
563    }
564
565    /// Select random worker.
566    fn select_random(&self, candidates: &[WorkerId]) -> Result<WorkerId> {
567        use std::collections::hash_map::RandomState;
568        use std::hash::BuildHasher;
569
570        let state = RandomState::new();
571        let index = (state.hash_one(Instant::now()) as usize) % candidates.len();
572
573        candidates
574            .get(index)
575            .copied()
576            .ok_or_else(|| ClusterError::WorkerPoolError("No workers available".to_string()))
577    }
578
579    /// Get worker metrics.
580    pub fn get_worker_metrics(&self, worker_id: WorkerId) -> Result<WorkerMetrics> {
581        let worker = self.get_worker(worker_id)?;
582        let worker = worker.read();
583
584        let cpu_utilization = worker.usage.cpu_cores / worker.capacity.cpu_cores;
585        let memory_utilization =
586            worker.usage.memory_bytes as f64 / worker.capacity.memory_bytes as f64;
587
588        let uptime = worker.registered_at.elapsed();
589
590        Ok(WorkerMetrics {
591            worker_id: worker_id.to_string(),
592            tasks_completed: worker.tasks_completed,
593            tasks_failed: worker.tasks_failed,
594            cpu_utilization,
595            memory_utilization,
596            network_sent: worker.usage.network_sent,
597            network_received: worker.usage.network_received,
598            last_heartbeat: Utc::now(),
599            uptime,
600        })
601    }
602
603    /// Get pool statistics.
604    pub fn get_statistics(&self) -> WorkerPoolStatistics {
605        let total_workers = self.inner.workers.len();
606        let mut status_counts = HashMap::new();
607
608        let mut total_capacity = WorkerCapacity::default();
609        let mut total_usage = WorkerUsage::default();
610
611        for entry in self.inner.workers.iter() {
612            let worker = entry.value().read();
613
614            *status_counts.entry(worker.status).or_insert(0) += 1;
615
616            total_capacity.cpu_cores += worker.capacity.cpu_cores;
617            total_capacity.memory_bytes += worker.capacity.memory_bytes;
618            total_capacity.storage_bytes += worker.capacity.storage_bytes;
619            total_capacity.gpu_count += worker.capacity.gpu_count;
620
621            total_usage.cpu_cores += worker.usage.cpu_cores;
622            total_usage.memory_bytes += worker.usage.memory_bytes;
623            total_usage.storage_bytes += worker.usage.storage_bytes;
624            total_usage.active_tasks += worker.usage.active_tasks;
625        }
626
627        WorkerPoolStatistics {
628            total_workers,
629            status_counts,
630            total_capacity,
631            total_usage,
632            cpu_workers: self.inner.cpu_workers.read().len(),
633            gpu_workers: self.inner.gpu_workers.read().len(),
634            storage_workers: self.inner.storage_workers.read().len(),
635        }
636    }
637
638    /// Drain a worker (no new tasks).
639    pub fn drain_worker(&self, worker_id: WorkerId) -> Result<()> {
640        let worker = self.get_worker(worker_id)?;
641        let mut worker = worker.write();
642
643        worker.status = WorkerStatus::Draining;
644
645        Ok(())
646    }
647
648    /// Resume a drained worker.
649    pub fn resume_worker(&self, worker_id: WorkerId) -> Result<()> {
650        let worker = self.get_worker(worker_id)?;
651        let mut worker = worker.write();
652
653        if worker.status == WorkerStatus::Draining {
654            worker.status = WorkerStatus::Active;
655        }
656
657        Ok(())
658    }
659
660    /// Get the current number of workers in the pool.
661    pub fn get_worker_count(&self) -> usize {
662        self.inner.workers.len()
663    }
664}
665
666/// Worker pool statistics.
667#[derive(Debug, Clone, Serialize, Deserialize)]
668pub struct WorkerPoolStatistics {
669    /// Total number of workers
670    pub total_workers: usize,
671
672    /// Worker counts by status
673    pub status_counts: HashMap<WorkerStatus, usize>,
674
675    /// Total capacity across all workers
676    pub total_capacity: WorkerCapacity,
677
678    /// Total usage across all workers
679    pub total_usage: WorkerUsage,
680
681    /// Number of CPU workers
682    pub cpu_workers: usize,
683
684    /// Number of GPU workers
685    pub gpu_workers: usize,
686
687    /// Number of storage workers
688    pub storage_workers: usize,
689}
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694
695    fn create_test_worker(name: &str) -> Worker {
696        Worker {
697            id: WorkerId::new(),
698            name: name.to_string(),
699            address: "localhost:8080".to_string(),
700            capabilities: WorkerCapabilities::default(),
701            capacity: WorkerCapacity::default(),
702            usage: WorkerUsage::default(),
703            status: WorkerStatus::Active,
704            last_heartbeat: Instant::now(),
705            registered_at: Instant::now(),
706            last_health_check: None,
707            health_check_failures: 0,
708            tasks_completed: 0,
709            tasks_failed: 0,
710            version: "1.0.0".to_string(),
711            metadata: HashMap::new(),
712        }
713    }
714
715    #[test]
716    fn test_worker_pool_creation() {
717        let pool = WorkerPool::with_defaults();
718        let stats = pool.get_statistics();
719        assert_eq!(stats.total_workers, 0);
720    }
721
722    #[test]
723    fn test_register_worker() {
724        let pool = WorkerPool::with_defaults();
725        let worker = create_test_worker("worker1");
726
727        let result = pool.register_worker(worker);
728        assert!(result.is_ok());
729
730        let stats = pool.get_statistics();
731        assert_eq!(stats.total_workers, 1);
732    }
733
734    #[test]
735    fn test_heartbeat() {
736        let pool = WorkerPool::with_defaults();
737        let worker = create_test_worker("worker1");
738        let worker_id = pool.register_worker(worker).ok().unwrap_or_default();
739
740        let result = pool.heartbeat(worker_id);
741        assert!(result.is_ok());
742    }
743
744    #[test]
745    fn test_worker_selection() {
746        let pool = WorkerPool::with_defaults();
747
748        let mut worker = create_test_worker("worker1");
749        worker.capacity.cpu_cores = 8.0;
750        worker.capacity.memory_bytes = 16_000_000_000;
751
752        pool.register_worker(worker).ok();
753
754        let requirements = ResourceRequirements {
755            cpu_cores: 2.0,
756            memory_bytes: 4_000_000_000,
757            gpu: false,
758            storage_bytes: 0,
759        };
760
761        let result = pool.select_worker(&requirements, SelectionStrategy::LeastLoaded);
762        assert!(result.is_ok());
763    }
764}