http_pool/http1/
pool.rs

1use crate::http1::{SendRequest, Sender};
2use hyper::client::conn::http1;
3use net_pool::backend::{Address, BackendState};
4use net_pool::{Error, Strategy, debug, instrument_current_span, tokio_spawn};
5use std::collections::HashMap;
6use std::sync::atomic::Ordering::Relaxed;
7use std::sync::atomic::{AtomicBool, AtomicU32};
8use std::sync::{Arc, LazyLock, Mutex};
9
10/// sender id 分配
11static SENDER_ID: LazyLock<AtomicU32> = LazyLock::new(|| AtomicU32::new(1));
12
13struct Inner {
14    /// 空闲的sr
15    free: HashMap<u32, Arc<SendRequest>>,
16    /// 在用的sr
17    work: HashMap<u32, Arc<SendRequest>>,
18}
19
20impl Inner {
21    fn new() -> Self {
22        Inner {
23            free: HashMap::new(),
24            work: HashMap::new(),
25        }
26    }
27
28    fn add_work_sr(&mut self, sr: Arc<SendRequest>) -> u32 {
29        let id = SENDER_ID.fetch_add(1, Relaxed);
30        self.work.insert(id, sr);
31        id
32    }
33
34    fn remove_sr(&mut self, id: u32) -> bool {
35        // 不应该同时处于work和free中
36        if self.work.remove(&id).is_some() {
37            true
38        } else if self.free.remove(&id).is_some() {
39            true
40        } else {
41            false
42        }
43    }
44
45    /// 返回值是被删除的元素个数, 及查找到的sr
46    fn get_sr(&mut self) -> (usize, Option<Arc<SendRequest>>) {
47        if self.free.is_empty() {
48            // 挪动
49            self.work.retain(|id, sr| {
50                // 引用计数为2时,代表外界不存在引用, 队列里一份, run那里一份
51                if Arc::strong_count(sr) == 2 {
52                    self.free.insert(*id, sr.clone());
53                    false
54                } else {
55                    true
56                }
57            });
58        }
59
60        // 这里虽然删除了一些closed状态的, 但算法却有延迟性,它无法及时删除所有的closed
61        let mut del = 0;
62        let ks: Vec<u32> = self.free.keys().cloned().collect();
63        for k in ks {
64            if let Some(sr) = self.free.remove(&k) {
65                if sr.is_closed() {
66                    del += 1;
67                } else {
68                    // 挪入work
69                    self.work.insert(k, sr.clone());
70                    return (del, Some(sr));
71                }
72            }
73        }
74
75        (del, None)
76    }
77}
78
79/// http1连接池
80/// 连接池中的连接受max_conn数据限制
81/// 默认的keepalive是一分钟
82/// 有三种行为会导致连接池中的连接被释放
83///     1.对方主动关闭连接
84///     2.remove_backend被调用清除后端地址, 这会导致该地址的所有连接失效
85///     3.在外界没有引用sender达到keepalive时间后, 注意如果没有设置keepalive则此条无效
86pub struct Pool {
87    id: String,
88    state: net_pool::pool::BaseState,
89    use_tls: AtomicBool,
90    free_conn_map: Mutex<HashMap<u64, Inner>>,
91}
92
93impl Pool {
94    pub fn new(id: String, strategy: Arc<dyn Strategy>) -> Self {
95        let p = Pool {
96            id,
97            state: net_pool::pool::BaseState::new(strategy),
98            use_tls: AtomicBool::new(false),
99            free_conn_map: Mutex::new(HashMap::new()),
100        };
101        <Pool as net_pool::pool::Pool>::set_keepalive(
102            &p,
103            Some(std::time::Duration::from_secs(60 * 1)),
104        );
105        p
106    }
107
108    pub fn set_id<I: Into<String>>(&mut self, id: I) {
109        self.id = id.into();
110    }
111}
112
113impl Default for Pool {
114    fn default() -> Self {
115        Pool::new(
116            "".to_string(),
117            Arc::new(net_pool::strategy::CHStrategy::default()),
118        )
119    }
120}
121
122impl<L: Strategy + 'static> From<L> for Pool {
123    fn from(value: L) -> Self {
124        Self::new("".to_string(), Arc::new(value))
125    }
126}
127
128impl net_pool::pool::Pool for Pool {
129    net_pool::macros::base_pool_impl! {state}
130
131    fn id(&self) -> &str {
132        &self.id
133    }
134
135    fn remove_backend(&self, addr: &Address) -> bool {
136        if self.state.lb_strategy.remove_backend(addr) {
137            // 清除缓存
138            self.clear_bs_sr(addr.hash_code());
139            true
140        } else {
141            false
142        }
143    }
144
145    fn use_tls(&self, tls: bool) {
146        self.use_tls.store(tls, Relaxed);
147    }
148
149    fn tls(&self) -> bool {
150        self.use_tls.load(Relaxed)
151    }
152}
153
154impl Pool {
155    fn clear_bs_sr(&self, hash_code: u64) {
156        let mut guard = self.free_conn_map.lock().unwrap();
157        if let Some(inner) = guard.remove(&hash_code) {
158            assert!(
159                self.state
160                    .cur_conn
161                    .fetch_sub(inner.free.len() + inner.work.len(), Relaxed)
162                    > 0
163            );
164            debug!(
165                "[http/1.1 pool] [desc] current connection count: {}",
166                self.state.cur_conn.load(Relaxed)
167            );
168        }
169    }
170
171    fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
172        let mut guard = self.free_conn_map.lock().unwrap();
173        let inner = guard.get_mut(&bs.hash_code())?;
174
175        let sr = {
176            let (del_cnt, sr) = inner.get_sr();
177            if del_cnt > 0 {
178                assert!(self.state.cur_conn.fetch_sub(del_cnt, Relaxed) > 0);
179                debug!(
180                    "[http/1.1 pool] [desc] current connection count: {}",
181                    self.state.cur_conn.load(Relaxed)
182                );
183            }
184            sr
185        };
186
187        let tls = <Pool as net_pool::pool::Pool>::tls(self);
188        sr.map(|s| Sender::new(s, crate::utils::base_url(tls, bs.get_address())))
189    }
190
191    fn add_sr(&self, hash_code: u64, sr: Arc<SendRequest>) -> u32 {
192        let mut guard = self.free_conn_map.lock().unwrap();
193        if let Some(inner) = guard.get_mut(&hash_code) {
194            inner.add_work_sr(sr)
195        } else {
196            let mut inner = Inner::new();
197            let id = inner.add_work_sr(sr);
198            guard.insert(hash_code, inner);
199            id
200        }
201    }
202
203    fn remove_sr(&self, hash_code: u64, id: u32) {
204        let mut guard = self.free_conn_map.lock().unwrap();
205        if let Some(inner) = guard.get_mut(&hash_code) {
206            if inner.remove_sr(id) {
207                assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
208                debug!(
209                    "[http/1.1 pool] [desc] current connection count: {}",
210                    self.state.cur_conn.load(Relaxed)
211                );
212            }
213        }
214    }
215
216    fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
217        pool: Arc<Self>,
218        c: C,
219        sr: Arc<SendRequest>,
220        bs: &BackendState,
221    ) {
222        let hash_code = bs.hash_code();
223
224        // 加入work,且必须在下面执行super::run_conn之前加入
225        // 添加sr进去缓存时, 可能此时bs已经被删除了
226        // 由于缓存中持有sr, 所以有可能会导致在connection没有设置超时的情况下一直在运行,从而一直占用最大连接数
227        // 由于缓存和bs是通过不同的锁持有的, 所以暂时无法对它们进行一致性限制
228        let id = pool.add_sr(hash_code, sr.clone());
229        let ka = <Pool as net_pool::pool::Pool>::get_keepalive(&pool);
230
231        tokio_spawn! {
232            instrument_current_span! {
233                async move {
234                    let _r = crate::utils::run_conn(sr, c, ka).await;
235                    debug!("[http/1.1 pool] connection closed: {:?}", _r);
236                    pool.remove_sr(hash_code, id);
237                }
238            }
239        };
240    }
241
242    async fn create_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
243        let tls = <Pool as net_pool::pool::Pool>::tls(&self);
244        let addr = bs.get_address();
245
246        // 获取连接
247        let tcp = crate::utils::create_https_stream(addr).await?;
248        let tls_tcp =
249            crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP1_TLS_CLIENT_CFG.clone())
250                .await?;
251
252        // 握手建立 sender 和 connection
253        let io = hyper_util::rt::TokioIo::new(tls_tcp);
254        let pair = http1::handshake(io)
255            .await
256            .map_err(|e| Error::from_other(e))?;
257
258        // 启动连接驱动
259        let sr = Arc::new(pair.0);
260        Pool::run_conn(self, pair.1, sr.clone(), &bs);
261
262        // 返回一份
263        Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
264    }
265
266    async fn create_non_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
267        let addr = bs.get_address();
268        let tls = <Pool as net_pool::pool::Pool>::tls(&self);
269
270        // 获取连接
271        let tcp = crate::utils::create_http_stream(addr).await?;
272
273        // 握手建立 sender 和 connection
274        let io = hyper_util::rt::TokioIo::new(tcp);
275        let pair = http1::handshake(io)
276            .await
277            .map_err(|e| Error::from_other(e))?;
278
279        // 启动连接驱动
280        let sr = Arc::new(pair.0);
281        Pool::run_conn(self, pair.1, sr.clone(), &bs);
282
283        // 返回一份
284        Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
285    }
286
287    async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
288        if <Pool as net_pool::pool::Pool>::tls(&self) {
289            self.create_tls_sender(bs).await
290        } else {
291            self.create_non_tls_sender(bs).await
292        }
293    }
294
295    async fn real(self: Arc<Self>, bs: BackendState) -> Result<Sender, Error> {
296        if let Some(sender) = self.clone().get_sender(&bs) {
297            return Ok(sender);
298        }
299
300        let sender = {
301            // 预分配数量
302            net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
303            self.clone().create_sender(&bs).await.map(|s| {
304                debug!(
305                    "[http/1.1 pool] [incr] current connection count: {}",
306                    self.state.cur_conn.load(Relaxed)
307                );
308                s
309            })
310        };
311
312        if sender.is_err() {
313            assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
314        }
315
316        sender
317    }
318}
319
320pub trait HttpPool {
321    /// 根据key做某种策略获取
322    fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
323
324    /// 根据地址获取
325    fn target(self: Arc<Self>, _: &Address) -> impl Future<Output = Result<Sender, Error>> + Send;
326}
327
328impl HttpPool for Pool {
329    async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
330        let bs = self
331            .state
332            .lb_strategy
333            .get_backend(key)
334            .ok_or(Error::NoBackend)?;
335
336        self.real(bs).await
337    }
338
339    async fn target(self: Arc<Self>, addr: &Address) -> Result<Sender, Error> {
340        if !self.state.lb_strategy.contain(addr) {
341            return Err(Error::NoBackend);
342        }
343
344        self.real(BackendState::new(None, addr.clone())).await
345    }
346}
347
348pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
349    HttpPool::get(pool, key).await
350}