Skip to main content

trojan_server/
pool.rs

1//! Connection pool for fallback backend.
2//!
3//! Warm pool strategy: pre-connect N fresh connections and hand them out once.
4//! Connections are not returned to the pool after use.
5
6use std::{
7    collections::VecDeque,
8    net::SocketAddr,
9    sync::Arc,
10    time::{Duration, Instant},
11};
12
13use parking_lot::Mutex;
14use tokio::net::TcpStream;
15use tracing::debug;
16use trojan_metrics::{record_fallback_pool_warm_fail, set_fallback_pool_size};
17
18/// A pooled connection with metadata.
19#[derive(Debug)]
20struct PooledConnection {
21    stream: TcpStream,
22    created_at: Instant,
23}
24
25/// Connection pool for a single backend address.
26#[derive(Debug)]
27pub struct ConnectionPool {
28    addr: SocketAddr,
29    connections: Arc<Mutex<VecDeque<PooledConnection>>>,
30    max_idle: usize,
31    max_age: Duration,
32    fill_batch: usize,
33    fill_delay: Duration,
34}
35
36impl ConnectionPool {
37    /// Create a new connection pool.
38    pub fn new(
39        addr: SocketAddr,
40        max_idle: usize,
41        max_age_secs: u64,
42        fill_batch: usize,
43        fill_delay_ms: u64,
44    ) -> Self {
45        let pool = Self {
46            addr,
47            connections: Arc::new(Mutex::new(VecDeque::new())),
48            max_idle,
49            max_age: Duration::from_secs(max_age_secs),
50            fill_batch,
51            fill_delay: Duration::from_millis(fill_delay_ms),
52        };
53        set_fallback_pool_size(0);
54        pool
55    }
56
57    /// Get a fresh connection from the pool or create a new one.
58    pub async fn get(&self) -> std::io::Result<TcpStream> {
59        // Pop one fresh connection if available
60        let pooled = {
61            let mut pool = self.connections.lock();
62            let pooled = pool.pop_front();
63            set_fallback_pool_size(pool.len());
64            pooled
65        };
66        if let Some(pooled) = pooled {
67            if pooled.created_at.elapsed() < self.max_age {
68                debug!(addr = %self.addr, "using pooled connection");
69                return Ok(pooled.stream);
70            }
71            debug!(addr = %self.addr, "discarding expired pooled connection");
72        }
73
74        // No valid pooled connection, create new one
75        debug!(addr = %self.addr, "creating new connection");
76        TcpStream::connect(self.addr).await
77    }
78
79    /// Warm pool maintains fresh connections; used connections are not returned.
80    /// Clean up expired connections.
81    pub fn cleanup(&self) {
82        let mut pool = self.connections.lock();
83        let before = pool.len();
84        pool.retain(|conn| conn.created_at.elapsed() < self.max_age);
85        let removed = before - pool.len();
86        set_fallback_pool_size(pool.len());
87        if removed > 0 {
88            debug!(addr = %self.addr, removed, remaining = pool.len(), "cleaned up expired connections");
89        }
90    }
91
92    /// Start a background warm-fill task.
93    pub fn start_cleanup_task(self: &Arc<Self>, interval: Duration) {
94        let pool = self.clone();
95        tokio::spawn(async move {
96            loop {
97                tokio::time::sleep(interval).await;
98                pool.cleanup();
99                pool.warm_fill().await;
100            }
101        });
102    }
103
104    /// Get current pool size.
105    pub fn size(&self) -> usize {
106        self.connections.lock().len()
107    }
108
109    /// Fill the pool with fresh connections up to max_idle.
110    async fn warm_fill(&self) {
111        let need = {
112            let pool = self.connections.lock();
113            if pool.len() >= self.max_idle {
114                return;
115            }
116            self.max_idle - pool.len()
117        };
118        if need == 0 {
119            return;
120        }
121        let batch = self.fill_batch.min(need);
122        for idx in 0..batch {
123            match TcpStream::connect(self.addr).await {
124                Ok(stream) => {
125                    let mut pool = self.connections.lock();
126                    if pool.len() < self.max_idle {
127                        pool.push_back(PooledConnection {
128                            stream,
129                            created_at: Instant::now(),
130                        });
131                        set_fallback_pool_size(pool.len());
132                        debug!(addr = %self.addr, size = pool.len(), "warm connection added");
133                    }
134                }
135                Err(err) => {
136                    record_fallback_pool_warm_fail();
137                    debug!(addr = %self.addr, error = %err, "warm connection failed");
138                    break;
139                }
140            }
141            if self.fill_delay > Duration::from_millis(0) && idx + 1 < batch {
142                tokio::time::sleep(self.fill_delay).await;
143            }
144        }
145    }
146}
147
148#[cfg(test)]
149impl ConnectionPool {
150    async fn warm_fill_once(&self) {
151        self.warm_fill().await;
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use std::net::TcpListener;
159
160    #[tokio::test]
161    async fn test_pool_basic() {
162        // Start a simple TCP listener
163        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
164        let addr = listener.local_addr().unwrap();
165
166        // Accept connections in background
167        std::thread::spawn(move || {
168            while let Ok((_, _)) = listener.accept() {
169                // Just accept, don't do anything
170            }
171        });
172
173        let pool = ConnectionPool::new(addr, 2, 60, 2, 0);
174
175        // Warm-fill the pool (fills up to max_idle=2 connections)
176        pool.warm_fill_once().await;
177        let initial_size = pool.size();
178        assert!(initial_size <= 2);
179
180        // Get a connection (takes one from pool)
181        let conn1 = pool.get().await.unwrap();
182        // Pool should have one less connection (or 0 if only 1 was added)
183        assert_eq!(pool.size(), initial_size.saturating_sub(1));
184
185        drop(conn1);
186    }
187
188    #[tokio::test]
189    async fn test_pool_max_idle() {
190        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
191        let addr = listener.local_addr().unwrap();
192
193        std::thread::spawn(move || while let Ok((_, _)) = listener.accept() {});
194
195        let pool = ConnectionPool::new(addr, 2, 60, 2, 0);
196
197        // Warm-fill should not exceed max_idle
198        pool.warm_fill_once().await;
199        assert!(pool.size() <= 2);
200    }
201}