Skip to main content

oxigdal_distributed/
coordinator.rs

1//! Coordinator for managing distributed task execution.
2//!
3//! This module implements the coordinator that schedules tasks across worker nodes,
4//! monitors progress, and aggregates results.
5
6use crate::error::{DistributedError, Result};
7use crate::task::{PartitionId, Task, TaskId, TaskOperation, TaskResult, TaskScheduler};
8use crate::worker::WorkerStatus;
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use std::time::{Duration, Instant};
12use tokio::sync::mpsc;
13use tracing::{debug, error, info, warn};
14
15/// Coordinator configuration.
16#[derive(Debug, Clone)]
17pub struct CoordinatorConfig {
18    /// Listen address for Flight server.
19    pub listen_addr: String,
20    /// Maximum task retry attempts.
21    pub max_retries: u32,
22    /// Task timeout in seconds.
23    pub task_timeout_secs: u64,
24    /// Worker heartbeat timeout in seconds.
25    pub worker_timeout_secs: u64,
26    /// Result buffer size.
27    pub result_buffer_size: usize,
28}
29
30impl CoordinatorConfig {
31    /// Create a new coordinator configuration.
32    pub fn new(listen_addr: String) -> Self {
33        Self {
34            listen_addr,
35            max_retries: 3,
36            task_timeout_secs: 300, // 5 minutes
37            worker_timeout_secs: 60,
38            result_buffer_size: 1000,
39        }
40    }
41
42    /// Set the maximum retry attempts.
43    pub fn with_max_retries(mut self, retries: u32) -> Self {
44        self.max_retries = retries;
45        self
46    }
47
48    /// Set the task timeout.
49    pub fn with_task_timeout(mut self, timeout_secs: u64) -> Self {
50        self.task_timeout_secs = timeout_secs;
51        self
52    }
53}
54
55/// Information about a connected worker.
56#[derive(Debug, Clone)]
57pub struct WorkerInfo {
58    /// Worker identifier.
59    pub worker_id: String,
60    /// Worker address.
61    pub address: String,
62    /// Current status.
63    pub status: WorkerStatus,
64    /// Last heartbeat timestamp.
65    pub last_heartbeat: Instant,
66    /// Number of active tasks.
67    pub active_tasks: usize,
68    /// Total tasks completed.
69    pub completed_tasks: u64,
70    /// Total tasks failed.
71    pub failed_tasks: u64,
72}
73
74impl WorkerInfo {
75    /// Create new worker info.
76    pub fn new(worker_id: String, address: String) -> Self {
77        Self {
78            worker_id,
79            address,
80            status: WorkerStatus::Idle,
81            last_heartbeat: Instant::now(),
82            active_tasks: 0,
83            completed_tasks: 0,
84            failed_tasks: 0,
85        }
86    }
87
88    /// Update heartbeat timestamp.
89    pub fn update_heartbeat(&mut self) {
90        self.last_heartbeat = Instant::now();
91    }
92
93    /// Check if the worker has timed out.
94    pub fn is_timed_out(&self, timeout: Duration) -> bool {
95        self.last_heartbeat.elapsed() > timeout
96    }
97
98    /// Get the success rate.
99    pub fn success_rate(&self) -> f64 {
100        let total = self.completed_tasks + self.failed_tasks;
101        if total == 0 {
102            1.0
103        } else {
104            self.completed_tasks as f64 / total as f64
105        }
106    }
107}
108
109/// Coordinator for distributed task execution.
110pub struct Coordinator {
111    /// Coordinator configuration.
112    config: CoordinatorConfig,
113    /// Task scheduler.
114    scheduler: Arc<RwLock<TaskScheduler>>,
115    /// Connected workers.
116    workers: Arc<RwLock<HashMap<String, WorkerInfo>>>,
117    /// Task assignments (task_id -> worker_id).
118    assignments: Arc<RwLock<HashMap<TaskId, String>>>,
119    /// Task results.
120    results: Arc<RwLock<HashMap<TaskId, TaskResult>>>,
121    /// Task counter for generating unique IDs.
122    next_task_id: Arc<RwLock<u64>>,
123}
124
125impl Coordinator {
126    /// Create a new coordinator.
127    pub fn new(config: CoordinatorConfig) -> Self {
128        Self {
129            config,
130            scheduler: Arc::new(RwLock::new(TaskScheduler::new())),
131            workers: Arc::new(RwLock::new(HashMap::new())),
132            assignments: Arc::new(RwLock::new(HashMap::new())),
133            results: Arc::new(RwLock::new(HashMap::new())),
134            next_task_id: Arc::new(RwLock::new(0)),
135        }
136    }
137
138    /// Add a worker to the coordinator.
139    pub fn add_worker(&self, worker_id: String, address: String) -> Result<()> {
140        info!("Adding worker: {} at {}", worker_id, address);
141
142        let worker_info = WorkerInfo::new(worker_id.clone(), address);
143
144        let mut workers = self
145            .workers
146            .write()
147            .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
148
149        if workers.contains_key(&worker_id) {
150            return Err(DistributedError::coordinator(format!(
151                "Worker {} already exists",
152                worker_id
153            )));
154        }
155
156        workers.insert(worker_id, worker_info);
157        Ok(())
158    }
159
160    /// Remove a worker from the coordinator.
161    pub fn remove_worker(&self, worker_id: &str) -> Result<()> {
162        info!("Removing worker: {}", worker_id);
163
164        let mut workers = self
165            .workers
166            .write()
167            .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
168
169        workers.remove(worker_id);
170
171        // Reassign tasks from this worker
172        self.reassign_worker_tasks(worker_id)?;
173
174        Ok(())
175    }
176
177    /// Update worker heartbeat.
178    pub fn update_worker_heartbeat(&self, worker_id: &str) -> Result<()> {
179        let mut workers = self
180            .workers
181            .write()
182            .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
183
184        if let Some(worker) = workers.get_mut(worker_id) {
185            worker.update_heartbeat();
186            debug!("Updated heartbeat for worker {}", worker_id);
187            Ok(())
188        } else {
189            Err(DistributedError::coordinator(format!(
190                "Worker {} not found",
191                worker_id
192            )))
193        }
194    }
195
196    /// Check for timed-out workers and reassign their tasks.
197    pub fn check_worker_timeouts(&self) -> Result<Vec<String>> {
198        let timeout = Duration::from_secs(self.config.worker_timeout_secs);
199        let mut timed_out = Vec::new();
200
201        let workers = self
202            .workers
203            .read()
204            .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
205
206        for (worker_id, worker) in workers.iter() {
207            if worker.is_timed_out(timeout) {
208                warn!("Worker {} has timed out", worker_id);
209                timed_out.push(worker_id.clone());
210            }
211        }
212
213        drop(workers);
214
215        // Reassign tasks from timed-out workers
216        for worker_id in &timed_out {
217            self.reassign_worker_tasks(worker_id)?;
218            self.remove_worker(worker_id)?;
219        }
220
221        Ok(timed_out)
222    }
223
224    /// Submit a task for execution.
225    pub fn submit_task(
226        &self,
227        partition_id: PartitionId,
228        operation: TaskOperation,
229    ) -> Result<TaskId> {
230        let task_id = self.generate_task_id()?;
231        let mut task = Task::new(task_id, partition_id, operation);
232        task.max_retries = self.config.max_retries;
233
234        let mut scheduler = self
235            .scheduler
236            .write()
237            .map_err(|_| DistributedError::coordinator("Failed to acquire scheduler lock"))?;
238
239        scheduler.add_task(task);
240        debug!("Submitted task {}", task_id);
241
242        Ok(task_id)
243    }
244
245    /// Get the next task to execute.
246    pub fn next_task(&self) -> Result<Option<Task>> {
247        let mut scheduler = self
248            .scheduler
249            .write()
250            .map_err(|_| DistributedError::coordinator("Failed to acquire scheduler lock"))?;
251
252        Ok(scheduler.next_task())
253    }
254
255    /// Assign a task to a worker.
256    pub fn assign_task(&self, task: Task, worker_id: String) -> Result<()> {
257        // Mark task as running
258        let mut scheduler = self
259            .scheduler
260            .write()
261            .map_err(|_| DistributedError::coordinator("Failed to acquire scheduler lock"))?;
262        scheduler.mark_running(task.clone(), worker_id.clone());
263        drop(scheduler);
264
265        // Record assignment
266        let mut assignments = self
267            .assignments
268            .write()
269            .map_err(|_| DistributedError::coordinator("Failed to acquire assignments lock"))?;
270        assignments.insert(task.id, worker_id.clone());
271
272        // Update worker info
273        let mut workers = self
274            .workers
275            .write()
276            .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
277        if let Some(worker) = workers.get_mut(&worker_id) {
278            worker.active_tasks += 1;
279            worker.status = WorkerStatus::Busy;
280        }
281
282        info!("Assigned task {} to worker {}", task.id, worker_id);
283        Ok(())
284    }
285
286    /// Record task completion.
287    pub fn complete_task(&self, task_id: TaskId, result: TaskResult) -> Result<()> {
288        let worker_id = {
289            let assignments = self
290                .assignments
291                .read()
292                .map_err(|_| DistributedError::coordinator("Failed to acquire assignments lock"))?;
293            assignments.get(&task_id).cloned()
294        };
295
296        // Update scheduler
297        let mut scheduler = self
298            .scheduler
299            .write()
300            .map_err(|_| DistributedError::coordinator("Failed to acquire scheduler lock"))?;
301
302        if result.is_success() {
303            scheduler.mark_completed(task_id)?;
304        } else {
305            scheduler.mark_failed(task_id)?;
306        }
307        drop(scheduler);
308
309        // Update worker info
310        if let Some(worker_id) = worker_id {
311            let mut workers = self
312                .workers
313                .write()
314                .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
315
316            if let Some(worker) = workers.get_mut(&worker_id) {
317                if worker.active_tasks > 0 {
318                    worker.active_tasks -= 1;
319                }
320                if result.is_success() {
321                    worker.completed_tasks += 1;
322                } else {
323                    worker.failed_tasks += 1;
324                }
325                if worker.active_tasks == 0 {
326                    worker.status = WorkerStatus::Idle;
327                }
328            }
329        }
330
331        // Store result
332        let mut results = self
333            .results
334            .write()
335            .map_err(|_| DistributedError::coordinator("Failed to acquire results lock"))?;
336        results.insert(task_id, result);
337
338        info!("Task {} completed", task_id);
339        Ok(())
340    }
341
342    /// Get the best available worker for a task.
343    pub fn get_available_worker(&self) -> Result<Option<String>> {
344        let workers = self
345            .workers
346            .read()
347            .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
348
349        // Find idle worker with best success rate
350        let best_worker = workers
351            .values()
352            .filter(|w| w.status == WorkerStatus::Idle)
353            .max_by(|a, b| {
354                a.success_rate()
355                    .partial_cmp(&b.success_rate())
356                    .unwrap_or(std::cmp::Ordering::Equal)
357            })
358            .map(|w| w.worker_id.clone());
359
360        Ok(best_worker)
361    }
362
363    /// Get execution progress.
364    pub fn get_progress(&self) -> Result<CoordinatorProgress> {
365        let scheduler = self
366            .scheduler
367            .read()
368            .map_err(|_| DistributedError::coordinator("Failed to acquire scheduler lock"))?;
369
370        let workers = self
371            .workers
372            .read()
373            .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
374
375        Ok(CoordinatorProgress {
376            pending_tasks: scheduler.pending_count(),
377            running_tasks: scheduler.running_count(),
378            completed_tasks: scheduler.completed_count(),
379            failed_tasks: scheduler.failed_count(),
380            active_workers: workers.len(),
381            idle_workers: workers
382                .values()
383                .filter(|w| w.status == WorkerStatus::Idle)
384                .count(),
385        })
386    }
387
388    /// Collect all task results.
389    pub fn collect_results(&self) -> Result<Vec<TaskResult>> {
390        let results = self
391            .results
392            .read()
393            .map_err(|_| DistributedError::coordinator("Failed to acquire results lock"))?;
394
395        Ok(results.values().cloned().collect())
396    }
397
398    /// Check if all tasks are complete.
399    pub fn is_complete(&self) -> bool {
400        self.scheduler
401            .read()
402            .map(|s| s.is_complete())
403            .unwrap_or(false)
404    }
405
406    /// Generate a unique task ID.
407    fn generate_task_id(&self) -> Result<TaskId> {
408        let mut next_id = self
409            .next_task_id
410            .write()
411            .map_err(|_| DistributedError::coordinator("Failed to acquire task ID lock"))?;
412        let id = *next_id;
413        *next_id += 1;
414        Ok(TaskId(id))
415    }
416
417    /// Reassign tasks from a specific worker.
418    fn reassign_worker_tasks(&self, worker_id: &str) -> Result<()> {
419        let mut scheduler = self
420            .scheduler
421            .write()
422            .map_err(|_| DistributedError::coordinator("Failed to acquire scheduler lock"))?;
423
424        let mut assignments = self
425            .assignments
426            .write()
427            .map_err(|_| DistributedError::coordinator("Failed to acquire assignments lock"))?;
428
429        // Find tasks assigned to this worker
430        let task_ids: Vec<TaskId> = assignments
431            .iter()
432            .filter(|(_, wid)| *wid == worker_id)
433            .map(|(tid, _)| *tid)
434            .collect();
435
436        // Mark them as failed (will be retried if possible)
437        for task_id in task_ids {
438            let _ = scheduler.mark_failed(task_id);
439            assignments.remove(&task_id);
440        }
441
442        Ok(())
443    }
444
445    /// Get list of all workers.
446    pub fn list_workers(&self) -> Result<Vec<WorkerInfo>> {
447        let workers = self
448            .workers
449            .read()
450            .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
451
452        Ok(workers.values().cloned().collect())
453    }
454
455    /// Start monitoring loop for worker health.
456    pub async fn start_monitoring(
457        self: Arc<Self>,
458        mut shutdown_rx: mpsc::Receiver<()>,
459    ) -> Result<()> {
460        info!("Starting coordinator monitoring loop");
461
462        let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(10));
463
464        loop {
465            tokio::select! {
466                _ = interval.tick() => {
467                    if let Err(e) = self.check_worker_timeouts() {
468                        error!("Error checking worker timeouts: {}", e);
469                    }
470
471                    let progress = self.get_progress().unwrap_or_default();
472                    debug!("Progress: {:?}", progress);
473                }
474                _ = shutdown_rx.recv() => {
475                    info!("Coordinator monitoring loop shutting down");
476                    break;
477                }
478            }
479        }
480
481        Ok(())
482    }
483}
484
485/// Progress information for the coordinator.
486#[derive(Debug, Clone, Default)]
487pub struct CoordinatorProgress {
488    /// Number of pending tasks.
489    pub pending_tasks: usize,
490    /// Number of running tasks.
491    pub running_tasks: usize,
492    /// Number of completed tasks.
493    pub completed_tasks: usize,
494    /// Number of failed tasks.
495    pub failed_tasks: usize,
496    /// Number of active workers.
497    pub active_workers: usize,
498    /// Number of idle workers.
499    pub idle_workers: usize,
500}
501
502impl CoordinatorProgress {
503    /// Get the total number of tasks.
504    pub fn total_tasks(&self) -> usize {
505        self.pending_tasks + self.running_tasks + self.completed_tasks + self.failed_tasks
506    }
507
508    /// Get the completion percentage.
509    pub fn completion_percentage(&self) -> f64 {
510        let total = self.total_tasks();
511        if total == 0 {
512            0.0
513        } else {
514            (self.completed_tasks as f64 / total as f64) * 100.0
515        }
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522
523    #[test]
524    fn test_coordinator_config() {
525        let config = CoordinatorConfig::new("localhost:50051".to_string())
526            .with_max_retries(5)
527            .with_task_timeout(600);
528
529        assert_eq!(config.listen_addr, "localhost:50051");
530        assert_eq!(config.max_retries, 5);
531        assert_eq!(config.task_timeout_secs, 600);
532    }
533
534    #[test]
535    fn test_worker_info() {
536        let mut info = WorkerInfo::new("worker-1".to_string(), "localhost:50052".to_string());
537
538        info.completed_tasks = 8;
539        info.failed_tasks = 2;
540
541        assert_eq!(info.success_rate(), 0.8);
542        assert!(!info.is_timed_out(Duration::from_secs(60)));
543    }
544
545    #[test]
546    fn test_coordinator_creation() -> std::result::Result<(), Box<dyn std::error::Error>> {
547        let config = CoordinatorConfig::new("localhost:50051".to_string());
548        let coordinator = Coordinator::new(config);
549
550        let progress = coordinator.get_progress()?;
551        assert_eq!(progress.total_tasks(), 0);
552        assert_eq!(progress.active_workers, 0);
553        Ok(())
554    }
555
556    #[test]
557    fn test_add_worker() -> std::result::Result<(), Box<dyn std::error::Error>> {
558        let config = CoordinatorConfig::new("localhost:50051".to_string());
559        let coordinator = Coordinator::new(config);
560
561        coordinator.add_worker("worker-1".to_string(), "localhost:50052".to_string())?;
562
563        let workers = coordinator.list_workers()?;
564        assert_eq!(workers.len(), 1);
565        assert_eq!(workers[0].worker_id, "worker-1");
566        Ok(())
567    }
568
569    #[test]
570    fn test_submit_task() -> std::result::Result<(), Box<dyn std::error::Error>> {
571        let config = CoordinatorConfig::new("localhost:50051".to_string());
572        let coordinator = Coordinator::new(config);
573
574        let task_id = coordinator.submit_task(
575            PartitionId(0),
576            TaskOperation::Filter {
577                expression: "value > 10".to_string(),
578            },
579        )?;
580
581        assert_eq!(task_id, TaskId(0));
582
583        let progress = coordinator.get_progress()?;
584        assert_eq!(progress.pending_tasks, 1);
585        Ok(())
586    }
587
588    #[test]
589    fn test_progress() {
590        let progress = CoordinatorProgress {
591            pending_tasks: 10,
592            running_tasks: 5,
593            completed_tasks: 30,
594            failed_tasks: 5,
595            active_workers: 4,
596            idle_workers: 2,
597        };
598
599        assert_eq!(progress.total_tasks(), 50);
600        assert_eq!(progress.completion_percentage(), 60.0);
601    }
602}