1use crate::http2::{SendRequest, Sender};
2use hyper::client::conn::http2;
3use net_pool::backend::{Address, BackendState};
4use net_pool::strategy::LbStrategy;
5use net_pool::{Error, debug, instrument_current_span, tokio_spawn};
6use std::collections::HashMap;
7use std::sync::atomic::Ordering::Relaxed;
8use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize};
9use std::sync::{Arc, LazyLock, Mutex};
10
11static SENDER_ID: LazyLock<AtomicU32> = LazyLock::new(|| AtomicU32::new(1));
13
14struct Inner(SendRequest, Arc<()>, u32);
15
16impl Inner {
17 fn new(sender: SendRequest) -> Self {
18 Inner(sender, Arc::new(()), SENDER_ID.fetch_add(1, Relaxed))
19 }
20
21 fn id(&self) -> u32 {
22 self.2
23 }
24
25 fn new_sender(&self, base_url: String) -> Sender {
26 Sender::new(self.0.clone(), base_url, self.1.clone())
27 }
28
29 fn ref_count(&self) -> usize {
31 Arc::strong_count(&self.1)
32 }
33
34 fn limited(&self, max_streams: Option<usize>) -> bool {
36 if let Some(max) = max_streams {
37 if self.ref_count() >= max + 1 {
38 return true;
39 }
40 }
41 return false;
42 }
43
44 fn is_closed(&self) -> bool {
45 self.0.is_closed()
46 }
47
48 fn reference(&self) -> Arc<()> {
49 self.1.clone()
50 }
51}
52
53pub struct Pool {
62 id: String,
63 state: net_pool::pool::BaseState,
64 max_streams: AtomicUsize,
65 free_conn_map: Mutex<HashMap<u64, Vec<Inner>>>,
66 use_tls: AtomicBool,
67}
68
69impl Pool {
70 pub fn new(id: String, strategy: Arc<dyn LbStrategy>, mut max_streams: Option<usize>) -> Self {
71 if let Some(0) = max_streams {
72 max_streams = None;
73 }
74
75 let p = Pool {
76 id,
77 state: net_pool::pool::BaseState::new(strategy),
78 max_streams: AtomicUsize::new(usize::MAX),
79 free_conn_map: Mutex::new(HashMap::new()),
80 use_tls: AtomicBool::new(false),
81 };
82
83 p.set_max_streams(max_streams);
84 <Pool as net_pool::pool::Pool>::set_keepalive(
85 &p,
86 Some(std::time::Duration::from_secs(60 * 1)),
87 );
88 p
89 }
90
91 pub fn set_max_streams(&self, mut max_streams: Option<usize>) {
92 if let Some(0) = max_streams {
93 max_streams = None;
94 }
95
96 match max_streams {
97 None => self.max_streams.store(usize::MAX, Relaxed),
98 Some(v) => self.max_streams.store(v, Relaxed),
99 }
100 }
101 pub fn get_max_streams(&self) -> Option<usize> {
102 match self.max_streams.load(Relaxed) {
103 usize::MAX => None,
104 v => Some(v),
105 }
106 }
107
108 pub fn set_id<I: Into<String>>(&mut self, id: I) {
109 self.id = id.into();
110 }
111}
112
113impl Default for Pool {
114 fn default() -> Self {
115 Pool::new(
116 "".to_string(),
117 Arc::new(net_pool::strategy::CHStrategy::default()),
118 Some(200),
119 )
120 }
121}
122
123impl<L: LbStrategy + 'static> From<L> for Pool {
124 fn from(value: L) -> Self {
125 Self::new("".to_string(), Arc::new(value), Some(200))
126 }
127}
128
129impl net_pool::pool::Pool for Pool {
130 net_pool::macros::base_pool_impl! {state}
131
132 fn id(&self) -> &str {
133 &self.id
134 }
135
136 fn remove_backend(&self, addr: &Address) -> bool {
137 if self.state.lb_strategy.remove_backend(addr) {
138 self.clear_bs_inner(addr);
140 true
141 } else {
142 false
143 }
144 }
145
146 fn use_tls(&self, tls: bool) {
147 self.use_tls.store(tls, Relaxed);
148 }
149
150 fn tls(&self) -> bool {
151 self.use_tls.load(Relaxed)
152 }
153}
154
155impl Pool {
156 fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
157 let mut guard = self.free_conn_map.lock().unwrap();
158
159 let inners = guard.get_mut(&bs.hash_code())?;
161 for idx in (0..inners.len()).rev() {
162 let inner = &mut inners[idx];
163 if inner.is_closed() {
164 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
166 debug!(
167 "[http/2.0 pool] [desc] current connection count: {}",
168 self.state.cur_conn.load(Relaxed)
169 );
170 inners.remove(idx);
171 continue;
172 }
173
174 if inner.limited(self.get_max_streams()) {
175 continue;
176 }
177
178 let tls = <Pool as net_pool::pool::Pool>::tls(self);
179 return Some(inner.new_sender(crate::utils::base_url(tls, bs.get_address())));
180 }
181
182 None
183 }
184
185 fn add_inner(&self, hash_code: u64, inner: Inner) {
186 let mut guard = self.free_conn_map.lock().unwrap();
187 if let Some(inners) = guard.get_mut(&hash_code) {
188 inners.push(inner);
189 } else {
190 guard.insert(hash_code, vec![inner]);
191 }
192 }
193
194 fn clear_bs_inner(&self, addr: &Address) {
195 let mut guard = self.free_conn_map.lock().unwrap();
196 if let Some(inners) = guard.remove(&addr.hash_code()) {
197 assert!(self.state.cur_conn.fetch_sub(inners.len(), Relaxed) > 0);
198 debug!(
199 "[http/2.0 pool] [desc] current connection count: {}",
200 self.state.cur_conn.load(Relaxed)
201 );
202 }
203 }
204
205 fn remove_inner(&self, hash_code: u64, id: u32) {
206 let mut guard = self.free_conn_map.lock().unwrap();
207 if let Some(inners) = guard.get_mut(&hash_code) {
208 let mut del = false;
209 inners.retain(|inner| {
210 if inner.id() == id {
211 del = true;
212 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
213 false
214 } else {
215 true
216 }
217 });
218 if del {
219 debug!(
220 "[http/2.0 pool] [desc] current connection count: {}",
221 self.state.cur_conn.load(Relaxed)
222 );
223 }
224 }
225 }
226
227 fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
228 pool: Arc<Self>,
229 c: C,
230 sr: SendRequest,
231 bs: &BackendState,
232 ) -> Sender {
233 let tls = <Pool as net_pool::pool::Pool>::tls(&pool);
234 let inner = Inner::new(sr);
235 let sender = inner.new_sender(crate::utils::base_url(tls, bs.get_address()));
236
237 let tuple = (
238 bs.hash_code(),
239 <Pool as net_pool::pool::Pool>::get_keepalive(&pool),
240 inner.reference(),
241 inner.id(),
242 );
243
244 pool.add_inner(tuple.0, inner);
249
250 tokio_spawn! {
252 instrument_current_span! {
253 async move {
254 let _r = crate::utils::run_conn(tuple.2, c, tuple.1).await;
255 debug!("[http/2.0 pool] connection closed: {:?}", _r);
256 pool.remove_inner(tuple.0, tuple.3);
257 }
258 }
259 };
260
261 sender
262 }
263
264 async fn create_tls_sender(
265 self: Arc<Self>,
266 bs: &BackendState,
267 exec: hyper_util::rt::TokioExecutor,
268 ) -> Result<Sender, Error> {
269 let addr = bs.get_address();
270
271 let tcp = crate::utils::create_https_stream(addr).await?;
273 let tls_tcp =
274 crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP2_TLS_CLIENT_CFG.clone())
275 .await?;
276
277 let max_streams = self.get_max_streams();
279 let io = hyper_util::rt::TokioIo::new(tls_tcp);
280 let pair = http2::Builder::new(exec)
281 .max_concurrent_streams(max_streams.map(|m| m as u32))
282 .handshake(io)
283 .await
284 .map_err(|e| Error::from_other(e))?;
285
286 let sender = Pool::run_conn(self, pair.1, pair.0, &bs);
288 Ok(sender)
289 }
290
291 async fn create_non_tls_sender(
292 self: Arc<Self>,
293 bs: &BackendState,
294 exec: hyper_util::rt::TokioExecutor,
295 ) -> Result<Sender, Error> {
296 let tcp = crate::utils::create_http_stream(bs.get_address()).await?;
297
298 let io = hyper_util::rt::TokioIo::new(tcp);
300 let pair = http2::handshake(exec, io)
301 .await
302 .map_err(|e| Error::from_other(e))?;
303
304 let sender = Pool::run_conn(self, pair.1, pair.0, bs);
306 Ok(sender)
307 }
308
309 async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
310 let exec = hyper_util::rt::tokio::TokioExecutor::new();
311 if <Pool as net_pool::pool::Pool>::tls(&self) {
312 self.create_tls_sender(bs, exec).await
313 } else {
314 self.create_non_tls_sender(bs, exec).await
315 }
316 }
317
318 async fn real(self: Arc<Self>, bs: BackendState) -> Result<Sender, Error> {
319 let sender = match self.get_sender(&bs) {
320 Some(s) => Ok(s),
321 None => {
322 net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
324 self.clone().create_sender(&bs).await.map(|s| {
325 debug!(
326 "[http/2.0 pool] [incr] current connection count: {}",
327 self.state.cur_conn.load(Relaxed)
328 );
329 s
330 })
331 }
332 };
333
334 if sender.is_err() {
335 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
336 }
337
338 sender
339 }
340}
341
342pub trait HttpPool {
343 fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
344
345 fn target(
346 self: Arc<Self>,
347 addr: &Address,
348 ) -> impl Future<Output = Result<Sender, Error>> + Send;
349}
350
351impl HttpPool for Pool {
352 async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
353 let bs = self
354 .state
355 .lb_strategy
356 .get_backend(key)
357 .ok_or(Error::NoBackend);
358
359 self.real(bs?).await
360 }
361
362 async fn target(self: Arc<Self>, addr: &Address) -> Result<Sender, Error> {
363 if !self.state.lb_strategy.contain(addr) {
364 Err(Error::NoBackend)
365 } else {
366 self.real(BackendState::new(addr.clone())).await
367 }
368 }
369}
370
371pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
372 HttpPool::get(pool, key).await
373}