1use crate::http2::{SendRequest, Sender};
2use hyper::client::conn::http2;
3use net_pool::backend::{Address, BackendState};
4use net_pool::{Error, Strategy, debug, instrument_current_span, tokio_spawn};
5use std::collections::HashMap;
6use std::collections::hash_map::Entry;
7use std::sync::atomic::Ordering::Relaxed;
8use std::sync::atomic::{AtomicBool, AtomicUsize};
9use std::sync::{Arc, Mutex};
10use crate::body::VariantBody;
11
12struct Inner(Vec<SendRequest>);
13
14impl Inner {
15 fn new() -> Self {
16 Self(vec![])
17 }
18
19 fn add_sr(&mut self, sr: SendRequest) {
20 self.0.push(sr);
21 }
22
23 fn remove_sr(&mut self, other: SendRequest) -> bool {
24 if let Some(item) = self.0.iter().position(|sr| sr == &other) {
25 self.0.remove(item);
26 true
27 } else {
28 false
29 }
30 }
31
32 fn get_sr(&mut self, max_streams: Option<usize>) -> (usize, Option<SendRequest>) {
34 let mut dels = 0;
36
37 for idx in (0..self.0.len()).rev() {
38 let sr = &self.0[idx];
39 if sr.is_closed() {
40 dels += 1;
41 self.0.remove(idx);
42 } else if sr.limited(max_streams) {
43 continue;
44 } else {
45 return (dels, Some(sr.clone()));
46 }
47 }
48
49 (dels, None)
50 }
51
52 fn len(&self) -> usize {
53 self.0.len()
54 }
55}
56
57pub struct Pool {
66 state: net_pool::pool::BaseState,
67 max_streams: AtomicUsize,
68 free_conn_map: Mutex<HashMap<u64, Inner>>,
69 use_tls: AtomicBool,
70}
71
72impl Pool {
73 pub fn new(strategy: Arc<dyn Strategy>, mut max_streams: Option<usize>) -> Self {
74 if let Some(0) = max_streams {
75 max_streams = None;
76 }
77
78 let p = Pool {
79 state: net_pool::pool::BaseState::new(strategy),
80 max_streams: AtomicUsize::new(usize::MAX),
81 free_conn_map: Mutex::new(HashMap::new()),
82 use_tls: AtomicBool::new(false),
83 };
84
85 p.set_max_streams(max_streams);
86 <Pool as net_pool::pool::Pool>::set_keepalive(
87 &p,
88 Some(std::time::Duration::from_secs(60 * 1)),
89 );
90 p
91 }
92
93 pub fn set_max_streams(&self, mut max_streams: Option<usize>) {
94 if let Some(0) = max_streams {
95 max_streams = None;
96 }
97
98 match max_streams {
99 None => self.max_streams.store(usize::MAX, Relaxed),
100 Some(v) => self.max_streams.store(v, Relaxed),
101 }
102 }
103
104 pub fn get_max_streams(&self) -> Option<usize> {
105 match self.max_streams.load(Relaxed) {
106 usize::MAX => None,
107 v => Some(v),
108 }
109 }
110}
111
112impl Default for Pool {
113 fn default() -> Self {
114 Pool::new(
115 Arc::new(net_pool::strategy::CHStrategy::default()),
116 Some(200),
117 )
118 }
119}
120
121impl<L: Strategy + 'static> From<L> for Pool {
122 fn from(value: L) -> Self {
123 Self::new(Arc::new(value), Some(200))
124 }
125}
126
127impl net_pool::pool::Pool for Pool {
128 net_pool::macros::base_pool_impl! {state}
129
130 fn remove_backend(&self, addr: &Address) -> bool {
131 if self.state.lb_strategy.remove_backend(addr) {
132 self.clear_bs_sr(addr);
134 true
135 } else {
136 false
137 }
138 }
139
140 fn use_tls(&self, tls: bool) {
141 self.use_tls.store(tls, Relaxed);
142 }
143
144 fn tls(&self) -> bool {
145 self.use_tls.load(Relaxed)
146 }
147}
148
149impl Pool {
150 fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
151 let mut guard = self.free_conn_map.lock().unwrap();
152 let inners = guard.get_mut(&bs.hash_code())?;
153
154 let sr = {
155 let (del_cnt, sr) = inners.get_sr(self.get_max_streams());
156 if del_cnt > 0 {
157 assert!(self.state.cur_conn.fetch_sub(del_cnt, Relaxed) > 0);
158 debug!(
159 "[desc] current connection count: {}",
160 self.state.cur_conn.load(Relaxed)
161 );
162 }
163 sr
164 };
165
166 if let Some(sr) = sr {
167 let tls = <Pool as net_pool::pool::Pool>::tls(self);
168 Some(Sender::new(
169 sr,
170 crate::utils::base_url(tls, bs.get_address()),
171 ))
172 } else {
173 None
174 }
175 }
176
177 fn add_sr(&self, hash_code: u64, sr: SendRequest) {
178 let mut guard = self.free_conn_map.lock().unwrap();
179 match guard.entry(hash_code) {
180 Entry::Occupied(mut o) => {
181 o.get_mut().add_sr(sr);
182 }
183 Entry::Vacant(v) => {
184 let mut inner = Inner::new();
185 inner.add_sr(sr);
186 v.insert(inner);
187 }
188 }
189 }
190
191 fn remove_sr(&self, hash_code: u64, sr: SendRequest) {
192 let mut guard = self.free_conn_map.lock().unwrap();
193 if let Some(i) = guard.get_mut(&hash_code) {
194 if i.remove_sr(sr) {
195 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
196 debug!(
197 "[desc] current connection count: {}",
198 self.state.cur_conn.load(Relaxed)
199 );
200 }
201 }
202 }
203
204 fn clear_bs_sr(&self, addr: &Address) {
205 let mut guard = self.free_conn_map.lock().unwrap();
206 if let Some(inners) = guard.remove(&addr.hash_code()) {
207 assert!(self.state.cur_conn.fetch_sub(inners.len(), Relaxed) > 0);
208 debug!(
209 "[desc] current connection count: {}",
210 self.state.cur_conn.load(Relaxed)
211 );
212 }
213 }
214
215 fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
216 pool: Arc<Self>,
217 c: C,
218 sr: http2::SendRequest<VariantBody>,
219 bs: &BackendState,
220 ) -> Sender {
221 let tls = <Pool as net_pool::pool::Pool>::tls(&pool);
222 let sr = SendRequest::new(sr);
223 let sender = Sender::new(sr.clone(), crate::utils::base_url(tls, bs.get_address()));
224
225 let code = bs.hash_code();
226
227 pool.add_sr(code, sr.clone());
232 let ka = <Pool as net_pool::pool::Pool>::get_keepalive(&pool);
233
234 tokio_spawn! {
236 instrument_current_span! {
237 async move {
238 let _r = crate::utils::run_conn(sr.clone(), c, ka).await;
239 debug!("connection closed: {:?}", _r);
240 pool.remove_sr(code, sr);
241 }
242 }
243 };
244
245 sender
246 }
247
248 async fn create_tls_sender(
249 self: Arc<Self>,
250 bs: &BackendState,
251 exec: hyper_util::rt::TokioExecutor,
252 ) -> Result<Sender, Error> {
253 let addr = bs.get_address();
254
255 let tcp = crate::utils::create_https_stream(addr).await?;
257 let tls_tcp =
258 crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP2_TLS_CLIENT_CFG.clone())
259 .await?;
260
261 let max_streams = self.get_max_streams();
263 let io = hyper_util::rt::TokioIo::new(tls_tcp);
264 let pair = http2::Builder::new(exec)
265 .max_concurrent_streams(max_streams.map(|m| m as u32))
266 .handshake(io)
267 .await
268 .map_err(|e| Error::from_other(e))?;
269
270 let sender = Pool::run_conn(self, pair.1, pair.0, &bs);
272 Ok(sender)
273 }
274
275 async fn create_non_tls_sender(
276 self: Arc<Self>,
277 bs: &BackendState,
278 exec: hyper_util::rt::TokioExecutor,
279 ) -> Result<Sender, Error> {
280 let tcp = crate::utils::create_http_stream(bs.get_address()).await?;
281
282 let io = hyper_util::rt::TokioIo::new(tcp);
284 let pair = http2::handshake(exec, io)
285 .await
286 .map_err(|e| Error::from_other(e))?;
287
288 let sender = Pool::run_conn(self, pair.1, pair.0, bs);
290 Ok(sender)
291 }
292
293 async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
294 let exec = hyper_util::rt::tokio::TokioExecutor::new();
295 if <Pool as net_pool::pool::Pool>::tls(&self) {
296 self.create_tls_sender(bs, exec).await
297 } else {
298 self.create_non_tls_sender(bs, exec).await
299 }
300 }
301
302 async fn real(self: Arc<Self>, bs: BackendState) -> Result<Sender, Error> {
303 let sender = match self.get_sender(&bs) {
304 Some(s) => Ok(s),
305 None => {
306 net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
308 self.clone().create_sender(&bs).await.map(|s| {
309 debug!(
310 "[incr] current connection count: {}",
311 self.state.cur_conn.load(Relaxed)
312 );
313 s
314 })
315 }
316 };
317
318 if sender.is_err() {
319 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
320 }
321
322 sender
323 }
324}
325
326pub trait HttpPool {
327 fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
328
329 fn target(
330 self: Arc<Self>,
331 addr: &Address,
332 ) -> impl Future<Output = Result<Sender, Error>> + Send;
333}
334
335impl HttpPool for Pool {
336 async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
337 let bs = self
338 .state
339 .lb_strategy
340 .get_backend(key)
341 .ok_or(Error::NoBackend);
342
343 self.real(bs?).await
344 }
345
346 async fn target(self: Arc<Self>, addr: &Address) -> Result<Sender, Error> {
347 if !self.state.lb_strategy.contain(addr) {
348 Err(Error::NoBackend)
349 } else {
350 self.real(BackendState::new(None, addr.clone())).await
351 }
352 }
353}
354
355pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
356 HttpPool::get(pool, key).await
357}