1use crate::http2::{SendRequest, Sender};
2use hyper::client::conn::http2;
3use net_pool::backend::{Address, BackendState};
4use net_pool::{Error, Strategy, debug, instrument_current_span, tokio_spawn};
5use std::collections::HashMap;
6use std::sync::atomic::Ordering::Relaxed;
7use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize};
8use std::sync::{Arc, LazyLock, Mutex};
9
10static SENDER_ID: LazyLock<AtomicU32> = LazyLock::new(|| AtomicU32::new(1));
12
13struct Inner(SendRequest, Arc<()>, u32);
14
15impl Inner {
16 fn new(sender: SendRequest) -> Self {
17 Inner(sender, Arc::new(()), SENDER_ID.fetch_add(1, Relaxed))
18 }
19
20 fn id(&self) -> u32 {
21 self.2
22 }
23
24 fn new_sender(&self, base_url: String) -> Sender {
25 Sender::new(self.0.clone(), base_url, self.1.clone())
26 }
27
28 fn ref_count(&self) -> usize {
30 Arc::strong_count(&self.1)
31 }
32
33 fn limited(&self, max_streams: Option<usize>) -> bool {
35 if let Some(max) = max_streams {
36 if self.ref_count() >= max + 1 {
37 return true;
38 }
39 }
40 return false;
41 }
42
43 fn is_closed(&self) -> bool {
44 self.0.is_closed()
45 }
46
47 fn reference(&self) -> Arc<()> {
48 self.1.clone()
49 }
50}
51
52pub struct Pool {
61 id: String,
62 state: net_pool::pool::BaseState,
63 max_streams: AtomicUsize,
64 free_conn_map: Mutex<HashMap<u64, Vec<Inner>>>,
65 use_tls: AtomicBool,
66}
67
68impl Pool {
69 pub fn new(id: String, strategy: Arc<dyn Strategy>, mut max_streams: Option<usize>) -> Self {
70 if let Some(0) = max_streams {
71 max_streams = None;
72 }
73
74 let p = Pool {
75 id,
76 state: net_pool::pool::BaseState::new(strategy),
77 max_streams: AtomicUsize::new(usize::MAX),
78 free_conn_map: Mutex::new(HashMap::new()),
79 use_tls: AtomicBool::new(false),
80 };
81
82 p.set_max_streams(max_streams);
83 <Pool as net_pool::pool::Pool>::set_keepalive(
84 &p,
85 Some(std::time::Duration::from_secs(60 * 1)),
86 );
87 p
88 }
89
90 pub fn set_max_streams(&self, mut max_streams: Option<usize>) {
91 if let Some(0) = max_streams {
92 max_streams = None;
93 }
94
95 match max_streams {
96 None => self.max_streams.store(usize::MAX, Relaxed),
97 Some(v) => self.max_streams.store(v, Relaxed),
98 }
99 }
100 pub fn get_max_streams(&self) -> Option<usize> {
101 match self.max_streams.load(Relaxed) {
102 usize::MAX => None,
103 v => Some(v),
104 }
105 }
106
107 pub fn set_id<I: Into<String>>(&mut self, id: I) {
108 self.id = id.into();
109 }
110}
111
112impl Default for Pool {
113 fn default() -> Self {
114 Pool::new(
115 "".to_string(),
116 Arc::new(net_pool::strategy::CHStrategy::default()),
117 Some(200),
118 )
119 }
120}
121
122impl<L: Strategy + 'static> From<L> for Pool {
123 fn from(value: L) -> Self {
124 Self::new("".to_string(), Arc::new(value), Some(200))
125 }
126}
127
128impl net_pool::pool::Pool for Pool {
129 net_pool::macros::base_pool_impl! {state}
130
131 fn id(&self) -> &str {
132 &self.id
133 }
134
135 fn remove_backend(&self, addr: &Address) -> bool {
136 if self.state.lb_strategy.remove_backend(addr) {
137 self.clear_bs_inner(addr);
139 true
140 } else {
141 false
142 }
143 }
144
145 fn use_tls(&self, tls: bool) {
146 self.use_tls.store(tls, Relaxed);
147 }
148
149 fn tls(&self) -> bool {
150 self.use_tls.load(Relaxed)
151 }
152}
153
154impl Pool {
155 fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
156 let mut guard = self.free_conn_map.lock().unwrap();
157
158 let inners = guard.get_mut(&bs.hash_code())?;
160 for idx in (0..inners.len()).rev() {
161 let inner = &mut inners[idx];
162 if inner.is_closed() {
163 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
165 debug!(
166 "[http/2.0 pool] [desc] current connection count: {}",
167 self.state.cur_conn.load(Relaxed)
168 );
169 inners.remove(idx);
170 continue;
171 }
172
173 if inner.limited(self.get_max_streams()) {
174 continue;
175 }
176
177 let tls = <Pool as net_pool::pool::Pool>::tls(self);
178 return Some(inner.new_sender(crate::utils::base_url(tls, bs.get_address())));
179 }
180
181 None
182 }
183
184 fn add_inner(&self, hash_code: u64, inner: Inner) {
185 let mut guard = self.free_conn_map.lock().unwrap();
186 if let Some(inners) = guard.get_mut(&hash_code) {
187 inners.push(inner);
188 } else {
189 guard.insert(hash_code, vec![inner]);
190 }
191 }
192
193 fn clear_bs_inner(&self, addr: &Address) {
194 let mut guard = self.free_conn_map.lock().unwrap();
195 if let Some(inners) = guard.remove(&addr.hash_code()) {
196 assert!(self.state.cur_conn.fetch_sub(inners.len(), Relaxed) > 0);
197 debug!(
198 "[http/2.0 pool] [desc] current connection count: {}",
199 self.state.cur_conn.load(Relaxed)
200 );
201 }
202 }
203
204 fn remove_inner(&self, hash_code: u64, id: u32) {
205 let mut guard = self.free_conn_map.lock().unwrap();
206 if let Some(inners) = guard.get_mut(&hash_code) {
207 let mut del = false;
208 inners.retain(|inner| {
209 if inner.id() == id {
210 del = true;
211 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
212 false
213 } else {
214 true
215 }
216 });
217 if del {
218 debug!(
219 "[http/2.0 pool] [desc] current connection count: {}",
220 self.state.cur_conn.load(Relaxed)
221 );
222 }
223 }
224 }
225
226 fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
227 pool: Arc<Self>,
228 c: C,
229 sr: SendRequest,
230 bs: &BackendState,
231 ) -> Sender {
232 let tls = <Pool as net_pool::pool::Pool>::tls(&pool);
233 let inner = Inner::new(sr);
234 let sender = inner.new_sender(crate::utils::base_url(tls, bs.get_address()));
235
236 let tuple = (
237 bs.hash_code(),
238 <Pool as net_pool::pool::Pool>::get_keepalive(&pool),
239 inner.reference(),
240 inner.id(),
241 );
242
243 pool.add_inner(tuple.0, inner);
248
249 tokio_spawn! {
251 instrument_current_span! {
252 async move {
253 let _r = crate::utils::run_conn(tuple.2, c, tuple.1).await;
254 debug!("[http/2.0 pool] connection closed: {:?}", _r);
255 pool.remove_inner(tuple.0, tuple.3);
256 }
257 }
258 };
259
260 sender
261 }
262
263 async fn create_tls_sender(
264 self: Arc<Self>,
265 bs: &BackendState,
266 exec: hyper_util::rt::TokioExecutor,
267 ) -> Result<Sender, Error> {
268 let addr = bs.get_address();
269
270 let tcp = crate::utils::create_https_stream(addr).await?;
272 let tls_tcp =
273 crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP2_TLS_CLIENT_CFG.clone())
274 .await?;
275
276 let max_streams = self.get_max_streams();
278 let io = hyper_util::rt::TokioIo::new(tls_tcp);
279 let pair = http2::Builder::new(exec)
280 .max_concurrent_streams(max_streams.map(|m| m as u32))
281 .handshake(io)
282 .await
283 .map_err(|e| Error::from_other(e))?;
284
285 let sender = Pool::run_conn(self, pair.1, pair.0, &bs);
287 Ok(sender)
288 }
289
290 async fn create_non_tls_sender(
291 self: Arc<Self>,
292 bs: &BackendState,
293 exec: hyper_util::rt::TokioExecutor,
294 ) -> Result<Sender, Error> {
295 let tcp = crate::utils::create_http_stream(bs.get_address()).await?;
296
297 let io = hyper_util::rt::TokioIo::new(tcp);
299 let pair = http2::handshake(exec, io)
300 .await
301 .map_err(|e| Error::from_other(e))?;
302
303 let sender = Pool::run_conn(self, pair.1, pair.0, bs);
305 Ok(sender)
306 }
307
308 async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
309 let exec = hyper_util::rt::tokio::TokioExecutor::new();
310 if <Pool as net_pool::pool::Pool>::tls(&self) {
311 self.create_tls_sender(bs, exec).await
312 } else {
313 self.create_non_tls_sender(bs, exec).await
314 }
315 }
316
317 async fn real(self: Arc<Self>, bs: BackendState) -> Result<Sender, Error> {
318 let sender = match self.get_sender(&bs) {
319 Some(s) => Ok(s),
320 None => {
321 net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
323 self.clone().create_sender(&bs).await.map(|s| {
324 debug!(
325 "[http/2.0 pool] [incr] current connection count: {}",
326 self.state.cur_conn.load(Relaxed)
327 );
328 s
329 })
330 }
331 };
332
333 if sender.is_err() {
334 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
335 }
336
337 sender
338 }
339}
340
341pub trait HttpPool {
342 fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
343
344 fn target(
345 self: Arc<Self>,
346 addr: &Address,
347 ) -> impl Future<Output = Result<Sender, Error>> + Send;
348}
349
350impl HttpPool for Pool {
351 async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
352 let bs = self
353 .state
354 .lb_strategy
355 .get_backend(key)
356 .ok_or(Error::NoBackend);
357
358 self.real(bs?).await
359 }
360
361 async fn target(self: Arc<Self>, addr: &Address) -> Result<Sender, Error> {
362 if !self.state.lb_strategy.contain(addr) {
363 Err(Error::NoBackend)
364 } else {
365 self.real(BackendState::new(None, addr.clone())).await
366 }
367 }
368}
369
370pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
371 HttpPool::get(pool, key).await
372}