1use crate::http1::{SendRequest, Sender};
2use hyper::client::conn::http1;
3use net_pool::backend::{Address, BackendState};
4use net_pool::{Error, Strategy, debug, instrument_current_span, tokio_spawn};
5use std::collections::hash_map::Entry;
6use std::collections::{HashMap, HashSet};
7use std::hash::{Hash, Hasher};
8use std::sync::atomic::AtomicBool;
9use std::sync::atomic::Ordering::Relaxed;
10use std::sync::{Arc, Mutex};
11
12#[derive(Clone)]
13struct SrData(Arc<SendRequest>);
14
15impl PartialEq for SrData {
16 fn eq(&self, other: &Self) -> bool {
17 self.0.as_ref() as *const _ == other.0.as_ref() as *const _
18 }
19}
20
21impl Eq for SrData {}
22
23impl Hash for SrData {
24 fn hash<H: Hasher>(&self, state: &mut H) {
25 state.write_usize(self.0.as_ref() as *const _ as usize)
26 }
27}
28
29struct Inner {
30 free: HashSet<SrData>,
32 work: HashSet<SrData>,
34}
35
36impl Inner {
37 fn new() -> Self {
38 Inner {
39 free: HashSet::new(),
40 work: HashSet::new(),
41 }
42 }
43
44 fn add_work_sr(&mut self, sr: Arc<SendRequest>) {
45 self.work.insert(SrData(sr));
46 }
47
48 fn remove_sr(&mut self, sr: Arc<SendRequest>) -> bool {
49 let d = SrData(sr);
50 if self.work.remove(&d) {
52 true
53 } else if self.free.remove(&d) {
54 true
55 } else {
56 false
57 }
58 }
59
60 fn get_sr(&mut self) -> (usize, Option<Arc<SendRequest>>) {
62 let mut dels = 0;
64
65 if self.free.is_empty() {
67 self.work.retain(|d| {
69 if Arc::strong_count(&d.0) == 2 {
71 if !d.0.is_closed() {
72 self.free.insert(d.clone());
73 } else {
74 dels += 1;
75 }
76 false
77 } else {
78 true
79 }
80 });
81 }
82
83 loop {
85 let d = match self.free.iter().next() {
86 None => return (dels, None),
87 Some(d) => d.clone(),
88 };
89
90 self.free.remove(&d);
91 if d.0.is_closed() {
92 dels += 1;
93 } else {
94 let sr = d.0.clone();
95 self.work.insert(d);
96 return (dels, Some(sr));
97 }
98 }
99 }
100}
101
102pub struct Pool {
110 state: net_pool::pool::BaseState,
111 use_tls: AtomicBool,
112 free_conn_map: Mutex<HashMap<u64, Inner>>,
113}
114
115impl Pool {
116 pub fn new(strategy: Arc<dyn Strategy>) -> Self {
117 let p = Pool {
118 state: net_pool::pool::BaseState::new(strategy),
119 use_tls: AtomicBool::new(false),
120 free_conn_map: Mutex::new(HashMap::new()),
121 };
122 <Pool as net_pool::pool::Pool>::set_keepalive(
123 &p,
124 Some(std::time::Duration::from_secs(60 * 1)),
125 );
126 p
127 }
128}
129
130impl Default for Pool {
131 fn default() -> Self {
132 Pool::new(Arc::new(net_pool::strategy::CHStrategy::default()))
133 }
134}
135
136impl<L: Strategy + 'static> From<L> for Pool {
137 fn from(value: L) -> Self {
138 Self::new(Arc::new(value))
139 }
140}
141
142impl net_pool::pool::Pool for Pool {
143 net_pool::macros::base_pool_impl! {state}
144
145 fn remove_backend(&self, addr: &Address) -> bool {
146 if self.state.lb_strategy.remove_backend(addr) {
147 self.clear_bs_sr(addr.hash_code());
149 true
150 } else {
151 false
152 }
153 }
154
155 fn use_tls(&self, tls: bool) {
156 self.use_tls.store(tls, Relaxed);
157 }
158
159 fn tls(&self) -> bool {
160 self.use_tls.load(Relaxed)
161 }
162}
163
164impl Pool {
165 fn clear_bs_sr(&self, hash_code: u64) {
166 let mut guard = self.free_conn_map.lock().unwrap();
167 if let Some(inner) = guard.remove(&hash_code) {
168 assert!(
169 self.state
170 .cur_conn
171 .fetch_sub(inner.free.len() + inner.work.len(), Relaxed)
172 > 0
173 );
174 debug!(
175 "[desc] current connection count: {}",
176 self.state.cur_conn.load(Relaxed)
177 );
178 }
179 }
180
181 fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
182 let mut guard = self.free_conn_map.lock().unwrap();
183 let inner = guard.get_mut(&bs.hash_code())?;
184
185 let sr = {
186 let (del_cnt, sr) = inner.get_sr();
187 if del_cnt > 0 {
188 assert!(self.state.cur_conn.fetch_sub(del_cnt, Relaxed) > 0);
189 debug!(
190 "[desc] current connection count: {}",
191 self.state.cur_conn.load(Relaxed)
192 );
193 }
194 sr
195 };
196
197 if let Some(sr) = sr {
198 let tls = <Pool as net_pool::pool::Pool>::tls(self);
199 Some(Sender::new(
200 sr,
201 crate::utils::base_url(tls, bs.get_address()),
202 ))
203 } else {
204 None
205 }
206 }
207
208 fn add_sr(&self, hash_code: u64, sr: Arc<SendRequest>) {
209 let mut guard = self.free_conn_map.lock().unwrap();
210 match guard.entry(hash_code) {
211 Entry::Occupied(mut o) => {
212 o.get_mut().add_work_sr(sr);
213 }
214 Entry::Vacant(e) => {
215 let mut inner = Inner::new();
216 inner.add_work_sr(sr);
217 e.insert(inner);
218 }
219 }
220 }
221
222 fn remove_sr(&self, hash_code: u64, sr: Arc<SendRequest>) {
223 let mut guard = self.free_conn_map.lock().unwrap();
224 if let Some(inner) = guard.get_mut(&hash_code) {
225 if inner.remove_sr(sr) {
226 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
227 debug!(
228 "[desc] current connection count: {}",
229 self.state.cur_conn.load(Relaxed)
230 );
231 }
232 }
233 }
234
235 fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
236 pool: Arc<Self>,
237 c: C,
238 sr: Arc<SendRequest>,
239 bs: &BackendState,
240 ) {
241 let hash_code = bs.hash_code();
242
243 pool.add_sr(hash_code, sr.clone());
248 let ka = <Pool as net_pool::pool::Pool>::get_keepalive(&pool);
249
250 tokio_spawn! {
251 instrument_current_span! {
252 async move {
253 let _r = crate::utils::run_conn(sr.clone(), c, ka).await;
254 debug!("connection closed: {:?}", _r);
255 pool.remove_sr(hash_code, sr);
256 }
257 }
258 };
259 }
260
261 async fn create_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
262 let tls = <Pool as net_pool::pool::Pool>::tls(&self);
263 let addr = bs.get_address();
264
265 let tcp = crate::utils::create_https_stream(addr).await?;
267 let tls_tcp =
268 crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP1_TLS_CLIENT_CFG.clone())
269 .await?;
270
271 let io = hyper_util::rt::TokioIo::new(tls_tcp);
273 let pair = http1::handshake(io)
274 .await
275 .map_err(|e| Error::from_other(e))?;
276
277 let sr = Arc::new(pair.0);
279 Pool::run_conn(self, pair.1, sr.clone(), &bs);
280
281 Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
283 }
284
285 async fn create_non_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
286 let addr = bs.get_address();
287 let tls = <Pool as net_pool::pool::Pool>::tls(&self);
288
289 let tcp = crate::utils::create_http_stream(addr).await?;
291
292 let io = hyper_util::rt::TokioIo::new(tcp);
294 let pair = http1::handshake(io)
295 .await
296 .map_err(|e| Error::from_other(e))?;
297
298 let sr = Arc::new(pair.0);
300 Pool::run_conn(self, pair.1, sr.clone(), &bs);
301
302 Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
304 }
305
306 async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
307 if <Pool as net_pool::pool::Pool>::tls(&self) {
308 self.create_tls_sender(bs).await
309 } else {
310 self.create_non_tls_sender(bs).await
311 }
312 }
313
314 async fn real(self: Arc<Self>, bs: BackendState) -> Result<Sender, Error> {
315 if let Some(sender) = self.clone().get_sender(&bs) {
316 return Ok(sender);
317 }
318
319 let sender = {
320 net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
322 self.clone().create_sender(&bs).await.map(|s| {
323 debug!(
324 "[incr] current connection count: {}",
325 self.state.cur_conn.load(Relaxed)
326 );
327 s
328 })
329 };
330
331 if sender.is_err() {
332 assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
333 }
334
335 sender
336 }
337}
338
339pub trait HttpPool {
340 fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
342
343 fn target(self: Arc<Self>, _: &Address) -> impl Future<Output = Result<Sender, Error>> + Send;
345}
346
347impl HttpPool for Pool {
348 async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
349 let bs = self
350 .state
351 .lb_strategy
352 .get_backend(key)
353 .ok_or(Error::NoBackend)?;
354
355 self.real(bs).await
356 }
357
358 async fn target(self: Arc<Self>, addr: &Address) -> Result<Sender, Error> {
359 if !self.state.lb_strategy.contain(addr) {
360 return Err(Error::NoBackend);
361 }
362
363 self.real(BackendState::new(None, addr.clone())).await
364 }
365}
366
367pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
368 HttpPool::get(pool, key).await
369}