use std::{
io::{self, IoSlice},
pin::Pin,
task::{Context, Poll},
};
use pin_project_lite::pin_project;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::TcpStream,
};
use super::{AsyncConnWithInfo, Connected, Connection, TlsInfoFactory};
use crate::{
proxy::matcher::Intercept,
tls::{
TlsInfo,
conn::{MaybeHttpsStream, TlsStream},
},
};
pin_project! {
pub struct Conn {
#[pin]
pub(super) inner: Box<dyn AsyncConnWithInfo>,
pub(super) tls_info: bool,
pub(super) proxy: Option<Intercept>,
}
}
pin_project! {
pub struct TlsConn<T> {
#[pin]
inner: TlsStream<T>,
}
}
impl Connection for Conn {
fn connected(&self) -> Connected {
let mut connected = self.inner.connected();
if let Some(proxy) = &self.proxy {
connected = connected.proxy(proxy.clone());
}
if self.tls_info {
if let Some(tls_info) = self.inner.tls_info() {
connected.extra(tls_info)
} else {
connected
}
} else {
connected
}
}
}
impl AsyncRead for Conn {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
AsyncRead::poll_read(self.project().inner, cx, buf)
}
}
impl AsyncWrite for Conn {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
AsyncWrite::poll_write(self.project().inner, cx, buf)
}
#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
AsyncWrite::poll_flush(self.project().inner, cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
AsyncWrite::poll_shutdown(self.project().inner, cx)
}
}
impl<T> TlsConn<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
#[inline(always)]
pub fn new(inner: TlsStream<T>) -> Self {
Self { inner }
}
}
impl Connection for TlsConn<TcpStream> {
fn connected(&self) -> Connected {
#[cfg(feature = "boring")]
{
let connected = self.inner.get_ref().connected();
if self.inner.ssl().selected_alpn_protocol() == Some(b"h2") {
connected.negotiated_h2()
} else {
connected
}
}
#[cfg(all(feature = "rustls-tls", not(feature = "boring")))]
{
let (io, session) = self.inner.get_ref();
let connected = io.connected();
if session.alpn_protocol() == Some(b"h2") {
connected.negotiated_h2()
} else {
connected
}
}
#[cfg(not(any(feature = "boring", feature = "rustls-tls")))]
{
self.inner.get_ref().connected()
}
}
}
impl Connection for TlsConn<MaybeHttpsStream<TcpStream>> {
fn connected(&self) -> Connected {
#[cfg(feature = "boring")]
{
let connected = self.inner.get_ref().connected();
if self.inner.ssl().selected_alpn_protocol() == Some(b"h2") {
connected.negotiated_h2()
} else {
connected
}
}
#[cfg(all(feature = "rustls-tls", not(feature = "boring")))]
{
let (io, session) = self.inner.get_ref();
let connected = io.connected();
if session.alpn_protocol() == Some(b"h2") {
connected.negotiated_h2()
} else {
connected
}
}
#[cfg(not(any(feature = "boring", feature = "rustls-tls")))]
{
self.inner.get_ref().connected()
}
}
}
#[cfg(unix)]
impl Connection for TlsConn<UnixStream> {
fn connected(&self) -> Connected {
#[cfg(feature = "boring")]
{
let connected = self.inner.get_ref().connected();
if self.inner.ssl().selected_alpn_protocol() == Some(b"h2") {
connected.negotiated_h2()
} else {
connected
}
}
#[cfg(all(feature = "rustls-tls", not(feature = "boring")))]
{
let (io, session) = self.inner.get_ref();
let connected = io.connected();
if session.alpn_protocol() == Some(b"h2") {
connected.negotiated_h2()
} else {
connected
}
}
#[cfg(not(any(feature = "boring", feature = "rustls-tls")))]
{
self.inner.get_ref().connected()
}
}
}
#[cfg(unix)]
impl Connection for TlsConn<MaybeHttpsStream<UnixStream>> {
fn connected(&self) -> Connected {
#[cfg(feature = "boring")]
{
let connected = self.inner.get_ref().connected();
if self.inner.ssl().selected_alpn_protocol() == Some(b"h2") {
connected.negotiated_h2()
} else {
connected
}
}
#[cfg(all(feature = "rustls-tls", not(feature = "boring")))]
{
let (io, session) = self.inner.get_ref();
let connected = io.connected();
if session.alpn_protocol() == Some(b"h2") {
connected.negotiated_h2()
} else {
connected
}
}
#[cfg(not(any(feature = "boring", feature = "rustls-tls")))]
{
self.inner.get_ref().connected()
}
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsConn<T> {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll<tokio::io::Result<()>> {
AsyncRead::poll_read(self.project().inner, cx, buf)
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsConn<T> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, tokio::io::Error>> {
AsyncWrite::poll_write(self.project().inner, cx, buf)
}
#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
AsyncWrite::poll_flush(self.project().inner, cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
AsyncWrite::poll_shutdown(self.project().inner, cx)
}
}
impl<T> TlsInfoFactory for TlsConn<T>
where
TlsStream<T>: TlsInfoFactory,
{
fn tls_info(&self) -> Option<TlsInfo> {
self.inner.tls_info()
}
}