1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
use crate::{tcp, Error, Result};
use std::{
    convert::TryFrom,
    fmt,
    io::{self, IoSlice, IoSliceMut, Read, Write},
};

pub struct TcpStream(Inner);

enum Inner {
    Connected(tcp::TcpStream),
    Handshaking(Option<tcp::MidHandshakeTlsStream>),
}

impl TryFrom<tcp::HandshakeResult> for TcpStream {
    type Error = Error;

    fn try_from(result: tcp::HandshakeResult) -> Result<Self> {
        Ok(Self(match result {
            Ok(stream) if stream.is_connected() => Inner::Connected(stream),
            Ok(stream) => Inner::Handshaking(Some(stream.into())),
            Err(handshaker) => {
                Inner::Handshaking(Some(handshaker.into_mid_handshake_tls_stream()?))
            }
        }))
    }
}

impl TcpStream {
    pub(crate) fn inner(&self) -> &tcp::TcpStream {
        match self.0 {
            Inner::Connected(ref stream) => stream,
            Inner::Handshaking(ref handshaker) => handshaker.as_ref().unwrap().get_ref(),
        }
    }

    pub(crate) fn inner_mut(&mut self) -> &mut tcp::TcpStream {
        match self.0 {
            Inner::Connected(ref mut stream) => stream,
            Inner::Handshaking(ref mut handshaker) => handshaker.as_mut().unwrap().get_mut(),
        }
    }

    pub(crate) fn is_handshaking(&self) -> bool {
        if let Inner::Handshaking(_) = self.0 {
            true
        } else {
            false
        }
    }

    pub(crate) fn handshake(&mut self) -> Result<()> {
        if let Inner::Handshaking(ref mut handshaker) = self.0 {
            match handshaker.take().unwrap().handshake() {
                Ok(stream) => self.0 = Inner::Connected(stream),
                Err(error) => {
                    self.0 = Inner::Handshaking(Some(error.into_mid_handshake_tls_stream()?))
                }
            }
        }
        Ok(())
    }
}

macro_rules! fwd_impl {
    ($self:ident, $method:ident, $($args:expr),*) => {
        match $self.0 {
            Inner::Connected(ref mut inner) => inner.$method($($args),*),
            Inner::Handshaking(ref mut inner) => inner.as_mut().unwrap().get_mut().$method($($args),*),
        }
    };
}

impl Read for TcpStream {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        fwd_impl!(self, read, buf)
    }

    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
        fwd_impl!(self, read_vectored, bufs)
    }

    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
        fwd_impl!(self, read_to_end, buf)
    }

    fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
        fwd_impl!(self, read_to_string, buf)
    }

    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
        fwd_impl!(self, read_exact, buf)
    }
}

impl Write for TcpStream {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        fwd_impl!(self, write, buf)
    }

    fn flush(&mut self) -> io::Result<()> {
        fwd_impl!(self, flush,)
    }

    fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
        fwd_impl!(self, write_vectored, bufs)
    }

    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
        fwd_impl!(self, write_all, buf)
    }

    fn write_fmt(&mut self, fmt: fmt::Arguments<'_>) -> io::Result<()> {
        fwd_impl!(self, write_fmt, fmt)
    }
}