http_pool/http2/
pool.rs

1use crate::http2::{SendRequest, Sender};
2use hyper::client::conn::http2;
3use net_pool::backend::{Address, BackendState};
4use net_pool::strategy::LbStrategy;
5use net_pool::{Error, debug, instrument_current_span, tokio_spawn};
6use std::collections::HashMap;
7use std::sync::atomic::Ordering::Relaxed;
8use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize};
9use std::sync::{Arc, LazyLock, Mutex};
10
11/// sender id 分配
12static SENDER_ID: LazyLock<AtomicU32> = LazyLock::new(|| AtomicU32::new(1));
13
14struct Inner(SendRequest, Arc<()>, u32);
15
16impl Inner {
17    fn new(sender: SendRequest) -> Self {
18        Inner(sender, Arc::new(()), SENDER_ID.fetch_add(1, Relaxed))
19    }
20
21    fn id(&self) -> u32 {
22        self.2
23    }
24
25    fn new_sender(&self, base_url: String) -> Sender {
26        Sender::new(self.0.clone(), base_url, self.1.clone())
27    }
28
29    /// 引用数量
30    fn ref_count(&self) -> usize {
31        Arc::strong_count(&self.1)
32    }
33
34    /// 是否达到最大数
35    fn limited(&self, max_streams: Option<usize>) -> bool {
36        if let Some(max) = max_streams {
37            if self.ref_count() >= max + 1 {
38                return true;
39            }
40        }
41        return false;
42    }
43
44    fn is_closed(&self) -> bool {
45        self.0.is_closed()
46    }
47
48    fn reference(&self) -> Arc<()> {
49        self.1.clone()
50    }
51}
52
53/// http2连接池
54/// 连接池中的连接受max_conn数据限制
55/// 默认的keepalive是一分钟
56/// 连接池中的每个连接可并发流数受max_streams影响, 默认是20个
57/// 有三种行为会导致连接池中的连接被释放
58///     1.对方主动关闭连接
59///     2.remove_backend被调用清除后端地址, 这会导致该地址的所有连接失效
60///     3.在外界没有引用sender达到keepalive时间后, 注意如果没有设置keepalive则此条无效
61pub struct Pool {
62    id: String,
63    state: net_pool::pool::BaseState,
64    max_streams: AtomicUsize,
65    free_conn_map: Mutex<HashMap<u64, Vec<Inner>>>,
66    use_tls: AtomicBool,
67}
68
69impl Pool {
70    pub fn new(id: String, strategy: Arc<dyn LbStrategy>, mut max_streams: Option<usize>) -> Self {
71        if let Some(0) = max_streams {
72            max_streams = None;
73        }
74
75        let p = Pool {
76            id,
77            state: net_pool::pool::BaseState::new(strategy),
78            max_streams: AtomicUsize::new(usize::MAX),
79            free_conn_map: Mutex::new(HashMap::new()),
80            use_tls: AtomicBool::new(false),
81        };
82
83        p.set_max_streams(max_streams);
84        <Pool as net_pool::pool::Pool>::set_keepalive(
85            &p,
86            Some(std::time::Duration::from_secs(60 * 1)),
87        );
88        p
89    }
90
91    pub fn set_max_streams(&self, mut max_streams: Option<usize>) {
92        if let Some(0) = max_streams {
93            max_streams = None;
94        }
95
96        match max_streams {
97            None => self.max_streams.store(usize::MAX, Relaxed),
98            Some(v) => self.max_streams.store(v, Relaxed),
99        }
100    }
101    pub fn get_max_streams(&self) -> Option<usize> {
102        match self.max_streams.load(Relaxed) {
103            usize::MAX => None,
104            v => Some(v),
105        }
106    }
107
108    pub fn set_id<I: Into<String>>(&mut self, id: I) {
109        self.id = id.into();
110    }
111}
112
113impl Default for Pool {
114    fn default() -> Self {
115        Pool::new(
116            "".to_string(),
117            Arc::new(net_pool::strategy::CHStrategy::default()),
118            Some(200),
119        )
120    }
121}
122
123impl<L: LbStrategy + 'static> From<L> for Pool {
124    fn from(value: L) -> Self {
125        Self::new("".to_string(), Arc::new(value), Some(200))
126    }
127}
128
129impl net_pool::pool::Pool for Pool {
130    net_pool::macros::base_pool_impl! {state}
131
132    fn id(&self) -> &str {
133        &self.id
134    }
135
136    fn remove_backend(&self, addr: &Address) -> bool {
137        if self.state.lb_strategy.remove_backend(addr) {
138            // 清除缓存
139            self.clear_bs_inner(addr);
140            true
141        } else {
142            false
143        }
144    }
145
146    fn use_tls(&self, tls: bool) {
147        self.use_tls.store(tls, Relaxed);
148    }
149
150    fn tls(&self) -> bool {
151        self.use_tls.load(Relaxed)
152    }
153}
154
155impl Pool {
156    fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
157        let mut guard = self.free_conn_map.lock().unwrap();
158
159        // 这里虽然删除了一些closed状态的, 但算法却有延迟性,它无法及时删除所有的closed
160        let inners = guard.get_mut(&bs.hash_code())?;
161        for idx in (0..inners.len()).rev() {
162            let inner = &mut inners[idx];
163            if inner.is_closed() {
164                // 无效了
165                assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
166                debug!(
167                    "[http/2.0 pool] [desc] current connection count: {}",
168                    self.state.cur_conn.load(Relaxed)
169                );
170                inners.remove(idx);
171                continue;
172            }
173
174            if inner.limited(self.get_max_streams()) {
175                continue;
176            }
177
178            let tls = <Pool as net_pool::pool::Pool>::tls(self);
179            return Some(inner.new_sender(crate::utils::base_url(tls, bs.get_address())));
180        }
181
182        None
183    }
184
185    fn add_inner(&self, hash_code: u64, inner: Inner) {
186        let mut guard = self.free_conn_map.lock().unwrap();
187        if let Some(inners) = guard.get_mut(&hash_code) {
188            inners.push(inner);
189        } else {
190            guard.insert(hash_code, vec![inner]);
191        }
192    }
193
194    fn clear_bs_inner(&self, addr: &Address) {
195        let mut guard = self.free_conn_map.lock().unwrap();
196        if let Some(inners) = guard.remove(&addr.hash_code()) {
197            assert!(self.state.cur_conn.fetch_sub(inners.len(), Relaxed) > 0);
198            debug!(
199                "[http/2.0 pool] [desc] current connection count: {}",
200                self.state.cur_conn.load(Relaxed)
201            );
202        }
203    }
204
205    fn remove_inner(&self, hash_code: u64, id: u32) {
206        let mut guard = self.free_conn_map.lock().unwrap();
207        if let Some(inners) = guard.get_mut(&hash_code) {
208            let mut del = false;
209            inners.retain(|inner| {
210                if inner.id() == id {
211                    del = true;
212                    assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
213                    false
214                } else {
215                    true
216                }
217            });
218            if del {
219                debug!(
220                    "[http/2.0 pool] [desc] current connection count: {}",
221                    self.state.cur_conn.load(Relaxed)
222                );
223            }
224        }
225    }
226
227    fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
228        pool: Arc<Self>,
229        c: C,
230        sr: SendRequest,
231        bs: &BackendState,
232    ) -> Sender {
233        let tls = <Pool as net_pool::pool::Pool>::tls(&pool);
234        let inner = Inner::new(sr);
235        let sender = inner.new_sender(crate::utils::base_url(tls, bs.get_address()));
236
237        let tuple = (
238            bs.hash_code(),
239            <Pool as net_pool::pool::Pool>::get_keepalive(&pool),
240            inner.reference(),
241            inner.id(),
242        );
243
244        // 加入
245        // 添加inner进去缓存时, 可能此时bs已经被删除了
246        // 由于缓存中持有inner, 所以有可能会导致在connection没有设置超时的情况下一直在运行,从而一直占用最大连接数
247        // 由于缓存和bs是通过不同的锁持有的, 所以暂时无法对它们进行一致性限制
248        pool.add_inner(tuple.0, inner);
249
250        // 驱动
251        tokio_spawn! {
252            instrument_current_span! {
253                async move {
254                    let _r = crate::utils::run_conn(tuple.2, c, tuple.1).await;
255                    debug!("[http/2.0 pool] connection closed: {:?}", _r);
256                    pool.remove_inner(tuple.0, tuple.3);
257                }
258            }
259        };
260
261        sender
262    }
263
264    async fn create_tls_sender(
265        self: Arc<Self>,
266        bs: &BackendState,
267        exec: hyper_util::rt::TokioExecutor,
268    ) -> Result<Sender, Error> {
269        let addr = bs.get_address();
270
271        // 获取连接
272        let tcp = crate::utils::create_https_stream(addr).await?;
273        let tls_tcp =
274            crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP2_TLS_CLIENT_CFG.clone())
275                .await?;
276
277        // 握手建立 sender 和 connection
278        let max_streams = self.get_max_streams();
279        let io = hyper_util::rt::TokioIo::new(tls_tcp);
280        let pair = http2::Builder::new(exec)
281            .max_concurrent_streams(max_streams.map(|m| m as u32))
282            .handshake(io)
283            .await
284            .map_err(|e| Error::from_other(e))?;
285
286        // 启动连接驱动
287        let sender = Pool::run_conn(self, pair.1, pair.0, &bs);
288        Ok(sender)
289    }
290
291    async fn create_non_tls_sender(
292        self: Arc<Self>,
293        bs: &BackendState,
294        exec: hyper_util::rt::TokioExecutor,
295    ) -> Result<Sender, Error> {
296        let tcp = crate::utils::create_http_stream(bs.get_address()).await?;
297
298        // 握手建立 sender 和 connection
299        let io = hyper_util::rt::TokioIo::new(tcp);
300        let pair = http2::handshake(exec, io)
301            .await
302            .map_err(|e| Error::from_other(e))?;
303
304        // 启动连接驱动
305        let sender = Pool::run_conn(self, pair.1, pair.0, bs);
306        Ok(sender)
307    }
308
309    async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
310        let exec = hyper_util::rt::tokio::TokioExecutor::new();
311        if <Pool as net_pool::pool::Pool>::tls(&self) {
312            self.create_tls_sender(bs, exec).await
313        } else {
314            self.create_non_tls_sender(bs, exec).await
315        }
316    }
317
318    async fn real(self: Arc<Self>, bs: BackendState) -> Result<Sender, Error> {
319        let sender = match self.get_sender(&bs) {
320            Some(s) => Ok(s),
321            None => {
322                // 预分配数量
323                net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
324                self.clone().create_sender(&bs).await.map(|s| {
325                    debug!(
326                        "[http/2.0 pool] [incr] current connection count: {}",
327                        self.state.cur_conn.load(Relaxed)
328                    );
329                    s
330                })
331            }
332        };
333
334        if sender.is_err() {
335            assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
336        }
337
338        sender
339    }
340}
341
342pub trait HttpPool {
343    fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
344
345    fn target(
346        self: Arc<Self>,
347        addr: &Address,
348    ) -> impl Future<Output = Result<Sender, Error>> + Send;
349}
350
351impl HttpPool for Pool {
352    async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
353        let bs = self
354            .state
355            .lb_strategy
356            .get_backend(key)
357            .ok_or(Error::NoBackend);
358
359        self.real(bs?).await
360    }
361
362    async fn target(self: Arc<Self>, addr: &Address) -> Result<Sender, Error> {
363        if !self.state.lb_strategy.contain(addr) {
364            Err(Error::NoBackend)
365        } else {
366            self.real(BackendState::new(None, addr.clone())).await   
367        }
368    }
369}
370
371pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
372    HttpPool::get(pool, key).await
373}