1use parking_lot::RwLock;
4use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
5use thiserror::Error;
6use tracing::{debug, info, warn};
7
8use crate::{WssClient, WssClientConfig, WssClientError};
9
10#[derive(Error, Debug)]
11pub enum PoolError {
12 #[error("Pool exhausted")]
13 PoolExhausted,
14
15 #[error("Connection failed: {0}")]
16 ConnectionFailed(#[from] WssClientError),
17
18 #[error("Pool is closed")]
19 PoolClosed,
20}
21
22#[derive(Debug, Clone)]
24pub struct ConnectionPoolConfig {
25 pub pool_size: usize,
27
28 pub endpoints: Vec<String>,
30
31 pub token: Option<String>,
33
34 pub auto_reconnect: bool,
36}
37
38impl Default for ConnectionPoolConfig {
39 fn default() -> Self {
40 Self {
41 pool_size: 6,
42 endpoints: Vec::new(),
43 token: None,
44 auto_reconnect: true,
45 }
46 }
47}
48
49pub struct PooledConnection {
51 client: WssClient,
52 endpoint: String,
53 id: usize,
54}
55
56impl PooledConnection {
57 pub fn client(&self) -> &WssClient {
58 &self.client
59 }
60
61 pub fn client_mut(&mut self) -> &mut WssClient {
62 &mut self.client
63 }
64
65 pub fn id(&self) -> usize {
66 self.id
67 }
68}
69
70pub struct ConnectionPool {
72 config: ConnectionPoolConfig,
73 connections: Vec<RwLock<Option<WssClient>>>,
74 robin_counter: AtomicUsize,
75 closed: AtomicBool,
76}
77
78impl ConnectionPool {
79 pub fn new(config: ConnectionPoolConfig) -> Self {
81 let mut connections = Vec::with_capacity(config.pool_size);
82 for _ in 0..config.pool_size {
83 connections.push(RwLock::new(None));
84 }
85
86 Self {
87 config,
88 connections,
89 robin_counter: AtomicUsize::new(0),
90 closed: AtomicBool::new(false),
91 }
92 }
93
94 pub async fn connect_all(&self) -> Result<(), PoolError> {
96 if self.config.endpoints.is_empty() {
97 return Err(PoolError::ConnectionFailed(WssClientError::InvalidUrl(
98 "No endpoints configured".to_string(),
99 )));
100 }
101
102 for i in 0..self.config.pool_size {
103 let endpoint = &self.config.endpoints[i % self.config.endpoints.len()];
104 self.connect_slot(i, endpoint).await?;
105 }
106
107 info!(
108 "Connection pool initialized with {} connections",
109 self.config.pool_size
110 );
111
112 Ok(())
113 }
114
115 async fn connect_slot(&self, slot: usize, endpoint: &str) -> Result<(), PoolError> {
117 let config = WssClientConfig {
118 url: endpoint.to_string(),
119 token: self.config.token.clone(),
120 ..Default::default()
121 };
122
123 let mut client = WssClient::connect(config).await?;
124
125 client.send_initial_frames().await?;
127
128 let mut guard = self.connections[slot].write();
129 *guard = Some(client);
130
131 debug!("Connected slot {} to {}", slot, endpoint);
132
133 Ok(())
134 }
135
136 pub fn get_slot(&self) -> usize {
138 let slot = self.robin_counter.fetch_add(1, Ordering::Relaxed) % self.config.pool_size;
139 slot
140 }
141
142 pub async fn with_connection<F, T>(&self, f: F) -> Result<T, PoolError>
144 where
145 F: FnOnce(
146 &mut WssClient,
147 ) -> std::pin::Pin<
148 Box<dyn std::future::Future<Output = Result<T, WssClientError>> + Send + '_>,
149 >,
150 {
151 if self.closed.load(Ordering::Relaxed) {
152 return Err(PoolError::PoolClosed);
153 }
154
155 let slot = self.get_slot();
156 let mut guard = self.connections[slot].write();
157
158 match guard.as_mut() {
159 Some(client) => {
160 let result = f(client).await;
161 match result {
162 Ok(v) => Ok(v),
163 Err(e) => {
164 warn!("Connection error on slot {}: {}", slot, e);
165 *guard = None;
167 Err(PoolError::ConnectionFailed(e))
168 }
169 }
170 }
171 None => Err(PoolError::PoolExhausted),
172 }
173 }
174
175 pub async fn close(&self) {
177 self.closed.store(true, Ordering::Relaxed);
178
179 for i in 0..self.connections.len() {
180 let mut guard = self.connections[i].write();
181 if let Some(mut client) = guard.take() {
182 let _ = client.close().await;
183 }
184 }
185
186 info!("Connection pool closed");
187 }
188
189 pub fn stats(&self) -> PoolStats {
191 let mut active = 0;
192 for conn in &self.connections {
193 if conn.read().is_some() {
194 active += 1;
195 }
196 }
197
198 PoolStats {
199 pool_size: self.config.pool_size,
200 active_connections: active,
201 total_requests: self.robin_counter.load(Ordering::Relaxed),
202 }
203 }
204}
205
206#[derive(Debug, Clone)]
208pub struct PoolStats {
209 pub pool_size: usize,
210 pub active_connections: usize,
211 pub total_requests: usize,
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn test_pool_config() {
220 let config = ConnectionPoolConfig::default();
221 assert_eq!(config.pool_size, 6);
222 assert!(config.auto_reconnect);
223 }
224
225 #[test]
226 fn test_round_robin() {
227 let config = ConnectionPoolConfig {
228 pool_size: 4,
229 endpoints: vec!["ws://test".to_string()],
230 ..Default::default()
231 };
232
233 let pool = ConnectionPool::new(config);
234
235 assert_eq!(pool.get_slot(), 0);
236 assert_eq!(pool.get_slot(), 1);
237 assert_eq!(pool.get_slot(), 2);
238 assert_eq!(pool.get_slot(), 3);
239 assert_eq!(pool.get_slot(), 0); }
241}