1use crate::{TaskError, TaskResult};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::time::Duration;
10use tokio::sync::RwLock;
11
12pub type WorkerId = String;
14
15#[derive(Debug, Clone)]
33pub enum LoadBalancingStrategy {
34 RoundRobin,
36 LeastConnections,
38 Weighted(HashMap<WorkerId, u32>),
40 Random,
42}
43
44#[derive(Debug)]
57pub struct WorkerInfo {
58 pub id: WorkerId,
60 pub weight: u32,
62 pub active_tasks: AtomicUsize,
64}
65
66impl WorkerInfo {
67 pub fn new(id: WorkerId, weight: u32) -> Self {
79 Self {
80 id,
81 weight,
82 active_tasks: AtomicUsize::new(0),
83 }
84 }
85
86 pub fn increment_tasks(&self) {
88 self.active_tasks.fetch_add(1, Ordering::SeqCst);
89 }
90
91 pub fn decrement_tasks(&self) {
93 let _ = self
95 .active_tasks
96 .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |current| {
97 Some(current.saturating_sub(1))
98 });
99 }
100
101 pub fn active_task_count(&self) -> usize {
103 self.active_tasks.load(Ordering::SeqCst)
104 }
105}
106
107#[derive(Debug, Clone)]
121pub struct WorkerMetrics {
122 pub tasks_completed: u64,
124 pub tasks_failed: u64,
126 pub average_execution_time: Duration,
128}
129
130impl WorkerMetrics {
131 pub fn new() -> Self {
142 Self {
143 tasks_completed: 0,
144 tasks_failed: 0,
145 average_execution_time: Duration::from_secs(0),
146 }
147 }
148
149 pub fn with_values(
151 tasks_completed: u64,
152 tasks_failed: u64,
153 average_execution_time: Duration,
154 ) -> Self {
155 Self {
156 tasks_completed,
157 tasks_failed,
158 average_execution_time,
159 }
160 }
161
162 fn update_execution_time(&mut self, duration: Duration) {
165 let current_tasks = self.tasks_completed.saturating_add(self.tasks_failed);
166 if current_tasks == 0 {
167 self.average_execution_time = duration;
168 } else {
169 let avg_ms = self.average_execution_time.as_millis();
170 let dur_ms = duration.as_millis();
171 let total_time = avg_ms
172 .saturating_mul(current_tasks as u128)
173 .saturating_add(dur_ms);
174 let new_count = (current_tasks as u128).saturating_add(1);
175 let avg = total_time / new_count;
176 self.average_execution_time = Duration::from_millis(avg.min(u64::MAX as u128) as u64);
178 }
179 }
180
181 pub fn record_success(&mut self, duration: Duration) {
183 self.update_execution_time(duration);
184 self.tasks_completed += 1;
185 }
186
187 pub fn record_failure(&mut self, duration: Duration) {
189 self.update_execution_time(duration);
190 self.tasks_failed += 1;
191 }
192}
193
194impl Default for WorkerMetrics {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200pub struct LoadBalancer {
219 strategy: LoadBalancingStrategy,
220 workers: Arc<RwLock<Vec<Arc<WorkerInfo>>>>,
221 metrics: Arc<RwLock<HashMap<WorkerId, WorkerMetrics>>>,
222 round_robin_index: Arc<AtomicUsize>,
223}
224
225impl LoadBalancer {
226 pub fn new(strategy: LoadBalancingStrategy) -> Self {
236 Self {
237 strategy,
238 workers: Arc::new(RwLock::new(Vec::new())),
239 metrics: Arc::new(RwLock::new(HashMap::new())),
240 round_robin_index: Arc::new(AtomicUsize::new(0)),
241 }
242 }
243
244 pub async fn register_worker(&self, worker: WorkerInfo) -> TaskResult<()> {
258 let worker_id = worker.id.clone();
259 self.workers.write().await.push(Arc::new(worker));
260 self.metrics
261 .write()
262 .await
263 .insert(worker_id, WorkerMetrics::new());
264 Ok(())
265 }
266
267 pub async fn unregister_worker(&self, worker_id: &str) -> TaskResult<()> {
282 self.workers.write().await.retain(|w| w.id != worker_id);
283 self.metrics.write().await.remove(worker_id);
284 Ok(())
285 }
286
287 pub async fn select_worker(&self) -> TaskResult<WorkerId> {
304 let workers = self.workers.read().await;
305 if workers.is_empty() {
306 return Err(TaskError::QueueError("No workers available".to_string()));
307 }
308
309 let selected = match &self.strategy {
310 LoadBalancingStrategy::RoundRobin => self.select_round_robin(&workers),
311 LoadBalancingStrategy::LeastConnections => self.select_least_connections(&workers),
312 LoadBalancingStrategy::Weighted(weights) => self.select_weighted(&workers, weights),
313 LoadBalancingStrategy::Random => self.select_random(&workers),
314 };
315
316 selected.increment_tasks();
317 Ok(selected.id.clone())
318 }
319
320 fn select_round_robin(&self, workers: &[Arc<WorkerInfo>]) -> Arc<WorkerInfo> {
322 let index = self.round_robin_index.fetch_add(1, Ordering::SeqCst) % workers.len();
323 workers[index].clone()
324 }
325
326 fn select_least_connections(&self, workers: &[Arc<WorkerInfo>]) -> Arc<WorkerInfo> {
328 workers
329 .iter()
330 .min_by_key(|w| w.active_task_count())
331 .unwrap()
332 .clone()
333 }
334
335 fn select_weighted(
337 &self,
338 workers: &[Arc<WorkerInfo>],
339 weights: &HashMap<WorkerId, u32>,
340 ) -> Arc<WorkerInfo> {
341 use rand::Rng;
342 let total_weight: u32 = workers
343 .iter()
344 .map(|w| weights.get(&w.id).copied().unwrap_or(w.weight))
345 .sum();
346
347 if total_weight == 0 {
349 return workers[0].clone();
350 }
351
352 let mut rng = rand::rng();
353 let mut random = rng.random_range(0..total_weight);
354 for worker in workers {
355 let weight = weights.get(&worker.id).copied().unwrap_or(worker.weight);
356 if random < weight {
357 return worker.clone();
358 }
359 random -= weight;
360 }
361
362 workers[0].clone()
363 }
364
365 fn select_random(&self, workers: &[Arc<WorkerInfo>]) -> Arc<WorkerInfo> {
367 use rand::Rng;
368 let mut rng = rand::rng();
369 let index = rng.random_range(0..workers.len());
370 workers[index].clone()
371 }
372
373 pub async fn task_completed(&self, worker_id: &str) -> TaskResult<()> {
390 let workers = self.workers.read().await;
391 if let Some(worker) = workers.iter().find(|w| w.id == worker_id) {
392 worker.decrement_tasks();
393 }
394 Ok(())
395 }
396
397 pub async fn update_metrics(&self, worker_id: &str, metrics: WorkerMetrics) -> TaskResult<()> {
415 self.metrics
416 .write()
417 .await
418 .insert(worker_id.to_string(), metrics);
419 Ok(())
420 }
421
422 pub async fn get_worker_stats(&self) -> HashMap<WorkerId, WorkerMetrics> {
439 self.metrics.read().await.clone()
440 }
441
442 pub async fn worker_count(&self) -> usize {
459 self.workers.read().await.len()
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use rstest::rstest;
467 use std::time::Duration;
468
469 #[rstest]
470 #[tokio::test]
471 async fn test_worker_info_creation() {
472 let worker = WorkerInfo::new("worker-1".to_string(), 2);
474
475 assert_eq!(worker.id, "worker-1");
477 assert_eq!(worker.weight, 2);
478 assert_eq!(worker.active_task_count(), 0);
479 }
480
481 #[rstest]
482 #[tokio::test]
483 async fn test_worker_info_task_count() {
484 let worker = WorkerInfo::new("worker-1".to_string(), 1);
486
487 worker.increment_tasks();
489 assert_eq!(worker.active_task_count(), 1);
490 worker.increment_tasks();
491 assert_eq!(worker.active_task_count(), 2);
492 worker.decrement_tasks();
493 assert_eq!(worker.active_task_count(), 1);
494 }
495
496 #[rstest]
497 #[tokio::test]
498 async fn test_worker_metrics_creation() {
499 let metrics = WorkerMetrics::new();
501
502 assert_eq!(metrics.tasks_completed, 0);
504 assert_eq!(metrics.tasks_failed, 0);
505 assert_eq!(metrics.average_execution_time, Duration::from_secs(0));
506 }
507
508 #[rstest]
509 #[tokio::test]
510 async fn test_worker_metrics_record_success() {
511 let mut metrics = WorkerMetrics::new();
513
514 metrics.record_success(Duration::from_millis(100));
516
517 assert_eq!(metrics.tasks_completed, 1);
519 assert_eq!(metrics.average_execution_time, Duration::from_millis(100));
520
521 metrics.record_success(Duration::from_millis(200));
523
524 assert_eq!(metrics.tasks_completed, 2);
526 assert_eq!(metrics.average_execution_time, Duration::from_millis(150));
527 }
528
529 #[rstest]
530 #[tokio::test]
531 async fn test_worker_metrics_record_failure() {
532 let mut metrics = WorkerMetrics::new();
534
535 metrics.record_failure(Duration::from_millis(50));
537
538 assert_eq!(metrics.tasks_failed, 1);
540 assert_eq!(metrics.average_execution_time, Duration::from_millis(50));
541 }
542
543 #[rstest]
544 #[tokio::test]
545 async fn test_load_balancer_creation() {
546 let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
548
549 assert_eq!(balancer.worker_count().await, 0);
551 }
552
553 #[rstest]
554 #[tokio::test]
555 async fn test_load_balancer_register_worker() {
556 let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
558
559 balancer
561 .register_worker(WorkerInfo::new("worker-1".to_string(), 1))
562 .await
563 .unwrap();
564
565 assert_eq!(balancer.worker_count().await, 1);
567 }
568
569 #[rstest]
570 #[tokio::test]
571 async fn test_load_balancer_unregister_worker() {
572 let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
574 balancer
575 .register_worker(WorkerInfo::new("worker-1".to_string(), 1))
576 .await
577 .unwrap();
578
579 balancer.unregister_worker("worker-1").await.unwrap();
581
582 assert_eq!(balancer.worker_count().await, 0);
584 }
585
586 #[rstest]
587 #[tokio::test]
588 async fn test_round_robin_strategy() {
589 let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
591 balancer
592 .register_worker(WorkerInfo::new("worker-1".to_string(), 1))
593 .await
594 .unwrap();
595 balancer
596 .register_worker(WorkerInfo::new("worker-2".to_string(), 1))
597 .await
598 .unwrap();
599
600 let worker1 = balancer.select_worker().await.unwrap();
602 let worker2 = balancer.select_worker().await.unwrap();
603 let worker3 = balancer.select_worker().await.unwrap();
604
605 assert_eq!(worker1, "worker-1");
607 assert_eq!(worker2, "worker-2");
608 assert_eq!(worker3, "worker-1");
609 }
610
611 #[rstest]
612 #[tokio::test]
613 async fn test_least_connections_strategy() {
614 let balancer = LoadBalancer::new(LoadBalancingStrategy::LeastConnections);
616 let worker1 = WorkerInfo::new("worker-1".to_string(), 1);
617 let worker2 = WorkerInfo::new("worker-2".to_string(), 1);
618
619 worker1.increment_tasks();
621 worker1.increment_tasks();
622
623 balancer.register_worker(worker1).await.unwrap();
624 balancer.register_worker(worker2).await.unwrap();
625
626 let selected = balancer.select_worker().await.unwrap();
628
629 assert_eq!(selected, "worker-2");
631 }
632
633 #[rstest]
634 #[tokio::test]
635 async fn test_weighted_strategy() {
636 let mut weights = HashMap::new();
638 weights.insert("worker-1".to_string(), 3);
639 weights.insert("worker-2".to_string(), 1);
640
641 let balancer = LoadBalancer::new(LoadBalancingStrategy::Weighted(weights));
642 balancer
643 .register_worker(WorkerInfo::new("worker-1".to_string(), 3))
644 .await
645 .unwrap();
646 balancer
647 .register_worker(WorkerInfo::new("worker-2".to_string(), 1))
648 .await
649 .unwrap();
650
651 let mut worker1_count = 0;
653 let mut worker2_count = 0;
654
655 for _ in 0..100 {
656 let selected = balancer.select_worker().await.unwrap();
657 balancer.task_completed(&selected).await.unwrap();
658 if selected == "worker-1" {
659 worker1_count += 1;
660 } else {
661 worker2_count += 1;
662 }
663 }
664
665 assert!(worker1_count > worker2_count);
667 }
668
669 #[rstest]
670 #[tokio::test]
671 async fn test_random_strategy() {
672 let balancer = LoadBalancer::new(LoadBalancingStrategy::Random);
674 balancer
675 .register_worker(WorkerInfo::new("worker-1".to_string(), 1))
676 .await
677 .unwrap();
678 balancer
679 .register_worker(WorkerInfo::new("worker-2".to_string(), 1))
680 .await
681 .unwrap();
682
683 let worker = balancer.select_worker().await.unwrap();
685
686 assert!(worker == "worker-1" || worker == "worker-2");
688 }
689
690 #[rstest]
691 #[tokio::test]
692 async fn test_select_worker_no_workers() {
693 let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
695
696 let result = balancer.select_worker().await;
698
699 assert!(result.is_err());
701 }
702
703 #[rstest]
704 #[tokio::test]
705 async fn test_task_completed() {
706 let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
708 balancer
709 .register_worker(WorkerInfo::new("worker-1".to_string(), 1))
710 .await
711 .unwrap();
712
713 let worker_id = balancer.select_worker().await.unwrap();
714 let workers = balancer.workers.read().await;
715 let worker = workers.iter().find(|w| w.id == worker_id).unwrap();
716 assert_eq!(worker.active_task_count(), 1);
717 drop(workers);
718
719 balancer.task_completed(&worker_id).await.unwrap();
721
722 let workers = balancer.workers.read().await;
724 let worker = workers.iter().find(|w| w.id == worker_id).unwrap();
725 assert_eq!(worker.active_task_count(), 0);
726 }
727
728 #[rstest]
729 #[tokio::test]
730 async fn test_update_metrics() {
731 let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
733 balancer
734 .register_worker(WorkerInfo::new("worker-1".to_string(), 1))
735 .await
736 .unwrap();
737
738 let metrics = WorkerMetrics::with_values(10, 2, Duration::from_millis(500));
739
740 balancer
742 .update_metrics("worker-1", metrics.clone())
743 .await
744 .unwrap();
745
746 let stats = balancer.get_worker_stats().await;
748 let worker_metrics = stats.get("worker-1").unwrap();
749 assert_eq!(worker_metrics.tasks_completed, 10);
750 assert_eq!(worker_metrics.tasks_failed, 2);
751 assert_eq!(
752 worker_metrics.average_execution_time,
753 Duration::from_millis(500)
754 );
755 }
756
757 #[rstest]
758 #[tokio::test]
759 async fn test_get_worker_stats() {
760 let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
762 balancer
763 .register_worker(WorkerInfo::new("worker-1".to_string(), 1))
764 .await
765 .unwrap();
766 balancer
767 .register_worker(WorkerInfo::new("worker-2".to_string(), 1))
768 .await
769 .unwrap();
770
771 let stats = balancer.get_worker_stats().await;
773
774 assert_eq!(stats.len(), 2);
776 assert!(stats.contains_key("worker-1"));
777 assert!(stats.contains_key("worker-2"));
778 }
779
780 #[rstest]
781 #[tokio::test]
782 async fn test_decrement_tasks_at_zero_does_not_underflow() {
783 let worker = WorkerInfo::new("worker-1".to_string(), 1);
785 assert_eq!(worker.active_task_count(), 0);
786
787 worker.decrement_tasks();
789
790 assert_eq!(worker.active_task_count(), 0);
792 }
793
794 #[rstest]
795 #[tokio::test]
796 async fn test_decrement_tasks_multiple_times_at_zero_stays_at_zero() {
797 let worker = WorkerInfo::new("worker-1".to_string(), 1);
799 worker.increment_tasks();
800 worker.decrement_tasks();
801 assert_eq!(worker.active_task_count(), 0);
802
803 worker.decrement_tasks();
805 worker.decrement_tasks();
806 worker.decrement_tasks();
807
808 assert_eq!(worker.active_task_count(), 0);
810 }
811
812 #[rstest]
813 #[tokio::test]
814 async fn test_weighted_strategy_zero_total_weight_does_not_panic() {
815 let mut weights = HashMap::new();
817 weights.insert("worker-1".to_string(), 0);
818 weights.insert("worker-2".to_string(), 0);
819
820 let balancer = LoadBalancer::new(LoadBalancingStrategy::Weighted(weights));
821 balancer
822 .register_worker(WorkerInfo::new("worker-1".to_string(), 0))
823 .await
824 .unwrap();
825 balancer
826 .register_worker(WorkerInfo::new("worker-2".to_string(), 0))
827 .await
828 .unwrap();
829
830 let selected = balancer.select_worker().await.unwrap();
832
833 assert!(selected == "worker-1" || selected == "worker-2");
835 }
836
837 #[rstest]
838 #[tokio::test]
839 async fn test_update_execution_time_does_not_overflow() {
840 let mut metrics = WorkerMetrics::new();
842 metrics.tasks_completed = u64::MAX - 1;
843 metrics.average_execution_time = Duration::from_millis(u64::MAX);
844
845 metrics.record_success(Duration::from_millis(1000));
847
848 assert!(metrics.tasks_completed > 0);
851 }
852
853 #[rstest]
854 #[tokio::test]
855 async fn test_update_execution_time_saturates_at_u64_max() {
856 let mut metrics = WorkerMetrics::new();
858 metrics.tasks_completed = 1;
859 metrics.average_execution_time = Duration::from_millis(u64::MAX);
860
861 metrics.record_success(Duration::from_millis(u64::MAX));
863
864 assert_eq!(
866 metrics.average_execution_time,
867 Duration::from_millis(u64::MAX)
868 );
869 }
870}