1use crate::http2::{SendRequest, Sender};
2use hyper::client::conn::http2;
3use net_pool::backend::{Address, BackendState};
4use net_pool::{Error, Strategy, debug, instrument_current_span, tokio_spawn};
5use std::collections::HashMap;
6use std::sync::atomic::Ordering::Relaxed;
7use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize};
8use std::sync::{Arc, LazyLock, Mutex};
9
10static SENDER_ID: LazyLock<AtomicU32> = LazyLock::new(|| AtomicU32::new(1));
12
13struct Inner(SendRequest, Arc<()>, u32);
14
15impl Inner {
16 fn new(sender: SendRequest) -> Self {
17 Inner(sender, Arc::new(()), SENDER_ID.fetch_add(1, Relaxed))
18 }
19
20 fn id(&self) -> u32 {
21 self.2
22 }
23
24 fn new_sender(&self, base_url: String) -> Sender {
25 Sender::new(self.0.clone(), base_url, self.1.clone())
26 }
27
28 fn ref_count(&self) -> usize {
30 Arc::strong_count(&self.1)
31 }
32
33 fn limited(&self, max_streams: Option<usize>) -> bool {
35 if let Some(max) = max_streams {
36 if self.ref_count() >= max + 1 {
37 return true;
38 }
39 }
40 return false;
41 }
42
43 fn is_closed(&self) -> bool {
44 self.0.is_closed()
45 }
46
47 fn reference(&self) -> Arc<()> {
48 self.1.clone()
49 }
50}
51
52pub struct Pool {
61 state: net_pool::pool::BaseState,
62 max_streams: AtomicUsize,
63 free_conn_map: Mutex<HashMap<u64, Vec<Inner>>>,
64 use_tls: AtomicBool,
65}
66
67impl Pool {
68 pub fn new(strategy: Arc<dyn Strategy>, mut max_streams: Option<usize>) -> Self {
69 if let Some(0) = max_streams {
70 max_streams = None;
71 }
72
73 let p = Pool {
74 state: net_pool::pool::BaseState::new(strategy),
75 max_streams: AtomicUsize::new(usize::MAX),
76 free_conn_map: Mutex::new(HashMap::new()),
77 use_tls: AtomicBool::new(false),
78 };
79
80 p.set_max_streams(max_streams);
81 <Pool as net_pool::pool::Pool>::set_keepalive(
82 &p,
83 Some(std::time::Duration::from_secs(60 * 1)),
84 );
85 p
86 }
87
88 pub fn set_max_streams(&self, mut max_streams: Option<usize>) {
89 if let Some(0) = max_streams {
90 max_streams = None;
91 }
92
93 match max_streams {
94 None => self.max_streams.store(usize::MAX, Relaxed),
95 Some(v) => self.max_streams.store(v, Relaxed),
96 }
97 }
98 pub fn get_max_streams(&self) -> Option<usize> {
99 match self.max_streams.load(Relaxed) {
100 usize::MAX => None,
101 v => Some(v),
102 }
103 }
104}
105
106impl Default for Pool {
107 fn default() -> Self {
108 Pool::new(
109 Arc::new(net_pool::strategy::CHStrategy::default()),
110 Some(200),
111 )
112 }
113}
114
115impl<L: Strategy + 'static> From<L> for Pool {
116 fn from(value: L) -> Self {
117 Self::new(Arc::new(value), Some(200))
118 }
119}
120
121impl net_pool::pool::Pool for Pool {
122 net_pool::macros::base_pool_impl! {state}
123
124 fn remove_backend(&self, addr: &Address) -> bool {
125 if self.state.lb_strategy.remove_backend(addr) {
126 self.clear_bs_inner(addr);
128 true
129 } else {
130 false
131 }
132 }
133
134 fn use_tls(&self, tls: bool) {
135 self.use_tls.store(tls, Relaxed);
136 }
137
138 fn tls(&self) -> bool {
139 self.use_tls.load(Relaxed)
140 }
141}
142
143impl Pool {
144 fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
145 let mut guard = self.free_conn_map.lock().unwrap();
146
147 let inners = guard.get_mut(&bs.hash_code())?;
149 for idx in (0..inners.len()).rev() {
150 let inner = &mut inners[idx];
151 if inner.is_closed() {
152 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
154 debug!(
155 "[http/2.0 pool] [desc] current connection count: {}",
156 self.state.cur_conn.load(Relaxed)
157 );
158 inners.remove(idx);
159 continue;
160 }
161
162 if inner.limited(self.get_max_streams()) {
163 continue;
164 }
165
166 let tls = <Pool as net_pool::pool::Pool>::tls(self);
167 return Some(inner.new_sender(crate::utils::base_url(tls, bs.get_address())));
168 }
169
170 None
171 }
172
173 fn add_inner(&self, hash_code: u64, inner: Inner) {
174 let mut guard = self.free_conn_map.lock().unwrap();
175 if let Some(inners) = guard.get_mut(&hash_code) {
176 inners.push(inner);
177 } else {
178 guard.insert(hash_code, vec![inner]);
179 }
180 }
181
182 fn clear_bs_inner(&self, addr: &Address) {
183 let mut guard = self.free_conn_map.lock().unwrap();
184 if let Some(inners) = guard.remove(&addr.hash_code()) {
185 assert!(self.state.cur_conn.fetch_sub(inners.len(), Relaxed) > 0);
186 debug!(
187 "[http/2.0 pool] [desc] current connection count: {}",
188 self.state.cur_conn.load(Relaxed)
189 );
190 }
191 }
192
193 fn remove_inner(&self, hash_code: u64, id: u32) {
194 let mut guard = self.free_conn_map.lock().unwrap();
195 if let Some(inners) = guard.get_mut(&hash_code) {
196 let mut del = false;
197 inners.retain(|inner| {
198 if inner.id() == id {
199 del = true;
200 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
201 false
202 } else {
203 true
204 }
205 });
206 if del {
207 debug!(
208 "[http/2.0 pool] [desc] current connection count: {}",
209 self.state.cur_conn.load(Relaxed)
210 );
211 }
212 }
213 }
214
215 fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
216 pool: Arc<Self>,
217 c: C,
218 sr: SendRequest,
219 bs: &BackendState,
220 ) -> Sender {
221 let tls = <Pool as net_pool::pool::Pool>::tls(&pool);
222 let inner = Inner::new(sr);
223 let sender = inner.new_sender(crate::utils::base_url(tls, bs.get_address()));
224
225 let tuple = (
226 bs.hash_code(),
227 <Pool as net_pool::pool::Pool>::get_keepalive(&pool),
228 inner.reference(),
229 inner.id(),
230 );
231
232 pool.add_inner(tuple.0, inner);
237
238 tokio_spawn! {
240 instrument_current_span! {
241 async move {
242 let _r = crate::utils::run_conn(tuple.2, c, tuple.1).await;
243 debug!("[http/2.0 pool] connection closed: {:?}", _r);
244 pool.remove_inner(tuple.0, tuple.3);
245 }
246 }
247 };
248
249 sender
250 }
251
252 async fn create_tls_sender(
253 self: Arc<Self>,
254 bs: &BackendState,
255 exec: hyper_util::rt::TokioExecutor,
256 ) -> Result<Sender, Error> {
257 let addr = bs.get_address();
258
259 let tcp = crate::utils::create_https_stream(addr).await?;
261 let tls_tcp =
262 crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP2_TLS_CLIENT_CFG.clone())
263 .await?;
264
265 let max_streams = self.get_max_streams();
267 let io = hyper_util::rt::TokioIo::new(tls_tcp);
268 let pair = http2::Builder::new(exec)
269 .max_concurrent_streams(max_streams.map(|m| m as u32))
270 .handshake(io)
271 .await
272 .map_err(|e| Error::from_other(e))?;
273
274 let sender = Pool::run_conn(self, pair.1, pair.0, &bs);
276 Ok(sender)
277 }
278
279 async fn create_non_tls_sender(
280 self: Arc<Self>,
281 bs: &BackendState,
282 exec: hyper_util::rt::TokioExecutor,
283 ) -> Result<Sender, Error> {
284 let tcp = crate::utils::create_http_stream(bs.get_address()).await?;
285
286 let io = hyper_util::rt::TokioIo::new(tcp);
288 let pair = http2::handshake(exec, io)
289 .await
290 .map_err(|e| Error::from_other(e))?;
291
292 let sender = Pool::run_conn(self, pair.1, pair.0, bs);
294 Ok(sender)
295 }
296
297 async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
298 let exec = hyper_util::rt::tokio::TokioExecutor::new();
299 if <Pool as net_pool::pool::Pool>::tls(&self) {
300 self.create_tls_sender(bs, exec).await
301 } else {
302 self.create_non_tls_sender(bs, exec).await
303 }
304 }
305
306 async fn real(self: Arc<Self>, bs: BackendState) -> Result<Sender, Error> {
307 let sender = match self.get_sender(&bs) {
308 Some(s) => Ok(s),
309 None => {
310 net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
312 self.clone().create_sender(&bs).await.map(|s| {
313 debug!(
314 "[http/2.0 pool] [incr] current connection count: {}",
315 self.state.cur_conn.load(Relaxed)
316 );
317 s
318 })
319 }
320 };
321
322 if sender.is_err() {
323 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
324 }
325
326 sender
327 }
328}
329
330pub trait HttpPool {
331 fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
332
333 fn target(
334 self: Arc<Self>,
335 addr: &Address,
336 ) -> impl Future<Output = Result<Sender, Error>> + Send;
337}
338
339impl HttpPool for Pool {
340 async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
341 let bs = self
342 .state
343 .lb_strategy
344 .get_backend(key)
345 .ok_or(Error::NoBackend);
346
347 self.real(bs?).await
348 }
349
350 async fn target(self: Arc<Self>, addr: &Address) -> Result<Sender, Error> {
351 if !self.state.lb_strategy.contain(addr) {
352 Err(Error::NoBackend)
353 } else {
354 self.real(BackendState::new(None, addr.clone())).await
355 }
356 }
357}
358
359pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
360 HttpPool::get(pool, key).await
361}