mod tls_info;
mod verbose;
pub(super) mod connector;
pub(super) mod descriptor;
pub(super) mod http;
pub(super) mod net;
pub(super) mod proxy;
use std::{
fmt::{self, Debug, Formatter},
io,
io::IoSlice,
pin::Pin,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
task::{Context, Poll},
};
use ::http::{Extensions, HeaderMap, HeaderValue};
#[cfg(any(feature = "tokio-rt", feature = "compio-rt"))]
use net::TcpConnector;
use pin_project_lite::pin_project;
use tls_info::TlsInfoFactory;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_btls::SslStream;
use tower::{
BoxError,
util::{BoxCloneSyncService, BoxCloneSyncServiceLayer},
};
use crate::{
dns::DynResolver,
proxy::matcher::Intercept,
tls::{AlpnProtocol, TlsInfo},
};
pub type HttpConnector = http::HttpConnector<DynResolver, TcpConnector>;
pub type BoxedConnectorService = BoxCloneSyncService<Unnameable, Conn, BoxError>;
pub type BoxedConnectorLayer =
BoxCloneSyncServiceLayer<BoxedConnectorService, Unnameable, Conn, BoxError>;
pub struct Unnameable(pub(super) descriptor::ConnectionDescriptor);
trait AsyncConn: AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static {}
trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {}
impl<T> AsyncConn for T where T: AsyncRead + AsyncWrite + Connection + Send + Sync + Unpin + 'static {}
impl<T> AsyncConnWithInfo for T where T: AsyncConn + TlsInfoFactory {}
pin_project! {
pub struct Conn {
tls_info: bool,
proxy: Option<Intercept>,
#[pin]
stream: Box<dyn AsyncConnWithInfo>,
}
}
pin_project! {
pub struct TlsConn<T> {
#[pin]
stream: SslStream<T>,
}
}
pub trait Connection {
fn connected(&self) -> Connected;
}
#[derive(Clone, Copy, Debug, PartialEq)]
enum Alpn {
H2,
None,
}
#[derive(Clone)]
struct PoisonPill(Arc<AtomicBool>);
#[derive(Debug)]
struct Extra(Box<dyn ExtraInner>);
trait ExtraInner: Send + Sync + Debug {
fn clone_box(&self) -> Box<dyn ExtraInner>;
fn set(&self, res: &mut Extensions);
}
#[derive(Debug, Clone)]
struct ExtraEnvelope<T>(T);
#[derive(Debug)]
struct ExtraChain<T>(Box<dyn ExtraInner>, T);
#[derive(Debug, Default, Clone)]
struct ProxyIdentity {
is_proxied: bool,
auth: Option<HeaderValue>,
headers: Option<HeaderMap>,
}
#[derive(Debug, Clone)]
pub struct Connected {
alpn: Alpn,
proxy: Box<ProxyIdentity>,
extra: Option<Extra>,
poisoned: PoisonPill,
}
impl Connection for Conn {
fn connected(&self) -> Connected {
let mut connected = self.stream.connected();
if let Some(proxy) = &self.proxy {
connected = connected.proxy(proxy.clone());
}
if self.tls_info {
if let Some(tls_info) = self.stream.tls_info() {
connected.extra(tls_info)
} else {
connected
}
} else {
connected
}
}
}
impl AsyncRead for Conn {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
AsyncRead::poll_read(self.project().stream, cx, buf)
}
}
impl AsyncWrite for Conn {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
AsyncWrite::poll_write(self.project().stream, cx, buf)
}
#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
AsyncWrite::poll_write_vectored(self.project().stream, cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.stream.is_write_vectored()
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
AsyncWrite::poll_flush(self.project().stream, cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
AsyncWrite::poll_shutdown(self.project().stream, cx)
}
}
impl<T> Connection for TlsConn<T>
where
T: Connection,
{
fn connected(&self) -> Connected {
let connected = self.stream.get_ref().connected();
if self
.stream
.ssl()
.selected_alpn_protocol()
.is_some_and(|alpn| AlpnProtocol::HTTP2.eq(alpn))
{
connected.negotiated_h2()
} else {
connected
}
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsConn<T> {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll<tokio::io::Result<()>> {
AsyncRead::poll_read(self.project().stream, cx, buf)
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsConn<T> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, tokio::io::Error>> {
AsyncWrite::poll_write(self.project().stream, cx, buf)
}
#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
AsyncWrite::poll_write_vectored(self.project().stream, cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.stream.is_write_vectored()
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
AsyncWrite::poll_flush(self.project().stream, cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
AsyncWrite::poll_shutdown(self.project().stream, cx)
}
}
impl<T> TlsInfoFactory for TlsConn<T>
where
SslStream<T>: TlsInfoFactory,
{
#[inline]
fn tls_info(&self) -> Option<TlsInfo> {
self.stream.tls_info()
}
}
impl fmt::Debug for PoisonPill {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(
f,
"PoisonPill@{:p} {{ poisoned: {} }}",
self.0,
self.0.load(Ordering::Relaxed)
)
}
}
impl PoisonPill {
#[inline]
fn healthy() -> Self {
Self(Arc::new(AtomicBool::new(false)))
}
}
impl Connected {
pub fn new() -> Connected {
Connected {
alpn: Alpn::None,
proxy: Box::new(ProxyIdentity::default()),
extra: None,
poisoned: PoisonPill::healthy(),
}
}
pub fn extra<T: Clone + Send + Sync + Debug + 'static>(mut self, extra: T) -> Connected {
if let Some(prev) = self.extra {
self.extra = Some(Extra(Box::new(ExtraChain(prev.0, extra))));
} else {
self.extra = Some(Extra(Box::new(ExtraEnvelope(extra))));
}
self
}
#[inline]
pub fn set_extras(&self, extensions: &mut Extensions) {
if let Some(extra) = &self.extra {
extra.set(extensions);
}
}
pub fn proxy(mut self, proxy: Intercept) -> Connected {
self.proxy.is_proxied = true;
if let Some(auth) = proxy.basic_auth() {
self.proxy.auth.replace(auth.clone());
}
if let Some(headers) = proxy.custom_headers() {
self.proxy.headers.replace(headers.clone());
}
self
}
#[inline]
pub fn is_proxied(&self) -> bool {
self.proxy.is_proxied
}
#[inline]
pub fn proxy_auth(&self) -> Option<&HeaderValue> {
self.proxy.auth.as_ref()
}
#[inline]
pub fn proxy_headers(&self) -> Option<&HeaderMap> {
self.proxy.headers.as_ref()
}
#[inline]
pub fn negotiated_h2(mut self) -> Connected {
self.alpn = Alpn::H2;
self
}
#[inline]
pub fn is_negotiated_h2(&self) -> bool {
self.alpn == Alpn::H2
}
#[inline]
pub fn poisoned(&self) -> bool {
self.poisoned.0.load(Ordering::Relaxed)
}
#[allow(unused)]
#[inline]
pub fn poison(&self) {
self.poisoned.0.store(true, Ordering::Relaxed);
debug!(
"connection was poisoned. this connection will not be reused for subsequent requests"
);
}
}
impl Extra {
#[inline]
fn set(&self, res: &mut Extensions) {
self.0.set(res);
}
}
impl Clone for Extra {
fn clone(&self) -> Extra {
Extra(self.0.clone_box())
}
}
impl<T> ExtraInner for ExtraEnvelope<T>
where
T: Clone + Send + Sync + Debug + 'static,
{
fn clone_box(&self) -> Box<dyn ExtraInner> {
Box::new(self.clone())
}
fn set(&self, res: &mut Extensions) {
res.insert(self.0.clone());
}
}
impl<T: Clone> Clone for ExtraChain<T> {
fn clone(&self) -> Self {
ExtraChain(self.0.clone_box(), self.1.clone())
}
}
impl<T> ExtraInner for ExtraChain<T>
where
T: Clone + Send + Sync + Debug + 'static,
{
fn clone_box(&self) -> Box<dyn ExtraInner> {
Box::new(self.clone())
}
fn set(&self, res: &mut Extensions) {
self.0.set(res);
res.insert(self.1.clone());
}
}