1use crate::error::{DistributedError, Result};
7use crate::task::{PartitionId, Task, TaskId, TaskOperation, TaskResult, TaskScheduler};
8use crate::worker::WorkerStatus;
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11use std::time::{Duration, Instant};
12use tokio::sync::mpsc;
13use tracing::{debug, error, info, warn};
14
15#[derive(Debug, Clone)]
17pub struct CoordinatorConfig {
18 pub listen_addr: String,
20 pub max_retries: u32,
22 pub task_timeout_secs: u64,
24 pub worker_timeout_secs: u64,
26 pub result_buffer_size: usize,
28}
29
30impl CoordinatorConfig {
31 pub fn new(listen_addr: String) -> Self {
33 Self {
34 listen_addr,
35 max_retries: 3,
36 task_timeout_secs: 300, worker_timeout_secs: 60,
38 result_buffer_size: 1000,
39 }
40 }
41
42 pub fn with_max_retries(mut self, retries: u32) -> Self {
44 self.max_retries = retries;
45 self
46 }
47
48 pub fn with_task_timeout(mut self, timeout_secs: u64) -> Self {
50 self.task_timeout_secs = timeout_secs;
51 self
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct WorkerInfo {
58 pub worker_id: String,
60 pub address: String,
62 pub status: WorkerStatus,
64 pub last_heartbeat: Instant,
66 pub active_tasks: usize,
68 pub completed_tasks: u64,
70 pub failed_tasks: u64,
72}
73
74impl WorkerInfo {
75 pub fn new(worker_id: String, address: String) -> Self {
77 Self {
78 worker_id,
79 address,
80 status: WorkerStatus::Idle,
81 last_heartbeat: Instant::now(),
82 active_tasks: 0,
83 completed_tasks: 0,
84 failed_tasks: 0,
85 }
86 }
87
88 pub fn update_heartbeat(&mut self) {
90 self.last_heartbeat = Instant::now();
91 }
92
93 pub fn is_timed_out(&self, timeout: Duration) -> bool {
95 self.last_heartbeat.elapsed() > timeout
96 }
97
98 pub fn success_rate(&self) -> f64 {
100 let total = self.completed_tasks + self.failed_tasks;
101 if total == 0 {
102 1.0
103 } else {
104 self.completed_tasks as f64 / total as f64
105 }
106 }
107}
108
109pub struct Coordinator {
111 config: CoordinatorConfig,
113 scheduler: Arc<RwLock<TaskScheduler>>,
115 workers: Arc<RwLock<HashMap<String, WorkerInfo>>>,
117 assignments: Arc<RwLock<HashMap<TaskId, String>>>,
119 results: Arc<RwLock<HashMap<TaskId, TaskResult>>>,
121 next_task_id: Arc<RwLock<u64>>,
123}
124
125impl Coordinator {
126 pub fn new(config: CoordinatorConfig) -> Self {
128 Self {
129 config,
130 scheduler: Arc::new(RwLock::new(TaskScheduler::new())),
131 workers: Arc::new(RwLock::new(HashMap::new())),
132 assignments: Arc::new(RwLock::new(HashMap::new())),
133 results: Arc::new(RwLock::new(HashMap::new())),
134 next_task_id: Arc::new(RwLock::new(0)),
135 }
136 }
137
138 pub fn add_worker(&self, worker_id: String, address: String) -> Result<()> {
140 info!("Adding worker: {} at {}", worker_id, address);
141
142 let worker_info = WorkerInfo::new(worker_id.clone(), address);
143
144 let mut workers = self
145 .workers
146 .write()
147 .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
148
149 if workers.contains_key(&worker_id) {
150 return Err(DistributedError::coordinator(format!(
151 "Worker {} already exists",
152 worker_id
153 )));
154 }
155
156 workers.insert(worker_id, worker_info);
157 Ok(())
158 }
159
160 pub fn remove_worker(&self, worker_id: &str) -> Result<()> {
162 info!("Removing worker: {}", worker_id);
163
164 let mut workers = self
165 .workers
166 .write()
167 .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
168
169 workers.remove(worker_id);
170
171 self.reassign_worker_tasks(worker_id)?;
173
174 Ok(())
175 }
176
177 pub fn update_worker_heartbeat(&self, worker_id: &str) -> Result<()> {
179 let mut workers = self
180 .workers
181 .write()
182 .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
183
184 if let Some(worker) = workers.get_mut(worker_id) {
185 worker.update_heartbeat();
186 debug!("Updated heartbeat for worker {}", worker_id);
187 Ok(())
188 } else {
189 Err(DistributedError::coordinator(format!(
190 "Worker {} not found",
191 worker_id
192 )))
193 }
194 }
195
196 pub fn check_worker_timeouts(&self) -> Result<Vec<String>> {
198 let timeout = Duration::from_secs(self.config.worker_timeout_secs);
199 let mut timed_out = Vec::new();
200
201 let workers = self
202 .workers
203 .read()
204 .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
205
206 for (worker_id, worker) in workers.iter() {
207 if worker.is_timed_out(timeout) {
208 warn!("Worker {} has timed out", worker_id);
209 timed_out.push(worker_id.clone());
210 }
211 }
212
213 drop(workers);
214
215 for worker_id in &timed_out {
217 self.reassign_worker_tasks(worker_id)?;
218 self.remove_worker(worker_id)?;
219 }
220
221 Ok(timed_out)
222 }
223
224 pub fn submit_task(
226 &self,
227 partition_id: PartitionId,
228 operation: TaskOperation,
229 ) -> Result<TaskId> {
230 let task_id = self.generate_task_id()?;
231 let mut task = Task::new(task_id, partition_id, operation);
232 task.max_retries = self.config.max_retries;
233
234 let mut scheduler = self
235 .scheduler
236 .write()
237 .map_err(|_| DistributedError::coordinator("Failed to acquire scheduler lock"))?;
238
239 scheduler.add_task(task);
240 debug!("Submitted task {}", task_id);
241
242 Ok(task_id)
243 }
244
245 pub fn next_task(&self) -> Result<Option<Task>> {
247 let mut scheduler = self
248 .scheduler
249 .write()
250 .map_err(|_| DistributedError::coordinator("Failed to acquire scheduler lock"))?;
251
252 Ok(scheduler.next_task())
253 }
254
255 pub fn assign_task(&self, task: Task, worker_id: String) -> Result<()> {
257 let mut scheduler = self
259 .scheduler
260 .write()
261 .map_err(|_| DistributedError::coordinator("Failed to acquire scheduler lock"))?;
262 scheduler.mark_running(task.clone(), worker_id.clone());
263 drop(scheduler);
264
265 let mut assignments = self
267 .assignments
268 .write()
269 .map_err(|_| DistributedError::coordinator("Failed to acquire assignments lock"))?;
270 assignments.insert(task.id, worker_id.clone());
271
272 let mut workers = self
274 .workers
275 .write()
276 .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
277 if let Some(worker) = workers.get_mut(&worker_id) {
278 worker.active_tasks += 1;
279 worker.status = WorkerStatus::Busy;
280 }
281
282 info!("Assigned task {} to worker {}", task.id, worker_id);
283 Ok(())
284 }
285
286 pub fn complete_task(&self, task_id: TaskId, result: TaskResult) -> Result<()> {
288 let worker_id = {
289 let assignments = self
290 .assignments
291 .read()
292 .map_err(|_| DistributedError::coordinator("Failed to acquire assignments lock"))?;
293 assignments.get(&task_id).cloned()
294 };
295
296 let mut scheduler = self
298 .scheduler
299 .write()
300 .map_err(|_| DistributedError::coordinator("Failed to acquire scheduler lock"))?;
301
302 if result.is_success() {
303 scheduler.mark_completed(task_id)?;
304 } else {
305 scheduler.mark_failed(task_id)?;
306 }
307 drop(scheduler);
308
309 if let Some(worker_id) = worker_id {
311 let mut workers = self
312 .workers
313 .write()
314 .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
315
316 if let Some(worker) = workers.get_mut(&worker_id) {
317 if worker.active_tasks > 0 {
318 worker.active_tasks -= 1;
319 }
320 if result.is_success() {
321 worker.completed_tasks += 1;
322 } else {
323 worker.failed_tasks += 1;
324 }
325 if worker.active_tasks == 0 {
326 worker.status = WorkerStatus::Idle;
327 }
328 }
329 }
330
331 let mut results = self
333 .results
334 .write()
335 .map_err(|_| DistributedError::coordinator("Failed to acquire results lock"))?;
336 results.insert(task_id, result);
337
338 info!("Task {} completed", task_id);
339 Ok(())
340 }
341
342 pub fn get_available_worker(&self) -> Result<Option<String>> {
344 let workers = self
345 .workers
346 .read()
347 .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
348
349 let best_worker = workers
351 .values()
352 .filter(|w| w.status == WorkerStatus::Idle)
353 .max_by(|a, b| {
354 a.success_rate()
355 .partial_cmp(&b.success_rate())
356 .unwrap_or(std::cmp::Ordering::Equal)
357 })
358 .map(|w| w.worker_id.clone());
359
360 Ok(best_worker)
361 }
362
363 pub fn get_progress(&self) -> Result<CoordinatorProgress> {
365 let scheduler = self
366 .scheduler
367 .read()
368 .map_err(|_| DistributedError::coordinator("Failed to acquire scheduler lock"))?;
369
370 let workers = self
371 .workers
372 .read()
373 .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
374
375 Ok(CoordinatorProgress {
376 pending_tasks: scheduler.pending_count(),
377 running_tasks: scheduler.running_count(),
378 completed_tasks: scheduler.completed_count(),
379 failed_tasks: scheduler.failed_count(),
380 active_workers: workers.len(),
381 idle_workers: workers
382 .values()
383 .filter(|w| w.status == WorkerStatus::Idle)
384 .count(),
385 })
386 }
387
388 pub fn collect_results(&self) -> Result<Vec<TaskResult>> {
390 let results = self
391 .results
392 .read()
393 .map_err(|_| DistributedError::coordinator("Failed to acquire results lock"))?;
394
395 Ok(results.values().cloned().collect())
396 }
397
398 pub fn is_complete(&self) -> bool {
400 self.scheduler
401 .read()
402 .map(|s| s.is_complete())
403 .unwrap_or(false)
404 }
405
406 fn generate_task_id(&self) -> Result<TaskId> {
408 let mut next_id = self
409 .next_task_id
410 .write()
411 .map_err(|_| DistributedError::coordinator("Failed to acquire task ID lock"))?;
412 let id = *next_id;
413 *next_id += 1;
414 Ok(TaskId(id))
415 }
416
417 fn reassign_worker_tasks(&self, worker_id: &str) -> Result<()> {
419 let mut scheduler = self
420 .scheduler
421 .write()
422 .map_err(|_| DistributedError::coordinator("Failed to acquire scheduler lock"))?;
423
424 let mut assignments = self
425 .assignments
426 .write()
427 .map_err(|_| DistributedError::coordinator("Failed to acquire assignments lock"))?;
428
429 let task_ids: Vec<TaskId> = assignments
431 .iter()
432 .filter(|(_, wid)| *wid == worker_id)
433 .map(|(tid, _)| *tid)
434 .collect();
435
436 for task_id in task_ids {
438 let _ = scheduler.mark_failed(task_id);
439 assignments.remove(&task_id);
440 }
441
442 Ok(())
443 }
444
445 pub fn list_workers(&self) -> Result<Vec<WorkerInfo>> {
447 let workers = self
448 .workers
449 .read()
450 .map_err(|_| DistributedError::coordinator("Failed to acquire workers lock"))?;
451
452 Ok(workers.values().cloned().collect())
453 }
454
455 pub async fn start_monitoring(
457 self: Arc<Self>,
458 mut shutdown_rx: mpsc::Receiver<()>,
459 ) -> Result<()> {
460 info!("Starting coordinator monitoring loop");
461
462 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(10));
463
464 loop {
465 tokio::select! {
466 _ = interval.tick() => {
467 if let Err(e) = self.check_worker_timeouts() {
468 error!("Error checking worker timeouts: {}", e);
469 }
470
471 let progress = self.get_progress().unwrap_or_default();
472 debug!("Progress: {:?}", progress);
473 }
474 _ = shutdown_rx.recv() => {
475 info!("Coordinator monitoring loop shutting down");
476 break;
477 }
478 }
479 }
480
481 Ok(())
482 }
483}
484
485#[derive(Debug, Clone, Default)]
487pub struct CoordinatorProgress {
488 pub pending_tasks: usize,
490 pub running_tasks: usize,
492 pub completed_tasks: usize,
494 pub failed_tasks: usize,
496 pub active_workers: usize,
498 pub idle_workers: usize,
500}
501
502impl CoordinatorProgress {
503 pub fn total_tasks(&self) -> usize {
505 self.pending_tasks + self.running_tasks + self.completed_tasks + self.failed_tasks
506 }
507
508 pub fn completion_percentage(&self) -> f64 {
510 let total = self.total_tasks();
511 if total == 0 {
512 0.0
513 } else {
514 (self.completed_tasks as f64 / total as f64) * 100.0
515 }
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522
523 #[test]
524 fn test_coordinator_config() {
525 let config = CoordinatorConfig::new("localhost:50051".to_string())
526 .with_max_retries(5)
527 .with_task_timeout(600);
528
529 assert_eq!(config.listen_addr, "localhost:50051");
530 assert_eq!(config.max_retries, 5);
531 assert_eq!(config.task_timeout_secs, 600);
532 }
533
534 #[test]
535 fn test_worker_info() {
536 let mut info = WorkerInfo::new("worker-1".to_string(), "localhost:50052".to_string());
537
538 info.completed_tasks = 8;
539 info.failed_tasks = 2;
540
541 assert_eq!(info.success_rate(), 0.8);
542 assert!(!info.is_timed_out(Duration::from_secs(60)));
543 }
544
545 #[test]
546 fn test_coordinator_creation() -> std::result::Result<(), Box<dyn std::error::Error>> {
547 let config = CoordinatorConfig::new("localhost:50051".to_string());
548 let coordinator = Coordinator::new(config);
549
550 let progress = coordinator.get_progress()?;
551 assert_eq!(progress.total_tasks(), 0);
552 assert_eq!(progress.active_workers, 0);
553 Ok(())
554 }
555
556 #[test]
557 fn test_add_worker() -> std::result::Result<(), Box<dyn std::error::Error>> {
558 let config = CoordinatorConfig::new("localhost:50051".to_string());
559 let coordinator = Coordinator::new(config);
560
561 coordinator.add_worker("worker-1".to_string(), "localhost:50052".to_string())?;
562
563 let workers = coordinator.list_workers()?;
564 assert_eq!(workers.len(), 1);
565 assert_eq!(workers[0].worker_id, "worker-1");
566 Ok(())
567 }
568
569 #[test]
570 fn test_submit_task() -> std::result::Result<(), Box<dyn std::error::Error>> {
571 let config = CoordinatorConfig::new("localhost:50051".to_string());
572 let coordinator = Coordinator::new(config);
573
574 let task_id = coordinator.submit_task(
575 PartitionId(0),
576 TaskOperation::Filter {
577 expression: "value > 10".to_string(),
578 },
579 )?;
580
581 assert_eq!(task_id, TaskId(0));
582
583 let progress = coordinator.get_progress()?;
584 assert_eq!(progress.pending_tasks, 1);
585 Ok(())
586 }
587
588 #[test]
589 fn test_progress() {
590 let progress = CoordinatorProgress {
591 pending_tasks: 10,
592 running_tasks: 5,
593 completed_tasks: 30,
594 failed_tasks: 5,
595 active_workers: 4,
596 idle_workers: 2,
597 };
598
599 assert_eq!(progress.total_tasks(), 50);
600 assert_eq!(progress.completion_percentage(), 60.0);
601 }
602}