use std::fmt::{self, Debug, Formatter};
use std::io::{Error as IoError, ErrorKind, Result as IoResult};
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_util::FutureExt;
use futures_util::future::BoxFuture;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, Result};
use crate::fuse::{ArcFusewire, FuseEvent};
enum State<S> {
Handshaking(BoxFuture<'static, Result<S>>),
Ready(S),
Error,
}
pub struct HandshakeStream<S> {
state: State<S>,
fusewire: Option<ArcFusewire>,
}
impl<S> Debug for HandshakeStream<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("HandshakeStream").finish()
}
}
impl<S> HandshakeStream<S> {
#[doc(hidden)]
pub fn new<F>(handshake: F, fusewire: Option<ArcFusewire>) -> Self
where
F: Future<Output = Result<S>> + Send + 'static,
{
if let Some(fusewire) = &fusewire {
fusewire.event(FuseEvent::TlsHandshaking);
}
Self {
state: State::Handshaking(handshake.boxed()),
fusewire,
}
}
fn set_state_ready(&mut self, stream: S) {
self.state = State::Ready(stream);
if let Some(fusewire) = &self.fusewire {
fusewire.event(FuseEvent::TlsHandshaked);
}
}
}
impl<S> AsyncRead for HandshakeStream<S>
where
S: AsyncRead + Unpin + Send + 'static,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
let this = &mut *self;
loop {
match &mut this.state {
State::Handshaking(fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(s)) => this.set_state_ready(s),
Poll::Ready(Err(err)) => {
this.state = State::Error;
return Poll::Ready(Err(err));
}
Poll::Pending => {
if let Some(fusewire) = &self.fusewire {
fusewire.event(FuseEvent::Alive);
}
return Poll::Pending;
}
},
State::Ready(stream) => {
let remaining = buf.remaining();
return match Pin::new(stream).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
if let Some(fusewire) = &self.fusewire {
fusewire.event(FuseEvent::ReadData(remaining - buf.remaining()));
}
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => {
if let Some(fusewire) = &self.fusewire {
fusewire.event(FuseEvent::Alive);
}
Poll::Pending
}
};
}
State::Error => {
return Poll::Ready(Err(invalid_data_error("poll read invalid data")));
}
}
}
}
}
impl<S> AsyncWrite for HandshakeStream<S>
where
S: AsyncWrite + Unpin + Send + 'static,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<IoResult<usize>> {
let this = &mut *self;
loop {
match &mut this.state {
State::Handshaking(fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(s)) => this.set_state_ready(s),
Poll::Ready(Err(err)) => {
this.state = State::Error;
return Poll::Ready(Err(err));
}
Poll::Pending => return Poll::Pending,
},
State::Ready(stream) => return Pin::new(stream).poll_write(cx, buf),
State::Error => {
return Poll::Ready(Err(invalid_data_error("poll write invalid data")));
}
}
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
let this = &mut *self;
loop {
match &mut this.state {
State::Handshaking(fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(s)) => this.set_state_ready(s),
Poll::Ready(Err(err)) => {
this.state = State::Error;
return Poll::Ready(Err(err));
}
Poll::Pending => return Poll::Pending,
},
State::Ready(stream) => return Pin::new(stream).poll_flush(cx),
State::Error => {
return Poll::Ready(Err(invalid_data_error("poll flush invalid data")));
}
}
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
let this = &mut *self;
loop {
match &mut this.state {
State::Handshaking(fut) => match fut.poll_unpin(cx) {
Poll::Ready(Ok(s)) => this.set_state_ready(s),
Poll::Ready(Err(err)) => {
this.state = State::Error;
return Poll::Ready(Err(err));
}
Poll::Pending => return Poll::Pending,
},
State::Ready(stream) => return Pin::new(stream).poll_shutdown(cx),
State::Error => {
return Poll::Ready(Err(invalid_data_error("poll shutdown invalid data")));
}
}
}
}
}
fn invalid_data_error(msg: &'static str) -> IoError {
IoError::new(ErrorKind::InvalidData, msg)
}