http_pool/http1/
pool.rs

1use crate::http1::{SendRequest, Sender};
2use hyper::client::conn::http1;
3use net_pool::backend::{Address, BackendState};
4use net_pool::strategy::LbStrategy;
5use net_pool::{Error, debug, instrument_current_span, tokio_spawn};
6use std::collections::HashMap;
7use std::sync::atomic::Ordering::Relaxed;
8use std::sync::atomic::{AtomicBool, AtomicU32};
9use std::sync::{Arc, LazyLock, Mutex};
10
11/// sender id 分配
12static SENDER_ID: LazyLock<AtomicU32> = LazyLock::new(|| AtomicU32::new(1));
13
14struct Inner {
15    /// 空闲的sr
16    free: HashMap<u32, Arc<SendRequest>>,
17    /// 在用的sr
18    work: HashMap<u32, Arc<SendRequest>>,
19}
20
21impl Inner {
22    fn new() -> Self {
23        Inner {
24            free: HashMap::new(),
25            work: HashMap::new(),
26        }
27    }
28
29    fn add_work_sr(&mut self, sr: Arc<SendRequest>) -> u32 {
30        let id = SENDER_ID.fetch_add(1, Relaxed);
31        self.work.insert(id, sr);
32        id
33    }
34
35    fn remove_sr(&mut self, id: u32) -> bool {
36        // 不应该同时处于work和free中
37        if self.work.remove(&id).is_some() {
38            true
39        } else if self.free.remove(&id).is_some() {
40            true
41        } else {
42            false
43        }
44    }
45
46    /// 返回值是被删除的元素个数, 及查找到的sr
47    fn get_sr(&mut self) -> (usize, Option<Arc<SendRequest>>) {
48        if self.free.is_empty() {
49            // 挪动
50            self.work.retain(|id, sr| {
51                // 引用计数为2时,代表外界不存在引用, 队列里一份, run那里一份
52                if Arc::strong_count(sr) == 2 {
53                    self.free.insert(*id, sr.clone());
54                    false
55                } else {
56                    true
57                }
58            });
59        }
60
61        // 这里虽然删除了一些closed状态的, 但算法却有延迟性,它无法及时删除所有的closed
62        let mut del = 0;
63        let ks: Vec<u32> = self.free.keys().cloned().collect();
64        for k in ks {
65            if let Some(sr) = self.free.remove(&k) {
66                if sr.is_closed() {
67                    del += 1;
68                } else {
69                    // 挪入work
70                    self.work.insert(k, sr.clone());
71                    return (del, Some(sr));
72                }
73            }
74        }
75
76        (del, None)
77    }
78}
79
80/// http1连接池
81/// 连接池中的连接受max_conn数据限制
82/// 默认的keepalive是一分钟
83/// 有三种行为会导致连接池中的连接被释放
84///     1.对方主动关闭连接
85///     2.remove_backend被调用清除后端地址, 这会导致该地址的所有连接失效
86///     3.在外界没有引用sender达到keepalive时间后, 注意如果没有设置keepalive则此条无效
87pub struct Pool {
88    id: String,
89    state: net_pool::pool::BaseState,
90    use_tls: AtomicBool,
91    free_conn_map: Mutex<HashMap<u64, Inner>>,
92}
93
94impl Pool {
95    pub fn new(id: String, strategy: Arc<dyn LbStrategy>) -> Self {
96        let p = Pool {
97            id,
98            state: net_pool::pool::BaseState::new(strategy),
99            use_tls: AtomicBool::new(false),
100            free_conn_map: Mutex::new(HashMap::new()),
101        };
102        <Pool as net_pool::pool::Pool>::set_keepalive(
103            &p,
104            Some(std::time::Duration::from_secs(60 * 1)),
105        );
106        p
107    }
108
109    pub fn set_id<I: Into<String>>(&mut self, id: I) {
110        self.id = id.into();
111    }
112}
113
114impl Default for Pool {
115    fn default() -> Self {
116        Pool::new(
117            "".to_string(),
118            Arc::new(net_pool::strategy::CHStrategy::default()),
119        )
120    }
121}
122
123impl<L: LbStrategy + 'static> From<L> for Pool {
124    fn from(value: L) -> Self {
125        Self::new("".to_string(), Arc::new(value))
126    }
127}
128
129impl net_pool::pool::Pool for Pool {
130    net_pool::macros::base_pool_impl! {state}
131
132    fn id(&self) -> &str {
133        &self.id
134    }
135
136    fn remove_backend(&self, addr: &Address) -> bool {
137        if self.state.lb_strategy.remove_backend(addr) {
138            // 清除缓存
139            self.clear_bs_sr(addr.hash_code());
140            true
141        } else {
142            false
143        }
144    }
145
146    fn use_tls(&self, tls: bool) {
147        self.use_tls.store(tls, Relaxed);
148    }
149
150    fn tls(&self) -> bool {
151        self.use_tls.load(Relaxed)
152    }
153}
154
155impl Pool {
156    fn clear_bs_sr(&self, hash_code: u64) {
157        let mut guard = self.free_conn_map.lock().unwrap();
158        if let Some(inner) = guard.remove(&hash_code) {
159            assert!(
160                self.state
161                    .cur_conn
162                    .fetch_sub(inner.free.len() + inner.work.len(), Relaxed)
163                    > 0
164            );
165            debug!(
166                "[http/1.1 pool] [desc] current connection count: {}",
167                self.state.cur_conn.load(Relaxed)
168            );
169        }
170    }
171
172    fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
173        let mut guard = self.free_conn_map.lock().unwrap();
174        let inner = guard.get_mut(&bs.hash_code())?;
175
176        let sr = {
177            let (del_cnt, sr) = inner.get_sr();
178            if del_cnt > 0 {
179                assert!(self.state.cur_conn.fetch_sub(del_cnt, Relaxed) > 0);
180                debug!(
181                    "[http/1.1 pool] [desc] current connection count: {}",
182                    self.state.cur_conn.load(Relaxed)
183                );
184            }
185            sr
186        };
187
188        let tls = <Pool as net_pool::pool::Pool>::tls(self);
189        sr.map(|s| Sender::new(s, crate::utils::base_url(tls, bs.get_address())))
190    }
191
192    fn add_sr(&self, hash_code: u64, sr: Arc<SendRequest>) -> u32 {
193        let mut guard = self.free_conn_map.lock().unwrap();
194        if let Some(inner) = guard.get_mut(&hash_code) {
195            inner.add_work_sr(sr)
196        } else {
197            let mut inner = Inner::new();
198            let id = inner.add_work_sr(sr);
199            guard.insert(hash_code, inner);
200            id
201        }
202    }
203
204    fn remove_sr(&self, hash_code: u64, id: u32) {
205        let mut guard = self.free_conn_map.lock().unwrap();
206        if let Some(inner) = guard.get_mut(&hash_code) {
207            if inner.remove_sr(id) {
208                assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
209                debug!(
210                    "[http/1.1 pool] [desc] current connection count: {}",
211                    self.state.cur_conn.load(Relaxed)
212                );
213            }
214        }
215    }
216
217    fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
218        pool: Arc<Self>,
219        c: C,
220        sr: Arc<SendRequest>,
221        bs: &BackendState,
222    ) {
223        let hash_code = bs.hash_code();
224
225        // 加入work,且必须在下面执行super::run_conn之前加入
226        // 添加sr进去缓存时, 可能此时bs已经被删除了
227        // 由于缓存中持有sr, 所以有可能会导致在connection没有设置超时的情况下一直在运行,从而一直占用最大连接数
228        // 由于缓存和bs是通过不同的锁持有的, 所以暂时无法对它们进行一致性限制
229        let id = pool.add_sr(hash_code, sr.clone());
230        let ka = <Pool as net_pool::pool::Pool>::get_keepalive(&pool);
231
232        tokio_spawn! {
233            instrument_current_span! {
234                async move {
235                    let _r = crate::utils::run_conn(sr, c, ka).await;
236                    debug!("[http/1.1 pool] connection closed: {:?}", _r);
237                    pool.remove_sr(hash_code, id);
238                }
239            }
240        };
241    }
242
243    async fn create_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
244        let tls = <Pool as net_pool::pool::Pool>::tls(&self);
245        let addr = bs.get_address();
246
247        // 获取连接
248        let tcp = crate::utils::create_https_stream(addr).await?;
249        let tls_tcp =
250            crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP1_TLS_CLIENT_CFG.clone())
251                .await?;
252
253        // 握手建立 sender 和 connection
254        let io = hyper_util::rt::TokioIo::new(tls_tcp);
255        let pair = http1::handshake(io)
256            .await
257            .map_err(|e| Error::from_other(e))?;
258
259        // 启动连接驱动
260        let sr = Arc::new(pair.0);
261        Pool::run_conn(self, pair.1, sr.clone(), &bs);
262
263        // 返回一份
264        Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
265    }
266
267    async fn create_non_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
268        let addr = bs.get_address();
269        let tls = <Pool as net_pool::pool::Pool>::tls(&self);
270
271        // 获取连接
272        let tcp = crate::utils::create_http_stream(addr).await?;
273
274        // 握手建立 sender 和 connection
275        let io = hyper_util::rt::TokioIo::new(tcp);
276        let pair = http1::handshake(io)
277            .await
278            .map_err(|e| Error::from_other(e))?;
279
280        // 启动连接驱动
281        let sr = Arc::new(pair.0);
282        Pool::run_conn(self, pair.1, sr.clone(), &bs);
283
284        // 返回一份
285        Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
286    }
287
288    async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
289        if <Pool as net_pool::pool::Pool>::tls(&self) {
290            self.create_tls_sender(bs).await
291        } else {
292            self.create_non_tls_sender(bs).await
293        }
294    }
295
296    async fn real(self: Arc<Self>, bs: BackendState) -> Result<Sender, Error> {
297        if let Some(sender) = self.clone().get_sender(&bs) {
298            return Ok(sender);
299        }
300
301        let sender = {
302            // 预分配数量
303            net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
304            self.clone().create_sender(&bs).await.map(|s| {
305                debug!(
306                    "[http/1.1 pool] [incr] current connection count: {}",
307                    self.state.cur_conn.load(Relaxed)
308                );
309                s
310            })
311        };
312
313        if sender.is_err() {
314            assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
315        }
316
317        sender
318    }
319}
320
321pub trait HttpPool {
322    /// 根据key做某种策略获取
323    fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
324
325    /// 根据地址获取
326    fn target(
327        self: Arc<Self>,
328        _: &Address,
329    ) -> impl Future<Output = Result<Sender, Error>> + Send;
330}
331
332impl HttpPool for Pool {
333    async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
334        let bs = self
335            .state
336            .lb_strategy
337            .get_backend(key)
338            .ok_or(Error::NoBackend)?;
339
340        self.real(bs).await
341    }
342
343    async fn target(self: Arc<Self>, addr: &Address) -> Result<Sender, Error> {
344        if !self.state.lb_strategy.contain(addr) {
345            return Err(Error::NoBackend);
346        }
347
348        self.real(BackendState::new(None, addr.clone())).await
349    }
350}
351
352pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
353    HttpPool::get(pool, key).await
354}