1use crate::balancer::{BalancingStrategy, EndpointId, LoadBalancer};
7use crate::circuit_breaker::CircuitBreaker;
8use crate::error::{NetError, NetResult};
9use async_trait::async_trait;
10use parking_lot::RwLock;
11use std::collections::VecDeque;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::time;
15use tonic::transport::{Channel, Endpoint};
16
17#[derive(Debug, Clone)]
19pub struct PoolConfig {
20 pub min_size: usize,
22 pub max_size: usize,
24 pub idle_timeout: Duration,
26 pub max_lifetime: Duration,
28 pub connect_timeout: Duration,
30 pub health_check_interval: Duration,
32 pub balancing_strategy: BalancingStrategy,
34 pub enable_circuit_breaker: bool,
36}
37
38impl Default for PoolConfig {
39 fn default() -> Self {
40 Self {
41 min_size: 2,
42 max_size: 10,
43 idle_timeout: Duration::from_secs(300), max_lifetime: Duration::from_secs(1800), connect_timeout: Duration::from_secs(10),
46 health_check_interval: Duration::from_secs(30),
47 balancing_strategy: BalancingStrategy::LeastConnections,
48 enable_circuit_breaker: true,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Default)]
55pub struct PoolStats {
56 pub total_connections: usize,
58 pub active_connections: usize,
60 pub idle_connections: usize,
62 pub failed_connections: u64,
64 pub total_created: u64,
66 pub total_closed: u64,
68 pub pool_exhausted_count: u64,
70 pub avg_wait_time_ms: u64,
72}
73
74#[derive(Debug)]
76struct ConnectionMeta {
77 channel: Channel,
79 endpoint_id: EndpointId,
81 created_at: Instant,
83 last_used: Instant,
85}
86
87impl ConnectionMeta {
88 fn new(channel: Channel, endpoint_id: EndpointId) -> Self {
90 let now = Instant::now();
91 Self {
92 channel,
93 endpoint_id,
94 created_at: now,
95 last_used: now,
96 }
97 }
98
99 fn is_idle_expired(&self, idle_timeout: Duration) -> bool {
101 self.last_used.elapsed() > idle_timeout
102 }
103
104 fn is_lifetime_expired(&self, max_lifetime: Duration) -> bool {
106 self.created_at.elapsed() > max_lifetime
107 }
108
109 fn touch(&mut self) {
111 self.last_used = Instant::now();
112 }
113}
114
115pub struct PooledConnection {
117 meta: Option<ConnectionMeta>,
118 pool: Arc<ConnectionPoolInner>,
119}
120
121impl PooledConnection {
122 pub fn channel(&self) -> &Channel {
124 &self.meta.as_ref().expect("connection should exist").channel
125 }
126
127 pub fn endpoint_id(&self) -> &str {
129 &self
130 .meta
131 .as_ref()
132 .expect("connection should exist")
133 .endpoint_id
134 }
135}
136
137impl Drop for PooledConnection {
138 fn drop(&mut self) {
139 if let Some(mut meta) = self.meta.take() {
140 meta.touch();
141 self.pool.return_connection(meta);
142 }
143 }
144}
145
146struct ConnectionPoolInner {
148 config: PoolConfig,
149 idle_connections: RwLock<VecDeque<ConnectionMeta>>,
150 active_count: std::sync::Mutex<usize>,
151 stats: RwLock<PoolStats>,
152 load_balancer: LoadBalancer,
153 circuit_breaker: Option<CircuitBreaker>,
154}
155
156impl ConnectionPoolInner {
157 fn return_connection(&self, meta: ConnectionMeta) {
159 if meta.is_idle_expired(self.config.idle_timeout)
161 || meta.is_lifetime_expired(self.config.max_lifetime)
162 {
163 self.stats.write().total_closed += 1;
165 let mut active = self
166 .active_count
167 .lock()
168 .expect("active count lock poisoned");
169 *active = active.saturating_sub(1);
170 return;
171 }
172
173 self.idle_connections.write().push_back(meta);
175 let mut active = self
176 .active_count
177 .lock()
178 .expect("active count lock poisoned");
179 *active = active.saturating_sub(1);
180 }
181
182 fn get_stats(&self) -> PoolStats {
184 let mut stats = self.stats.read().clone();
185 let idle = self.idle_connections.read().len();
186 let active = *self
187 .active_count
188 .lock()
189 .expect("active count lock poisoned");
190 stats.total_connections = idle + active;
191 stats.active_connections = active;
192 stats.idle_connections = idle;
193 stats
194 }
195}
196
197pub struct ConnectionPool {
199 inner: Arc<ConnectionPoolInner>,
200 shutdown_tx: tokio::sync::watch::Sender<bool>,
201}
202
203impl ConnectionPool {
204 pub fn new(config: PoolConfig) -> Self {
206 let load_balancer = LoadBalancer::new(config.balancing_strategy);
207 let circuit_breaker = if config.enable_circuit_breaker {
208 Some(CircuitBreaker::new())
209 } else {
210 None
211 };
212
213 let inner = Arc::new(ConnectionPoolInner {
214 config: config.clone(),
215 idle_connections: RwLock::new(VecDeque::new()),
216 active_count: std::sync::Mutex::new(0),
217 stats: RwLock::new(PoolStats::default()),
218 load_balancer,
219 circuit_breaker,
220 });
221
222 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
223
224 let health_check_inner = Arc::clone(&inner);
226 tokio::spawn(async move {
227 Self::health_check_loop(health_check_inner, shutdown_rx).await;
228 });
229
230 Self { inner, shutdown_tx }
231 }
232
233 pub fn add_endpoint(&self, id: EndpointId, address: String) {
235 self.add_endpoint_with_weight(id, address, 1);
236 }
237
238 pub fn add_endpoint_with_weight(&self, id: EndpointId, address: String, weight: u32) {
240 let endpoint = crate::balancer::Endpoint::with_weight(id, address, weight);
241 self.inner.load_balancer.add_endpoint(endpoint);
242 }
243
244 pub fn remove_endpoint(&self, endpoint_id: &str) -> bool {
246 let removed = self.inner.load_balancer.remove_endpoint(endpoint_id);
248
249 if removed {
251 let mut idle = self.inner.idle_connections.write();
252 idle.retain(|conn| conn.endpoint_id != endpoint_id);
253 }
254
255 removed
256 }
257
258 pub async fn get_connection(&self) -> NetResult<PooledConnection> {
260 let start = Instant::now();
261
262 if let Some(ref cb) = self.inner.circuit_breaker {
264 cb.is_request_allowed()?;
265 }
266
267 if let Some(mut meta) = self.inner.idle_connections.write().pop_front() {
269 meta.touch();
270 *self
271 .inner
272 .active_count
273 .lock()
274 .expect("active count lock poisoned") += 1;
275
276 return Ok(PooledConnection {
277 meta: Some(meta),
278 pool: Arc::clone(&self.inner),
279 });
280 }
281
282 let active = *self
284 .inner
285 .active_count
286 .lock()
287 .expect("active count lock poisoned");
288 let idle = self.inner.idle_connections.read().len();
289
290 if active + idle >= self.inner.config.max_size {
291 self.inner.stats.write().pool_exhausted_count += 1;
293
294 let timeout = Duration::from_secs(30);
296 let deadline = Instant::now() + timeout;
297
298 while Instant::now() < deadline {
299 if let Some(mut meta) = self.inner.idle_connections.write().pop_front() {
300 meta.touch();
301 *self
302 .inner
303 .active_count
304 .lock()
305 .expect("active count lock poisoned") += 1;
306
307 let wait_time = start.elapsed().as_millis() as u64;
309 let mut stats = self.inner.stats.write();
310 stats.avg_wait_time_ms = (stats.avg_wait_time_ms + wait_time) / 2;
311
312 return Ok(PooledConnection {
313 meta: Some(meta),
314 pool: Arc::clone(&self.inner),
315 });
316 }
317
318 time::sleep(Duration::from_millis(100)).await;
320 }
321
322 return Err(NetError::ServerOverloaded(
323 "Connection pool exhausted".to_string(),
324 ));
325 }
326
327 let meta = self.create_connection().await?;
329 *self
330 .inner
331 .active_count
332 .lock()
333 .expect("active count lock poisoned") += 1;
334
335 Ok(PooledConnection {
336 meta: Some(meta),
337 pool: Arc::clone(&self.inner),
338 })
339 }
340
341 async fn create_connection(&self) -> NetResult<ConnectionMeta> {
343 let endpoint = self.inner.load_balancer.select_endpoint()?;
345
346 let channel = Endpoint::from_shared(format!("http://{}", endpoint.address))
348 .map_err(|e| NetError::InvalidRequest(format!("Invalid endpoint: {}", e)))?
349 .connect_timeout(self.inner.config.connect_timeout)
350 .timeout(Duration::from_secs(30))
351 .connect()
352 .await
353 .map_err(|e| {
354 self.inner.stats.write().failed_connections += 1;
355 if let Some(ref cb) = self.inner.circuit_breaker {
356 cb.record_failure();
357 }
358 NetError::ConnectionRefused(format!("Failed to connect: {}", e))
359 })?;
360
361 if let Some(ref cb) = self.inner.circuit_breaker {
363 cb.record_success();
364 }
365
366 self.inner.stats.write().total_created += 1;
367
368 Ok(ConnectionMeta::new(channel, endpoint.id.clone()))
369 }
370
371 async fn health_check_loop(
373 inner: Arc<ConnectionPoolInner>,
374 mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
375 ) {
376 let mut interval = time::interval(inner.config.health_check_interval);
377
378 loop {
379 tokio::select! {
380 _ = interval.tick() => {
381 Self::perform_health_check(&inner).await;
382 }
383 _ = shutdown_rx.changed() => {
384 if *shutdown_rx.borrow() {
385 break;
386 }
387 }
388 }
389 }
390 }
391
392 async fn perform_health_check(inner: &Arc<ConnectionPoolInner>) {
394 let needed = {
395 let mut idle = inner.idle_connections.write();
397 let config = &inner.config;
398
399 idle.retain(|conn| {
401 !conn.is_idle_expired(config.idle_timeout)
402 && !conn.is_lifetime_expired(config.max_lifetime)
403 });
404
405 let current_size = idle.len()
407 + *inner
408 .active_count
409 .lock()
410 .expect("active count lock poisoned");
411 config.min_size.saturating_sub(current_size)
412 }; for _ in 0..needed {
416 let _ = async {
419 }
421 .await;
422 }
423 }
424
425 pub fn stats(&self) -> PoolStats {
427 self.inner.get_stats()
428 }
429
430 pub fn circuit_breaker_stats(&self) -> Option<crate::circuit_breaker::CircuitBreakerStats> {
432 self.inner.circuit_breaker.as_ref().map(|cb| cb.stats())
433 }
434
435 pub async fn shutdown(self) -> NetResult<()> {
437 self.shutdown_tx
439 .send(true)
440 .map_err(|_| NetError::ServerInternal("Failed to signal shutdown".to_string()))?;
441
442 time::sleep(Duration::from_millis(500)).await;
444
445 let mut idle = self.inner.idle_connections.write();
447 let count = idle.len();
448 idle.clear();
449
450 self.inner.stats.write().total_closed += count as u64;
451
452 Ok(())
453 }
454
455 pub async fn drain(&self) -> NetResult<()> {
457 let timeout = Duration::from_secs(30);
459 let deadline = Instant::now() + timeout;
460
461 while Instant::now() < deadline {
462 let active = *self
463 .inner
464 .active_count
465 .lock()
466 .expect("active count lock poisoned");
467 if active == 0 {
468 break;
469 }
470 time::sleep(Duration::from_millis(100)).await;
471 }
472
473 let active = *self
474 .inner
475 .active_count
476 .lock()
477 .expect("active count lock poisoned");
478 if active > 0 {
479 return Err(NetError::Timeout(format!(
480 "Drain timeout: {} active connections remaining",
481 active
482 )));
483 }
484
485 Ok(())
486 }
487}
488
489pub struct ConnectionPoolBuilder {
491 config: PoolConfig,
492 endpoints: Vec<(EndpointId, String, u32)>,
493}
494
495impl ConnectionPoolBuilder {
496 pub fn new() -> Self {
498 Self {
499 config: PoolConfig::default(),
500 endpoints: Vec::new(),
501 }
502 }
503
504 pub fn min_size(mut self, size: usize) -> Self {
506 self.config.min_size = size;
507 self
508 }
509
510 pub fn max_size(mut self, size: usize) -> Self {
512 self.config.max_size = size;
513 self
514 }
515
516 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
518 self.config.idle_timeout = timeout;
519 self
520 }
521
522 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
524 self.config.max_lifetime = lifetime;
525 self
526 }
527
528 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
530 self.config.connect_timeout = timeout;
531 self
532 }
533
534 pub fn health_check_interval(mut self, interval: Duration) -> Self {
536 self.config.health_check_interval = interval;
537 self
538 }
539
540 pub fn balancing_strategy(mut self, strategy: BalancingStrategy) -> Self {
542 self.config.balancing_strategy = strategy;
543 self
544 }
545
546 pub fn circuit_breaker(mut self, enabled: bool) -> Self {
548 self.config.enable_circuit_breaker = enabled;
549 self
550 }
551
552 pub fn add_endpoint(mut self, id: EndpointId, address: String) -> Self {
554 self.endpoints.push((id, address, 1));
555 self
556 }
557
558 pub fn add_endpoint_with_weight(
560 mut self,
561 id: EndpointId,
562 address: String,
563 weight: u32,
564 ) -> Self {
565 self.endpoints.push((id, address, weight));
566 self
567 }
568
569 pub fn build(self) -> ConnectionPool {
571 let pool = ConnectionPool::new(self.config);
572
573 for (id, address, weight) in self.endpoints {
574 pool.add_endpoint_with_weight(id, address, weight);
575 }
576
577 pool
578 }
579}
580
581impl Default for ConnectionPoolBuilder {
582 fn default() -> Self {
583 Self::new()
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
592 fn test_pool_config_default() {
593 let config = PoolConfig::default();
594 assert_eq!(config.min_size, 2);
595 assert_eq!(config.max_size, 10);
596 assert!(config.enable_circuit_breaker);
597 }
598
599 #[tokio::test]
600 async fn test_connection_meta_expiry() {
601 let endpoint = Endpoint::from_static("http://localhost:50051");
603 if let Ok(channel) = endpoint.connect().await {
604 let meta = ConnectionMeta::new(channel, "ep1".to_string());
605
606 assert!(!meta.is_idle_expired(Duration::from_secs(10)));
607 assert!(!meta.is_lifetime_expired(Duration::from_secs(10)));
608 }
609 }
611
612 #[tokio::test]
613 async fn test_pool_builder() {
614 let pool = ConnectionPoolBuilder::new()
615 .min_size(5)
616 .max_size(20)
617 .idle_timeout(Duration::from_secs(600))
618 .balancing_strategy(BalancingStrategy::RoundRobin)
619 .add_endpoint("ep1".to_string(), "localhost:50051".to_string())
620 .add_endpoint("ep2".to_string(), "localhost:50052".to_string())
621 .build();
622
623 let stats = pool.stats();
624 assert_eq!(stats.active_connections, 0);
625 assert_eq!(stats.idle_connections, 0);
626 }
627
628 #[tokio::test]
629 async fn test_pool_add_remove_endpoint() {
630 let pool = ConnectionPool::new(PoolConfig::default());
631
632 pool.add_endpoint("ep1".to_string(), "localhost:50051".to_string());
633 pool.add_endpoint("ep2".to_string(), "localhost:50052".to_string());
634
635 assert!(pool.remove_endpoint("ep1"));
636 assert!(!pool.remove_endpoint("ep3"));
637 }
638
639 #[tokio::test]
640 async fn test_pool_stats() {
641 let pool = ConnectionPool::new(PoolConfig::default());
642 pool.add_endpoint("ep1".to_string(), "localhost:50051".to_string());
643
644 let stats = pool.stats();
645 assert_eq!(stats.total_connections, 0);
646 assert_eq!(stats.active_connections, 0);
647 assert_eq!(stats.idle_connections, 0);
648 }
649
650 #[tokio::test]
651 async fn test_pool_shutdown() {
652 let pool = ConnectionPool::new(PoolConfig::default());
653 pool.add_endpoint("ep1".to_string(), "localhost:50051".to_string());
654
655 let result = pool.shutdown().await;
657 assert!(result.is_ok());
658 }
659}