Skip to main content

apfsds_transport/
pool.rs

1//! Connection pool for WebSocket connections
2
3use parking_lot::RwLock;
4use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
5use thiserror::Error;
6use tracing::{debug, info, warn};
7
8use crate::{WssClient, WssClientConfig, WssClientError};
9
10#[derive(Error, Debug)]
11pub enum PoolError {
12    #[error("Pool exhausted")]
13    PoolExhausted,
14
15    #[error("Connection failed: {0}")]
16    ConnectionFailed(#[from] WssClientError),
17
18    #[error("Pool is closed")]
19    PoolClosed,
20}
21
22/// Connection pool configuration
23#[derive(Debug, Clone)]
24pub struct ConnectionPoolConfig {
25    /// Number of connections to maintain
26    pub pool_size: usize,
27
28    /// Server endpoints (will round-robin)
29    pub endpoints: Vec<String>,
30
31    /// Authorization token
32    pub token: Option<String>,
33
34    /// Reconnect on failure
35    pub auto_reconnect: bool,
36}
37
38impl Default for ConnectionPoolConfig {
39    fn default() -> Self {
40        Self {
41            pool_size: 6,
42            endpoints: Vec::new(),
43            token: None,
44            auto_reconnect: true,
45        }
46    }
47}
48
49/// A managed connection in the pool
50pub struct PooledConnection {
51    client: WssClient,
52    endpoint: String,
53    id: usize,
54}
55
56impl PooledConnection {
57    pub fn client(&self) -> &WssClient {
58        &self.client
59    }
60
61    pub fn client_mut(&mut self) -> &mut WssClient {
62        &mut self.client
63    }
64
65    pub fn id(&self) -> usize {
66        self.id
67    }
68}
69
70/// Connection pool for WebSocket connections
71pub struct ConnectionPool {
72    config: ConnectionPoolConfig,
73    connections: Vec<RwLock<Option<WssClient>>>,
74    robin_counter: AtomicUsize,
75    closed: AtomicBool,
76}
77
78impl ConnectionPool {
79    /// Create a new connection pool
80    pub fn new(config: ConnectionPoolConfig) -> Self {
81        let mut connections = Vec::with_capacity(config.pool_size);
82        for _ in 0..config.pool_size {
83            connections.push(RwLock::new(None));
84        }
85
86        Self {
87            config,
88            connections,
89            robin_counter: AtomicUsize::new(0),
90            closed: AtomicBool::new(false),
91        }
92    }
93
94    /// Initialize all connections
95    pub async fn connect_all(&self) -> Result<(), PoolError> {
96        if self.config.endpoints.is_empty() {
97            return Err(PoolError::ConnectionFailed(WssClientError::InvalidUrl(
98                "No endpoints configured".to_string(),
99            )));
100        }
101
102        for i in 0..self.config.pool_size {
103            let endpoint = &self.config.endpoints[i % self.config.endpoints.len()];
104            self.connect_slot(i, endpoint).await?;
105        }
106
107        info!(
108            "Connection pool initialized with {} connections",
109            self.config.pool_size
110        );
111
112        Ok(())
113    }
114
115    /// Connect a specific slot
116    async fn connect_slot(&self, slot: usize, endpoint: &str) -> Result<(), PoolError> {
117        let config = WssClientConfig {
118            url: endpoint.to_string(),
119            token: self.config.token.clone(),
120            ..Default::default()
121        };
122
123        let mut client = WssClient::connect(config).await?;
124
125        // Send initial frames
126        client.send_initial_frames().await?;
127
128        let mut guard = self.connections[slot].write();
129        *guard = Some(client);
130
131        debug!("Connected slot {} to {}", slot, endpoint);
132
133        Ok(())
134    }
135
136    /// Get the next connection (round-robin)
137    pub fn get_slot(&self) -> usize {
138        let slot = self.robin_counter.fetch_add(1, Ordering::Relaxed) % self.config.pool_size;
139        slot
140    }
141
142    /// Execute an operation on a connection
143    pub async fn with_connection<F, T>(&self, f: F) -> Result<T, PoolError>
144    where
145        F: FnOnce(
146            &mut WssClient,
147        ) -> std::pin::Pin<
148            Box<dyn std::future::Future<Output = Result<T, WssClientError>> + Send + '_>,
149        >,
150    {
151        if self.closed.load(Ordering::Relaxed) {
152            return Err(PoolError::PoolClosed);
153        }
154
155        let slot = self.get_slot();
156        let mut guard = self.connections[slot].write();
157
158        match guard.as_mut() {
159            Some(client) => {
160                let result = f(client).await;
161                match result {
162                    Ok(v) => Ok(v),
163                    Err(e) => {
164                        warn!("Connection error on slot {}: {}", slot, e);
165                        // Mark for reconnection
166                        *guard = None;
167                        Err(PoolError::ConnectionFailed(e))
168                    }
169                }
170            }
171            None => Err(PoolError::PoolExhausted),
172        }
173    }
174
175    /// Close all connections
176    pub async fn close(&self) {
177        self.closed.store(true, Ordering::Relaxed);
178
179        for i in 0..self.connections.len() {
180            let mut guard = self.connections[i].write();
181            if let Some(mut client) = guard.take() {
182                let _ = client.close().await;
183            }
184        }
185
186        info!("Connection pool closed");
187    }
188
189    /// Get pool statistics
190    pub fn stats(&self) -> PoolStats {
191        let mut active = 0;
192        for conn in &self.connections {
193            if conn.read().is_some() {
194                active += 1;
195            }
196        }
197
198        PoolStats {
199            pool_size: self.config.pool_size,
200            active_connections: active,
201            total_requests: self.robin_counter.load(Ordering::Relaxed),
202        }
203    }
204}
205
206/// Pool statistics
207#[derive(Debug, Clone)]
208pub struct PoolStats {
209    pub pool_size: usize,
210    pub active_connections: usize,
211    pub total_requests: usize,
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn test_pool_config() {
220        let config = ConnectionPoolConfig::default();
221        assert_eq!(config.pool_size, 6);
222        assert!(config.auto_reconnect);
223    }
224
225    #[test]
226    fn test_round_robin() {
227        let config = ConnectionPoolConfig {
228            pool_size: 4,
229            endpoints: vec!["ws://test".to_string()],
230            ..Default::default()
231        };
232
233        let pool = ConnectionPool::new(config);
234
235        assert_eq!(pool.get_slot(), 0);
236        assert_eq!(pool.get_slot(), 1);
237        assert_eq!(pool.get_slot(), 2);
238        assert_eq!(pool.get_slot(), 3);
239        assert_eq!(pool.get_slot(), 0); // Wraps around
240    }
241}