1use crate::http1::{SendRequest, Sender};
2use hyper::client::conn::http1;
3use net_pool::backend::{Address, BackendState};
4use net_pool::strategy::LbStrategy;
5use net_pool::{Error, debug, instrument_current_span, tokio_spawn};
6use std::collections::HashMap;
7use std::sync::atomic::Ordering::Relaxed;
8use std::sync::atomic::{AtomicBool, AtomicU32};
9use std::sync::{Arc, LazyLock, Mutex};
10
11static SENDER_ID: LazyLock<AtomicU32> = LazyLock::new(|| AtomicU32::new(1));
13
14struct Inner {
15 free: HashMap<u32, Arc<SendRequest>>,
17 work: HashMap<u32, Arc<SendRequest>>,
19}
20
21impl Inner {
22 fn new() -> Self {
23 Inner {
24 free: HashMap::new(),
25 work: HashMap::new(),
26 }
27 }
28
29 fn add_work_sr(&mut self, sr: Arc<SendRequest>) -> u32 {
30 let id = SENDER_ID.fetch_add(1, Relaxed);
31 self.work.insert(id, sr);
32 id
33 }
34
35 fn remove_sr(&mut self, id: u32) -> bool {
36 if self.work.remove(&id).is_some() {
38 true
39 } else if self.free.remove(&id).is_some() {
40 true
41 } else {
42 false
43 }
44 }
45
46 fn get_sr(&mut self) -> (usize, Option<Arc<SendRequest>>) {
48 if self.free.is_empty() {
49 self.work.retain(|id, sr| {
51 if Arc::strong_count(sr) == 2 {
53 self.free.insert(*id, sr.clone());
54 false
55 } else {
56 true
57 }
58 });
59 }
60
61 let mut del = 0;
63 let ks: Vec<u32> = self.free.keys().cloned().collect();
64 for k in ks {
65 if let Some(sr) = self.free.remove(&k) {
66 if sr.is_closed() {
67 del += 1;
68 } else {
69 self.work.insert(k, sr.clone());
71 return (del, Some(sr));
72 }
73 }
74 }
75
76 (del, None)
77 }
78}
79
80pub struct Pool {
88 id: String,
89 state: net_pool::pool::BaseState,
90 use_tls: AtomicBool,
91 free_conn_map: Mutex<HashMap<u64, Inner>>,
92}
93
94impl Pool {
95 pub fn new(id: String, strategy: Arc<dyn LbStrategy>) -> Self {
96 let p = Pool {
97 id,
98 state: net_pool::pool::BaseState::new(strategy),
99 use_tls: AtomicBool::new(false),
100 free_conn_map: Mutex::new(HashMap::new()),
101 };
102 <Pool as net_pool::pool::Pool>::set_keepalive(
103 &p,
104 Some(std::time::Duration::from_secs(60 * 1)),
105 );
106 p
107 }
108
109 pub fn set_id<I: Into<String>>(&mut self, id: I) {
110 self.id = id.into();
111 }
112}
113
114impl Default for Pool {
115 fn default() -> Self {
116 Pool::new(
117 "".to_string(),
118 Arc::new(net_pool::strategy::CHStrategy::default()),
119 )
120 }
121}
122
123impl<L: LbStrategy + 'static> From<L> for Pool {
124 fn from(value: L) -> Self {
125 Self::new("".to_string(), Arc::new(value))
126 }
127}
128
129impl net_pool::pool::Pool for Pool {
130 net_pool::macros::base_pool_impl! {state}
131
132 fn id(&self) -> &str {
133 &self.id
134 }
135
136 fn remove_backend(&self, addr: &Address) -> bool {
137 if self.state.lb_strategy.remove_backend(addr) {
138 self.clear_bs_sr(addr.hash_code());
140 true
141 } else {
142 false
143 }
144 }
145
146 fn use_tls(&self, tls: bool) {
147 self.use_tls.store(tls, Relaxed);
148 }
149
150 fn tls(&self) -> bool {
151 self.use_tls.load(Relaxed)
152 }
153}
154
155impl Pool {
156 fn clear_bs_sr(&self, hash_code: u64) {
157 let mut guard = self.free_conn_map.lock().unwrap();
158 if let Some(inner) = guard.remove(&hash_code) {
159 assert!(
160 self.state
161 .cur_conn
162 .fetch_sub(inner.free.len() + inner.work.len(), Relaxed)
163 > 0
164 );
165 debug!(
166 "[http/1.1 pool] [desc] current connection count: {}",
167 self.state.cur_conn.load(Relaxed)
168 );
169 }
170 }
171
172 fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
173 let mut guard = self.free_conn_map.lock().unwrap();
174 let inner = guard.get_mut(&bs.hash_code())?;
175
176 let sr = {
177 let (del_cnt, sr) = inner.get_sr();
178 if del_cnt > 0 {
179 assert!(self.state.cur_conn.fetch_sub(del_cnt, Relaxed) > 0);
180 debug!(
181 "[http/1.1 pool] [desc] current connection count: {}",
182 self.state.cur_conn.load(Relaxed)
183 );
184 }
185 sr
186 };
187
188 let tls = <Pool as net_pool::pool::Pool>::tls(self);
189 sr.map(|s| Sender::new(s, crate::utils::base_url(tls, bs.get_address())))
190 }
191
192 fn add_sr(&self, hash_code: u64, sr: Arc<SendRequest>) -> u32 {
193 let mut guard = self.free_conn_map.lock().unwrap();
194 if let Some(inner) = guard.get_mut(&hash_code) {
195 inner.add_work_sr(sr)
196 } else {
197 let mut inner = Inner::new();
198 let id = inner.add_work_sr(sr);
199 guard.insert(hash_code, inner);
200 id
201 }
202 }
203
204 fn remove_sr(&self, hash_code: u64, id: u32) {
205 let mut guard = self.free_conn_map.lock().unwrap();
206 if let Some(inner) = guard.get_mut(&hash_code) {
207 if inner.remove_sr(id) {
208 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
209 debug!(
210 "[http/1.1 pool] [desc] current connection count: {}",
211 self.state.cur_conn.load(Relaxed)
212 );
213 }
214 }
215 }
216
217 fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
218 pool: Arc<Self>,
219 c: C,
220 sr: Arc<SendRequest>,
221 bs: &BackendState,
222 ) {
223 let hash_code = bs.hash_code();
224
225 let id = pool.add_sr(hash_code, sr.clone());
230 let ka = <Pool as net_pool::pool::Pool>::get_keepalive(&pool);
231
232 tokio_spawn! {
233 instrument_current_span! {
234 async move {
235 let _r = crate::utils::run_conn(sr, c, ka).await;
236 debug!("[http/1.1 pool] connection closed: {:?}", _r);
237 pool.remove_sr(hash_code, id);
238 }
239 }
240 };
241 }
242
243 async fn create_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
244 let tls = <Pool as net_pool::pool::Pool>::tls(&self);
245 let addr = bs.get_address();
246
247 let tcp = crate::utils::create_https_stream(addr).await?;
249 let tls_tcp =
250 crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP1_TLS_CLIENT_CFG.clone())
251 .await?;
252
253 let io = hyper_util::rt::TokioIo::new(tls_tcp);
255 let pair = http1::handshake(io)
256 .await
257 .map_err(|e| Error::from_other(e))?;
258
259 let sr = Arc::new(pair.0);
261 Pool::run_conn(self, pair.1, sr.clone(), &bs);
262
263 Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
265 }
266
267 async fn create_non_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
268 let addr = bs.get_address();
269 let tls = <Pool as net_pool::pool::Pool>::tls(&self);
270
271 let tcp = crate::utils::create_http_stream(addr).await?;
273
274 let io = hyper_util::rt::TokioIo::new(tcp);
276 let pair = http1::handshake(io)
277 .await
278 .map_err(|e| Error::from_other(e))?;
279
280 let sr = Arc::new(pair.0);
282 Pool::run_conn(self, pair.1, sr.clone(), &bs);
283
284 Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
286 }
287
288 async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
289 if <Pool as net_pool::pool::Pool>::tls(&self) {
290 self.create_tls_sender(bs).await
291 } else {
292 self.create_non_tls_sender(bs).await
293 }
294 }
295
296 async fn real(self: Arc<Self>, bs: BackendState) -> Result<Sender, Error> {
297 if let Some(sender) = self.clone().get_sender(&bs) {
298 return Ok(sender);
299 }
300
301 let sender = {
302 net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
304 self.clone().create_sender(&bs).await.map(|s| {
305 debug!(
306 "[http/1.1 pool] [incr] current connection count: {}",
307 self.state.cur_conn.load(Relaxed)
308 );
309 s
310 })
311 };
312
313 if sender.is_err() {
314 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
315 }
316
317 sender
318 }
319}
320
321pub trait HttpPool {
322 fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
324
325 fn target(
327 self: Arc<Self>,
328 _: &Address,
329 ) -> impl Future<Output = Result<Sender, Error>> + Send;
330}
331
332impl HttpPool for Pool {
333 async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
334 let bs = self
335 .state
336 .lb_strategy
337 .get_backend(key)
338 .ok_or(Error::NoBackend)?;
339
340 self.real(bs).await
341 }
342
343 async fn target(self: Arc<Self>, addr: &Address) -> Result<Sender, Error> {
344 if !self.state.lb_strategy.contain(addr) {
345 return Err(Error::NoBackend);
346 }
347
348 self.real(BackendState::new(None, addr.clone())).await
349 }
350}
351
352pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
353 HttpPool::get(pool, key).await
354}