use std::{
ops::{Deref, DerefMut},
sync::Arc,
};
use http::Uri;
use hyper::{
body::Incoming,
client::{self, conn::http1::SendRequest},
};
use hyper_util::rt::TokioIo;
use tokio::{
select,
sync::{
mpsc::{self, Receiver, Sender},
Mutex, Semaphore,
},
};
use crate::{cfg_logging, tcp_connect};
#[derive(Debug)]
pub(crate) struct ConnPool {
semaphore: Arc<Semaphore>,
receiver: Mutex<Receiver<SendRequest<Incoming>>>,
sender: Sender<SendRequest<Incoming>>,
uri: Uri,
}
#[derive(Debug)]
pub(crate) struct PooledConn {
sender: Sender<SendRequest<Incoming>>,
conn: Option<SendRequest<Incoming>>,
}
impl ConnPool {
pub(crate) fn new(uri: Uri, max_connections: usize) -> Self {
let (sender, receiver) = mpsc::channel::<SendRequest<Incoming>>(max_connections);
ConnPool {
semaphore: Arc::new(Semaphore::new(max_connections)),
sender,
receiver: Mutex::new(receiver),
uri,
}
}
pub(crate) async fn get_sender(&self) -> Result<PooledConn, crate::Error> {
let mut receiver = self.receiver.lock().await;
loop {
let mut conn = select! {
biased;
sender = receiver.recv() => {
cfg_logging! {trace!("Reusing connection to: {}", self.uri);}
Ok::<_, crate::Error>(sender.unwrap())
},
permit = Arc::clone(&self.semaphore).acquire_owned() => {
let permit = permit.unwrap();
cfg_logging! {info!("Opened new connection to: {}", self.uri);}
let stream = tcp_connect(self.uri.authority().unwrap().as_str()).await?;
let (sender, conn) = client::conn::http1::Builder::new()
.preserve_header_case(true)
.title_case_headers(true)
.handshake(TokioIo::new(stream))
.await?;
tokio::task::spawn(async move {
if let Err(err) = conn.with_upgrades().await {
cfg_logging! {error!("Connection failed: {:?}", err);}
}
drop(permit);
});
Ok(sender)
}
}?;
if let Ok(_) = conn.ready().await {
return Ok(PooledConn {
sender: self.sender.clone(),
conn: Some(conn),
});
}
}
}
}
impl Deref for PooledConn {
type Target = SendRequest<Incoming>;
fn deref(&self) -> &Self::Target {
self.conn.as_ref().unwrap()
}
}
impl DerefMut for PooledConn {
fn deref_mut(&mut self) -> &mut Self::Target {
self.conn.as_mut().unwrap()
}
}
impl Drop for PooledConn {
fn drop(&mut self) {
if let Err(err) = self.sender.try_send(self.conn.take().unwrap()) {
cfg_logging! {tracing::error!("Failed to send conn back to pool! {err:?}");}
};
}
}