use std::fmt;
use pin_project_lite::pin_project;
use rama_boring::ssl::SslRef;
use rama_boring_tokio::SslStream;
use rama_net::stream::Stream;
use tokio::io::{AsyncRead, AsyncWrite};
pin_project! {
pub struct AutoTlsStream<S> {
#[pin]
inner: AutoTlsStreamData<S>,
}
}
impl<S> AutoTlsStream<S> {
pub(super) fn secure(inner: SslStream<S>) -> Self {
Self {
inner: AutoTlsStreamData::Secure { inner },
}
}
pub(super) fn plain(inner: S) -> Self {
Self {
inner: AutoTlsStreamData::Plain { inner },
}
}
pub fn ssl_ref(&self) -> Option<&SslRef> {
match &self.inner {
AutoTlsStreamData::Secure { inner } => Some(inner.ssl()),
AutoTlsStreamData::Plain { .. } => None,
}
}
}
impl<S: fmt::Debug> fmt::Debug for AutoTlsStream<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AutoTlsStream")
.field("inner", &self.inner)
.finish()
}
}
pin_project! {
#[project = AutoTlsStreamDataProj]
enum AutoTlsStreamData<S> {
Secure{ #[pin] inner: SslStream<S> },
Plain { #[pin] inner: S },
}
}
impl<S: fmt::Debug> fmt::Debug for AutoTlsStreamData<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AutoTlsStreamData::Secure { inner } => f.debug_tuple("Secure").field(inner).finish(),
AutoTlsStreamData::Plain { inner } => f.debug_tuple("Plain").field(inner).finish(),
}
}
}
impl<S> AsyncRead for AutoTlsStream<S>
where
S: Stream + Unpin,
{
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.project().inner.project() {
AutoTlsStreamDataProj::Secure { inner } => inner.poll_read(cx, buf),
AutoTlsStreamDataProj::Plain { inner } => inner.poll_read(cx, buf),
}
}
}
impl<S> AsyncWrite for AutoTlsStream<S>
where
S: Stream + Unpin,
{
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
match self.project().inner.project() {
AutoTlsStreamDataProj::Secure { inner } => inner.poll_write(cx, buf),
AutoTlsStreamDataProj::Plain { inner } => inner.poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.project().inner.project() {
AutoTlsStreamDataProj::Secure { inner } => inner.poll_flush(cx),
AutoTlsStreamDataProj::Plain { inner } => inner.poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
match self.project().inner.project() {
AutoTlsStreamDataProj::Secure { inner } => inner.poll_shutdown(cx),
AutoTlsStreamDataProj::Plain { inner } => inner.poll_shutdown(cx),
}
}
fn poll_write_vectored(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let buf = bufs
.iter()
.find(|b| !b.is_empty())
.map_or(&[][..], |b| &**b);
self.poll_write(cx, buf)
}
fn is_write_vectored(&self) -> bool {
false
}
}