use std::{error::Error, task::Context, time::Duration};
use crate::connect::{Connect as TcpConnect, Connector as TcpConnector};
use crate::service::{Service, ServiceCtx, ServiceFactory, apply_fn_factory, boxed};
use crate::{SharedCfg, http::Uri, io::IoBoxed, time::Seconds, util::join};
use super::{Connect, Connection, error::ConnectError, pool::ConnectionPool};
#[cfg(feature = "openssl")]
use tls_openssl::ssl::SslConnector as OpensslConnector;
#[cfg(feature = "rustls")]
use tls_rustls::ClientConfig;
type BoxedConnector =
boxed::BoxServiceFactory<SharedCfg, Connect, IoBoxed, ConnectError, Box<dyn Error>>;
#[derive(Debug)]
pub struct Connector {
conn_lifetime: Duration,
conn_keep_alive: Duration,
limit: usize,
svc: BoxedConnector,
secure_svc: Option<BoxedConnector>,
}
impl Default for Connector {
fn default() -> Self {
Connector::new()
}
}
impl Connector {
pub fn new() -> Connector {
let conn = Connector {
svc: boxed::factory(
apply_fn_factory(TcpConnector::new(), async move |msg: Connect, svc| {
svc.call(TcpConnect::new(msg.uri).set_addr(msg.addr)).await
})
.map(IoBoxed::from)
.map_err(ConnectError::from)
.map_init_err(|e| Box::new(e) as Box<dyn Error>),
),
secure_svc: None,
conn_lifetime: Duration::from_secs(75),
conn_keep_alive: Duration::from_secs(15),
limit: 8,
};
#[cfg(feature = "openssl")]
{
use tls_openssl::ssl::SslMethod;
let mut ssl = OpensslConnector::builder(SslMethod::tls()).unwrap();
let _ = ssl
.set_alpn_protos(b"\x02h2\x08http/1.1")
.map_err(|e| log::error!("Cannot set ALPN protocol: {e:?}"));
ssl.set_verify(tls_openssl::ssl::SslVerifyMode::NONE);
conn.openssl(ssl.build())
}
#[cfg(all(not(feature = "openssl"), feature = "rustls"))]
{
use tls_rustls::RootCertStore;
let protos = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let cert_store =
RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let mut config = ClientConfig::builder()
.with_root_certificates(cert_store)
.with_no_client_auth();
config.alpn_protocols = protos;
conn.rustls(config)
}
#[cfg(not(any(feature = "openssl", feature = "rustls")))]
{
conn
}
}
}
impl Connector {
#[must_use]
#[cfg(feature = "openssl")]
pub fn openssl(self, connector: OpensslConnector) -> Self {
use crate::connect::openssl::SslConnector;
self.secure_connector(SslConnector::new(connector))
}
#[must_use]
#[cfg(feature = "rustls")]
pub fn rustls(self, connector: ClientConfig) -> Self {
use crate::connect::rustls::TlsConnector;
self.secure_connector(TlsConnector::new(connector))
}
#[must_use]
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
#[must_use]
pub fn keep_alive<T: Into<Seconds>>(mut self, dur: T) -> Self {
self.conn_keep_alive = dur.into().into();
self
}
#[must_use]
pub fn lifetime<T: Into<Seconds>>(mut self, dur: T) -> Self {
self.conn_lifetime = dur.into().into();
self
}
#[must_use]
pub fn connector<T>(mut self, connector: T) -> Self
where
T: ServiceFactory<TcpConnect<Uri>, SharedCfg, Error = crate::connect::ConnectError>
+ 'static,
T::InitError: Error,
IoBoxed: From<T::Response>,
{
self.svc = boxed::factory(
apply_fn_factory(connector, async move |msg: Connect, svc| {
svc.call(TcpConnect::new(msg.uri).set_addr(msg.addr)).await
})
.map(IoBoxed::from)
.map_err(ConnectError::from)
.map_init_err(|e| Box::new(e) as Box<dyn Error>),
);
self
}
#[must_use]
pub fn secure_connector<T>(mut self, connector: T) -> Self
where
T: ServiceFactory<TcpConnect<Uri>, SharedCfg, Error = crate::connect::ConnectError>
+ 'static,
T::InitError: Error,
IoBoxed: From<T::Response>,
{
self.secure_svc = Some(boxed::factory(
apply_fn_factory(connector, async move |msg: Connect, svc| {
svc.call(TcpConnect::new(msg.uri).set_addr(msg.addr)).await
})
.map(IoBoxed::from)
.map_err(ConnectError::from)
.map_init_err(|e| Box::new(e) as Box<dyn Error>),
));
self
}
}
impl ServiceFactory<Connect, SharedCfg> for Connector {
type Response = Connection;
type Error = ConnectError;
type Service = ConnectorService;
type InitError = Box<dyn Error>;
async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
let ssl_pool = if let Some(ref svc) = self.secure_svc {
Some(ConnectionPool::new(
svc.create(cfg.clone()).await?.into(),
self.conn_lifetime,
self.conn_keep_alive,
self.limit,
cfg.clone(),
))
} else {
None
};
let tcp_pool = ConnectionPool::new(
self.svc.create(cfg.clone()).await?.into(),
self.conn_lifetime,
self.conn_keep_alive,
self.limit,
cfg,
);
Ok(ConnectorService { tcp_pool, ssl_pool })
}
}
#[derive(Clone, Debug)]
pub struct ConnectorService {
tcp_pool: ConnectionPool,
ssl_pool: Option<ConnectionPool>,
}
impl Service<Connect> for ConnectorService {
type Response = Connection;
type Error = ConnectError;
#[inline]
async fn ready(&self, ctx: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
if let Some(ref ssl_pool) = self.ssl_pool {
let (r1, r2) = join(ctx.ready(&self.tcp_pool), ctx.ready(ssl_pool)).await;
r1?;
r2
} else {
ctx.ready(&self.tcp_pool).await
}
}
#[inline]
fn poll(&self, cx: &mut Context<'_>) -> Result<(), Self::Error> {
self.tcp_pool.poll(cx)?;
if let Some(ref ssl_pool) = self.ssl_pool {
ssl_pool.poll(cx)?;
}
Ok(())
}
async fn shutdown(&self) {
self.tcp_pool.shutdown().await;
if let Some(ref ssl_pool) = self.ssl_pool {
ssl_pool.shutdown().await;
}
}
async fn call(
&self,
req: Connect,
ctx: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
match req.uri.scheme_str() {
Some("https" | "wss") => {
if let Some(ref conn) = self.ssl_pool {
ctx.call(conn, req).await
} else {
Err(ConnectError::SslIsNotSupported)
}
}
_ => ctx.call(&self.tcp_pool, req).await,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{service::Pipeline, util::lazy};
#[crate::rt_test]
async fn test_readiness() {
let conn = Pipeline::new(
Connector::default()
.create(SharedCfg::default())
.await
.unwrap(),
)
.bind();
assert!(lazy(|cx| conn.poll_ready(cx).is_ready()).await);
assert!(lazy(|cx| conn.poll_shutdown(cx).is_ready()).await);
}
}