1use crate::error::{NetError, NetResult};
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10
11pub type EndpointId = String;
13
14pub type Weight = u32;
16
17#[derive(Debug, Clone)]
19pub struct Endpoint {
20 pub id: EndpointId,
22 pub address: String,
24 pub weight: Weight,
26 pub active_connections: Arc<AtomicUsize>,
28 pub total_requests: Arc<AtomicU64>,
30 pub healthy: Arc<parking_lot::RwLock<bool>>,
32}
33
34impl Endpoint {
35 pub fn new(id: EndpointId, address: String) -> Self {
37 Self::with_weight(id, address, 1)
38 }
39
40 pub fn with_weight(id: EndpointId, address: String, weight: Weight) -> Self {
42 Self {
43 id,
44 address,
45 weight,
46 active_connections: Arc::new(AtomicUsize::new(0)),
47 total_requests: Arc::new(AtomicU64::new(0)),
48 healthy: Arc::new(parking_lot::RwLock::new(true)),
49 }
50 }
51
52 pub fn is_healthy(&self) -> bool {
54 *self.healthy.read()
55 }
56
57 pub fn mark_healthy(&self) {
59 *self.healthy.write() = true;
60 }
61
62 pub fn mark_unhealthy(&self) {
64 *self.healthy.write() = false;
65 }
66
67 pub fn active_connections(&self) -> usize {
69 self.active_connections.load(Ordering::Relaxed)
70 }
71
72 pub fn increment_connections(&self) {
74 self.active_connections.fetch_add(1, Ordering::Relaxed);
75 self.total_requests.fetch_add(1, Ordering::Relaxed);
76 }
77
78 pub fn decrement_connections(&self) {
80 self.active_connections.fetch_sub(1, Ordering::Relaxed);
81 }
82
83 pub fn total_requests(&self) -> u64 {
85 self.total_requests.load(Ordering::Relaxed)
86 }
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum BalancingStrategy {
92 RoundRobin,
94 LeastConnections,
96 Weighted,
98}
99
100#[derive(Debug, Clone)]
102pub struct Affinity {
103 sessions: Arc<RwLock<HashMap<String, EndpointId>>>,
105}
106
107impl Affinity {
108 pub fn new() -> Self {
110 Self {
111 sessions: Arc::new(RwLock::new(HashMap::new())),
112 }
113 }
114
115 pub fn get(&self, session_id: &str) -> Option<EndpointId> {
117 self.sessions.read().get(session_id).cloned()
118 }
119
120 pub fn set(&self, session_id: String, endpoint_id: EndpointId) {
122 self.sessions.write().insert(session_id, endpoint_id);
123 }
124
125 pub fn remove(&self, session_id: &str) {
127 self.sessions.write().remove(session_id);
128 }
129
130 pub fn clear(&self) {
132 self.sessions.write().clear();
133 }
134}
135
136impl Default for Affinity {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142#[derive(Debug)]
144pub struct LoadBalancer {
145 strategy: BalancingStrategy,
147 endpoints: Arc<RwLock<Vec<Arc<Endpoint>>>>,
149 round_robin_index: AtomicUsize,
151 affinity: Affinity,
153}
154
155impl LoadBalancer {
156 pub fn new(strategy: BalancingStrategy) -> Self {
158 Self {
159 strategy,
160 endpoints: Arc::new(RwLock::new(Vec::new())),
161 round_robin_index: AtomicUsize::new(0),
162 affinity: Affinity::new(),
163 }
164 }
165
166 pub fn add_endpoint(&self, endpoint: Endpoint) {
168 self.endpoints.write().push(Arc::new(endpoint));
169 }
170
171 pub fn remove_endpoint(&self, endpoint_id: &str) -> bool {
173 let mut endpoints = self.endpoints.write();
174 if let Some(pos) = endpoints.iter().position(|e| e.id == endpoint_id) {
175 endpoints.remove(pos);
176 true
177 } else {
178 false
179 }
180 }
181
182 pub fn endpoints(&self) -> Vec<Arc<Endpoint>> {
184 self.endpoints.read().clone()
185 }
186
187 pub fn healthy_endpoints(&self) -> Vec<Arc<Endpoint>> {
189 self.endpoints
190 .read()
191 .iter()
192 .filter(|e| e.is_healthy())
193 .cloned()
194 .collect()
195 }
196
197 pub fn select_endpoint(&self) -> NetResult<Arc<Endpoint>> {
199 let healthy_endpoints = self.healthy_endpoints();
200
201 if healthy_endpoints.is_empty() {
202 return Err(NetError::ServerUnavailable(
203 "No healthy endpoints available".to_string(),
204 ));
205 }
206
207 match self.strategy {
208 BalancingStrategy::RoundRobin => self.select_round_robin(&healthy_endpoints),
209 BalancingStrategy::LeastConnections => {
210 self.select_least_connections(&healthy_endpoints)
211 }
212 BalancingStrategy::Weighted => self.select_weighted(&healthy_endpoints),
213 }
214 }
215
216 pub fn select_with_affinity(&self, session_id: &str) -> NetResult<Arc<Endpoint>> {
218 if let Some(endpoint_id) = self.affinity.get(session_id) {
220 if let Some(endpoint) = self
222 .healthy_endpoints()
223 .iter()
224 .find(|e| e.id == endpoint_id)
225 {
226 return Ok(Arc::clone(endpoint));
227 }
228 }
229
230 let endpoint = self.select_endpoint()?;
232 self.affinity
233 .set(session_id.to_string(), endpoint.id.clone());
234 Ok(endpoint)
235 }
236
237 pub fn clear_affinity(&self, session_id: &str) {
239 self.affinity.remove(session_id);
240 }
241
242 pub fn stats(&self) -> BalancerStats {
244 let endpoints = self.endpoints.read();
245 let total_endpoints = endpoints.len();
246 let healthy_endpoints = endpoints.iter().filter(|e| e.is_healthy()).count();
247 let total_connections: usize = endpoints.iter().map(|e| e.active_connections()).sum();
248 let total_requests: u64 = endpoints.iter().map(|e| e.total_requests()).sum();
249
250 BalancerStats {
251 total_endpoints,
252 healthy_endpoints,
253 total_connections,
254 total_requests,
255 strategy: self.strategy,
256 }
257 }
258
259 fn select_round_robin(&self, endpoints: &[Arc<Endpoint>]) -> NetResult<Arc<Endpoint>> {
261 if endpoints.is_empty() {
262 return Err(NetError::ServerUnavailable(
263 "No endpoints available".to_string(),
264 ));
265 }
266
267 let index = self.round_robin_index.fetch_add(1, Ordering::Relaxed);
268 let endpoint = &endpoints[index % endpoints.len()];
269 Ok(Arc::clone(endpoint))
270 }
271
272 fn select_least_connections(&self, endpoints: &[Arc<Endpoint>]) -> NetResult<Arc<Endpoint>> {
274 endpoints
275 .iter()
276 .min_by_key(|e| e.active_connections())
277 .map(Arc::clone)
278 .ok_or_else(|| NetError::ServerUnavailable("No endpoints available".to_string()))
279 }
280
281 fn select_weighted(&self, endpoints: &[Arc<Endpoint>]) -> NetResult<Arc<Endpoint>> {
283 if endpoints.is_empty() {
284 return Err(NetError::ServerUnavailable(
285 "No endpoints available".to_string(),
286 ));
287 }
288
289 let total_weight: u32 = endpoints.iter().map(|e| e.weight).sum();
291
292 if total_weight == 0 {
293 return self.select_round_robin(endpoints);
295 }
296
297 let selector = self.round_robin_index.fetch_add(1, Ordering::Relaxed) as u32;
299 let target = selector % total_weight;
300
301 let mut cumulative = 0u32;
303 for endpoint in endpoints {
304 cumulative += endpoint.weight;
305 if target < cumulative {
306 return Ok(Arc::clone(endpoint));
307 }
308 }
309
310 Ok(Arc::clone(&endpoints[endpoints.len() - 1]))
312 }
313}
314
315#[derive(Debug, Clone)]
317pub struct BalancerStats {
318 pub total_endpoints: usize,
320 pub healthy_endpoints: usize,
322 pub total_connections: usize,
324 pub total_requests: u64,
326 pub strategy: BalancingStrategy,
328}
329
330pub struct ConnectionGuard {
332 endpoint: Arc<Endpoint>,
333}
334
335impl ConnectionGuard {
336 pub fn new(endpoint: Arc<Endpoint>) -> Self {
338 endpoint.increment_connections();
339 Self { endpoint }
340 }
341
342 pub fn endpoint(&self) -> &Arc<Endpoint> {
344 &self.endpoint
345 }
346}
347
348impl Drop for ConnectionGuard {
349 fn drop(&mut self) {
350 self.endpoint.decrement_connections();
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[test]
359 fn test_endpoint_creation() {
360 let endpoint = Endpoint::new("ep1".to_string(), "localhost:50051".to_string());
361 assert_eq!(endpoint.id, "ep1");
362 assert_eq!(endpoint.address, "localhost:50051");
363 assert_eq!(endpoint.weight, 1);
364 assert!(endpoint.is_healthy());
365 }
366
367 #[test]
368 fn test_endpoint_health() {
369 let endpoint = Endpoint::new("ep1".to_string(), "localhost:50051".to_string());
370 assert!(endpoint.is_healthy());
371
372 endpoint.mark_unhealthy();
373 assert!(!endpoint.is_healthy());
374
375 endpoint.mark_healthy();
376 assert!(endpoint.is_healthy());
377 }
378
379 #[test]
380 fn test_endpoint_connections() {
381 let endpoint = Endpoint::new("ep1".to_string(), "localhost:50051".to_string());
382 assert_eq!(endpoint.active_connections(), 0);
383
384 endpoint.increment_connections();
385 assert_eq!(endpoint.active_connections(), 1);
386
387 endpoint.increment_connections();
388 assert_eq!(endpoint.active_connections(), 2);
389
390 endpoint.decrement_connections();
391 assert_eq!(endpoint.active_connections(), 1);
392 }
393
394 #[test]
395 fn test_load_balancer_round_robin() {
396 let lb = LoadBalancer::new(BalancingStrategy::RoundRobin);
397
398 lb.add_endpoint(Endpoint::new(
399 "ep1".to_string(),
400 "localhost:50051".to_string(),
401 ));
402 lb.add_endpoint(Endpoint::new(
403 "ep2".to_string(),
404 "localhost:50052".to_string(),
405 ));
406 lb.add_endpoint(Endpoint::new(
407 "ep3".to_string(),
408 "localhost:50053".to_string(),
409 ));
410
411 let ep1 = lb.select_endpoint().expect("should select endpoint");
413 let ep2 = lb.select_endpoint().expect("should select endpoint");
414 let ep3 = lb.select_endpoint().expect("should select endpoint");
415 let ep4 = lb.select_endpoint().expect("should select endpoint");
416
417 assert_eq!(ep1.id, "ep1");
418 assert_eq!(ep2.id, "ep2");
419 assert_eq!(ep3.id, "ep3");
420 assert_eq!(ep4.id, "ep1"); }
422
423 #[test]
424 fn test_load_balancer_least_connections() {
425 let lb = LoadBalancer::new(BalancingStrategy::LeastConnections);
426
427 lb.add_endpoint(Endpoint::new(
428 "ep1".to_string(),
429 "localhost:50051".to_string(),
430 ));
431 lb.add_endpoint(Endpoint::new(
432 "ep2".to_string(),
433 "localhost:50052".to_string(),
434 ));
435
436 let ep1 = lb.select_endpoint().expect("should select endpoint");
438 ep1.increment_connections();
439
440 let ep2 = lb.select_endpoint().expect("should select endpoint");
442 assert_eq!(ep2.id, "ep2");
443
444 ep2.increment_connections();
445 ep2.increment_connections(); let ep3 = lb.select_endpoint().expect("should select endpoint");
449 assert_eq!(ep3.id, "ep1");
450 }
451
452 #[test]
453 fn test_load_balancer_weighted() {
454 let lb = LoadBalancer::new(BalancingStrategy::Weighted);
455
456 lb.add_endpoint(Endpoint::with_weight(
457 "ep1".to_string(),
458 "localhost:50051".to_string(),
459 3,
460 ));
461 lb.add_endpoint(Endpoint::with_weight(
462 "ep2".to_string(),
463 "localhost:50052".to_string(),
464 1,
465 ));
466
467 let mut counts = HashMap::new();
469 for _ in 0..40 {
470 let ep = lb.select_endpoint().expect("should select endpoint");
471 *counts.entry(ep.id.clone()).or_insert(0) += 1;
472 }
473
474 let ep1_count = counts.get("ep1").copied().unwrap_or(0);
476 let ep2_count = counts.get("ep2").copied().unwrap_or(0);
477
478 assert!(ep1_count > ep2_count);
480 assert!(ep1_count >= 20); }
482
483 #[test]
484 fn test_load_balancer_no_endpoints() {
485 let lb = LoadBalancer::new(BalancingStrategy::RoundRobin);
486 let result = lb.select_endpoint();
487 assert!(result.is_err());
488 }
489
490 #[test]
491 fn test_load_balancer_unhealthy_endpoints() {
492 let lb = LoadBalancer::new(BalancingStrategy::RoundRobin);
493
494 let ep1 = Endpoint::new("ep1".to_string(), "localhost:50051".to_string());
495 let ep2 = Endpoint::new("ep2".to_string(), "localhost:50052".to_string());
496
497 ep1.mark_unhealthy();
498
499 lb.add_endpoint(ep1);
500 lb.add_endpoint(ep2);
501
502 for _ in 0..5 {
504 let ep = lb.select_endpoint().expect("should select endpoint");
505 assert_eq!(ep.id, "ep2");
506 }
507 }
508
509 #[test]
510 fn test_load_balancer_affinity() {
511 let lb = LoadBalancer::new(BalancingStrategy::RoundRobin);
512
513 lb.add_endpoint(Endpoint::new(
514 "ep1".to_string(),
515 "localhost:50051".to_string(),
516 ));
517 lb.add_endpoint(Endpoint::new(
518 "ep2".to_string(),
519 "localhost:50052".to_string(),
520 ));
521
522 let session_id = "session123";
523
524 let ep1 = lb
526 .select_with_affinity(session_id)
527 .expect("should select endpoint");
528
529 let ep2 = lb
531 .select_with_affinity(session_id)
532 .expect("should select endpoint");
533 let ep3 = lb
534 .select_with_affinity(session_id)
535 .expect("should select endpoint");
536
537 assert_eq!(ep1.id, ep2.id);
538 assert_eq!(ep2.id, ep3.id);
539
540 lb.clear_affinity(session_id);
542
543 let _ep4 = lb
545 .select_with_affinity(session_id)
546 .expect("should select endpoint");
547 }
548
549 #[test]
550 fn test_load_balancer_remove_endpoint() {
551 let lb = LoadBalancer::new(BalancingStrategy::RoundRobin);
552
553 lb.add_endpoint(Endpoint::new(
554 "ep1".to_string(),
555 "localhost:50051".to_string(),
556 ));
557 lb.add_endpoint(Endpoint::new(
558 "ep2".to_string(),
559 "localhost:50052".to_string(),
560 ));
561
562 assert_eq!(lb.endpoints().len(), 2);
563
564 lb.remove_endpoint("ep1");
565 assert_eq!(lb.endpoints().len(), 1);
566
567 let ep = lb.select_endpoint().expect("should select endpoint");
568 assert_eq!(ep.id, "ep2");
569 }
570
571 #[test]
572 fn test_load_balancer_stats() {
573 let lb = LoadBalancer::new(BalancingStrategy::LeastConnections);
574
575 lb.add_endpoint(Endpoint::new(
576 "ep1".to_string(),
577 "localhost:50051".to_string(),
578 ));
579 lb.add_endpoint(Endpoint::new(
580 "ep2".to_string(),
581 "localhost:50052".to_string(),
582 ));
583
584 let stats = lb.stats();
585 assert_eq!(stats.total_endpoints, 2);
586 assert_eq!(stats.healthy_endpoints, 2);
587 assert_eq!(stats.total_connections, 0);
588 assert_eq!(stats.strategy, BalancingStrategy::LeastConnections);
589 }
590
591 #[test]
592 fn test_connection_guard() {
593 let endpoint = Arc::new(Endpoint::new(
594 "ep1".to_string(),
595 "localhost:50051".to_string(),
596 ));
597
598 assert_eq!(endpoint.active_connections(), 0);
599
600 {
601 let _guard = ConnectionGuard::new(Arc::clone(&endpoint));
602 assert_eq!(endpoint.active_connections(), 1);
603 }
604
605 assert_eq!(endpoint.active_connections(), 0);
607 }
608
609 #[test]
610 fn test_affinity() {
611 let affinity = Affinity::new();
612
613 affinity.set("session1".to_string(), "ep1".to_string());
614 affinity.set("session2".to_string(), "ep2".to_string());
615
616 assert_eq!(affinity.get("session1"), Some("ep1".to_string()));
617 assert_eq!(affinity.get("session2"), Some("ep2".to_string()));
618 assert_eq!(affinity.get("session3"), None);
619
620 affinity.remove("session1");
621 assert_eq!(affinity.get("session1"), None);
622
623 affinity.clear();
624 assert_eq!(affinity.get("session2"), None);
625 }
626}