1use std::{
7 collections::VecDeque,
8 net::SocketAddr,
9 sync::Arc,
10 time::{Duration, Instant},
11};
12
13use parking_lot::Mutex;
14use tokio::net::TcpStream;
15use tracing::debug;
16use trojan_metrics::{record_fallback_pool_warm_fail, set_fallback_pool_size};
17
18#[derive(Debug)]
20struct PooledConnection {
21 stream: TcpStream,
22 created_at: Instant,
23}
24
25#[derive(Debug)]
27pub struct ConnectionPool {
28 addr: SocketAddr,
29 connections: Arc<Mutex<VecDeque<PooledConnection>>>,
30 max_idle: usize,
31 max_age: Duration,
32 fill_batch: usize,
33 fill_delay: Duration,
34}
35
36impl ConnectionPool {
37 pub fn new(
39 addr: SocketAddr,
40 max_idle: usize,
41 max_age_secs: u64,
42 fill_batch: usize,
43 fill_delay_ms: u64,
44 ) -> Self {
45 let pool = Self {
46 addr,
47 connections: Arc::new(Mutex::new(VecDeque::new())),
48 max_idle,
49 max_age: Duration::from_secs(max_age_secs),
50 fill_batch,
51 fill_delay: Duration::from_millis(fill_delay_ms),
52 };
53 set_fallback_pool_size(0);
54 pool
55 }
56
57 pub async fn get(&self) -> std::io::Result<TcpStream> {
59 let pooled = {
61 let mut pool = self.connections.lock();
62 let pooled = pool.pop_front();
63 set_fallback_pool_size(pool.len());
64 pooled
65 };
66 if let Some(pooled) = pooled {
67 if pooled.created_at.elapsed() < self.max_age {
68 debug!(addr = %self.addr, "using pooled connection");
69 return Ok(pooled.stream);
70 }
71 debug!(addr = %self.addr, "discarding expired pooled connection");
72 }
73
74 debug!(addr = %self.addr, "creating new connection");
76 TcpStream::connect(self.addr).await
77 }
78
79 pub fn cleanup(&self) {
82 let mut pool = self.connections.lock();
83 let before = pool.len();
84 pool.retain(|conn| conn.created_at.elapsed() < self.max_age);
85 let removed = before - pool.len();
86 set_fallback_pool_size(pool.len());
87 if removed > 0 {
88 debug!(addr = %self.addr, removed, remaining = pool.len(), "cleaned up expired connections");
89 }
90 }
91
92 pub fn start_cleanup_task(self: &Arc<Self>, interval: Duration) {
94 let pool = self.clone();
95 tokio::spawn(async move {
96 loop {
97 tokio::time::sleep(interval).await;
98 pool.cleanup();
99 pool.warm_fill().await;
100 }
101 });
102 }
103
104 pub fn size(&self) -> usize {
106 self.connections.lock().len()
107 }
108
109 async fn warm_fill(&self) {
111 let need = {
112 let pool = self.connections.lock();
113 if pool.len() >= self.max_idle {
114 return;
115 }
116 self.max_idle - pool.len()
117 };
118 if need == 0 {
119 return;
120 }
121 let batch = self.fill_batch.min(need);
122 for idx in 0..batch {
123 match TcpStream::connect(self.addr).await {
124 Ok(stream) => {
125 let mut pool = self.connections.lock();
126 if pool.len() < self.max_idle {
127 pool.push_back(PooledConnection {
128 stream,
129 created_at: Instant::now(),
130 });
131 set_fallback_pool_size(pool.len());
132 debug!(addr = %self.addr, size = pool.len(), "warm connection added");
133 }
134 }
135 Err(err) => {
136 record_fallback_pool_warm_fail();
137 debug!(addr = %self.addr, error = %err, "warm connection failed");
138 break;
139 }
140 }
141 if self.fill_delay > Duration::from_millis(0) && idx + 1 < batch {
142 tokio::time::sleep(self.fill_delay).await;
143 }
144 }
145 }
146}
147
148#[cfg(test)]
149impl ConnectionPool {
150 async fn warm_fill_once(&self) {
151 self.warm_fill().await;
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use std::net::TcpListener;
159
160 #[tokio::test]
161 async fn test_pool_basic() {
162 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
164 let addr = listener.local_addr().unwrap();
165
166 std::thread::spawn(move || {
168 while let Ok((_, _)) = listener.accept() {
169 }
171 });
172
173 let pool = ConnectionPool::new(addr, 2, 60, 2, 0);
174
175 pool.warm_fill_once().await;
177 let initial_size = pool.size();
178 assert!(initial_size <= 2);
179
180 let conn1 = pool.get().await.unwrap();
182 assert_eq!(pool.size(), initial_size.saturating_sub(1));
184
185 drop(conn1);
186 }
187
188 #[tokio::test]
189 async fn test_pool_max_idle() {
190 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
191 let addr = listener.local_addr().unwrap();
192
193 std::thread::spawn(move || while let Ok((_, _)) = listener.accept() {});
194
195 let pool = ConnectionPool::new(addr, 2, 60, 2, 0);
196
197 pool.warm_fill_once().await;
199 assert!(pool.size() <= 2);
200 }
201}