http_pool/http2/
pool.rs

1use crate::http2::{SendRequest, Sender};
2use hyper::client::conn::http2;
3use net_pool::backend::{Address, BackendState};
4use net_pool::{Error, Strategy, debug, instrument_current_span, tokio_spawn};
5use std::collections::HashMap;
6use std::collections::hash_map::Entry;
7use std::sync::atomic::Ordering::Relaxed;
8use std::sync::atomic::{AtomicBool, AtomicUsize};
9use std::sync::{Arc, Mutex};
10use crate::body::VariantBody;
11
12struct Inner(Vec<SendRequest>);
13
14impl Inner {
15    fn new() -> Self {
16        Self(vec![])
17    }
18
19    fn add_sr(&mut self, sr: SendRequest) {
20        self.0.push(sr);
21    }
22
23    fn remove_sr(&mut self, other: SendRequest) -> bool {
24        if let Some(item) = self.0.iter().position(|sr| sr == &other) {
25            self.0.remove(item);
26            true
27        } else {
28            false
29        }
30    }
31
32    // 这里虽然删除了一些closed状态的, 但算法却有延迟性,它无法及时删除所有的closed
33    fn get_sr(&mut self, max_streams: Option<usize>) -> (usize, Option<SendRequest>) {
34        // 删除个数
35        let mut dels = 0;
36
37        for idx in (0..self.0.len()).rev() {
38            let sr = &self.0[idx];
39            if sr.is_closed() {
40                dels += 1;
41                self.0.remove(idx);
42            } else if sr.limited(max_streams) {
43                continue;
44            } else {
45                return (dels, Some(sr.clone()));
46            }
47        }
48
49        (dels, None)
50    }
51
52    fn len(&self) -> usize {
53        self.0.len()
54    }
55}
56
57/// http2连接池
58/// 连接池中的连接受max_conn数据限制
59/// 默认的keepalive是一分钟
60/// 连接池中的每个连接可并发流数受max_streams影响, 默认是20个
61/// 有三种行为会导致连接池中的连接被释放
62///     1.对方主动关闭连接
63///     2.remove_backend被调用清除后端地址, 这会导致该地址的所有连接失效
64///     3.在外界没有引用sender达到keepalive时间后, 注意如果没有设置keepalive则此条无效
65pub struct Pool {
66    state: net_pool::pool::BaseState,
67    max_streams: AtomicUsize,
68    free_conn_map: Mutex<HashMap<u64, Inner>>,
69    use_tls: AtomicBool,
70}
71
72impl Pool {
73    pub fn new(strategy: Arc<dyn Strategy>, mut max_streams: Option<usize>) -> Self {
74        if let Some(0) = max_streams {
75            max_streams = None;
76        }
77
78        let p = Pool {
79            state: net_pool::pool::BaseState::new(strategy),
80            max_streams: AtomicUsize::new(usize::MAX),
81            free_conn_map: Mutex::new(HashMap::new()),
82            use_tls: AtomicBool::new(false),
83        };
84
85        p.set_max_streams(max_streams);
86        <Pool as net_pool::pool::Pool>::set_keepalive(
87            &p,
88            Some(std::time::Duration::from_secs(60 * 1)),
89        );
90        p
91    }
92
93    pub fn set_max_streams(&self, mut max_streams: Option<usize>) {
94        if let Some(0) = max_streams {
95            max_streams = None;
96        }
97
98        match max_streams {
99            None => self.max_streams.store(usize::MAX, Relaxed),
100            Some(v) => self.max_streams.store(v, Relaxed),
101        }
102    }
103
104    pub fn get_max_streams(&self) -> Option<usize> {
105        match self.max_streams.load(Relaxed) {
106            usize::MAX => None,
107            v => Some(v),
108        }
109    }
110}
111
112impl Default for Pool {
113    fn default() -> Self {
114        Pool::new(
115            Arc::new(net_pool::strategy::CHStrategy::default()),
116            Some(200),
117        )
118    }
119}
120
121impl<L: Strategy + 'static> From<L> for Pool {
122    fn from(value: L) -> Self {
123        Self::new(Arc::new(value), Some(200))
124    }
125}
126
127impl net_pool::pool::Pool for Pool {
128    net_pool::macros::base_pool_impl! {state}
129
130    fn remove_backend(&self, addr: &Address) -> bool {
131        if self.state.lb_strategy.remove_backend(addr) {
132            // 清除缓存
133            self.clear_bs_sr(addr);
134            true
135        } else {
136            false
137        }
138    }
139
140    fn use_tls(&self, tls: bool) {
141        self.use_tls.store(tls, Relaxed);
142    }
143
144    fn tls(&self) -> bool {
145        self.use_tls.load(Relaxed)
146    }
147}
148
149impl Pool {
150    fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
151        let mut guard = self.free_conn_map.lock().unwrap();
152        let inners = guard.get_mut(&bs.hash_code())?;
153
154        let sr = {
155            let (del_cnt, sr) = inners.get_sr(self.get_max_streams());
156            if del_cnt > 0 {
157                assert!(self.state.cur_conn.fetch_sub(del_cnt, Relaxed) > 0);
158                debug!(
159                    "[desc] current connection count: {}",
160                    self.state.cur_conn.load(Relaxed)
161                );
162            }
163            sr
164        };
165
166        if let Some(sr) = sr {
167            let tls = <Pool as net_pool::pool::Pool>::tls(self);
168            Some(Sender::new(
169                sr,
170                crate::utils::base_url(tls, bs.get_address()),
171            ))
172        } else {
173            None
174        }
175    }
176
177    fn add_sr(&self, hash_code: u64, sr: SendRequest) {
178        let mut guard = self.free_conn_map.lock().unwrap();
179        match guard.entry(hash_code) {
180            Entry::Occupied(mut o) => {
181                o.get_mut().add_sr(sr);
182            }
183            Entry::Vacant(v) => {
184                let mut inner = Inner::new();
185                inner.add_sr(sr);
186                v.insert(inner);
187            }
188        }
189    }
190
191    fn remove_sr(&self, hash_code: u64, sr: SendRequest) {
192        let mut guard = self.free_conn_map.lock().unwrap();
193        if let Some(i) = guard.get_mut(&hash_code) {
194            if i.remove_sr(sr) {
195                assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
196                debug!(
197                    "[desc] current connection count: {}",
198                    self.state.cur_conn.load(Relaxed)
199                );
200            }
201        }
202    }
203
204    fn clear_bs_sr(&self, addr: &Address) {
205        let mut guard = self.free_conn_map.lock().unwrap();
206        if let Some(inners) = guard.remove(&addr.hash_code()) {
207            assert!(self.state.cur_conn.fetch_sub(inners.len(), Relaxed) > 0);
208            debug!(
209                "[desc] current connection count: {}",
210                self.state.cur_conn.load(Relaxed)
211            );
212        }
213    }
214
215    fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
216        pool: Arc<Self>,
217        c: C,
218        sr: http2::SendRequest<VariantBody>,
219        bs: &BackendState,
220    ) -> Sender {
221        let tls = <Pool as net_pool::pool::Pool>::tls(&pool);
222        let sr = SendRequest::new(sr);
223        let sender = Sender::new(sr.clone(), crate::utils::base_url(tls, bs.get_address()));
224
225        let code = bs.hash_code();
226
227        // 加入
228        // 添加sr进去缓存时, 可能此时bs已经被删除了
229        // 由于缓存中持有sr, 所以有可能会导致在connection没有设置超时的情况下一直在运行,从而一直占用最大连接数
230        // 由于缓存和bs是通过不同的锁持有的, 所以暂时无法对它们进行一致性限制
231        pool.add_sr(code, sr.clone());
232        let ka = <Pool as net_pool::pool::Pool>::get_keepalive(&pool);
233
234        // 驱动
235        tokio_spawn! {
236            instrument_current_span! {
237                async move {
238                    let _r = crate::utils::run_conn(sr.clone(), c, ka).await;
239                    debug!("connection closed: {:?}", _r);
240                    pool.remove_sr(code, sr);
241                }
242            }
243        };
244
245        sender
246    }
247
248    async fn create_tls_sender(
249        self: Arc<Self>,
250        bs: &BackendState,
251        exec: hyper_util::rt::TokioExecutor,
252    ) -> Result<Sender, Error> {
253        let addr = bs.get_address();
254
255        // 获取连接
256        let tcp = crate::utils::create_https_stream(addr).await?;
257        let tls_tcp =
258            crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP2_TLS_CLIENT_CFG.clone())
259                .await?;
260
261        // 握手建立 sender 和 connection
262        let max_streams = self.get_max_streams();
263        let io = hyper_util::rt::TokioIo::new(tls_tcp);
264        let pair = http2::Builder::new(exec)
265            .max_concurrent_streams(max_streams.map(|m| m as u32))
266            .handshake(io)
267            .await
268            .map_err(|e| Error::from_other(e))?;
269
270        // 启动连接驱动
271        let sender = Pool::run_conn(self, pair.1, pair.0, &bs);
272        Ok(sender)
273    }
274
275    async fn create_non_tls_sender(
276        self: Arc<Self>,
277        bs: &BackendState,
278        exec: hyper_util::rt::TokioExecutor,
279    ) -> Result<Sender, Error> {
280        let tcp = crate::utils::create_http_stream(bs.get_address()).await?;
281
282        // 握手建立 sender 和 connection
283        let io = hyper_util::rt::TokioIo::new(tcp);
284        let pair = http2::handshake(exec, io)
285            .await
286            .map_err(|e| Error::from_other(e))?;
287
288        // 启动连接驱动
289        let sender = Pool::run_conn(self, pair.1, pair.0, bs);
290        Ok(sender)
291    }
292
293    async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
294        let exec = hyper_util::rt::tokio::TokioExecutor::new();
295        if <Pool as net_pool::pool::Pool>::tls(&self) {
296            self.create_tls_sender(bs, exec).await
297        } else {
298            self.create_non_tls_sender(bs, exec).await
299        }
300    }
301
302    async fn real(self: Arc<Self>, bs: BackendState) -> Result<Sender, Error> {
303        let sender = match self.get_sender(&bs) {
304            Some(s) => Ok(s),
305            None => {
306                // 预分配数量
307                net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
308                self.clone().create_sender(&bs).await.map(|s| {
309                    debug!(
310                        "[incr] current connection count: {}",
311                        self.state.cur_conn.load(Relaxed)
312                    );
313                    s
314                })
315            }
316        };
317
318        if sender.is_err() {
319            assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
320        }
321
322        sender
323    }
324}
325
326pub trait HttpPool {
327    fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
328
329    fn target(
330        self: Arc<Self>,
331        addr: &Address,
332    ) -> impl Future<Output = Result<Sender, Error>> + Send;
333}
334
335impl HttpPool for Pool {
336    async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
337        let bs = self
338            .state
339            .lb_strategy
340            .get_backend(key)
341            .ok_or(Error::NoBackend);
342
343        self.real(bs?).await
344    }
345
346    async fn target(self: Arc<Self>, addr: &Address) -> Result<Sender, Error> {
347        if !self.state.lb_strategy.contain(addr) {
348            Err(Error::NoBackend)
349        } else {
350            self.real(BackendState::new(None, addr.clone())).await
351        }
352    }
353}
354
355pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
356    HttpPool::get(pool, key).await
357}