1use std::{
7 collections::VecDeque,
8 net::SocketAddr,
9 sync::Arc,
10 time::{Duration, Instant},
11};
12
13use parking_lot::Mutex;
14use tokio::net::TcpStream;
15use tracing::debug;
16use trojan_metrics::{record_fallback_pool_warm_fail, set_fallback_pool_size};
17
18struct PooledConnection {
20 stream: TcpStream,
21 created_at: Instant,
22}
23
24pub struct ConnectionPool {
26 addr: SocketAddr,
27 connections: Arc<Mutex<VecDeque<PooledConnection>>>,
28 max_idle: usize,
29 max_age: Duration,
30 fill_batch: usize,
31 fill_delay: Duration,
32}
33
34impl ConnectionPool {
35 pub fn new(
37 addr: SocketAddr,
38 max_idle: usize,
39 max_age_secs: u64,
40 fill_batch: usize,
41 fill_delay_ms: u64,
42 ) -> Self {
43 let pool = Self {
44 addr,
45 connections: Arc::new(Mutex::new(VecDeque::new())),
46 max_idle,
47 max_age: Duration::from_secs(max_age_secs),
48 fill_batch,
49 fill_delay: Duration::from_millis(fill_delay_ms),
50 };
51 set_fallback_pool_size(0);
52 pool
53 }
54
55 pub async fn get(&self) -> std::io::Result<TcpStream> {
57 let pooled = {
59 let mut pool = self.connections.lock();
60 let pooled = pool.pop_front();
61 set_fallback_pool_size(pool.len());
62 pooled
63 };
64 if let Some(pooled) = pooled {
65 if pooled.created_at.elapsed() < self.max_age {
66 debug!(addr = %self.addr, "using pooled connection");
67 return Ok(pooled.stream);
68 }
69 debug!(addr = %self.addr, "discarding expired pooled connection");
70 }
71
72 debug!(addr = %self.addr, "creating new connection");
74 TcpStream::connect(self.addr).await
75 }
76
77 pub fn cleanup(&self) {
80 let mut pool = self.connections.lock();
81 let before = pool.len();
82 pool.retain(|conn| conn.created_at.elapsed() < self.max_age);
83 let removed = before - pool.len();
84 set_fallback_pool_size(pool.len());
85 if removed > 0 {
86 debug!(addr = %self.addr, removed, remaining = pool.len(), "cleaned up expired connections");
87 }
88 }
89
90 pub fn start_cleanup_task(self: &Arc<Self>, interval: Duration) {
92 let pool = self.clone();
93 tokio::spawn(async move {
94 loop {
95 tokio::time::sleep(interval).await;
96 pool.cleanup();
97 pool.warm_fill().await;
98 }
99 });
100 }
101
102 pub fn size(&self) -> usize {
104 self.connections.lock().len()
105 }
106
107 async fn warm_fill(&self) {
109 let need = {
110 let pool = self.connections.lock();
111 if pool.len() >= self.max_idle {
112 return;
113 }
114 self.max_idle - pool.len()
115 };
116 if need == 0 {
117 return;
118 }
119 let batch = self.fill_batch.min(need);
120 for idx in 0..batch {
121 match TcpStream::connect(self.addr).await {
122 Ok(stream) => {
123 let mut pool = self.connections.lock();
124 if pool.len() < self.max_idle {
125 pool.push_back(PooledConnection {
126 stream,
127 created_at: Instant::now(),
128 });
129 set_fallback_pool_size(pool.len());
130 debug!(addr = %self.addr, size = pool.len(), "warm connection added");
131 }
132 }
133 Err(err) => {
134 record_fallback_pool_warm_fail();
135 debug!(addr = %self.addr, error = %err, "warm connection failed");
136 break;
137 }
138 }
139 if self.fill_delay > Duration::from_millis(0) && idx + 1 < batch {
140 tokio::time::sleep(self.fill_delay).await;
141 }
142 }
143 }
144}
145
146#[cfg(test)]
147impl ConnectionPool {
148 async fn warm_fill_once(&self) {
149 self.warm_fill().await;
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use std::net::TcpListener;
157
158 #[tokio::test]
159 async fn test_pool_basic() {
160 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
162 let addr = listener.local_addr().unwrap();
163
164 std::thread::spawn(move || {
166 while let Ok((_, _)) = listener.accept() {
167 }
169 });
170
171 let pool = ConnectionPool::new(addr, 2, 60, 2, 0);
172
173 pool.warm_fill_once().await;
175 let initial_size = pool.size();
176 assert!(initial_size <= 2);
177
178 let conn1 = pool.get().await.unwrap();
180 assert_eq!(pool.size(), initial_size.saturating_sub(1));
182
183 drop(conn1);
184 }
185
186 #[tokio::test]
187 async fn test_pool_max_idle() {
188 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
189 let addr = listener.local_addr().unwrap();
190
191 std::thread::spawn(move || while let Ok((_, _)) = listener.accept() {});
192
193 let pool = ConnectionPool::new(addr, 2, 60, 2, 0);
194
195 pool.warm_fill_once().await;
197 assert!(pool.size() <= 2);
198 }
199}