1use crate::connection::VCLConnection;
30use crate::error::VCLError;
31use crate::packet::VCLPacket;
32use std::collections::HashMap;
33use tracing::{debug, info, warn};
34
35pub type ConnectionId = u64;
37
38pub struct VCLPool {
43 connections: HashMap<ConnectionId, VCLConnection>,
44 next_id: ConnectionId,
45 max_connections: usize,
46}
47
48impl VCLPool {
49 pub fn new(max_connections: usize) -> Self {
57 info!(max_connections, "VCLPool created");
58 VCLPool {
59 connections: HashMap::new(),
60 next_id: 0,
61 max_connections,
62 }
63 }
64
65 pub async fn bind(&mut self, addr: &str) -> Result<ConnectionId, VCLError> {
73 if self.connections.len() >= self.max_connections {
74 warn!(
75 current = self.connections.len(),
76 max = self.max_connections,
77 "Pool is at maximum capacity"
78 );
79 return Err(VCLError::InvalidPacket(format!(
80 "Pool is full: max {} connections",
81 self.max_connections
82 )));
83 }
84
85 let conn = VCLConnection::bind(addr).await?;
86 let id = self.next_id;
87 self.next_id += 1;
88 self.connections.insert(id, conn);
89 info!(id, addr, "Connection added to pool");
90 Ok(id)
91 }
92
93 pub async fn connect(&mut self, id: ConnectionId, addr: &str) -> Result<(), VCLError> {
99 let conn = self.get_mut(id)?;
100 debug!(id, peer = %addr, "Pool: connecting");
101 conn.connect(addr).await
102 }
103
104 pub async fn accept_handshake(&mut self, id: ConnectionId) -> Result<(), VCLError> {
110 let conn = self.get_mut(id)?;
111 debug!(id, "Pool: accepting handshake");
112 conn.accept_handshake().await
113 }
114
115 pub async fn send(&mut self, id: ConnectionId, data: &[u8]) -> Result<(), VCLError> {
121 let conn = self.get_mut(id)?;
122 debug!(id, size = data.len(), "Pool: sending");
123 conn.send(data).await
124 }
125
126 pub async fn recv(&mut self, id: ConnectionId) -> Result<VCLPacket, VCLError> {
132 let conn = self.get_mut(id)?;
133 debug!(id, "Pool: waiting for packet");
134 conn.recv().await
135 }
136
137 pub async fn ping(&mut self, id: ConnectionId) -> Result<(), VCLError> {
143 let conn = self.get_mut(id)?;
144 debug!(id, "Pool: sending ping");
145 conn.ping().await
146 }
147
148 pub async fn rotate_keys(&mut self, id: ConnectionId) -> Result<(), VCLError> {
154 let conn = self.get_mut(id)?;
155 debug!(id, "Pool: rotating keys");
156 conn.rotate_keys().await
157 }
158
159 pub fn close(&mut self, id: ConnectionId) -> Result<(), VCLError> {
165 match self.connections.get_mut(&id) {
166 Some(conn) => {
167 conn.close()?;
168 self.connections.remove(&id);
169 info!(id, "Connection removed from pool");
170 Ok(())
171 }
172 None => {
173 warn!(id, "close() called with unknown connection ID");
174 Err(VCLError::InvalidPacket(format!(
175 "Connection ID {} not found in pool",
176 id
177 )))
178 }
179 }
180 }
181
182 pub fn close_all(&mut self) {
184 info!(count = self.connections.len(), "Closing all pool connections");
185 for (id, conn) in self.connections.iter_mut() {
186 if let Err(e) = conn.close() {
187 warn!(id, error = %e, "Error closing connection during close_all");
188 }
189 }
190 self.connections.clear();
191 }
192
193 pub fn len(&self) -> usize {
195 self.connections.len()
196 }
197
198 pub fn is_empty(&self) -> bool {
200 self.connections.is_empty()
201 }
202
203 pub fn is_full(&self) -> bool {
205 self.connections.len() >= self.max_connections
206 }
207
208 pub fn connection_ids(&self) -> Vec<ConnectionId> {
210 self.connections.keys().copied().collect()
211 }
212
213 pub fn contains(&self, id: ConnectionId) -> bool {
215 self.connections.contains_key(&id)
216 }
217
218 pub fn get(&self, id: ConnectionId) -> Result<&VCLConnection, VCLError> {
223 self.connections.get(&id).ok_or_else(|| {
224 VCLError::InvalidPacket(format!("Connection ID {} not found in pool", id))
225 })
226 }
227
228 fn get_mut(&mut self, id: ConnectionId) -> Result<&mut VCLConnection, VCLError> {
229 self.connections.get_mut(&id).ok_or_else(|| {
230 VCLError::InvalidPacket(format!("Connection ID {} not found in pool", id))
231 })
232 }
233}
234
235impl Drop for VCLPool {
236 fn drop(&mut self) {
237 if !self.connections.is_empty() {
238 self.close_all();
239 }
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_pool_new() {
249 let pool = VCLPool::new(5);
250 assert_eq!(pool.len(), 0);
251 assert!(pool.is_empty());
252 assert!(!pool.is_full());
253 }
254
255 #[tokio::test]
256 async fn test_pool_bind() {
257 let mut pool = VCLPool::new(5);
258 let id = pool.bind("127.0.0.1:0").await.unwrap();
259 assert_eq!(pool.len(), 1);
260 assert!(pool.contains(id));
261 assert!(!pool.is_empty());
262 }
263
264 #[tokio::test]
265 async fn test_pool_max_capacity() {
266 let mut pool = VCLPool::new(2);
267 pool.bind("127.0.0.1:0").await.unwrap();
268 pool.bind("127.0.0.1:0").await.unwrap();
269 assert!(pool.is_full());
270 let result = pool.bind("127.0.0.1:0").await;
271 assert!(result.is_err());
272 }
273
274 #[tokio::test]
275 async fn test_pool_close() {
276 let mut pool = VCLPool::new(5);
277 let id = pool.bind("127.0.0.1:0").await.unwrap();
278 assert_eq!(pool.len(), 1);
279 pool.close(id).unwrap();
280 assert_eq!(pool.len(), 0);
281 assert!(!pool.contains(id));
282 }
283
284 #[tokio::test]
285 async fn test_pool_close_all() {
286 let mut pool = VCLPool::new(5);
287 pool.bind("127.0.0.1:0").await.unwrap();
288 pool.bind("127.0.0.1:0").await.unwrap();
289 pool.bind("127.0.0.1:0").await.unwrap();
290 assert_eq!(pool.len(), 3);
291 pool.close_all();
292 assert!(pool.is_empty());
293 }
294
295 #[tokio::test]
296 async fn test_pool_unknown_id() {
297 let mut pool = VCLPool::new(5);
298 let result = pool.close(999);
299 assert!(result.is_err());
300 }
301
302 #[tokio::test]
303 async fn test_pool_connection_ids() {
304 let mut pool = VCLPool::new(5);
305 let id1 = pool.bind("127.0.0.1:0").await.unwrap();
306 let id2 = pool.bind("127.0.0.1:0").await.unwrap();
307 let mut ids = pool.connection_ids();
308 ids.sort();
309 assert_eq!(ids, vec![id1, id2]);
310 }
311}