use std::collections::VecDeque;
use std::{cell::Cell, cell::RefCell, fmt, marker::PhantomData, rc::Rc, time::Duration};
use nanorand::{Rng, WyRand};
use ntex_bytes::ByteString;
use ntex_error::Error;
use ntex_http::{HeaderMap, Method, uri::Scheme};
use ntex_io::IoBoxed;
use ntex_net::connect::{Address, Connect, ConnectError, Connector2 as DefaultConnector};
use ntex_service::cfg::{Cfg, SharedCfg};
use ntex_service::{IntoServiceFactory, Pipeline, ServiceFactory};
use ntex_util::time::{Millis, Seconds, timeout_checked};
use ntex_util::{channel::oneshot, channel::pool, future::BoxFuture};
use super::stream::{InflightStorage, RecvStream, SendStream};
use super::{ClientError, simple::SimpleClient};
use crate::ServiceConfig;
type Fut = BoxFuture<'static, Result<IoBoxed, Error<ConnectError>>>;
type Connector = Box<dyn Fn() -> BoxFuture<'static, Result<IoBoxed, Error<ConnectError>>>>;
#[derive(Clone)]
pub struct Client {
inner: Rc<Inner>,
waiters: Rc<RefCell<VecDeque<pool::Sender<()>>>>,
}
struct Inner {
cfg: Cfg<ServiceConfig>,
config: InnerConfig,
connector: Connector,
}
fn notify(waiters: &mut VecDeque<pool::Sender<()>>) {
#[cfg(feature = "extra-trace")]
log::debug!("Notify waiter, total {:?}", waiters.len());
while let Some(waiter) = waiters.pop_front() {
if waiter.send(()).is_ok() {
break;
}
}
}
impl Client {
#[inline]
pub fn builder<A, U, T, F>(addr: U, connector: F) -> ClientBuilder<A, T>
where
A: Address + Clone,
F: IntoServiceFactory<T, Connect<A>, SharedCfg>,
T: ServiceFactory<Connect<A>, SharedCfg, Error = Error<ConnectError>> + 'static,
IoBoxed: From<T::Response>,
Connect<A>: From<U>,
{
ClientBuilder::new(addr, connector)
}
pub async fn send(
&self,
method: Method,
path: ByteString,
headers: HeaderMap,
eof: bool,
) -> Result<(SendStream, RecvStream), Error<ClientError>> {
self.client()
.await?
.send(method, path, headers, eof)
.await
.map_err(|e| e.map(ClientError::from))
}
pub async fn client(&self) -> Result<SimpleClient, Error<ClientError>> {
loop {
let (client, num) = self.get_client();
if let Some(client) = client {
return Ok(client);
}
self.connect(num).await?;
}
}
async fn connect(&self, num: usize) -> Result<(), Error<ClientError>> {
let cfg = &self.inner.config;
if !cfg.connecting.get() && (num < cfg.maxconn || (cfg.minconn > 0 && num < cfg.minconn)) {
cfg.connecting.set(true);
self.create_connection().await?;
} else {
#[cfg(feature = "extra-trace")]
log::debug!(
"New connection is being established {:?} or number of existing cons {num} greater than allowed {}",
cfg.connecting.get(),
cfg.maxconn
);
let (tx, rx) = cfg.pool.channel();
self.waiters.borrow_mut().push_back(tx);
rx.await
.map_err(|e| Error::new(e, self.inner.cfg.service()))?;
}
Ok(())
}
fn get_client(&self) -> (Option<SimpleClient>, usize) {
let cfg = &self.inner.config;
let mut connections = cfg.connections.borrow_mut();
let mut idx = 0;
while idx < connections.len() {
if connections[idx].is_closed() {
connections.remove(idx);
} else if connections[idx].is_disconnecting() {
let con = connections.remove(idx);
let timeout = cfg.disconnect_timeout;
ntex_util::spawn(async move {
let _ = con.disconnect().disconnect_timeout(timeout).await;
});
} else {
idx += 1;
}
}
let num = connections.len();
if cfg.minconn > 0 && num < cfg.minconn {
(None, num)
} else {
let client = connections.iter().find(|item| {
let cap = item.max_streams().unwrap_or(cfg.max_streams) >> 1;
item.active_streams() <= cap
});
if let Some(client) = client {
(Some(client.clone()), num)
} else {
let available = connections.iter().filter(|item| item.is_ready()).count();
let client = if available > 0 {
let idx = WyRand::new().generate_range(0_usize..available);
connections
.iter()
.filter(|item| item.is_ready())
.nth(idx)
.cloned()
} else {
None
};
(client, num)
}
}
}
async fn create_connection(&self) -> Result<(), Error<ClientError>> {
let (tx, rx) = oneshot::channel();
let inner = self.inner.clone();
let waiters = self.waiters.clone();
ntex_util::spawn(async move {
let res = match timeout_checked(inner.config.conn_timeout, (*inner.connector)()).await {
Ok(Ok(io)) => {
let waiters2 = waiters.clone();
let storage = InflightStorage::new(move |_| {
notify(&mut waiters2.borrow_mut());
});
let client = SimpleClient::with_params(
io,
inner.cfg.clone(),
&inner.config.scheme,
inner.config.authority.clone(),
inner.config.skip_unknown_streams,
storage,
inner.config.pool.clone(),
);
inner.config.connections.borrow_mut().push(client);
inner
.config
.total_connections
.set(inner.config.total_connections.get() + 1);
Ok(())
}
Ok(Err(err)) => Err(err.map(ClientError::from)),
Err(()) => Err(Error::from(ClientError::HandshakeTimeout)),
};
inner.config.connecting.set(false);
for waiter in waiters.borrow_mut().drain(..) {
let _ = waiter.send(());
}
if res.is_err() {
inner
.config
.connect_errors
.set(inner.config.connect_errors.get() + 1);
}
let _ = tx.send(res);
});
rx.await
.map_err(|e| Error::new(e, self.inner.cfg.service()))?
}
#[inline]
pub fn is_ready(&self) -> bool {
let connections = self.inner.config.connections.borrow();
for client in &*connections {
if client.is_ready() {
return true;
}
}
!self.inner.config.connecting.get() && connections.len() < self.inner.config.maxconn
}
#[inline]
pub async fn ready(&self) {
loop {
if self.is_ready() {
break;
}
let (tx, rx) = self.inner.config.pool.channel();
self.waiters.borrow_mut().push_back(tx);
let _ = rx.await;
'inner: while let Some(tx) = self.waiters.borrow_mut().pop_front() {
if tx.send(()).is_ok() {
break 'inner;
}
}
}
}
}
#[doc(hidden)]
impl Client {
pub fn stat_active_connections(&self) -> usize {
self.inner.config.connections.borrow().len()
}
pub fn stat_total_connections(&self) -> usize {
self.inner.config.total_connections.get()
}
pub fn stat_connect_errors(&self) -> usize {
self.inner.config.connect_errors.get()
}
pub fn stat_connections<F, R>(&self, f: F) -> R
where
F: FnOnce(&[SimpleClient]) -> R,
{
f(&self.inner.config.connections.borrow())
}
}
pub struct ClientBuilder<A, T> {
connect: Connect<A>,
inner: InnerConfig,
connector: T,
_t: PhantomData<A>,
}
struct InnerConfig {
minconn: usize,
maxconn: usize,
conn_timeout: Millis,
conn_lifetime: Duration,
disconnect_timeout: Millis,
max_streams: u32,
skip_unknown_streams: bool,
scheme: Scheme,
authority: ByteString,
connecting: Cell<bool>,
connections: RefCell<Vec<SimpleClient>>,
total_connections: Cell<usize>,
connect_errors: Cell<usize>,
pool: pool::Pool<()>,
}
impl<A, T> ClientBuilder<A, T>
where
A: Address + Clone,
T: ServiceFactory<Connect<A>, SharedCfg, Error = Error<ConnectError>>,
IoBoxed: From<T::Response>,
{
fn new<U, F>(addr: U, connector: F) -> Self
where
Connect<A>: From<U>,
F: IntoServiceFactory<T, Connect<A>, SharedCfg>,
{
let connect = Connect::from(addr);
let authority = ByteString::from(connect.host());
let connector = connector.into_factory();
ClientBuilder {
connect,
connector,
inner: InnerConfig {
authority,
conn_timeout: Millis(1_000),
conn_lifetime: Duration::from_secs(0),
disconnect_timeout: Millis(15_000),
max_streams: 100,
skip_unknown_streams: false,
minconn: 1,
maxconn: 16,
scheme: Scheme::HTTP,
connecting: Cell::new(false),
connections: RefCell::default(),
total_connections: Cell::new(0),
connect_errors: Cell::new(0),
pool: pool::new(),
},
_t: PhantomData,
}
}
}
impl<A> ClientBuilder<A, DefaultConnector<A>>
where
A: Address + Clone,
{
pub fn with_default<U>(addr: U) -> Self
where
Connect<A>: From<U>,
{
Self::new(addr, DefaultConnector::default())
}
}
impl<A, T> ClientBuilder<A, T>
where
A: Address + Clone,
{
#[must_use]
pub fn scheme(mut self, scheme: Scheme) -> Self {
self.inner.scheme = scheme;
self
}
#[must_use]
pub fn max_streams(mut self, limit: u32) -> Self {
self.inner.max_streams = limit;
self
}
#[must_use]
pub fn skip_unknown_streams(mut self) -> Self {
self.inner.skip_unknown_streams = true;
self
}
#[must_use]
pub fn lifetime(mut self, dur: Seconds) -> Self {
self.inner.conn_lifetime = dur.into();
self
}
#[must_use]
pub fn minconn(mut self, num: usize) -> Self {
self.inner.minconn = num;
self
}
#[must_use]
pub fn maxconn(mut self, num: usize) -> Self {
self.inner.maxconn = num;
self
}
pub fn connector<U, F>(self, connector: F) -> ClientBuilder<A, U>
where
F: IntoServiceFactory<U, Connect<A>, SharedCfg>,
U: ServiceFactory<Connect<A>, SharedCfg, Error = Error<ConnectError>> + 'static,
IoBoxed: From<U::Response>,
{
ClientBuilder {
connect: self.connect,
connector: connector.into_factory(),
inner: self.inner,
_t: PhantomData,
}
}
}
impl<A, T> ClientBuilder<A, T>
where
A: Address + Clone,
T: ServiceFactory<Connect<A>, SharedCfg, Error = Error<ConnectError>> + 'static,
IoBoxed: From<T::Response>,
{
pub async fn build(self, cfg: SharedCfg) -> Result<Client, T::InitError> {
let connect = self.connect;
let tag = cfg.tag();
let client_cfg = cfg.get();
let svc = Pipeline::new(self.connector.create(cfg).await?);
let connector = Box::new(move || {
log::trace!("{tag}: Opening http/2 connection to {}", connect.host());
let fut = svc.call_static(connect.clone());
Box::pin(async move { fut.await.map(IoBoxed::from) }) as Fut
});
Ok(Client {
inner: Rc::new(Inner {
connector,
cfg: client_cfg,
config: self.inner,
}),
waiters: Rc::default(),
})
}
}
impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Client")
.field("scheme", &self.inner.config.scheme)
.field("authority", &self.inner.config.authority)
.field("conn_timeout", &self.inner.config.conn_timeout)
.field("conn_lifetime", &self.inner.config.conn_lifetime)
.field("disconnect_timeout", &self.inner.config.disconnect_timeout)
.field("minconn", &self.inner.config.minconn)
.field("maxconn", &self.inner.config.maxconn)
.field("max-streams", &self.inner.config.max_streams)
.finish()
}
}
impl<A, T> fmt::Debug for ClientBuilder<A, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ClientBuilder")
.field("scheme", &self.inner.scheme)
.field("authority", &self.inner.authority)
.field("conn_timeout", &self.inner.conn_timeout)
.field("conn_lifetime", &self.inner.conn_lifetime)
.field("disconnect_timeout", &self.inner.disconnect_timeout)
.field("minconn", &self.inner.minconn)
.field("maxconn", &self.inner.maxconn)
.field("max-streams", &self.inner.max_streams)
.finish()
}
}