use std::task::{Context, Poll, ready};
use std::time::{Duration, Instant};
use std::{cell::Cell, cell::RefCell, collections::VecDeque, fmt, future, pin, rc::Rc};
use ntex_h2::{self as h2};
use crate::http::uri::{Authority, Scheme, Uri};
use crate::io::{IoBoxed, types::HttpProtocol};
use crate::service::cfg::SharedCfg;
use crate::service::{Pipeline, PipelineCall, Service, ServiceCtx, boxed};
use crate::util::{ByteString, Either, HashMap, HashSet, select};
use crate::{channel::inplace, channel::oneshot, channel::pool, rt::spawn, time::now};
use super::connection::{Connection, ConnectionType};
use super::{Connect, error::ConnectError, h2proto::H2Client};
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub(super) struct Key {
authority: Authority,
}
type Connector = boxed::BoxService<Connect, IoBoxed, ConnectError>;
impl From<Authority> for Key {
fn from(authority: Authority) -> Key {
Key { authority }
}
}
type Waiter = pool::Sender<Result<Connection, ConnectError>>;
type WaiterReceiver = pool::Receiver<Result<Connection, ConnectError>>;
enum Acquire {
Acquired(ConnectionType, Instant),
Available,
NotAvailable,
}
#[derive(Debug)]
struct AvailableConnection {
io: ConnectionType,
used: Instant,
created: Instant,
}
pub(super) struct ConnectionPool(Rc<ConnectionPoolInner>);
struct ConnectionPoolInner {
svc: Pipeline<Connector>,
inner: Rc<RefCell<Inner>>,
waiters: Rc<RefCell<Waiters>>,
stop: Rc<Cell<Option<oneshot::Sender<()>>>>,
config: SharedCfg,
}
#[derive(Debug)]
pub(super) struct Inner {
stopped: bool,
conn_lifetime: Duration,
conn_keep_alive: Duration,
limit: usize,
acquired: usize,
available: HashMap<Key, VecDeque<AvailableConnection>>,
connecting: HashSet<Key>,
waker: inplace::Inplace<()>,
waiters: Rc<RefCell<Waiters>>,
}
impl ConnectionPool {
pub(super) fn new(
svc: Pipeline<Connector>,
conn_lifetime: Duration,
conn_keep_alive: Duration,
limit: usize,
config: SharedCfg,
) -> Self {
let waiters = Rc::new(RefCell::new(Waiters {
waiters: HashMap::default(),
pool: pool::new(),
}));
let inner = Rc::new(RefCell::new(Inner {
conn_lifetime,
conn_keep_alive,
limit,
stopped: false,
acquired: 0,
available: HashMap::default(),
connecting: HashSet::default(),
waker: inplace::channel(),
waiters: waiters.clone(),
}));
let (stop, stop_rx) = oneshot::channel();
crate::rt::spawn(run_connection_pool(
svc.clone(),
inner.clone(),
waiters.clone(),
config.clone(),
stop_rx,
));
ConnectionPool(Rc::new(ConnectionPoolInner {
svc,
inner,
waiters,
config,
stop: Rc::new(Cell::new(Some(stop))),
}))
}
}
impl Drop for ConnectionPool {
fn drop(&mut self) {
if Rc::strong_count(&self.0) == 1 {
self.0.stop.take();
self.0.waiters.borrow_mut().waiters.clear();
let mut inner = self.0.inner.borrow_mut();
inner.stopped = true;
let _ = inner.waker.send(());
}
}
}
impl Clone for ConnectionPool {
fn clone(&self) -> Self {
ConnectionPool(self.0.clone())
}
}
impl fmt::Debug for ConnectionPool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ConnectionPool")
.field("svc", &self.0.svc)
.field("inner", &self.0.inner)
.field("waiters", &self.0.waiters)
.field("config", &self.0.config)
.finish()
}
}
impl Service<Connect> for ConnectionPool {
type Response = Connection;
type Error = ConnectError;
#[inline]
async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
self.0.svc.ready().await
}
#[inline]
fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
self.0.svc.poll(cx)
}
#[inline]
async fn shutdown(&self) {
self.0.stop.take();
self.0.inner.borrow_mut().stopped = true;
self.0.svc.shutdown().await;
}
async fn call(
&self,
req: Connect,
_: ServiceCtx<'_, Self>,
) -> Result<Connection, ConnectError> {
log::trace!("{}: Get connection for {:?}", self.0.config.tag(), req.uri);
let inner = self.0.inner.clone();
let waiters = self.0.waiters.clone();
let key = if let Some(authority) = req.uri.authority() {
authority.clone().into()
} else {
return Err(ConnectError::Unresolved);
};
let result = inner.borrow_mut().acquire(&key);
match result {
Acquire::Acquired(io, created) => {
log::trace!(
"{}: Use existing {:?} connection for {:?}",
self.0.config.tag(),
io,
req.uri
);
Ok(Connection::new(
io,
created,
Some(Acquired::new(key, inner)),
))
}
Acquire::Available => {
log::trace!("{}: Connecting to {:?}", self.0.config.tag(), req.uri);
let uri = req.uri.clone();
let (tx, rx) = waiters.borrow_mut().pool.channel();
OpenConnection::spawn(key, tx, uri, inner, &self.0.svc, req);
match rx.await {
Err(_) => Err(ConnectError::Disconnected(None)),
Ok(res) => res,
}
}
Acquire::NotAvailable => {
log::trace!(
"{}: Pool is full, waiting for available connections for {:?}",
self.0.config.tag(),
req.uri
);
let rx = waiters.borrow_mut().wait_for(req);
match rx.await {
Err(_) => Err(ConnectError::Disconnected(None)),
Ok(res) => res,
}
}
}
}
}
#[derive(Debug)]
struct Waiters {
waiters: HashMap<Key, VecDeque<(Connect, Waiter)>>,
pool: pool::Pool<Result<Connection, ConnectError>>,
}
impl Waiters {
fn wait_for(&mut self, connect: Connect) -> WaiterReceiver {
let (tx, rx) = self.pool.channel();
let key: Key = connect.uri.authority().unwrap().clone().into();
self.waiters
.entry(key)
.or_default()
.push_back((connect, tx));
rx
}
fn cleanup(&mut self) {
let mut keys = Vec::new();
for (key, waiters) in &mut self.waiters {
while !waiters.is_empty() {
let (req, tx) = waiters.front().unwrap();
if tx.is_canceled() {
log::trace!("Waiter for {:?} is gone, remove waiter", req.uri);
waiters.pop_front();
continue;
}
break;
}
if waiters.is_empty() {
keys.push(key.clone());
}
}
for key in keys {
self.waiters.remove(&key);
}
}
}
impl Inner {
fn acquire(&mut self, key: &Key) -> Acquire {
if self.limit > 0 && self.acquired >= self.limit {
return Acquire::NotAvailable;
}
if let Some(ref mut connections) = self.available.get_mut(key) {
let now = now();
while let Some(conn) = connections.pop_back() {
if (now - conn.used) > self.conn_keep_alive
|| (now - conn.created) > self.conn_lifetime
{
if let ConnectionType::H1(io) = conn.io {
spawn(async move {
let _ = io.shutdown().await;
});
}
continue;
}
let io = conn.io;
match io {
ConnectionType::H1(ref s) => {
if s.is_closed() {
continue;
}
let is_valid = s.with_read_buf(|buf| {
if buf.is_empty() || (buf.len() == 2 && &buf[..] == b"\r\n") {
buf.clear();
true
} else {
false
}
});
if !is_valid {
continue;
}
}
ConnectionType::H2(ref s) => {
if s.is_closed() {
continue;
}
let conn = AvailableConnection {
io: ConnectionType::H2(s.clone()),
used: now,
created: conn.created,
};
connections.push_front(conn);
}
}
return Acquire::Acquired(io, conn.created);
}
}
if self.connecting.contains(key) {
Acquire::NotAvailable
} else {
Acquire::Available
}
}
fn check_availibility(&mut self) {
let mut waiters = self.waiters.borrow_mut();
waiters.cleanup();
if !waiters.waiters.is_empty() && self.acquired < self.limit {
let _ = self.waker.send(());
}
}
}
async fn run_connection_pool(
svc: Pipeline<Connector>,
inner: Rc<RefCell<Inner>>,
waiters: Rc<RefCell<Waiters>>,
config: SharedCfg,
mut stop: oneshot::Receiver<()>,
) {
let tag = config.tag();
log::trace!("{tag}: Starting connection pool support task");
loop {
{
let mut cleanup = false;
let mut waiters = waiters.borrow_mut();
for (key, waiters) in &mut waiters.waiters {
while let Some((req, tx)) = waiters.front() {
if tx.is_canceled() {
log::trace!("{tag}: Waiter for {:?} is gone, cleanup", req.uri);
cleanup = true;
waiters.pop_front();
continue;
}
let result = inner.borrow_mut().acquire(key);
match result {
Acquire::NotAvailable => break,
Acquire::Acquired(io, created) => {
log::trace!(
"{tag}: Use existing {:?} connection for {:?}, wake up waiter",
io,
req.uri
);
cleanup = true;
let (_, tx) = waiters.pop_front().unwrap();
let _ = tx.send(Ok(Connection::new(
io,
created,
Some(Acquired::new(key.clone(), inner.clone())),
)));
}
Acquire::Available => {
log::trace!(
"{tag}: Connecting to {:?} and wake up waiter",
req.uri
);
cleanup = true;
let (connect, tx) = waiters.pop_front().unwrap();
let uri = connect.uri.clone();
OpenConnection::spawn(
key.clone(),
tx,
uri,
inner.clone(),
&svc,
connect,
);
}
}
}
}
if cleanup {
waiters.cleanup();
}
}
let result = select(
&mut stop,
future::poll_fn(|cx| inner.borrow().waker.poll_recv(cx)),
)
.await;
if matches!(result, Either::Left(_)) || inner.borrow().stopped {
log::trace!("{tag}: Stopping connection pool support task");
break;
}
}
}
pin_project_lite::pin_project! {
struct OpenConnection {
key: Key,
#[pin]
fut: PipelineCall<Connector, Connect>,
uri: Uri,
tx: Option<Waiter>,
guard: Option<OpenGuard>,
inner: Rc<RefCell<Inner>>,
}
}
impl OpenConnection {
fn spawn(
key: Key,
tx: Waiter,
uri: Uri,
inner: Rc<RefCell<Inner>>,
pipeline: &Pipeline<Connector>,
msg: Connect,
) {
let fut = pipeline.call_static(msg);
spawn(async move {
OpenConnection {
tx: Some(tx),
key: key.clone(),
inner: inner.clone(),
guard: Some(OpenGuard::new(key, inner)),
fut,
uri,
}
.await;
});
}
}
impl future::Future for OpenConnection {
type Output = ();
fn poll(self: pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match ready!(this.fut.poll(cx)) {
Err(err) => {
log::trace!(
"Failed to open client connection for {:?} with error {:?}",
&this.key.authority,
err
);
let _ = this.guard.take();
if let Some(rx) = this.tx.take() {
let _ = rx.send(Err(err));
}
Poll::Ready(())
}
Ok(io) => {
if this.inner.borrow().stopped {
return Poll::Ready(());
}
if io.query::<HttpProtocol>().get() == Some(HttpProtocol::Http2) {
log::trace!(
"Connection for {:?} is established, start http2 handshake",
&this.key.authority
);
let auth = if let Some(auth) = this.uri.authority() {
format!("{auth}").into()
} else {
ByteString::new()
};
let client = h2::client::SimpleClient::new(
io,
this.uri.scheme().cloned().unwrap_or(Scheme::HTTPS),
auth,
);
let client = H2Client::new(client);
let guard = this.guard.take().unwrap().consume();
let conn = Connection::new(
ConnectionType::H2(client.clone()),
now(),
Some(guard.clone()),
);
if this.tx.take().unwrap().send(Ok(conn)).is_err() {
log::trace!(
"Waiter for {:?} is gone while connecting to host",
&this.key.authority
);
}
Connection::new(ConnectionType::H2(client), now(), Some(guard))
.release(false);
Poll::Ready(())
} else {
log::trace!(
"Connection for {:?} is established, init http1 connection",
&this.key.authority
);
let conn = Connection::new(
ConnectionType::H1(io),
now(),
Some(this.guard.take().unwrap().consume()),
);
if let Err(Ok(conn)) = this.tx.take().unwrap().send(Ok(conn)) {
conn.release(false);
}
this.inner.borrow_mut().check_availibility();
Poll::Ready(())
}
}
}
}
}
struct OpenGuard {
key: Key,
inner: Option<Rc<RefCell<Inner>>>,
}
impl OpenGuard {
fn new(key: Key, inner: Rc<RefCell<Inner>>) -> Self {
inner.borrow_mut().connecting.insert(key.clone());
OpenGuard {
key,
inner: Some(inner),
}
}
fn consume(mut self) -> Acquired {
let inner = self.inner.take().unwrap();
inner.borrow_mut().connecting.remove(&self.key);
Acquired::new(self.key.clone(), inner)
}
}
impl Drop for OpenGuard {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
let mut pool = inner.borrow_mut();
pool.connecting.remove(&self.key);
pool.check_availibility();
}
}
}
pub(super) struct Acquired(Key, Option<Rc<RefCell<Inner>>>);
impl Acquired {
fn new(key: Key, inner: Rc<RefCell<Inner>>) -> Self {
inner.borrow_mut().acquired += 1;
Acquired(key, Some(inner))
}
fn clone(&self) -> Self {
Acquired::new(self.0.clone(), self.1.as_ref().unwrap().clone())
}
pub(super) fn release(&mut self, conn: Connection, close: bool) {
if let Some(inner) = self.1.take() {
let (io, created, _) = conn.into_inner();
let mut inner = inner.borrow_mut();
inner.acquired -= 1;
if close {
log::trace!(
"{:?}: Releasing and closing connection for {:?}",
io.tag(),
self.0.authority
);
match io {
ConnectionType::H1(io) => {
spawn(async move {
let _ = io.shutdown().await;
});
}
ConnectionType::H2(io) => io.close(),
}
} else {
log::trace!(
"{:?}: Releasing connection for {:?}",
io.tag(),
self.0.authority
);
inner
.available
.entry(self.0.clone())
.or_insert_with(VecDeque::new)
.push_back(AvailableConnection {
io,
created,
used: now(),
});
}
inner.check_availibility();
}
}
}
impl Drop for Acquired {
fn drop(&mut self) {
if let Some(inner) = self.1.take() {
let mut inner = inner.borrow_mut();
inner.acquired -= 1;
inner.check_availibility();
}
}
}
#[cfg(test)]
mod tests {
use std::future::Future;
use super::*;
use crate::time::{Millis, sleep};
use crate::{io as nio, service::fn_service, testing::IoTest, util::lazy};
#[crate::rt_test]
async fn test_basics() {
let store = Rc::new(RefCell::new(Vec::new()));
let store2 = store.clone();
let pool = Pipeline::new(
ConnectionPool::new(
Pipeline::new(boxed::service(fn_service(move |req| {
let (client, server) = IoTest::create();
store2.borrow_mut().push((req, server));
Box::pin(async move {
Ok(IoBoxed::from(nio::Io::new(client, SharedCfg::default())))
})
}))),
Duration::from_secs(10),
Duration::from_secs(10),
1,
SharedCfg::default(),
)
.clone(),
)
.bind();
let req = Connect {
uri: Uri::try_from("/test").unwrap(),
addr: None,
};
assert!(matches!(
pool.call(req).await,
Err(ConnectError::Unresolved)
));
let req = Connect {
uri: Uri::try_from("http://localhost/test").unwrap(),
addr: None,
};
let conn = pool.call(req.clone()).await.unwrap();
assert_eq!(store.borrow().len(), 1);
assert!(format!("{conn:?}").contains("Connection(h1)"));
assert_eq!(conn.protocol(), HttpProtocol::Http1);
assert_eq!(pool.get_ref().0.inner.borrow().acquired, 1);
assert!(pool.get_ref().0.inner.borrow().connecting.is_empty());
let mut fut = std::pin::pin!(pool.call(req.clone()));
assert!(lazy(|cx| fut.as_mut().poll(cx)).await.is_pending());
assert_eq!(pool.get_ref().0.waiters.borrow().waiters.len(), 1);
conn.release(false);
assert_eq!(pool.get_ref().0.inner.borrow().acquired, 0);
let conn = fut.await.unwrap();
assert_eq!(store.borrow().len(), 1);
assert!(pool.get_ref().0.waiters.borrow().waiters.is_empty());
drop(conn);
let conn = pool.call(req.clone()).await.unwrap();
assert_eq!(store.borrow().len(), 2);
assert_eq!(pool.get_ref().0.inner.borrow().acquired, 1);
assert!(pool.get_ref().0.inner.borrow().connecting.is_empty());
let mut fut = std::pin::pin!(pool.call(req.clone()));
assert!(lazy(|cx| fut.as_mut().poll(cx)).await.is_pending());
assert_eq!(pool.get_ref().0.waiters.borrow().waiters.len(), 1);
conn.release(true);
assert_eq!(pool.get_ref().0.inner.borrow().acquired, 0);
assert!(pool.get_ref().0.inner.borrow().connecting.is_empty());
let conn = fut.await.unwrap();
assert_eq!(store.borrow().len(), 3);
assert!(pool.get_ref().0.waiters.borrow().waiters.is_empty());
assert!(pool.get_ref().0.inner.borrow().connecting.is_empty());
assert_eq!(pool.get_ref().0.inner.borrow().acquired, 1);
let mut fut = Box::pin(pool.call(req.clone()));
assert!(
lazy(|cx| pin::Pin::new(&mut fut).poll(cx))
.await
.is_pending()
);
drop(fut);
sleep(Millis(50)).await;
pool.get_ref().0.inner.borrow_mut().check_availibility();
assert!(pool.get_ref().0.waiters.borrow().waiters.is_empty());
let req = Connect {
uri: Uri::try_from("http://localhost2/test").unwrap(),
addr: None,
};
let mut fut = std::pin::pin!(pool.call(req.clone()));
assert!(lazy(|cx| fut.as_mut().poll(cx)).await.is_pending());
assert_eq!(pool.get_ref().0.waiters.borrow().waiters.len(), 1);
conn.release(false);
assert_eq!(pool.get_ref().0.inner.borrow().acquired, 0);
assert_eq!(pool.get_ref().0.inner.borrow().available.len(), 1);
let conn = fut.await.unwrap();
assert_eq!(store.borrow().len(), 4);
assert!(pool.get_ref().0.waiters.borrow().waiters.is_empty());
assert!(pool.get_ref().0.inner.borrow().connecting.is_empty());
assert_eq!(pool.get_ref().0.inner.borrow().acquired, 1);
conn.release(false);
assert_eq!(pool.get_ref().0.inner.borrow().acquired, 0);
assert_eq!(pool.get_ref().0.inner.borrow().available.len(), 2);
assert!(lazy(|cx| pool.poll_ready(cx)).await.is_ready());
assert!(lazy(|cx| pool.poll_shutdown(cx)).await.is_ready());
}
}