1use std::{
27 collections::HashMap,
28 fmt::Debug,
29 hash::Hash,
30 sync::{
31 atomic::{AtomicU64, Ordering},
32 Arc,
33 },
34};
35
36use async_lock::Semaphore;
37use bytes::Bytes;
38use parking_lot::Mutex;
39
40use crate::Transport;
41
42#[derive(Debug, Clone)]
44pub struct PoolConfig {
45 pub max_concurrent_per_peer: usize,
47
48 pub max_concurrent_global: usize,
50
51 pub max_queue_per_peer: usize,
53
54 pub fair_scheduling: bool,
56}
57
58impl Default for PoolConfig {
59 fn default() -> Self {
60 Self {
61 max_concurrent_per_peer: 8,
62 max_concurrent_global: 256,
63 max_queue_per_peer: 64,
64 fair_scheduling: true,
65 }
66 }
67}
68
69impl PoolConfig {
70 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub fn high_throughput() -> Self {
77 Self {
78 max_concurrent_per_peer: 16,
79 max_concurrent_global: 512,
80 max_queue_per_peer: 128,
81 fair_scheduling: true,
82 }
83 }
84
85 pub fn low_latency() -> Self {
87 Self {
88 max_concurrent_per_peer: 4,
89 max_concurrent_global: 128,
90 max_queue_per_peer: 32,
91 fair_scheduling: false,
92 }
93 }
94
95 pub fn large_cluster() -> Self {
97 Self {
98 max_concurrent_per_peer: 4,
99 max_concurrent_global: 1024,
100 max_queue_per_peer: 32,
101 fair_scheduling: true,
102 }
103 }
104
105 pub const fn with_max_concurrent_per_peer(mut self, max: usize) -> Self {
107 self.max_concurrent_per_peer = max;
108 self
109 }
110
111 pub const fn with_max_concurrent_global(mut self, max: usize) -> Self {
113 self.max_concurrent_global = max;
114 self
115 }
116
117 pub const fn with_max_queue_per_peer(mut self, max: usize) -> Self {
119 self.max_queue_per_peer = max;
120 self
121 }
122}
123
124#[derive(Debug, Clone, Default)]
126pub struct PoolStats {
127 pub messages_sent: u64,
129
130 pub messages_dropped: u64,
132
133 pub active_sends: u64,
135
136 pub queued_sends: u64,
138
139 pub send_errors: u64,
141
142 pub active_peers: usize,
144
145 pub peak_concurrent: u64,
147}
148
149#[derive(Debug)]
151struct PeerState {
152 semaphore: Arc<Semaphore>,
154 queue_depth: AtomicU64,
156 messages_sent: AtomicU64,
158 messages_dropped: AtomicU64,
160}
161
162impl PeerState {
163 fn new(max_concurrent: usize) -> Self {
164 Self {
165 semaphore: Arc::new(Semaphore::new(max_concurrent)),
166 queue_depth: AtomicU64::new(0),
167 messages_sent: AtomicU64::new(0),
168 messages_dropped: AtomicU64::new(0),
169 }
170 }
171}
172
173#[derive(Debug)]
175pub enum PooledTransportError<E> {
176 QueueFull,
178 Transport(E),
180}
181
182impl<E: std::fmt::Display> std::fmt::Display for PooledTransportError<E> {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 match self {
185 PooledTransportError::QueueFull => write!(f, "pool queue full, message dropped"),
186 PooledTransportError::Transport(e) => write!(f, "transport error: {}", e),
187 }
188 }
189}
190
191impl<E: std::error::Error + 'static> std::error::Error for PooledTransportError<E> {
192 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
193 match self {
194 PooledTransportError::QueueFull => None,
195 PooledTransportError::Transport(e) => Some(e),
196 }
197 }
198}
199
200#[derive(Debug)]
202pub struct PooledTransport<T, I> {
203 inner: T,
205 config: PoolConfig,
207 peers: Mutex<HashMap<I, Arc<PeerState>>>,
209 global_semaphore: Arc<Semaphore>,
211 stats: PoolStatsInner,
213}
214
215#[derive(Debug, Default)]
217struct PoolStatsInner {
218 messages_sent: AtomicU64,
219 messages_dropped: AtomicU64,
220 active_sends: AtomicU64,
221 send_errors: AtomicU64,
222 peak_concurrent: AtomicU64,
223}
224
225impl<T, I> PooledTransport<T, I>
226where
227 I: Clone + Eq + Hash + Debug + Send + Sync + 'static,
228 T: Transport<I>,
229{
230 pub fn new(inner: T, config: PoolConfig) -> Self {
232 let global_semaphore = Arc::new(Semaphore::new(config.max_concurrent_global));
233 Self {
234 inner,
235 config,
236 peers: Mutex::new(HashMap::new()),
237 global_semaphore,
238 stats: PoolStatsInner::default(),
239 }
240 }
241
242 pub fn with_defaults(inner: T) -> Self {
244 Self::new(inner, PoolConfig::default())
245 }
246
247 fn get_peer_state(&self, peer: &I) -> Arc<PeerState> {
249 let mut peers = self.peers.lock();
250 peers
251 .entry(peer.clone())
252 .or_insert_with(|| Arc::new(PeerState::new(self.config.max_concurrent_per_peer)))
253 .clone()
254 }
255
256 pub fn stats(&self) -> PoolStats {
258 let peers = self.peers.lock();
259 let queued: u64 = peers
260 .values()
261 .map(|p| p.queue_depth.load(Ordering::Relaxed))
262 .sum();
263
264 PoolStats {
265 messages_sent: self.stats.messages_sent.load(Ordering::Relaxed),
266 messages_dropped: self.stats.messages_dropped.load(Ordering::Relaxed),
267 active_sends: self.stats.active_sends.load(Ordering::Relaxed),
268 queued_sends: queued,
269 send_errors: self.stats.send_errors.load(Ordering::Relaxed),
270 active_peers: peers.len(),
271 peak_concurrent: self.stats.peak_concurrent.load(Ordering::Relaxed),
272 }
273 }
274
275 pub fn reset_stats(&self) {
277 self.stats.messages_sent.store(0, Ordering::Relaxed);
278 self.stats.messages_dropped.store(0, Ordering::Relaxed);
279 self.stats.send_errors.store(0, Ordering::Relaxed);
280 self.stats.peak_concurrent.store(0, Ordering::Relaxed);
281 }
282
283 pub fn remove_peer(&self, peer: &I) {
287 let mut peers = self.peers.lock();
288 peers.remove(peer);
289 }
290
291 pub fn clear(&self) {
293 let mut peers = self.peers.lock();
294 peers.clear();
295 }
296
297 pub fn config(&self) -> &PoolConfig {
299 &self.config
300 }
301
302 pub fn inner(&self) -> &T {
304 &self.inner
305 }
306
307 pub async fn send_to(
309 &self,
310 target: &I,
311 data: Bytes,
312 ) -> Result<(), PooledTransportError<T::Error>> {
313 let peer_state = self.get_peer_state(target);
314
315 let current_queue = peer_state.queue_depth.fetch_add(1, Ordering::Relaxed);
317 if current_queue >= self.config.max_queue_per_peer as u64 {
318 peer_state.queue_depth.fetch_sub(1, Ordering::Relaxed);
319 peer_state.messages_dropped.fetch_add(1, Ordering::Relaxed);
320 self.stats.messages_dropped.fetch_add(1, Ordering::Relaxed);
321 return Err(PooledTransportError::QueueFull);
322 }
323
324 let (global_permit, peer_permit) = if self.config.fair_scheduling {
327 let global = self.global_semaphore.acquire_arc().await;
328 let peer = peer_state.semaphore.acquire_arc().await;
329 (global, peer)
330 } else {
331 let peer = peer_state.semaphore.acquire_arc().await;
332 let global = self.global_semaphore.acquire_arc().await;
333 (global, peer)
334 };
335
336 let active = self.stats.active_sends.fetch_add(1, Ordering::Relaxed) + 1;
338 self.stats
339 .peak_concurrent
340 .fetch_max(active, Ordering::Relaxed);
341
342 peer_state.queue_depth.fetch_sub(1, Ordering::Relaxed);
344
345 let result = self.inner.send_to(target, data).await;
347
348 drop(peer_permit);
350 drop(global_permit);
351 self.stats.active_sends.fetch_sub(1, Ordering::Relaxed);
352
353 match result {
354 Ok(()) => {
355 self.stats.messages_sent.fetch_add(1, Ordering::Relaxed);
356 peer_state.messages_sent.fetch_add(1, Ordering::Relaxed);
357 Ok(())
358 }
359 Err(e) => {
360 self.stats.send_errors.fetch_add(1, Ordering::Relaxed);
361 Err(PooledTransportError::Transport(e))
362 }
363 }
364 }
365}
366
367impl<T, I> Clone for PooledTransport<T, I>
368where
369 T: Clone,
370{
371 fn clone(&self) -> Self {
372 Self {
373 inner: self.inner.clone(),
374 config: self.config.clone(),
375 peers: Mutex::new(HashMap::new()),
376 global_semaphore: Arc::new(Semaphore::new(self.config.max_concurrent_global)),
377 stats: PoolStatsInner::default(),
378 }
379 }
380}
381
382impl<T, I> Transport<I> for PooledTransport<T, I>
384where
385 I: Clone + Eq + Hash + Debug + Send + Sync + 'static,
386 T: Transport<I>,
387{
388 type Error = PooledTransportError<T::Error>;
389
390 async fn send_to(&self, target: &I, data: Bytes) -> Result<(), Self::Error> {
391 PooledTransport::send_to(self, target, data).await
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use crate::transport::{ChannelTransport, NoopTransport};
399
400 #[test]
401 fn test_pool_config_defaults() {
402 let config = PoolConfig::default();
403 assert_eq!(config.max_concurrent_per_peer, 8);
404 assert_eq!(config.max_concurrent_global, 256);
405 }
406
407 #[test]
408 fn test_pool_config_presets() {
409 let high = PoolConfig::high_throughput();
410 assert!(high.max_concurrent_per_peer > PoolConfig::default().max_concurrent_per_peer);
411
412 let low = PoolConfig::low_latency();
413 assert!(low.max_concurrent_per_peer < PoolConfig::default().max_concurrent_per_peer);
414 }
415
416 #[tokio::test]
417 async fn test_pooled_transport_basic() {
418 let (inner, rx) = ChannelTransport::<u64>::bounded(16);
419 let pooled = PooledTransport::with_defaults(inner);
420
421 pooled.send_to(&42u64, Bytes::from("hello")).await.unwrap();
422
423 let (target, data) = rx.recv().await.unwrap();
424 assert_eq!(target, 42);
425 assert_eq!(data, Bytes::from("hello"));
426
427 let stats = pooled.stats();
428 assert_eq!(stats.messages_sent, 1);
429 assert_eq!(stats.active_peers, 1);
430 }
431
432 #[tokio::test]
433 async fn test_pooled_transport_stats() {
434 let inner = NoopTransport;
435 let pooled = PooledTransport::with_defaults(inner);
436
437 for i in 0..10u64 {
438 pooled.send_to(&(i % 3), Bytes::from("test")).await.unwrap();
439 }
440
441 let stats = pooled.stats();
442 assert_eq!(stats.messages_sent, 10);
443 assert_eq!(stats.active_peers, 3); assert_eq!(stats.send_errors, 0);
445 }
446
447 #[tokio::test]
448 async fn test_pooled_transport_queue_full() {
449 let inner = NoopTransport;
450 let config = PoolConfig::default().with_max_queue_per_peer(0);
451 let pooled = PooledTransport::new(inner, config);
452
453 let result = pooled.send_to(&1u64, Bytes::from("test")).await;
456
457 assert!(matches!(result, Err(PooledTransportError::QueueFull)));
459
460 let stats = pooled.stats();
461 assert_eq!(stats.messages_dropped, 1);
462 }
463
464 #[tokio::test]
465 async fn test_pooled_transport_remove_peer() {
466 let inner = NoopTransport;
467 let pooled = PooledTransport::with_defaults(inner);
468
469 pooled.send_to(&1u64, Bytes::from("test")).await.unwrap();
470 pooled.send_to(&2u64, Bytes::from("test")).await.unwrap();
471
472 let stats = pooled.stats();
473 assert_eq!(stats.active_peers, 2);
474
475 pooled.remove_peer(&1u64);
476
477 let stats = pooled.stats();
478 assert_eq!(stats.active_peers, 1);
479 }
480
481 #[tokio::test]
482 async fn test_pooled_transport_clear() {
483 let inner = NoopTransport;
484 let pooled = PooledTransport::with_defaults(inner);
485
486 pooled.send_to(&1u64, Bytes::from("test")).await.unwrap();
487 pooled.send_to(&2u64, Bytes::from("test")).await.unwrap();
488 pooled.send_to(&3u64, Bytes::from("test")).await.unwrap();
489
490 pooled.clear();
491
492 let stats = pooled.stats();
493 assert_eq!(stats.active_peers, 0);
494 }
495
496 #[tokio::test]
497 async fn test_pooled_transport_concurrent() {
498 use std::sync::Arc;
499
500 let inner = NoopTransport;
501 let config = PoolConfig::default()
502 .with_max_concurrent_per_peer(2)
503 .with_max_concurrent_global(4);
504 let pooled = Arc::new(PooledTransport::new(inner, config));
505
506 let mut handles = vec![];
508 for i in 0..10u64 {
509 let p = pooled.clone();
510 handles.push(tokio::spawn(async move {
511 p.send_to(&(i % 2), Bytes::from("test")).await.unwrap();
512 }));
513 }
514
515 for h in handles {
516 h.await.unwrap();
517 }
518
519 let stats = pooled.stats();
520 assert_eq!(stats.messages_sent, 10);
521 assert!(stats.peak_concurrent <= 4);
523 }
524
525 #[test]
526 fn test_pool_stats_default() {
527 let stats = PoolStats::default();
528 assert_eq!(stats.messages_sent, 0);
529 assert_eq!(stats.messages_dropped, 0);
530 assert_eq!(stats.active_sends, 0);
531 }
532
533 #[tokio::test]
534 async fn test_pooled_transport_reset_stats() {
535 let inner = NoopTransport;
536 let pooled = PooledTransport::with_defaults(inner);
537
538 pooled.send_to(&1u64, Bytes::from("test")).await.unwrap();
539
540 let stats = pooled.stats();
541 assert_eq!(stats.messages_sent, 1);
542
543 pooled.reset_stats();
544
545 let stats = pooled.stats();
546 assert_eq!(stats.messages_sent, 0);
547 }
548}