use std::{collections::HashMap, sync::Arc};
use http::Uri;
use hyper::{
body::Incoming,
client::{self, conn::http1::SendRequest},
};
use once_cell::sync::OnceCell;
use tokio::{
select,
sync::{
mpsc::{self, Receiver, Sender},
Mutex, Semaphore,
},
};
use crate::{
cfg_logging,
config::{Config, Upstream},
tcp_connect,
};
pub(crate) static CONN_POOLS: OnceCell<HashMap<Uri, Mutex<ConnPool>>> = OnceCell::new();
#[derive(Debug)]
pub(crate) struct ConnPool {
semaphore: Arc<Semaphore>,
receiver: Receiver<SendRequest<Incoming>>,
sender: Sender<SendRequest<Incoming>>,
}
impl ConnPool {
pub(crate) async fn get_sender(
&mut self,
upstream: &Upstream,
) -> Result<(Sender<SendRequest<Incoming>>, SendRequest<Incoming>), crate::Error> {
loop {
let mut sender = select! {
biased;
sender = self.receiver.recv() => {
cfg_logging! {trace!("Reusing connection to: {}", upstream.addr);}
Ok::<_, crate::Error>(sender.unwrap())
},
permit = Arc::clone(&self.semaphore).acquire_owned() => {
let permit = permit.unwrap();
cfg_logging! {info!("Opened new connection to: {}", upstream.addr);}
let stream = tcp_connect(upstream.addr.authority().unwrap()).await?;
let (sender, conn) = client::conn::http1::Builder::new()
.http1_preserve_header_case(true)
.http1_title_case_headers(true)
.handshake(stream)
.await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
cfg_logging! {error!("Connection failed: {:?}", err);}
}
drop(permit);
});
Ok(sender)
}
}?;
if let Ok(_) = sender.ready().await {
return Ok((self.sender.clone(), sender));
}
}
}
}
pub(crate) fn init_conn_pools(config: &Config) {
CONN_POOLS
.set(HashMap::from_iter(config.upstreams.values().map(|v| {
let (sender, receiver) = mpsc::channel::<SendRequest<Incoming>>(v.max_connections);
(
v.addr.clone(),
Mutex::new(ConnPool {
semaphore: Arc::new(Semaphore::new(v.max_connections)),
sender,
receiver,
}),
)
})))
.unwrap();
}