use crate::http2::{SendRequest, Sender};
use hyper::client::conn::http2;
use net_pool::backend::{Address, BackendState};
use net_pool::strategy::LbStrategy;
use net_pool::{Error, debug, instrument_current_span, tokio_spawn, trace};
use std::collections::HashMap;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize};
use std::sync::{Arc, LazyLock, Mutex};
static SENDER_ID: LazyLock<AtomicU32> = LazyLock::new(|| AtomicU32::new(1));
struct Inner(SendRequest, Arc<()>, u32);
impl Inner {
fn new(sender: SendRequest) -> Self {
Inner(sender, Arc::new(()), SENDER_ID.fetch_add(1, Relaxed))
}
fn id(&self) -> u32 {
self.2
}
fn new_sender(&self, base_url: String) -> Sender {
Sender::new(self.0.clone(), base_url, self.1.clone())
}
fn ref_count(&self) -> usize {
Arc::strong_count(&self.1)
}
fn limited(&self, max_streams: Option<usize>) -> bool {
if let Some(max) = max_streams {
if self.ref_count() >= max + 1 {
return true;
}
}
return false;
}
fn is_closed(&self) -> bool {
self.0.is_closed()
}
fn reference(&self) -> Arc<()> {
self.1.clone()
}
}
pub struct Pool {
id: String,
state: net_pool::pool::BaseState,
max_streams: AtomicUsize,
free_conn_map: Mutex<HashMap<u64, Vec<Inner>>>,
use_tls: AtomicBool,
}
impl Pool {
pub fn new(id: String, strategy: Arc<dyn LbStrategy>, mut max_streams: Option<usize>) -> Self {
if let Some(0) = max_streams {
max_streams = None;
}
let p = Pool {
id,
state: net_pool::pool::BaseState::new(strategy),
max_streams: AtomicUsize::new(usize::MAX),
free_conn_map: Mutex::new(HashMap::new()),
use_tls: AtomicBool::new(false),
};
p.set_max_streams(max_streams);
<Pool as net_pool::pool::Pool>::set_keepalive(
&p,
Some(std::time::Duration::from_secs(60 * 1)),
);
p
}
pub fn set_max_streams(&self, mut max_streams: Option<usize>) {
if let Some(0) = max_streams {
max_streams = None;
}
match max_streams {
None => self.max_streams.store(usize::MAX, Relaxed),
Some(v) => self.max_streams.store(v, Relaxed),
}
}
pub fn get_max_streams(&self) -> Option<usize> {
match self.max_streams.load(Relaxed) {
usize::MAX => None,
v => Some(v),
}
}
}
impl Default for Pool {
fn default() -> Self {
Pool::new(
"".to_string(),
Arc::new(net_pool::strategy::CHStrategy::default()),
Some(200),
)
}
}
impl net_pool::pool::Pool for Pool {
net_pool::macros::base_pool_impl! {state}
fn remove_backend(&self, addr: &Address) -> bool {
if self.state.lb_strategy.remove_backend(addr) {
self.clear_bs_inner(addr);
true
} else {
false
}
}
fn use_tls(&self, tls: bool) {
self.use_tls.store(tls, Relaxed);
}
fn tls(&self) -> bool {
self.use_tls.load(Relaxed)
}
}
impl Pool {
fn get_sender(&self, bs: &BackendState) -> Option<Sender> {
let mut guard = self.free_conn_map.lock().unwrap();
let inners = guard.get_mut(&bs.hash_code())?;
for idx in (0..inners.len()).rev() {
let inner = &mut inners[idx];
if inner.is_closed() {
assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
trace!(
"[http/2.0 pool] [desc] current connection count: {}",
self.state.cur_conn.load(Relaxed)
);
inners.remove(idx);
continue;
}
if inner.limited(self.get_max_streams()) {
continue;
}
let tls = <Pool as net_pool::pool::Pool>::tls(self);
return Some(inner.new_sender(crate::utils::base_url(tls, bs.get_address())));
}
None
}
fn add_inner(&self, hash_code: u64, inner: Inner) {
let mut guard = self.free_conn_map.lock().unwrap();
if let Some(inners) = guard.get_mut(&hash_code) {
inners.push(inner);
} else {
guard.insert(hash_code, vec![inner]);
}
}
fn clear_bs_inner(&self, addr: &Address) {
let mut guard = self.free_conn_map.lock().unwrap();
if let Some(inners) = guard.remove(&addr.hash_code()) {
assert!(self.state.cur_conn.fetch_sub(inners.len(), Relaxed) > 0);
trace!(
"[http/2.0 pool] [desc] current connection count: {}",
self.state.cur_conn.load(Relaxed)
);
}
}
fn remove_inner(&self, hash_code: u64, id: u32) {
let mut guard = self.free_conn_map.lock().unwrap();
if let Some(inners) = guard.get_mut(&hash_code) {
let mut del = false;
inners.retain(|inner| {
if inner.id() == id {
del = true;
assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
false
} else {
true
}
});
if del {
trace!(
"[http/2.0 pool] [desc] current connection count: {}",
self.state.cur_conn.load(Relaxed)
);
}
}
}
fn run_conn<C: Future<Output = Result<(), hyper::Error>> + Send + 'static>(
pool: Arc<Self>,
c: C,
sr: SendRequest,
bs: &BackendState,
) -> Sender {
let tls = <Pool as net_pool::pool::Pool>::tls(&pool);
let inner = Inner::new(sr);
let sender = inner.new_sender(crate::utils::base_url(tls, bs.get_address()));
let tuple = (
bs.hash_code(),
<Pool as net_pool::pool::Pool>::get_keepalive(&pool),
inner.reference(),
inner.id(),
);
pool.add_inner(tuple.0, inner);
tokio_spawn! {
instrument_current_span! {
async move {
let _r = crate::utils::run_conn(tuple.2, c, tuple.1).await;
debug!("[http/2.0 pool] connection closed: {:?}", _r);
pool.remove_inner(tuple.0, tuple.3);
}
}
};
sender
}
async fn create_tls_sender(
self: Arc<Self>,
bs: &BackendState,
exec: hyper_util::rt::TokioExecutor,
) -> Result<Sender, Error> {
let addr = bs.get_address();
let tcp = crate::utils::create_https_stream(addr).await?;
let tls_tcp =
crate::utils::create_tls_tcp(tcp, addr, crate::utils::HTTP2_TLS_CLIENT_CFG.clone())
.await?;
let max_streams = self.get_max_streams();
let io = hyper_util::rt::TokioIo::new(tls_tcp);
let pair = http2::Builder::new(exec)
.max_concurrent_streams(max_streams.map(|m| m as u32))
.handshake(io)
.await
.map_err(|e| Error::from_other(e))?;
let sender = Pool::run_conn(self, pair.1, pair.0, &bs);
Ok(sender)
}
async fn create_non_tls_sender(
self: Arc<Self>,
bs: &BackendState,
exec: hyper_util::rt::TokioExecutor,
) -> Result<Sender, Error> {
let tcp = crate::utils::create_http_stream(bs.get_address()).await?;
let io = hyper_util::rt::TokioIo::new(tcp);
let pair = http2::handshake(exec, io)
.await
.map_err(|e| Error::from_other(e))?;
let sender = Pool::run_conn(self, pair.1, pair.0, bs);
Ok(sender)
}
async fn create_sender(self: Arc<Self>, bs: &BackendState) -> Result<Sender, Error> {
let exec = hyper_util::rt::tokio::TokioExecutor::new();
if <Pool as net_pool::pool::Pool>::tls(&self) {
self.create_tls_sender(bs, exec).await
} else {
self.create_non_tls_sender(bs, exec).await
}
}
}
pub trait HttpPool {
fn get(self: Arc<Self>, key: &str) -> impl Future<Output = Result<Sender, Error>> + Send;
fn get_id(&self) -> &String;
}
impl HttpPool for Pool {
async fn get(self: Arc<Self>, key: &str) -> Result<Sender, Error> {
let bs = self
.state
.lb_strategy
.get_backend(key)
.ok_or(Error::NoBackend)?;
let sender = match self.get_sender(&bs) {
Some(s) => Ok(s),
None => {
net_pool::pool::increase_current(&self.state.max_conn, &self.state.cur_conn)?;
self.clone().create_sender(&bs).await.map(|s| {
trace!(
"[http/2.0 pool] [incr] current connection count: {}",
self.state.cur_conn.load(Relaxed)
);
s
})
}
};
if sender.is_err() {
assert!(self.state.cur_conn.fetch_sub(1, Relaxed) > 0);
}
sender
}
fn get_id(&self) -> &String {
&self.id
}
}
pub async fn get<P: HttpPool>(pool: Arc<P>, key: &str) -> Result<Sender, Error> {
HttpPool::get(pool, key).await
}