http_pool/http2/
pool.rs

1use crate::http2::{SendRequest, Sender};
2use hyper::client::conn::http2;
3use net_pool::backend::{Address, BackendState};
4use net_pool::{Error, Strategy, debug, instrument_current_span, tokio_spawn};
5use std::collections::HashMap;
6use std::sync::atomic::Ordering::Relaxed;
7use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize};
8use std::sync::{Arc, LazyLock, Mutex};
9
10/// sender id 分配
11static SENDER_ID: LazyLock<AtomicU32> = LazyLock::new(|| AtomicU32::new(1));
12
13struct Inner(SendRequest, Arc<()>, u32);
14
15impl Inner {
16    fn new(sender: SendRequest) -> Self {
17        Inner(sender, Arc::new(()), SENDER_ID.fetch_add(1, Relaxed))
18    }
19
20    fn id(&self) -> u32 {
21        self.2
22    }
23
24    fn new_sender(&self, base_url: String) -> Sender {
25        Sender::new(self.0.clone(), base_url, self.1.clone())
26    }
27
28    /// 引用数量
29    fn ref_count(&self) -> usize {
30        Arc::strong_count(&self.1)
31    }
32
33    /// 是否达到最大数
34    fn limited(&self, max_streams: Option<usize>) -> bool {
35        if let Some(max) = max_streams {
36            if self.ref_count() >= max + 1 {
37                return true;
38            }
39        }
40        return false;
41    }
42
43    fn is_closed(&self) -> bool {
44        self.0.is_closed()
45    }
46
47    fn reference(&self) -> Arc<()> {
48        self.1.clone()
49    }
50}
51
52/// http2连接池
53/// 连接池中的连接受max_conn数据限制
54/// 默认的keepalive是一分钟
55/// 连接池中的每个连接可并发流数受max_streams影响, 默认是20个
56/// 有三种行为会导致连接池中的连接被释放
57///     1.对方主动关闭连接
58///     2.remove_backend被调用清除后端地址, 这会导致该地址的所有连接失效
59///     3.在外界没有引用sender达到keepalive时间后, 注意如果没有设置keepalive则此条无效
60pub struct Pool {
61    id: String,
62    state: net_pool::pool::BaseState,
63    max_streams: AtomicUsize,
64    free_conn_map: Mutex<HashMap<u64, Vec<Inner>>>,
65    use_tls: AtomicBool,
66}
67
68impl Pool {
69    pub fn new(id: String, strategy: Arc<dyn Strategy>, mut max_streams: Option<usize>) -> Self {
70        if let Some(0) = max_streams {
71            max_streams = None;
72        }
73
74        let p = Pool {
75            id,
76            state: net_pool::pool::BaseState::new(strategy),
77            max_streams: AtomicUsize::new(usize::MAX),
78            free_conn_map: Mutex::new(HashMap::new()),
79            use_tls: AtomicBool::new(false),
80        };
81
82        p.set_max_streams(max_streams);
83        <Pool as net_pool::pool::Pool>::set_keepalive(
84            &p,
85            Some(std::time::Duration::from_secs(60 * 1)),
86        );
87        p
88    }
89
90    pub fn set_max_streams(&self, mut max_streams: Option<usize>) {
91        if let Some(0) = max_streams {
92            max_streams = None;
93        }
94
95        match max_streams {
96            None => self.max_streams.store(usize::MAX, Relaxed),
97            Some(v) => self.max_streams.store(v, Relaxed),
98        }
99    }
100    pub fn get_max_streams(&self) -> Option<usize> {
101        match self.max_streams.load(Relaxed) {
102            usize::MAX => None,
103            v => Some(v),
104        }
105    }
106
107    pub fn set_id<I: Into<String>>(&mut self, id: I) {
108        self.id = id.into();
109    }
110}
111
112impl Default for Pool {
113    fn default() -> Self {
114        Pool::new(
115            "".to_string(),
116            Arc::new(net_pool::strategy::CHStrategy::default()),
117            Some(200),
118        )
119    }
120}
121
122impl<L: Strategy + 'static> From<L> for Pool {
123    fn from(value: L) -> Self {
124        Self::new("".to_string(), Arc::new(value), Some(200))
125    }
126}
127
128impl net_pool::pool::Pool for Pool {
129    net_pool::macros::base_pool_impl! {state}
130
131    fn id(&self) -> &str {
132        &self.id
133    }
134
135    fn remove_backend(&self, addr: &Address) -> bool {
136        if self.state.lb_strategy.remove_backend(addr) {
137            // 清除缓存
138            self.clear_bs_inner(addr);
139            true
140        } else {
141            false
142        }
143    }
144
145    fn use_tls(&self, tls: bool) {
146        self.use_tls.store(tls, Relaxed);
147    }
148
149    fn tls(&self) -> bool {
150        self.use_tls.load(Relaxed)
151    }
152}
153
154impl Pool {
155    fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
156        let mut guard = self.free_conn_map.lock().unwrap();
157
158        // 这里虽然删除了一些closed状态的, 但算法却有延迟性,它无法及时删除所有的closed
159        let inners = guard.get_mut(&bs.hash_code())?;
160        for idx in (0..inners.len()).rev() {
161            let inner = &mut inners[idx];
162            if inner.is_closed() {
163                // 无效了
164                assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
165                debug!(
166                    "[http/2.0 pool] [desc] current connection count: {}",
167                    self.state.cur_conn.load(Relaxed)
168                );
169                inners.remove(idx);
170                continue;
171            }
172
173            if inner.limited(self.get_max_streams()) {
174                continue;
175            }
176
177            let tls = <Pool as net_pool::pool::Pool>::tls(self);
178            return Some(inner.new_sender(crate::utils::base_url(tls, bs.get_address())));
179        }
180
181        None
182    }
183
184    fn add_inner(&self, hash_code: u64, inner: Inner) {
185        let mut guard = self.free_conn_map.lock().unwrap();
186        if let Some(inners) = guard.get_mut(&hash_code) {
187            inners.push(inner);
188        } else {
189            guard.insert(hash_code, vec![inner]);
190        }
191    }
192
193    fn clear_bs_inner(&self, addr: &Address) {
194        let mut guard = self.free_conn_map.lock().unwrap();
195        if let Some(inners) = guard.remove(&addr.hash_code()) {
196            assert!(self.state.cur_conn.fetch_sub(inners.len(), Relaxed) > 0);
197            debug!(
198                "[http/2.0 pool] [desc] current connection count: {}",
199                self.state.cur_conn.load(Relaxed)
200            );
201        }
202    }
203
204    fn remove_inner(&self, hash_code: u64, id: u32) {
205        let mut guard = self.free_conn_map.lock().unwrap();
206        if let Some(inners) = guard.get_mut(&hash_code) {
207            let mut del = false;
208            inners.retain(|inner| {
209                if inner.id() == id {
210                    del = true;
211                    assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
212                    false
213                } else {
214                    true
215                }
216            });
217            if del {
218                debug!(
219                    "[http/2.0 pool] [desc] current connection count: {}",
220                    self.state.cur_conn.load(Relaxed)
221                );
222            }
223        }
224    }
225
226    fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
227        pool: Arc<Self>,
228        c: C,
229        sr: SendRequest,
230        bs: &BackendState,
231    ) -> Sender {
232        let tls = <Pool as net_pool::pool::Pool>::tls(&pool);
233        let inner = Inner::new(sr);
234        let sender = inner.new_sender(crate::utils::base_url(tls, bs.get_address()));
235
236        let tuple = (
237            bs.hash_code(),
238            <Pool as net_pool::pool::Pool>::get_keepalive(&pool),
239            inner.reference(),
240            inner.id(),
241        );
242
243        // 加入
244        // 添加inner进去缓存时, 可能此时bs已经被删除了
245        // 由于缓存中持有inner, 所以有可能会导致在connection没有设置超时的情况下一直在运行,从而一直占用最大连接数
246        // 由于缓存和bs是通过不同的锁持有的, 所以暂时无法对它们进行一致性限制
247        pool.add_inner(tuple.0, inner);
248
249        // 驱动
250        tokio_spawn! {
251            instrument_current_span! {
252                async move {
253                    let _r = crate::utils::run_conn(tuple.2, c, tuple.1).await;
254                    debug!("[http/2.0 pool] connection closed: {:?}", _r);
255                    pool.remove_inner(tuple.0, tuple.3);
256                }
257            }
258        };
259
260        sender
261    }
262
263    async fn create_tls_sender(
264        self: Arc<Self>,
265        bs: &BackendState,
266        exec: hyper_util::rt::TokioExecutor,
267    ) -> Result<Sender, Error> {
268        let addr = bs.get_address();
269
270        // 获取连接
271        let tcp = crate::utils::create_https_stream(addr).await?;
272        let tls_tcp =
273            crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP2_TLS_CLIENT_CFG.clone())
274                .await?;
275
276        // 握手建立 sender 和 connection
277        let max_streams = self.get_max_streams();
278        let io = hyper_util::rt::TokioIo::new(tls_tcp);
279        let pair = http2::Builder::new(exec)
280            .max_concurrent_streams(max_streams.map(|m| m as u32))
281            .handshake(io)
282            .await
283            .map_err(|e| Error::from_other(e))?;
284
285        // 启动连接驱动
286        let sender = Pool::run_conn(self, pair.1, pair.0, &bs);
287        Ok(sender)
288    }
289
290    async fn create_non_tls_sender(
291        self: Arc<Self>,
292        bs: &BackendState,
293        exec: hyper_util::rt::TokioExecutor,
294    ) -> Result<Sender, Error> {
295        let tcp = crate::utils::create_http_stream(bs.get_address()).await?;
296
297        // 握手建立 sender 和 connection
298        let io = hyper_util::rt::TokioIo::new(tcp);
299        let pair = http2::handshake(exec, io)
300            .await
301            .map_err(|e| Error::from_other(e))?;
302
303        // 启动连接驱动
304        let sender = Pool::run_conn(self, pair.1, pair.0, bs);
305        Ok(sender)
306    }
307
308    async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
309        let exec = hyper_util::rt::tokio::TokioExecutor::new();
310        if <Pool as net_pool::pool::Pool>::tls(&self) {
311            self.create_tls_sender(bs, exec).await
312        } else {
313            self.create_non_tls_sender(bs, exec).await
314        }
315    }
316
317    async fn real(self: Arc<Self>, bs: BackendState) -> Result<Sender, Error> {
318        let sender = match self.get_sender(&bs) {
319            Some(s) => Ok(s),
320            None => {
321                // 预分配数量
322                net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
323                self.clone().create_sender(&bs).await.map(|s| {
324                    debug!(
325                        "[http/2.0 pool] [incr] current connection count: {}",
326                        self.state.cur_conn.load(Relaxed)
327                    );
328                    s
329                })
330            }
331        };
332
333        if sender.is_err() {
334            assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
335        }
336
337        sender
338    }
339}
340
341pub trait HttpPool {
342    fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
343
344    fn target(
345        self: Arc<Self>,
346        addr: &Address,
347    ) -> impl Future<Output = Result<Sender, Error>> + Send;
348}
349
350impl HttpPool for Pool {
351    async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
352        let bs = self
353            .state
354            .lb_strategy
355            .get_backend(key)
356            .ok_or(Error::NoBackend);
357
358        self.real(bs?).await
359    }
360
361    async fn target(self: Arc<Self>, addr: &Address) -> Result<Sender, Error> {
362        if !self.state.lb_strategy.contain(addr) {
363            Err(Error::NoBackend)
364        } else {
365            self.real(BackendState::new(None, addr.clone())).await
366        }
367    }
368}
369
370pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
371    HttpPool::get(pool, key).await
372}