memberlist_plumtree/
pooled_transport.rs

1//! Pooled transport with connection management and concurrency control.
2//!
3//! This module provides a transport wrapper that adds:
4//!
5//! - **Concurrency limiting**: Prevent overwhelming individual peers
6//! - **Request queuing**: Buffer messages when at capacity
7//! - **Fair scheduling**: Ensure all peers get fair access
8//! - **Statistics**: Track pool usage and performance
9//!
10//! # Example
11//!
12//! ```ignore
13//! use memberlist_plumtree::{PooledTransport, PoolConfig, Transport};
14//!
15//! let inner = MyTransport::new();
16//! let pooled = PooledTransport::new(inner, PoolConfig::default());
17//!
18//! // Use like any other transport
19//! pooled.send_to(&peer_id, data).await?;
20//!
21//! // Check pool statistics
22//! let stats = pooled.stats();
23//! println!("Active: {}, Queued: {}", stats.active_sends, stats.queued_sends);
24//! ```
25
26use std::{
27    collections::HashMap,
28    fmt::Debug,
29    hash::Hash,
30    sync::{
31        atomic::{AtomicU64, Ordering},
32        Arc,
33    },
34};
35
36use async_lock::Semaphore;
37use bytes::Bytes;
38use parking_lot::Mutex;
39
40use crate::Transport;
41
42/// Configuration for the pooled transport.
43#[derive(Debug, Clone)]
44pub struct PoolConfig {
45    /// Maximum concurrent sends per peer.
46    pub max_concurrent_per_peer: usize,
47
48    /// Global maximum concurrent sends across all peers.
49    pub max_concurrent_global: usize,
50
51    /// Maximum pending sends in queue per peer before dropping.
52    pub max_queue_per_peer: usize,
53
54    /// Whether to enable fair scheduling across peers.
55    pub fair_scheduling: bool,
56}
57
58impl Default for PoolConfig {
59    fn default() -> Self {
60        Self {
61            max_concurrent_per_peer: 8,
62            max_concurrent_global: 256,
63            max_queue_per_peer: 64,
64            fair_scheduling: true,
65        }
66    }
67}
68
69impl PoolConfig {
70    /// Create a new pool configuration with defaults.
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    /// Configuration for high-throughput scenarios.
76    pub fn high_throughput() -> Self {
77        Self {
78            max_concurrent_per_peer: 16,
79            max_concurrent_global: 512,
80            max_queue_per_peer: 128,
81            fair_scheduling: true,
82        }
83    }
84
85    /// Configuration for low-latency scenarios.
86    pub fn low_latency() -> Self {
87        Self {
88            max_concurrent_per_peer: 4,
89            max_concurrent_global: 128,
90            max_queue_per_peer: 32,
91            fair_scheduling: false,
92        }
93    }
94
95    /// Configuration for large clusters.
96    pub fn large_cluster() -> Self {
97        Self {
98            max_concurrent_per_peer: 4,
99            max_concurrent_global: 1024,
100            max_queue_per_peer: 32,
101            fair_scheduling: true,
102        }
103    }
104
105    /// Set max concurrent per peer (builder pattern).
106    pub const fn with_max_concurrent_per_peer(mut self, max: usize) -> Self {
107        self.max_concurrent_per_peer = max;
108        self
109    }
110
111    /// Set global max concurrent (builder pattern).
112    pub const fn with_max_concurrent_global(mut self, max: usize) -> Self {
113        self.max_concurrent_global = max;
114        self
115    }
116
117    /// Set max queue per peer (builder pattern).
118    pub const fn with_max_queue_per_peer(mut self, max: usize) -> Self {
119        self.max_queue_per_peer = max;
120        self
121    }
122}
123
124/// Statistics for the pooled transport.
125#[derive(Debug, Clone, Default)]
126pub struct PoolStats {
127    /// Total messages sent successfully.
128    pub messages_sent: u64,
129
130    /// Total messages dropped due to queue overflow.
131    pub messages_dropped: u64,
132
133    /// Current number of active sends.
134    pub active_sends: u64,
135
136    /// Current number of queued sends (waiting).
137    pub queued_sends: u64,
138
139    /// Total send errors encountered.
140    pub send_errors: u64,
141
142    /// Number of peers with active connections.
143    pub active_peers: usize,
144
145    /// Peak concurrent sends observed.
146    pub peak_concurrent: u64,
147}
148
149/// Per-peer state in the pool.
150#[derive(Debug)]
151struct PeerState {
152    /// Semaphore for limiting concurrent sends to this peer.
153    semaphore: Arc<Semaphore>,
154    /// Current queue depth for this peer.
155    queue_depth: AtomicU64,
156    /// Messages sent to this peer.
157    messages_sent: AtomicU64,
158    /// Messages dropped for this peer.
159    messages_dropped: AtomicU64,
160}
161
162impl PeerState {
163    fn new(max_concurrent: usize) -> Self {
164        Self {
165            semaphore: Arc::new(Semaphore::new(max_concurrent)),
166            queue_depth: AtomicU64::new(0),
167            messages_sent: AtomicU64::new(0),
168            messages_dropped: AtomicU64::new(0),
169        }
170    }
171}
172
173/// Error type for pooled transport.
174#[derive(Debug)]
175pub enum PooledTransportError<E> {
176    /// Queue is full, message was dropped.
177    QueueFull,
178    /// Underlying transport error.
179    Transport(E),
180}
181
182impl<E: std::fmt::Display> std::fmt::Display for PooledTransportError<E> {
183    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        match self {
185            PooledTransportError::QueueFull => write!(f, "pool queue full, message dropped"),
186            PooledTransportError::Transport(e) => write!(f, "transport error: {}", e),
187        }
188    }
189}
190
191impl<E: std::error::Error + 'static> std::error::Error for PooledTransportError<E> {
192    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
193        match self {
194            PooledTransportError::QueueFull => None,
195            PooledTransportError::Transport(e) => Some(e),
196        }
197    }
198}
199
200/// Pooled transport wrapper that provides connection management.
201#[derive(Debug)]
202pub struct PooledTransport<T, I> {
203    /// Inner transport.
204    inner: T,
205    /// Configuration.
206    config: PoolConfig,
207    /// Per-peer state.
208    peers: Mutex<HashMap<I, Arc<PeerState>>>,
209    /// Global concurrency semaphore.
210    global_semaphore: Arc<Semaphore>,
211    /// Statistics.
212    stats: PoolStatsInner,
213}
214
215/// Internal statistics tracking.
216#[derive(Debug, Default)]
217struct PoolStatsInner {
218    messages_sent: AtomicU64,
219    messages_dropped: AtomicU64,
220    active_sends: AtomicU64,
221    send_errors: AtomicU64,
222    peak_concurrent: AtomicU64,
223}
224
225impl<T, I> PooledTransport<T, I>
226where
227    I: Clone + Eq + Hash + Debug + Send + Sync + 'static,
228    T: Transport<I>,
229{
230    /// Create a new pooled transport wrapping the given inner transport.
231    pub fn new(inner: T, config: PoolConfig) -> Self {
232        let global_semaphore = Arc::new(Semaphore::new(config.max_concurrent_global));
233        Self {
234            inner,
235            config,
236            peers: Mutex::new(HashMap::new()),
237            global_semaphore,
238            stats: PoolStatsInner::default(),
239        }
240    }
241
242    /// Create with default configuration.
243    pub fn with_defaults(inner: T) -> Self {
244        Self::new(inner, PoolConfig::default())
245    }
246
247    /// Get or create peer state.
248    fn get_peer_state(&self, peer: &I) -> Arc<PeerState> {
249        let mut peers = self.peers.lock();
250        peers
251            .entry(peer.clone())
252            .or_insert_with(|| Arc::new(PeerState::new(self.config.max_concurrent_per_peer)))
253            .clone()
254    }
255
256    /// Get current statistics.
257    pub fn stats(&self) -> PoolStats {
258        let peers = self.peers.lock();
259        let queued: u64 = peers
260            .values()
261            .map(|p| p.queue_depth.load(Ordering::Relaxed))
262            .sum();
263
264        PoolStats {
265            messages_sent: self.stats.messages_sent.load(Ordering::Relaxed),
266            messages_dropped: self.stats.messages_dropped.load(Ordering::Relaxed),
267            active_sends: self.stats.active_sends.load(Ordering::Relaxed),
268            queued_sends: queued,
269            send_errors: self.stats.send_errors.load(Ordering::Relaxed),
270            active_peers: peers.len(),
271            peak_concurrent: self.stats.peak_concurrent.load(Ordering::Relaxed),
272        }
273    }
274
275    /// Reset statistics.
276    pub fn reset_stats(&self) {
277        self.stats.messages_sent.store(0, Ordering::Relaxed);
278        self.stats.messages_dropped.store(0, Ordering::Relaxed);
279        self.stats.send_errors.store(0, Ordering::Relaxed);
280        self.stats.peak_concurrent.store(0, Ordering::Relaxed);
281    }
282
283    /// Remove a peer from the pool.
284    ///
285    /// Call this when a peer disconnects to clean up resources.
286    pub fn remove_peer(&self, peer: &I) {
287        let mut peers = self.peers.lock();
288        peers.remove(peer);
289    }
290
291    /// Clear all peer state.
292    pub fn clear(&self) {
293        let mut peers = self.peers.lock();
294        peers.clear();
295    }
296
297    /// Get the configuration.
298    pub fn config(&self) -> &PoolConfig {
299        &self.config
300    }
301
302    /// Get a reference to the inner transport.
303    pub fn inner(&self) -> &T {
304        &self.inner
305    }
306
307    /// Send a message to a peer with pooling and concurrency control.
308    pub async fn send_to(
309        &self,
310        target: &I,
311        data: Bytes,
312    ) -> Result<(), PooledTransportError<T::Error>> {
313        let peer_state = self.get_peer_state(target);
314
315        // Check queue depth
316        let current_queue = peer_state.queue_depth.fetch_add(1, Ordering::Relaxed);
317        if current_queue >= self.config.max_queue_per_peer as u64 {
318            peer_state.queue_depth.fetch_sub(1, Ordering::Relaxed);
319            peer_state.messages_dropped.fetch_add(1, Ordering::Relaxed);
320            self.stats.messages_dropped.fetch_add(1, Ordering::Relaxed);
321            return Err(PooledTransportError::QueueFull);
322        }
323
324        // Acquire permits (peer + global)
325        // In fair scheduling mode, we acquire global first to prevent per-peer starvation
326        let (global_permit, peer_permit) = if self.config.fair_scheduling {
327            let global = self.global_semaphore.acquire_arc().await;
328            let peer = peer_state.semaphore.acquire_arc().await;
329            (global, peer)
330        } else {
331            let peer = peer_state.semaphore.acquire_arc().await;
332            let global = self.global_semaphore.acquire_arc().await;
333            (global, peer)
334        };
335
336        // Update active count and track peak
337        let active = self.stats.active_sends.fetch_add(1, Ordering::Relaxed) + 1;
338        self.stats
339            .peak_concurrent
340            .fetch_max(active, Ordering::Relaxed);
341
342        // Decrement queue depth now that we're active
343        peer_state.queue_depth.fetch_sub(1, Ordering::Relaxed);
344
345        // Send the message
346        let result = self.inner.send_to(target, data).await;
347
348        // Release permits and update stats
349        drop(peer_permit);
350        drop(global_permit);
351        self.stats.active_sends.fetch_sub(1, Ordering::Relaxed);
352
353        match result {
354            Ok(()) => {
355                self.stats.messages_sent.fetch_add(1, Ordering::Relaxed);
356                peer_state.messages_sent.fetch_add(1, Ordering::Relaxed);
357                Ok(())
358            }
359            Err(e) => {
360                self.stats.send_errors.fetch_add(1, Ordering::Relaxed);
361                Err(PooledTransportError::Transport(e))
362            }
363        }
364    }
365}
366
367impl<T, I> Clone for PooledTransport<T, I>
368where
369    T: Clone,
370{
371    fn clone(&self) -> Self {
372        Self {
373            inner: self.inner.clone(),
374            config: self.config.clone(),
375            peers: Mutex::new(HashMap::new()),
376            global_semaphore: Arc::new(Semaphore::new(self.config.max_concurrent_global)),
377            stats: PoolStatsInner::default(),
378        }
379    }
380}
381
382// Implement Transport for PooledTransport so it can be used as a drop-in replacement
383impl<T, I> Transport<I> for PooledTransport<T, I>
384where
385    I: Clone + Eq + Hash + Debug + Send + Sync + 'static,
386    T: Transport<I>,
387{
388    type Error = PooledTransportError<T::Error>;
389
390    async fn send_to(&self, target: &I, data: Bytes) -> Result<(), Self::Error> {
391        PooledTransport::send_to(self, target, data).await
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use crate::transport::{ChannelTransport, NoopTransport};
399
400    #[test]
401    fn test_pool_config_defaults() {
402        let config = PoolConfig::default();
403        assert_eq!(config.max_concurrent_per_peer, 8);
404        assert_eq!(config.max_concurrent_global, 256);
405    }
406
407    #[test]
408    fn test_pool_config_presets() {
409        let high = PoolConfig::high_throughput();
410        assert!(high.max_concurrent_per_peer > PoolConfig::default().max_concurrent_per_peer);
411
412        let low = PoolConfig::low_latency();
413        assert!(low.max_concurrent_per_peer < PoolConfig::default().max_concurrent_per_peer);
414    }
415
416    #[tokio::test]
417    async fn test_pooled_transport_basic() {
418        let (inner, rx) = ChannelTransport::<u64>::bounded(16);
419        let pooled = PooledTransport::with_defaults(inner);
420
421        pooled.send_to(&42u64, Bytes::from("hello")).await.unwrap();
422
423        let (target, data) = rx.recv().await.unwrap();
424        assert_eq!(target, 42);
425        assert_eq!(data, Bytes::from("hello"));
426
427        let stats = pooled.stats();
428        assert_eq!(stats.messages_sent, 1);
429        assert_eq!(stats.active_peers, 1);
430    }
431
432    #[tokio::test]
433    async fn test_pooled_transport_stats() {
434        let inner = NoopTransport;
435        let pooled = PooledTransport::with_defaults(inner);
436
437        for i in 0..10u64 {
438            pooled.send_to(&(i % 3), Bytes::from("test")).await.unwrap();
439        }
440
441        let stats = pooled.stats();
442        assert_eq!(stats.messages_sent, 10);
443        assert_eq!(stats.active_peers, 3); // 0, 1, 2
444        assert_eq!(stats.send_errors, 0);
445    }
446
447    #[tokio::test]
448    async fn test_pooled_transport_queue_full() {
449        let inner = NoopTransport;
450        let config = PoolConfig::default().with_max_queue_per_peer(0);
451        let pooled = PooledTransport::new(inner, config);
452
453        // First send should work (no queue, direct send)
454        // But with max_queue_per_peer=0, even the first will fail
455        let result = pooled.send_to(&1u64, Bytes::from("test")).await;
456
457        // With 0 queue capacity, should get QueueFull
458        assert!(matches!(result, Err(PooledTransportError::QueueFull)));
459
460        let stats = pooled.stats();
461        assert_eq!(stats.messages_dropped, 1);
462    }
463
464    #[tokio::test]
465    async fn test_pooled_transport_remove_peer() {
466        let inner = NoopTransport;
467        let pooled = PooledTransport::with_defaults(inner);
468
469        pooled.send_to(&1u64, Bytes::from("test")).await.unwrap();
470        pooled.send_to(&2u64, Bytes::from("test")).await.unwrap();
471
472        let stats = pooled.stats();
473        assert_eq!(stats.active_peers, 2);
474
475        pooled.remove_peer(&1u64);
476
477        let stats = pooled.stats();
478        assert_eq!(stats.active_peers, 1);
479    }
480
481    #[tokio::test]
482    async fn test_pooled_transport_clear() {
483        let inner = NoopTransport;
484        let pooled = PooledTransport::with_defaults(inner);
485
486        pooled.send_to(&1u64, Bytes::from("test")).await.unwrap();
487        pooled.send_to(&2u64, Bytes::from("test")).await.unwrap();
488        pooled.send_to(&3u64, Bytes::from("test")).await.unwrap();
489
490        pooled.clear();
491
492        let stats = pooled.stats();
493        assert_eq!(stats.active_peers, 0);
494    }
495
496    #[tokio::test]
497    async fn test_pooled_transport_concurrent() {
498        use std::sync::Arc;
499
500        let inner = NoopTransport;
501        let config = PoolConfig::default()
502            .with_max_concurrent_per_peer(2)
503            .with_max_concurrent_global(4);
504        let pooled = Arc::new(PooledTransport::new(inner, config));
505
506        // Spawn multiple concurrent sends
507        let mut handles = vec![];
508        for i in 0..10u64 {
509            let p = pooled.clone();
510            handles.push(tokio::spawn(async move {
511                p.send_to(&(i % 2), Bytes::from("test")).await.unwrap();
512            }));
513        }
514
515        for h in handles {
516            h.await.unwrap();
517        }
518
519        let stats = pooled.stats();
520        assert_eq!(stats.messages_sent, 10);
521        // Peak should be limited by global max
522        assert!(stats.peak_concurrent <= 4);
523    }
524
525    #[test]
526    fn test_pool_stats_default() {
527        let stats = PoolStats::default();
528        assert_eq!(stats.messages_sent, 0);
529        assert_eq!(stats.messages_dropped, 0);
530        assert_eq!(stats.active_sends, 0);
531    }
532
533    #[tokio::test]
534    async fn test_pooled_transport_reset_stats() {
535        let inner = NoopTransport;
536        let pooled = PooledTransport::with_defaults(inner);
537
538        pooled.send_to(&1u64, Bytes::from("test")).await.unwrap();
539
540        let stats = pooled.stats();
541        assert_eq!(stats.messages_sent, 1);
542
543        pooled.reset_stats();
544
545        let stats = pooled.stats();
546        assert_eq!(stats.messages_sent, 0);
547    }
548}