http_pool/http1/
pool.rs

1use crate::http1::{SendRequest, Sender};
2use hyper::client::conn::http1;
3use net_pool::backend::{Address, BackendState};
4use net_pool::{Error, Strategy, debug, instrument_current_span, tokio_spawn};
5use std::collections::hash_map::Entry;
6use std::collections::{HashMap, HashSet};
7use std::hash::{Hash, Hasher};
8use std::sync::atomic::AtomicBool;
9use std::sync::atomic::Ordering::Relaxed;
10use std::sync::{Arc, Mutex};
11
12#[derive(Clone)]
13struct SrData(Arc<SendRequest>);
14
15impl PartialEq for SrData {
16    fn eq(&self, other: &Self) -> bool {
17        self.0.as_ref() as *const _ == other.0.as_ref() as *const _
18    }
19}
20
21impl Eq for SrData {}
22
23impl Hash for SrData {
24    fn hash<H: Hasher>(&self, state: &mut H) {
25        state.write_usize(self.0.as_ref() as *const _ as usize)
26    }
27}
28
29struct Inner {
30    /// 空闲的sr
31    free: HashSet<SrData>,
32    /// 在用的sr
33    work: HashSet<SrData>,
34}
35
36impl Inner {
37    fn new() -> Self {
38        Inner {
39            free: HashSet::new(),
40            work: HashSet::new(),
41        }
42    }
43
44    fn add_work_sr(&mut self, sr: Arc<SendRequest>) {
45        self.work.insert(SrData(sr));
46    }
47
48    fn remove_sr(&mut self, sr: Arc<SendRequest>) -> bool {
49        let d = SrData(sr);
50        // 不应该同时处于work和free中
51        if self.work.remove(&d) {
52            true
53        } else if self.free.remove(&d) {
54            true
55        } else {
56            false
57        }
58    }
59
60    /// 返回值是被删除的元素个数, 及查找到的sr
61    fn get_sr(&mut self) -> (usize, Option<Arc<SendRequest>>) {
62        // 删除个数
63        let mut dels = 0;
64
65        // 空了的话从work里取, 这里效率不高
66        if self.free.is_empty() {
67            // 挪动
68            self.work.retain(|d| {
69                // 引用计数为2时,代表外界不存在引用, 队列里一份, run那里一份
70                if Arc::strong_count(&d.0) == 2 {
71                    if !d.0.is_closed() {
72                        self.free.insert(d.clone());
73                    } else {
74                        dels += 1;
75                    }
76                    false
77                } else {
78                    true
79                }
80            });
81        }
82
83        // 这里虽然删除了一些closed状态的, 但算法却有延迟性,它无法及时删除所有的closed
84        loop {
85            let d = match self.free.iter().next() {
86                None => return (dels, None),
87                Some(d) => d.clone(),
88            };
89
90            self.free.remove(&d);
91            if d.0.is_closed() {
92                dels += 1;
93            } else {
94                let sr = d.0.clone();
95                self.work.insert(d);
96                return (dels, Some(sr));
97            }
98        }
99    }
100}
101
102/// http1连接池
103/// 连接池中的连接受max_conn数据限制
104/// 默认的keepalive是一分钟
105/// 有三种行为会导致连接池中的连接被释放
106///     1.对方主动关闭连接
107///     2.remove_backend被调用清除后端地址, 这会导致该地址的所有连接失效
108///     3.在外界没有引用sender达到keepalive时间后, 注意如果没有设置keepalive则此条无效
109pub struct Pool {
110    state: net_pool::pool::BaseState,
111    use_tls: AtomicBool,
112    free_conn_map: Mutex<HashMap<u64, Inner>>,
113}
114
115impl Pool {
116    pub fn new(strategy: Arc<dyn Strategy>) -> Self {
117        let p = Pool {
118            state: net_pool::pool::BaseState::new(strategy),
119            use_tls: AtomicBool::new(false),
120            free_conn_map: Mutex::new(HashMap::new()),
121        };
122        <Pool as net_pool::pool::Pool>::set_keepalive(
123            &p,
124            Some(std::time::Duration::from_secs(60 * 1)),
125        );
126        p
127    }
128}
129
130impl Default for Pool {
131    fn default() -> Self {
132        Pool::new(Arc::new(net_pool::strategy::CHStrategy::default()))
133    }
134}
135
136impl<L: Strategy + 'static> From<L> for Pool {
137    fn from(value: L) -> Self {
138        Self::new(Arc::new(value))
139    }
140}
141
142impl net_pool::pool::Pool for Pool {
143    net_pool::macros::base_pool_impl! {state}
144
145    fn remove_backend(&self, addr: &Address) -> bool {
146        if self.state.lb_strategy.remove_backend(addr) {
147            // 清除缓存
148            self.clear_bs_sr(addr.hash_code());
149            true
150        } else {
151            false
152        }
153    }
154
155    fn use_tls(&self, tls: bool) {
156        self.use_tls.store(tls, Relaxed);
157    }
158
159    fn tls(&self) -> bool {
160        self.use_tls.load(Relaxed)
161    }
162}
163
164impl Pool {
165    fn clear_bs_sr(&self, hash_code: u64) {
166        let mut guard = self.free_conn_map.lock().unwrap();
167        if let Some(inner) = guard.remove(&hash_code) {
168            assert!(
169                self.state
170                    .cur_conn
171                    .fetch_sub(inner.free.len() + inner.work.len(), Relaxed)
172                    > 0
173            );
174            debug!(
175                "[desc] current connection count: {}",
176                self.state.cur_conn.load(Relaxed)
177            );
178        }
179    }
180
181    fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
182        let mut guard = self.free_conn_map.lock().unwrap();
183        let inner = guard.get_mut(&bs.hash_code())?;
184
185        let sr = {
186            let (del_cnt, sr) = inner.get_sr();
187            if del_cnt > 0 {
188                assert!(self.state.cur_conn.fetch_sub(del_cnt, Relaxed) > 0);
189                debug!(
190                    "[desc] current connection count: {}",
191                    self.state.cur_conn.load(Relaxed)
192                );
193            }
194            sr
195        };
196
197        if let Some(sr) = sr {
198            let tls = <Pool as net_pool::pool::Pool>::tls(self);
199            Some(Sender::new(
200                sr,
201                crate::utils::base_url(tls, bs.get_address()),
202            ))
203        } else {
204            None
205        }
206    }
207
208    fn add_sr(&self, hash_code: u64, sr: Arc<SendRequest>) {
209        let mut guard = self.free_conn_map.lock().unwrap();
210        match guard.entry(hash_code) {
211            Entry::Occupied(mut o) => {
212                o.get_mut().add_work_sr(sr);
213            }
214            Entry::Vacant(e) => {
215                let mut inner = Inner::new();
216                inner.add_work_sr(sr);
217                e.insert(inner);
218            }
219        }
220    }
221
222    fn remove_sr(&self, hash_code: u64, sr: Arc<SendRequest>) {
223        let mut guard = self.free_conn_map.lock().unwrap();
224        if let Some(inner) = guard.get_mut(&hash_code) {
225            if inner.remove_sr(sr) {
226                assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
227                debug!(
228                    "[desc] current connection count: {}",
229                    self.state.cur_conn.load(Relaxed)
230                );
231            }
232        }
233    }
234
235    fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
236        pool: Arc<Self>,
237        c: C,
238        sr: Arc<SendRequest>,
239        bs: &BackendState,
240    ) {
241        let hash_code = bs.hash_code();
242
243        // 加入work,且必须在下面执行super::run_conn之前加入
244        // 添加sr进去缓存时, 可能此时bs已经被删除了
245        // 由于缓存中持有sr, 所以有可能会导致在connection没有设置超时的情况下一直在运行,从而一直占用最大连接数
246        // 由于缓存和bs是通过不同的锁持有的, 所以暂时无法对它们进行一致性限制
247        pool.add_sr(hash_code, sr.clone());
248        let ka = <Pool as net_pool::pool::Pool>::get_keepalive(&pool);
249
250        tokio_spawn! {
251            instrument_current_span! {
252                async move {
253                    let _r = crate::utils::run_conn(sr.clone(), c, ka).await;
254                    debug!("connection closed: {:?}", _r);
255                    pool.remove_sr(hash_code, sr);
256                }
257            }
258        };
259    }
260
261    async fn create_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
262        let tls = <Pool as net_pool::pool::Pool>::tls(&self);
263        let addr = bs.get_address();
264
265        // 获取连接
266        let tcp = crate::utils::create_https_stream(addr).await?;
267        let tls_tcp =
268            crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP1_TLS_CLIENT_CFG.clone())
269                .await?;
270
271        // 握手建立 sender 和 connection
272        let io = hyper_util::rt::TokioIo::new(tls_tcp);
273        let pair = http1::handshake(io)
274            .await
275            .map_err(|e| Error::from_other(e))?;
276
277        // 启动连接驱动
278        let sr = Arc::new(pair.0);
279        Pool::run_conn(self, pair.1, sr.clone(), &bs);
280
281        // 返回一份
282        Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
283    }
284
285    async fn create_non_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
286        let addr = bs.get_address();
287        let tls = <Pool as net_pool::pool::Pool>::tls(&self);
288
289        // 获取连接
290        let tcp = crate::utils::create_http_stream(addr).await?;
291
292        // 握手建立 sender 和 connection
293        let io = hyper_util::rt::TokioIo::new(tcp);
294        let pair = http1::handshake(io)
295            .await
296            .map_err(|e| Error::from_other(e))?;
297
298        // 启动连接驱动
299        let sr = Arc::new(pair.0);
300        Pool::run_conn(self, pair.1, sr.clone(), &bs);
301
302        // 返回一份
303        Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
304    }
305
306    async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
307        if <Pool as net_pool::pool::Pool>::tls(&self) {
308            self.create_tls_sender(bs).await
309        } else {
310            self.create_non_tls_sender(bs).await
311        }
312    }
313
314    async fn real(self: Arc<Self>, bs: BackendState) -> Result<Sender, Error> {
315        if let Some(sender) = self.clone().get_sender(&bs) {
316            return Ok(sender);
317        }
318
319        let sender = {
320            // 预分配数量
321            net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
322            self.clone().create_sender(&bs).await.map(|s| {
323                debug!(
324                    "[incr] current connection count: {}",
325                    self.state.cur_conn.load(Relaxed)
326                );
327                s
328            })
329        };
330
331        if sender.is_err() {
332            assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
333        }
334
335        sender
336    }
337}
338
339pub trait HttpPool {
340    /// 根据key做某种策略获取
341    fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
342
343    /// 根据地址获取
344    fn target(self: Arc<Self>, _: &Address) -> impl Future<Output = Result<Sender, Error>> + Send;
345}
346
347impl HttpPool for Pool {
348    async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
349        let bs = self
350            .state
351            .lb_strategy
352            .get_backend(key)
353            .ok_or(Error::NoBackend)?;
354
355        self.real(bs).await
356    }
357
358    async fn target(self: Arc<Self>, addr: &Address) -> Result<Sender, Error> {
359        if !self.state.lb_strategy.contain(addr) {
360            return Err(Error::NoBackend);
361        }
362
363        self.real(BackendState::new(None, addr.clone())).await
364    }
365}
366
367pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
368    HttpPool::get(pool, key).await
369}