use crate::http1::{SendRequest, Sender};
use cpool::backend::{Address, BackendState};
use cpool::strategy::LbStrategy;
use cpool::{Error, debug, instrument_current_span, tokio_spawn, trace};
use hyper::client::conn::http1;
use std::collections::HashMap;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::atomic::{AtomicBool, AtomicU32};
use std::sync::{Arc, LazyLock, Mutex};
static SENDER_ID: LazyLock<AtomicU32> = LazyLock::new(|| AtomicU32::new(1));
struct Inner {
free: HashMap<u32, Arc<SendRequest>>,
work: HashMap<u32, Arc<SendRequest>>,
}
impl Inner {
fn new() -> Self {
Inner {
free: HashMap::new(),
work: HashMap::new(),
}
}
fn add_work_sr(&mut self, sr: Arc<SendRequest>) -> u32 {
let id = SENDER_ID.fetch_add(1, Relaxed);
self.work.insert(id, sr);
id
}
fn remove_sr(&mut self, id: u32) -> bool {
if self.work.remove(&id).is_some() {
true
} else if self.free.remove(&id).is_some() {
true
} else {
false
}
}
fn get_sr(&mut self) -> (usize, Option<Arc<SendRequest>>) {
if self.free.is_empty() {
self.work.retain(|id, sr| {
if Arc::strong_count(sr) == 2 {
self.free.insert(*id, sr.clone());
false
} else {
true
}
});
}
let mut del = 0;
let ks: Vec<u32> = self.free.keys().cloned().collect();
for k in ks {
if let Some(sr) = self.free.remove(&k) {
if sr.is_closed() {
del += 1;
} else {
self.work.insert(k, sr.clone());
return (del, Some(sr));
}
}
}
(del, None)
}
}
pub struct Pool {
id: String,
state: cpool::pool::BaseState,
use_tls: AtomicBool,
free_conn_map: Mutex<HashMap<u64, Inner>>,
}
impl Pool {
pub fn new(id: String, strategy: Arc<dyn LbStrategy>) -> Self {
let p = Pool {
id,
state: cpool::pool::BaseState::new(strategy),
use_tls: AtomicBool::new(false),
free_conn_map: Mutex::new(HashMap::new()),
};
<Pool as cpool::pool::Pool>::set_keepalive(
&p,
Some(std::time::Duration::from_secs(60 * 1)),
);
p
}
}
impl Default for Pool {
fn default() -> Self {
Pool::new(
"".to_string(),
Arc::new(cpool::strategy::CHStrategy::default()),
)
}
}
impl cpool::pool::Pool for Pool {
cpool::macros::base_pool_impl! {state}
fn remove_backend(&self, addr: &Address) -> bool {
if self.state.lb_strategy.remove_backend(addr) {
self.clear_bs_sr(addr.hash_code());
true
} else {
false
}
}
fn use_tls(&self, tls: bool) {
self.use_tls.store(tls, Relaxed);
}
fn tls(&self) -> bool {
self.use_tls.load(Relaxed)
}
}
impl Pool {
fn clear_bs_sr(&self, hash_code: u64) {
let mut guard = self.free_conn_map.lock().unwrap();
if let Some(inner) = guard.remove(&hash_code) {
assert!(
self.state
.cur_conn
.fetch_sub(inner.free.len() + inner.work.len(), Relaxed)
> 0
);
trace!(
"[http/1.1 pool] [desc] current connection count: {}",
self.state.cur_conn.load(Relaxed)
);
}
}
fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
let mut guard = self.free_conn_map.lock().unwrap();
let inner = guard.get_mut(&bs.hash_code())?;
let sr = {
let (del_cnt, sr) = inner.get_sr();
if del_cnt > 0 {
assert!(self.state.cur_conn.fetch_sub(del_cnt, Relaxed) > 0);
trace!(
"[http/1.1 pool] [desc] current connection count: {}",
self.state.cur_conn.load(Relaxed)
);
}
sr
};
let tls = <Pool as cpool::pool::Pool>::tls(self);
sr.map(|s| Sender::new(s, crate::utils::base_url(tls, bs.get_address())))
}
fn add_sr(&self, hash_code: u64, sr: Arc<SendRequest>) -> u32 {
let mut guard = self.free_conn_map.lock().unwrap();
if let Some(inner) = guard.get_mut(&hash_code) {
inner.add_work_sr(sr)
} else {
let mut inner = Inner::new();
let id = inner.add_work_sr(sr);
guard.insert(hash_code, inner);
id
}
}
fn remove_sr(&self, hash_code: u64, id: u32) {
let mut guard = self.free_conn_map.lock().unwrap();
if let Some(inner) = guard.get_mut(&hash_code) {
if inner.remove_sr(id) {
assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
trace!(
"[http/1.1 pool] [desc] current connection count: {}",
self.state.cur_conn.load(Relaxed)
);
}
}
}
fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
pool: Arc<Self>,
c: C,
sr: Arc<SendRequest>,
bs: &BackendState,
) {
let hash_code = bs.hash_code();
let id = pool.add_sr(hash_code, sr.clone());
let ka = <Pool as cpool::pool::Pool>::get_keepalive(&pool);
tokio_spawn! {
instrument_current_span! {
async move {
let _r = crate::utils::run_conn(sr, c, ka).await;
debug!("[http/1.1 pool] connection closed: {:?}", _r);
pool.remove_sr(hash_code, id);
}
}
};
}
async fn create_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
let tls = <Pool as cpool::pool::Pool>::tls(&self);
let addr = bs.get_address();
let tcp = crate::utils::create_https_stream(addr).await?;
let tls_tcp =
crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP1_TLS_CLIENT_CFG.clone())
.await?;
let io = hyper_util::rt::TokioIo::new(tls_tcp);
let pair = http1::handshake(io)
.await
.map_err(|e| Error::from_other(e))?;
let sr = Arc::new(pair.0);
Pool::run_conn(self, pair.1, sr.clone(), &bs);
Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
}
async fn create_non_tls_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
let addr = bs.get_address();
let tls = <Pool as cpool::pool::Pool>::tls(&self);
let tcp = crate::utils::create_http_stream(addr).await?;
let io = hyper_util::rt::TokioIo::new(tcp);
let pair = http1::handshake(io)
.await
.map_err(|e| Error::from_other(e))?;
let sr = Arc::new(pair.0);
Pool::run_conn(self, pair.1, sr.clone(), &bs);
Ok(Sender::new(sr, crate::utils::base_url(tls, addr)))
}
async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
if <Pool as cpool::pool::Pool>::tls(&self) {
self.create_tls_sender(bs).await
} else {
self.create_non_tls_sender(bs).await
}
}
}
pub trait Http1Pool {
fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
fn get_id(&self) -> &String;
}
impl Http1Pool for Pool {
async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
let bs = self
.state
.lb_strategy
.get_backend(key)
.ok_or(Error::NoBackend)?;
if let Some(sender) = self.clone().get_sender(&bs) {
return Ok(sender);
}
let sender = {
cpool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
self.clone().create_sender(&bs).await.map(|s| {
trace!(
"[http/1.1 pool] [incr] current connection count: {}",
self.state.cur_conn.load(Relaxed)
);
s
})
};
if sender.is_err() {
assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
}
sender
}
fn get_id(&self) -> &String {
&self.id
}
}
pub async fn get<P: Http1Pool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
Http1Pool::get(pool, key).await
}