Skip to main content

reinhardt_tasks/
load_balancer.rs

1//! Worker load balancing
2//!
3//! Provides load balancing strategies for distributing tasks across multiple workers.
4
5use crate::{TaskError, TaskResult};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicUsize, Ordering};
9use std::time::Duration;
10use tokio::sync::RwLock;
11
12/// Worker identifier
13pub type WorkerId = String;
14
15/// Load balancing strategy
16///
17/// # Examples
18///
19/// ```rust
20/// use reinhardt_tasks::{LoadBalancingStrategy};
21/// use std::collections::HashMap;
22///
23/// // Round-robin strategy
24/// let strategy = LoadBalancingStrategy::RoundRobin;
25///
26/// // Weighted strategy
27/// let mut weights = HashMap::new();
28/// weights.insert("worker-1".to_string(), 2);
29/// weights.insert("worker-2".to_string(), 1);
30/// let strategy = LoadBalancingStrategy::Weighted(weights);
31/// ```
32#[derive(Debug, Clone)]
33pub enum LoadBalancingStrategy {
34	/// Round-robin distribution
35	RoundRobin,
36	/// Least connections - select worker with fewest active tasks
37	LeastConnections,
38	/// Weighted distribution - workers with higher weights receive more tasks
39	Weighted(HashMap<WorkerId, u32>),
40	/// Random distribution
41	Random,
42}
43
44/// Worker information
45///
46/// # Examples
47///
48/// ```rust
49/// use reinhardt_tasks::WorkerInfo;
50///
51/// let worker = WorkerInfo::new("worker-1".to_string(), 1);
52/// assert_eq!(worker.id, "worker-1");
53/// assert_eq!(worker.weight, 1);
54/// assert_eq!(worker.active_tasks.load(std::sync::atomic::Ordering::SeqCst), 0);
55/// ```
56#[derive(Debug)]
57pub struct WorkerInfo {
58	/// Unique identifier for this worker.
59	pub id: WorkerId,
60	/// Weight used for weighted load balancing strategies.
61	pub weight: u32,
62	/// Number of currently active tasks on this worker.
63	pub active_tasks: AtomicUsize,
64}
65
66impl WorkerInfo {
67	/// Create a new worker info
68	///
69	/// # Examples
70	///
71	/// ```rust
72	/// use reinhardt_tasks::WorkerInfo;
73	///
74	/// let worker = WorkerInfo::new("worker-1".to_string(), 2);
75	/// assert_eq!(worker.id, "worker-1");
76	/// assert_eq!(worker.weight, 2);
77	/// ```
78	pub fn new(id: WorkerId, weight: u32) -> Self {
79		Self {
80			id,
81			weight,
82			active_tasks: AtomicUsize::new(0),
83		}
84	}
85
86	/// Increment active task count
87	pub fn increment_tasks(&self) {
88		self.active_tasks.fetch_add(1, Ordering::SeqCst);
89	}
90
91	/// Decrement active task count (saturates at 0 to prevent underflow wrap)
92	pub fn decrement_tasks(&self) {
93		// Use fetch_update with saturating_sub to prevent wrapping to usize::MAX
94		let _ = self
95			.active_tasks
96			.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |current| {
97				Some(current.saturating_sub(1))
98			});
99	}
100
101	/// Get current active task count
102	pub fn active_task_count(&self) -> usize {
103		self.active_tasks.load(Ordering::SeqCst)
104	}
105}
106
107/// Worker metrics
108///
109/// # Examples
110///
111/// ```rust
112/// use reinhardt_tasks::WorkerMetrics;
113/// use std::time::Duration;
114///
115/// let metrics = WorkerMetrics::new();
116/// assert_eq!(metrics.tasks_completed, 0);
117/// assert_eq!(metrics.tasks_failed, 0);
118/// assert_eq!(metrics.average_execution_time, Duration::from_secs(0));
119/// ```
120#[derive(Debug, Clone)]
121pub struct WorkerMetrics {
122	/// Total number of successfully completed tasks.
123	pub tasks_completed: u64,
124	/// Total number of failed tasks.
125	pub tasks_failed: u64,
126	/// Average execution time across all completed tasks.
127	pub average_execution_time: Duration,
128}
129
130impl WorkerMetrics {
131	/// Create new metrics with default values
132	///
133	/// # Examples
134	///
135	/// ```rust
136	/// use reinhardt_tasks::WorkerMetrics;
137	///
138	/// let metrics = WorkerMetrics::new();
139	/// assert_eq!(metrics.tasks_completed, 0);
140	/// ```
141	pub fn new() -> Self {
142		Self {
143			tasks_completed: 0,
144			tasks_failed: 0,
145			average_execution_time: Duration::from_secs(0),
146		}
147	}
148
149	/// Create metrics with specific values
150	pub fn with_values(
151		tasks_completed: u64,
152		tasks_failed: u64,
153		average_execution_time: Duration,
154	) -> Self {
155		Self {
156			tasks_completed,
157			tasks_failed,
158			average_execution_time,
159		}
160	}
161
162	/// Update average execution time with a new task duration.
163	/// Uses checked/saturating arithmetic to prevent overflow on duration casting.
164	fn update_execution_time(&mut self, duration: Duration) {
165		let current_tasks = self.tasks_completed.saturating_add(self.tasks_failed);
166		if current_tasks == 0 {
167			self.average_execution_time = duration;
168		} else {
169			let avg_ms = self.average_execution_time.as_millis();
170			let dur_ms = duration.as_millis();
171			let total_time = avg_ms
172				.saturating_mul(current_tasks as u128)
173				.saturating_add(dur_ms);
174			let new_count = (current_tasks as u128).saturating_add(1);
175			let avg = total_time / new_count;
176			// Clamp to u64::MAX to prevent truncation panic
177			self.average_execution_time = Duration::from_millis(avg.min(u64::MAX as u128) as u64);
178		}
179	}
180
181	/// Record a successful task completion
182	pub fn record_success(&mut self, duration: Duration) {
183		self.update_execution_time(duration);
184		self.tasks_completed += 1;
185	}
186
187	/// Record a failed task
188	pub fn record_failure(&mut self, duration: Duration) {
189		self.update_execution_time(duration);
190		self.tasks_failed += 1;
191	}
192}
193
194impl Default for WorkerMetrics {
195	fn default() -> Self {
196		Self::new()
197	}
198}
199
200/// Load balancer for distributing tasks across workers
201///
202/// # Examples
203///
204/// ```rust
205/// use reinhardt_tasks::{LoadBalancer, LoadBalancingStrategy, WorkerInfo};
206///
207/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
208/// let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
209/// balancer.register_worker(WorkerInfo::new("worker-1".to_string(), 1)).await?;
210/// balancer.register_worker(WorkerInfo::new("worker-2".to_string(), 1)).await?;
211///
212/// // Select worker for task
213/// let worker_id = balancer.select_worker().await?;
214/// println!("Selected worker: {}", worker_id);
215/// # Ok(())
216/// # }
217/// ```
218pub struct LoadBalancer {
219	strategy: LoadBalancingStrategy,
220	workers: Arc<RwLock<Vec<Arc<WorkerInfo>>>>,
221	metrics: Arc<RwLock<HashMap<WorkerId, WorkerMetrics>>>,
222	round_robin_index: Arc<AtomicUsize>,
223}
224
225impl LoadBalancer {
226	/// Create a new load balancer with the specified strategy
227	///
228	/// # Examples
229	///
230	/// ```rust
231	/// use reinhardt_tasks::{LoadBalancer, LoadBalancingStrategy};
232	///
233	/// let balancer = LoadBalancer::new(LoadBalancingStrategy::LeastConnections);
234	/// ```
235	pub fn new(strategy: LoadBalancingStrategy) -> Self {
236		Self {
237			strategy,
238			workers: Arc::new(RwLock::new(Vec::new())),
239			metrics: Arc::new(RwLock::new(HashMap::new())),
240			round_robin_index: Arc::new(AtomicUsize::new(0)),
241		}
242	}
243
244	/// Register a new worker
245	///
246	/// # Examples
247	///
248	/// ```rust
249	/// use reinhardt_tasks::{LoadBalancer, LoadBalancingStrategy, WorkerInfo};
250	///
251	/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
252	/// let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
253	/// balancer.register_worker(WorkerInfo::new("worker-1".to_string(), 1)).await?;
254	/// # Ok(())
255	/// # }
256	/// ```
257	pub async fn register_worker(&self, worker: WorkerInfo) -> TaskResult<()> {
258		let worker_id = worker.id.clone();
259		self.workers.write().await.push(Arc::new(worker));
260		self.metrics
261			.write()
262			.await
263			.insert(worker_id, WorkerMetrics::new());
264		Ok(())
265	}
266
267	/// Unregister a worker
268	///
269	/// # Examples
270	///
271	/// ```rust
272	/// use reinhardt_tasks::{LoadBalancer, LoadBalancingStrategy, WorkerInfo};
273	///
274	/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
275	/// let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
276	/// balancer.register_worker(WorkerInfo::new("worker-1".to_string(), 1)).await?;
277	/// balancer.unregister_worker("worker-1").await?;
278	/// # Ok(())
279	/// # }
280	/// ```
281	pub async fn unregister_worker(&self, worker_id: &str) -> TaskResult<()> {
282		self.workers.write().await.retain(|w| w.id != worker_id);
283		self.metrics.write().await.remove(worker_id);
284		Ok(())
285	}
286
287	/// Select a worker based on the load balancing strategy
288	///
289	/// # Examples
290	///
291	/// ```rust
292	/// use reinhardt_tasks::{LoadBalancer, LoadBalancingStrategy, WorkerInfo};
293	///
294	/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
295	/// let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
296	/// balancer.register_worker(WorkerInfo::new("worker-1".to_string(), 1)).await?;
297	///
298	/// let worker_id = balancer.select_worker().await?;
299	/// assert_eq!(worker_id, "worker-1");
300	/// # Ok(())
301	/// # }
302	/// ```
303	pub async fn select_worker(&self) -> TaskResult<WorkerId> {
304		let workers = self.workers.read().await;
305		if workers.is_empty() {
306			return Err(TaskError::QueueError("No workers available".to_string()));
307		}
308
309		let selected = match &self.strategy {
310			LoadBalancingStrategy::RoundRobin => self.select_round_robin(&workers),
311			LoadBalancingStrategy::LeastConnections => self.select_least_connections(&workers),
312			LoadBalancingStrategy::Weighted(weights) => self.select_weighted(&workers, weights),
313			LoadBalancingStrategy::Random => self.select_random(&workers),
314		};
315
316		selected.increment_tasks();
317		Ok(selected.id.clone())
318	}
319
320	/// Round-robin selection
321	fn select_round_robin(&self, workers: &[Arc<WorkerInfo>]) -> Arc<WorkerInfo> {
322		let index = self.round_robin_index.fetch_add(1, Ordering::SeqCst) % workers.len();
323		workers[index].clone()
324	}
325
326	/// Least connections selection
327	fn select_least_connections(&self, workers: &[Arc<WorkerInfo>]) -> Arc<WorkerInfo> {
328		workers
329			.iter()
330			.min_by_key(|w| w.active_task_count())
331			.unwrap()
332			.clone()
333	}
334
335	/// Weighted selection
336	fn select_weighted(
337		&self,
338		workers: &[Arc<WorkerInfo>],
339		weights: &HashMap<WorkerId, u32>,
340	) -> Arc<WorkerInfo> {
341		use rand::Rng;
342		let total_weight: u32 = workers
343			.iter()
344			.map(|w| weights.get(&w.id).copied().unwrap_or(w.weight))
345			.sum();
346
347		// Guard against zero total weight to prevent panic in random_range(0..0)
348		if total_weight == 0 {
349			return workers[0].clone();
350		}
351
352		let mut rng = rand::rng();
353		let mut random = rng.random_range(0..total_weight);
354		for worker in workers {
355			let weight = weights.get(&worker.id).copied().unwrap_or(worker.weight);
356			if random < weight {
357				return worker.clone();
358			}
359			random -= weight;
360		}
361
362		workers[0].clone()
363	}
364
365	/// Random selection
366	fn select_random(&self, workers: &[Arc<WorkerInfo>]) -> Arc<WorkerInfo> {
367		use rand::Rng;
368		let mut rng = rand::rng();
369		let index = rng.random_range(0..workers.len());
370		workers[index].clone()
371	}
372
373	/// Mark a task as completed on a worker
374	///
375	/// # Examples
376	///
377	/// ```rust
378	/// use reinhardt_tasks::{LoadBalancer, LoadBalancingStrategy, WorkerInfo};
379	///
380	/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
381	/// let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
382	/// balancer.register_worker(WorkerInfo::new("worker-1".to_string(), 1)).await?;
383	///
384	/// let worker_id = balancer.select_worker().await?;
385	/// balancer.task_completed(&worker_id).await?;
386	/// # Ok(())
387	/// # }
388	/// ```
389	pub async fn task_completed(&self, worker_id: &str) -> TaskResult<()> {
390		let workers = self.workers.read().await;
391		if let Some(worker) = workers.iter().find(|w| w.id == worker_id) {
392			worker.decrement_tasks();
393		}
394		Ok(())
395	}
396
397	/// Update worker metrics
398	///
399	/// # Examples
400	///
401	/// ```rust
402	/// use reinhardt_tasks::{LoadBalancer, LoadBalancingStrategy, WorkerInfo, WorkerMetrics};
403	/// use std::time::Duration;
404	///
405	/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
406	/// let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
407	/// balancer.register_worker(WorkerInfo::new("worker-1".to_string(), 1)).await?;
408	///
409	/// let metrics = WorkerMetrics::with_values(10, 1, Duration::from_millis(500));
410	/// balancer.update_metrics("worker-1", metrics).await?;
411	/// # Ok(())
412	/// # }
413	/// ```
414	pub async fn update_metrics(&self, worker_id: &str, metrics: WorkerMetrics) -> TaskResult<()> {
415		self.metrics
416			.write()
417			.await
418			.insert(worker_id.to_string(), metrics);
419		Ok(())
420	}
421
422	/// Get statistics for all workers
423	///
424	/// # Examples
425	///
426	/// ```rust
427	/// use reinhardt_tasks::{LoadBalancer, LoadBalancingStrategy, WorkerInfo};
428	///
429	/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
430	/// let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
431	/// balancer.register_worker(WorkerInfo::new("worker-1".to_string(), 1)).await?;
432	///
433	/// let stats = balancer.get_worker_stats().await;
434	/// assert_eq!(stats.len(), 1);
435	/// # Ok(())
436	/// # }
437	/// ```
438	pub async fn get_worker_stats(&self) -> HashMap<WorkerId, WorkerMetrics> {
439		self.metrics.read().await.clone()
440	}
441
442	/// Get active worker count
443	///
444	/// # Examples
445	///
446	/// ```rust
447	/// use reinhardt_tasks::{LoadBalancer, LoadBalancingStrategy, WorkerInfo};
448	///
449	/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
450	/// let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
451	/// balancer.register_worker(WorkerInfo::new("worker-1".to_string(), 1)).await?;
452	/// balancer.register_worker(WorkerInfo::new("worker-2".to_string(), 1)).await?;
453	///
454	/// assert_eq!(balancer.worker_count().await, 2);
455	/// # Ok(())
456	/// # }
457	/// ```
458	pub async fn worker_count(&self) -> usize {
459		self.workers.read().await.len()
460	}
461}
462
463#[cfg(test)]
464mod tests {
465	use super::*;
466	use rstest::rstest;
467	use std::time::Duration;
468
469	#[rstest]
470	#[tokio::test]
471	async fn test_worker_info_creation() {
472		// Arrange
473		let worker = WorkerInfo::new("worker-1".to_string(), 2);
474
475		// Assert
476		assert_eq!(worker.id, "worker-1");
477		assert_eq!(worker.weight, 2);
478		assert_eq!(worker.active_task_count(), 0);
479	}
480
481	#[rstest]
482	#[tokio::test]
483	async fn test_worker_info_task_count() {
484		// Arrange
485		let worker = WorkerInfo::new("worker-1".to_string(), 1);
486
487		// Act & Assert
488		worker.increment_tasks();
489		assert_eq!(worker.active_task_count(), 1);
490		worker.increment_tasks();
491		assert_eq!(worker.active_task_count(), 2);
492		worker.decrement_tasks();
493		assert_eq!(worker.active_task_count(), 1);
494	}
495
496	#[rstest]
497	#[tokio::test]
498	async fn test_worker_metrics_creation() {
499		// Arrange
500		let metrics = WorkerMetrics::new();
501
502		// Assert
503		assert_eq!(metrics.tasks_completed, 0);
504		assert_eq!(metrics.tasks_failed, 0);
505		assert_eq!(metrics.average_execution_time, Duration::from_secs(0));
506	}
507
508	#[rstest]
509	#[tokio::test]
510	async fn test_worker_metrics_record_success() {
511		// Arrange
512		let mut metrics = WorkerMetrics::new();
513
514		// Act
515		metrics.record_success(Duration::from_millis(100));
516
517		// Assert
518		assert_eq!(metrics.tasks_completed, 1);
519		assert_eq!(metrics.average_execution_time, Duration::from_millis(100));
520
521		// Act
522		metrics.record_success(Duration::from_millis(200));
523
524		// Assert
525		assert_eq!(metrics.tasks_completed, 2);
526		assert_eq!(metrics.average_execution_time, Duration::from_millis(150));
527	}
528
529	#[rstest]
530	#[tokio::test]
531	async fn test_worker_metrics_record_failure() {
532		// Arrange
533		let mut metrics = WorkerMetrics::new();
534
535		// Act
536		metrics.record_failure(Duration::from_millis(50));
537
538		// Assert
539		assert_eq!(metrics.tasks_failed, 1);
540		assert_eq!(metrics.average_execution_time, Duration::from_millis(50));
541	}
542
543	#[rstest]
544	#[tokio::test]
545	async fn test_load_balancer_creation() {
546		// Arrange
547		let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
548
549		// Assert
550		assert_eq!(balancer.worker_count().await, 0);
551	}
552
553	#[rstest]
554	#[tokio::test]
555	async fn test_load_balancer_register_worker() {
556		// Arrange
557		let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
558
559		// Act
560		balancer
561			.register_worker(WorkerInfo::new("worker-1".to_string(), 1))
562			.await
563			.unwrap();
564
565		// Assert
566		assert_eq!(balancer.worker_count().await, 1);
567	}
568
569	#[rstest]
570	#[tokio::test]
571	async fn test_load_balancer_unregister_worker() {
572		// Arrange
573		let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
574		balancer
575			.register_worker(WorkerInfo::new("worker-1".to_string(), 1))
576			.await
577			.unwrap();
578
579		// Act
580		balancer.unregister_worker("worker-1").await.unwrap();
581
582		// Assert
583		assert_eq!(balancer.worker_count().await, 0);
584	}
585
586	#[rstest]
587	#[tokio::test]
588	async fn test_round_robin_strategy() {
589		// Arrange
590		let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
591		balancer
592			.register_worker(WorkerInfo::new("worker-1".to_string(), 1))
593			.await
594			.unwrap();
595		balancer
596			.register_worker(WorkerInfo::new("worker-2".to_string(), 1))
597			.await
598			.unwrap();
599
600		// Act
601		let worker1 = balancer.select_worker().await.unwrap();
602		let worker2 = balancer.select_worker().await.unwrap();
603		let worker3 = balancer.select_worker().await.unwrap();
604
605		// Assert
606		assert_eq!(worker1, "worker-1");
607		assert_eq!(worker2, "worker-2");
608		assert_eq!(worker3, "worker-1");
609	}
610
611	#[rstest]
612	#[tokio::test]
613	async fn test_least_connections_strategy() {
614		// Arrange
615		let balancer = LoadBalancer::new(LoadBalancingStrategy::LeastConnections);
616		let worker1 = WorkerInfo::new("worker-1".to_string(), 1);
617		let worker2 = WorkerInfo::new("worker-2".to_string(), 1);
618
619		// Simulate worker-1 having more tasks
620		worker1.increment_tasks();
621		worker1.increment_tasks();
622
623		balancer.register_worker(worker1).await.unwrap();
624		balancer.register_worker(worker2).await.unwrap();
625
626		// Act - should select worker-2 as it has fewer tasks
627		let selected = balancer.select_worker().await.unwrap();
628
629		// Assert
630		assert_eq!(selected, "worker-2");
631	}
632
633	#[rstest]
634	#[tokio::test]
635	async fn test_weighted_strategy() {
636		// Arrange
637		let mut weights = HashMap::new();
638		weights.insert("worker-1".to_string(), 3);
639		weights.insert("worker-2".to_string(), 1);
640
641		let balancer = LoadBalancer::new(LoadBalancingStrategy::Weighted(weights));
642		balancer
643			.register_worker(WorkerInfo::new("worker-1".to_string(), 3))
644			.await
645			.unwrap();
646		balancer
647			.register_worker(WorkerInfo::new("worker-2".to_string(), 1))
648			.await
649			.unwrap();
650
651		// Act - run multiple selections
652		let mut worker1_count = 0;
653		let mut worker2_count = 0;
654
655		for _ in 0..100 {
656			let selected = balancer.select_worker().await.unwrap();
657			balancer.task_completed(&selected).await.unwrap();
658			if selected == "worker-1" {
659				worker1_count += 1;
660			} else {
661				worker2_count += 1;
662			}
663		}
664
665		// Assert - worker-1 should be selected approximately 3x more often
666		assert!(worker1_count > worker2_count);
667	}
668
669	#[rstest]
670	#[tokio::test]
671	async fn test_random_strategy() {
672		// Arrange
673		let balancer = LoadBalancer::new(LoadBalancingStrategy::Random);
674		balancer
675			.register_worker(WorkerInfo::new("worker-1".to_string(), 1))
676			.await
677			.unwrap();
678		balancer
679			.register_worker(WorkerInfo::new("worker-2".to_string(), 1))
680			.await
681			.unwrap();
682
683		// Act
684		let worker = balancer.select_worker().await.unwrap();
685
686		// Assert
687		assert!(worker == "worker-1" || worker == "worker-2");
688	}
689
690	#[rstest]
691	#[tokio::test]
692	async fn test_select_worker_no_workers() {
693		// Arrange
694		let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
695
696		// Act
697		let result = balancer.select_worker().await;
698
699		// Assert
700		assert!(result.is_err());
701	}
702
703	#[rstest]
704	#[tokio::test]
705	async fn test_task_completed() {
706		// Arrange
707		let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
708		balancer
709			.register_worker(WorkerInfo::new("worker-1".to_string(), 1))
710			.await
711			.unwrap();
712
713		let worker_id = balancer.select_worker().await.unwrap();
714		let workers = balancer.workers.read().await;
715		let worker = workers.iter().find(|w| w.id == worker_id).unwrap();
716		assert_eq!(worker.active_task_count(), 1);
717		drop(workers);
718
719		// Act
720		balancer.task_completed(&worker_id).await.unwrap();
721
722		// Assert
723		let workers = balancer.workers.read().await;
724		let worker = workers.iter().find(|w| w.id == worker_id).unwrap();
725		assert_eq!(worker.active_task_count(), 0);
726	}
727
728	#[rstest]
729	#[tokio::test]
730	async fn test_update_metrics() {
731		// Arrange
732		let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
733		balancer
734			.register_worker(WorkerInfo::new("worker-1".to_string(), 1))
735			.await
736			.unwrap();
737
738		let metrics = WorkerMetrics::with_values(10, 2, Duration::from_millis(500));
739
740		// Act
741		balancer
742			.update_metrics("worker-1", metrics.clone())
743			.await
744			.unwrap();
745
746		// Assert
747		let stats = balancer.get_worker_stats().await;
748		let worker_metrics = stats.get("worker-1").unwrap();
749		assert_eq!(worker_metrics.tasks_completed, 10);
750		assert_eq!(worker_metrics.tasks_failed, 2);
751		assert_eq!(
752			worker_metrics.average_execution_time,
753			Duration::from_millis(500)
754		);
755	}
756
757	#[rstest]
758	#[tokio::test]
759	async fn test_get_worker_stats() {
760		// Arrange
761		let balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
762		balancer
763			.register_worker(WorkerInfo::new("worker-1".to_string(), 1))
764			.await
765			.unwrap();
766		balancer
767			.register_worker(WorkerInfo::new("worker-2".to_string(), 1))
768			.await
769			.unwrap();
770
771		// Act
772		let stats = balancer.get_worker_stats().await;
773
774		// Assert
775		assert_eq!(stats.len(), 2);
776		assert!(stats.contains_key("worker-1"));
777		assert!(stats.contains_key("worker-2"));
778	}
779
780	#[rstest]
781	#[tokio::test]
782	async fn test_decrement_tasks_at_zero_does_not_underflow() {
783		// Arrange
784		let worker = WorkerInfo::new("worker-1".to_string(), 1);
785		assert_eq!(worker.active_task_count(), 0);
786
787		// Act - decrement at 0 should saturate, not wrap to usize::MAX
788		worker.decrement_tasks();
789
790		// Assert
791		assert_eq!(worker.active_task_count(), 0);
792	}
793
794	#[rstest]
795	#[tokio::test]
796	async fn test_decrement_tasks_multiple_times_at_zero_stays_at_zero() {
797		// Arrange
798		let worker = WorkerInfo::new("worker-1".to_string(), 1);
799		worker.increment_tasks();
800		worker.decrement_tasks();
801		assert_eq!(worker.active_task_count(), 0);
802
803		// Act - multiple decrements below zero should all saturate at 0
804		worker.decrement_tasks();
805		worker.decrement_tasks();
806		worker.decrement_tasks();
807
808		// Assert
809		assert_eq!(worker.active_task_count(), 0);
810	}
811
812	#[rstest]
813	#[tokio::test]
814	async fn test_weighted_strategy_zero_total_weight_does_not_panic() {
815		// Arrange
816		let mut weights = HashMap::new();
817		weights.insert("worker-1".to_string(), 0);
818		weights.insert("worker-2".to_string(), 0);
819
820		let balancer = LoadBalancer::new(LoadBalancingStrategy::Weighted(weights));
821		balancer
822			.register_worker(WorkerInfo::new("worker-1".to_string(), 0))
823			.await
824			.unwrap();
825		balancer
826			.register_worker(WorkerInfo::new("worker-2".to_string(), 0))
827			.await
828			.unwrap();
829
830		// Act - should not panic, returns first worker as fallback
831		let selected = balancer.select_worker().await.unwrap();
832
833		// Assert
834		assert!(selected == "worker-1" || selected == "worker-2");
835	}
836
837	#[rstest]
838	#[tokio::test]
839	async fn test_update_execution_time_does_not_overflow() {
840		// Arrange
841		let mut metrics = WorkerMetrics::new();
842		metrics.tasks_completed = u64::MAX - 1;
843		metrics.average_execution_time = Duration::from_millis(u64::MAX);
844
845		// Act - should not overflow or panic
846		metrics.record_success(Duration::from_millis(1000));
847
848		// Assert - tasks_completed wraps via addition (that's expected for the counter)
849		// but the average_execution_time calculation should not panic
850		assert!(metrics.tasks_completed > 0);
851	}
852
853	#[rstest]
854	#[tokio::test]
855	async fn test_update_execution_time_saturates_at_u64_max() {
856		// Arrange
857		let mut metrics = WorkerMetrics::new();
858		metrics.tasks_completed = 1;
859		metrics.average_execution_time = Duration::from_millis(u64::MAX);
860
861		// Act - saturating arithmetic should clamp instead of overflowing
862		metrics.record_success(Duration::from_millis(u64::MAX));
863
864		// Assert - the result should be clamped to u64::MAX milliseconds
865		assert_eq!(
866			metrics.average_execution_time,
867			Duration::from_millis(u64::MAX)
868		);
869	}
870}