salvo_core 0.91.1

Salvo is a powerful web framework that can make your work easier.
Documentation
use std::fmt::{self, Debug, Formatter};
use std::io::{Error as IoError, ErrorKind, Result as IoResult};
use std::pin::Pin;
use std::task::{Context, Poll};

use futures_util::FutureExt;
use futures_util::future::BoxFuture;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, Result};

use crate::fuse::{ArcFusewire, FuseEvent};

enum State<S> {
    Handshaking(BoxFuture<'static, Result<S>>),
    Ready(S),
    Error,
}

/// Tls stream.
pub struct HandshakeStream<S> {
    state: State<S>,
    fusewire: Option<ArcFusewire>,
}

impl<S> Debug for HandshakeStream<S>
where
    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        f.debug_struct("HandshakeStream").finish()
    }
}

impl<S> HandshakeStream<S> {
    #[doc(hidden)]
    pub fn new<F>(handshake: F, fusewire: Option<ArcFusewire>) -> Self
    where
        F: Future<Output = Result<S>> + Send + 'static,
    {
        if let Some(fusewire) = &fusewire {
            fusewire.event(FuseEvent::TlsHandshaking);
        }
        Self {
            state: State::Handshaking(handshake.boxed()),
            fusewire,
        }
    }

    fn set_state_ready(&mut self, stream: S) {
        self.state = State::Ready(stream);
        if let Some(fusewire) = &self.fusewire {
            fusewire.event(FuseEvent::TlsHandshaked);
        }
    }
}

impl<S> AsyncRead for HandshakeStream<S>
where
    S: AsyncRead + Unpin + Send + 'static,
{
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<Result<()>> {
        let this = &mut *self;

        loop {
            match &mut this.state {
                State::Handshaking(fut) => match fut.poll_unpin(cx) {
                    Poll::Ready(Ok(s)) => this.set_state_ready(s),
                    Poll::Ready(Err(err)) => {
                        this.state = State::Error;
                        return Poll::Ready(Err(err));
                    }
                    Poll::Pending => {
                        if let Some(fusewire) = &self.fusewire {
                            fusewire.event(FuseEvent::Alive);
                        }
                        return Poll::Pending;
                    }
                },
                State::Ready(stream) => {
                    let remaining = buf.remaining();
                    return match Pin::new(stream).poll_read(cx, buf) {
                        Poll::Ready(Ok(())) => {
                            if let Some(fusewire) = &self.fusewire {
                                fusewire.event(FuseEvent::ReadData(remaining - buf.remaining()));
                            }
                            Poll::Ready(Ok(()))
                        }
                        Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
                        Poll::Pending => {
                            if let Some(fusewire) = &self.fusewire {
                                fusewire.event(FuseEvent::Alive);
                            }
                            Poll::Pending
                        }
                    };
                }
                State::Error => {
                    return Poll::Ready(Err(invalid_data_error("poll read invalid data")));
                }
            }
        }
    }
}

impl<S> AsyncWrite for HandshakeStream<S>
where
    S: AsyncWrite + Unpin + Send + 'static,
{
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<IoResult<usize>> {
        let this = &mut *self;

        loop {
            match &mut this.state {
                State::Handshaking(fut) => match fut.poll_unpin(cx) {
                    Poll::Ready(Ok(s)) => this.set_state_ready(s),
                    Poll::Ready(Err(err)) => {
                        this.state = State::Error;
                        return Poll::Ready(Err(err));
                    }
                    Poll::Pending => return Poll::Pending,
                },
                State::Ready(stream) => return Pin::new(stream).poll_write(cx, buf),
                State::Error => {
                    return Poll::Ready(Err(invalid_data_error("poll write invalid data")));
                }
            }
        }
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
        let this = &mut *self;

        loop {
            match &mut this.state {
                State::Handshaking(fut) => match fut.poll_unpin(cx) {
                    Poll::Ready(Ok(s)) => this.set_state_ready(s),
                    Poll::Ready(Err(err)) => {
                        this.state = State::Error;
                        return Poll::Ready(Err(err));
                    }
                    Poll::Pending => return Poll::Pending,
                },
                State::Ready(stream) => return Pin::new(stream).poll_flush(cx),
                State::Error => {
                    return Poll::Ready(Err(invalid_data_error("poll flush invalid data")));
                }
            }
        }
    }

    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
        let this = &mut *self;

        loop {
            match &mut this.state {
                State::Handshaking(fut) => match fut.poll_unpin(cx) {
                    Poll::Ready(Ok(s)) => this.set_state_ready(s),
                    Poll::Ready(Err(err)) => {
                        this.state = State::Error;
                        return Poll::Ready(Err(err));
                    }
                    Poll::Pending => return Poll::Pending,
                },
                State::Ready(stream) => return Pin::new(stream).poll_shutdown(cx),
                State::Error => {
                    return Poll::Ready(Err(invalid_data_error("poll shutdown invalid data")));
                }
            }
        }
    }
}

fn invalid_data_error(msg: &'static str) -> IoError {
    IoError::new(ErrorKind::InvalidData, msg)
}