#![warn(
absolute_paths_not_starting_with_crate,
meta_variable_misuse,
missing_debug_implementations,
missing_docs,
noop_method_call,
pointer_structural_match,
unreachable_pub,
unused_crate_dependencies,
unused_lifetimes,
clippy::cast_lossless,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_precision_loss,
clippy::cast_sign_loss,
clippy::checked_conversions,
clippy::cognitive_complexity,
clippy::exhaustive_enums,
clippy::exhaustive_structs,
clippy::future_not_send,
clippy::inconsistent_struct_constructor,
clippy::inefficient_to_string,
clippy::use_debug,
clippy::use_self
)]
use futures::{
future::BoxFuture,
ready,
task::{Context, Poll},
Future,
};
use hyper::{client::connect::Connection, service::Service, Uri};
use std::{error::Error, fmt, pin::Pin};
use tokio::io::{AsyncRead, AsyncWrite};
use trust_dns_resolver::{
error::{ResolveError, ResolveErrorKind},
lookup::SrvLookup,
TokioAsyncResolver,
};
#[derive(Debug, Clone)]
pub struct ServiceConnector<C> {
resolver: Option<TokioAsyncResolver>,
inner: C,
}
impl<C> Service<Uri> for ServiceConnector<C>
where
C: Service<Uri> + Clone + Unpin,
C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
C::Error: Into<Box<dyn Error + Send + Sync>>,
C::Future: Unpin + Send,
{
type Response = C::Response;
type Error = ServiceError;
type Future = ServiceConnecting<C>;
fn poll_ready(&mut self, ctx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(ctx).map_err(ServiceError::inner)
}
fn call(&mut self, uri: Uri) -> Self::Future {
let fut = match (&self.resolver, uri.host(), uri.port()) {
(Some(resolver), Some(_), None) => {
ServiceConnectingKind::Preresolve {
inner: self.inner.clone(),
fut: {
let resolver = resolver.clone();
Box::pin(async move {
let host = uri.host().expect("host was right here, now it is gone");
let resolved = resolver.srv_lookup(host).await;
(resolved, uri)
})
},
}
},
_ => {
ServiceConnectingKind::Inner {
fut: self.inner.call(uri),
}
},
};
ServiceConnecting(fut)
}
}
impl<C> ServiceConnector<C> {
pub fn new(inner: C, resolver: Option<TokioAsyncResolver>) -> Self {
Self {
resolver,
inner,
}
}
}
#[derive(Debug)]
enum ServiceErrorKind {
Resolve(Box<ResolveError>),
Inner(Box<dyn Error + Send + Sync>),
}
#[derive(Debug)]
pub struct ServiceError(ServiceErrorKind);
impl From<ResolveError> for ServiceError {
fn from(error: ResolveError) -> Self {
Self(ServiceErrorKind::Resolve(Box::new(error)))
}
}
impl fmt::Display for ServiceError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.0 {
ServiceErrorKind::Resolve(err) => fmt::Display::fmt(err, f),
ServiceErrorKind::Inner(err) => fmt::Display::fmt(err, f),
}
}
}
impl Error for ServiceError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match &self.0 {
ServiceErrorKind::Resolve(_) => None,
ServiceErrorKind::Inner(err) => Some(err.as_ref()),
}
}
}
impl ServiceError {
fn inner<E>(inner: E) -> Self
where
E: Into<Box<dyn Error + Send + Sync>>,
{
Self(ServiceErrorKind::Inner(inner.into()))
}
}
#[allow(clippy::large_enum_variant)]
enum ServiceConnectingKind<C>
where
C: Service<Uri> + Unpin,
{
Preresolve {
inner: C,
fut: BoxFuture<'static, (Result<SrvLookup, ResolveError>, Uri)>,
},
Inner {
fut: C::Future,
},
}
impl<C> fmt::Debug for ServiceConnectingKind<C>
where
C: Service<Uri> + Unpin,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ServiceConnectingKind").finish()
}
}
#[derive(Debug)]
pub struct ServiceConnecting<C>(ServiceConnectingKind<C>)
where
C: Service<Uri> + Unpin;
impl<C> Future for ServiceConnecting<C>
where
C: Service<Uri> + Unpin,
C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
C::Error: Into<Box<dyn Error + Send + Sync>>,
C::Future: Unpin + Send,
{
type Output = Result<C::Response, ServiceError>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
match &mut self.0 {
ServiceConnectingKind::Preresolve {
inner,
fut,
} => {
let (res, uri) = ready!(Pin::new(fut).poll(ctx));
let response = res.map(Some).or_else(|err| {
match err.kind() {
ResolveErrorKind::NoRecordsFound {
..
} => Ok(None),
_unexpected => Err(ServiceError(ServiceErrorKind::Resolve(Box::new(err)))),
}
})?;
let uri = match response.as_ref().and_then(|response| response.iter().next()) {
Some(srv) => {
let authority = format!("{}:{}", srv.target(), srv.port());
let builder = Uri::builder().authority(authority.as_str());
let builder = match uri.scheme() {
Some(scheme) => builder.scheme(scheme.clone()),
None => builder,
};
let builder = match uri.path_and_query() {
Some(path_and_query) => builder.path_and_query(path_and_query.clone()),
None => builder,
};
builder.build().map_err(ServiceError::inner)?
},
None => uri,
};
{
*self = Self(ServiceConnectingKind::Inner {
fut: inner.call(uri),
});
}
self.poll(ctx)
},
ServiceConnectingKind::Inner {
fut,
} => Pin::new(fut).poll(ctx).map_err(ServiceError::inner),
}
}
}